このページの内容は最新ではありません。最新版の英語を参照するには、ここをクリックします。
predict
構文
説明
深層学習層によっては学習時と推論時 (予測時) の動作が異なる場合があります。たとえば、学習時には過適合を防ぐためにドロップアウト層によって入力要素がランダムに 0 に設定されますが、推論時にはドロップアウト層によって入力が変更されることはありません。
推論用のネットワーク出力を計算するには、関数 predict
を使用します。学習用のネットワーク出力を計算するには、関数 forward
を使用します。SeriesNetwork
オブジェクトおよび DAGNetwork
オブジェクトを使用した予測については、predict
を参照してください。
ヒント
SeriesNetwork
オブジェクトおよび DAGNetwork
オブジェクトを使用した予測については、predict
を参照してください。
[Y1,...,YN] = predict(___)
は、前述のいずれかの構文を使用して、N
個の出力をもつネットワークについて、N
個の推論時の出力 Y1
, …, YN
を返します。
[Y1,...,YK] = predict(___,Outputs=
は、前述のいずれかの構文を使用して、指定された層について、推論時の出力 layerNames
)Y1
, …, YK
を返します。
[___] = predict(___,
は、1 つ以上の名前と値の引数を使用して追加のオプションを指定します。Name=Value
)
[___,
は、更新されたネットワークの状態も返します。state
] = predict(___)
例
dlnetwork
オブジェクトを使用した予測の実行
この例では、データをミニバッチに分割することにより、dlnetwork
オブジェクトを使用して予測を行う方法を示します。
データ セットが大きい場合、またはメモリが限られたハードウェアで予測を行う場合、データをミニバッチに分割して予測を行います。SeriesNetwork
または DAGNetwork
オブジェクトで予測を行う場合、関数 predict
は入力データをミニバッチに自動的に分割します。dlnetwork
オブジェクトでは、データを手動でミニバッチに分割しなければなりません。
dlnetwork
オブジェクトの読み込み
学習済み dlnetwork
オブジェクトと対応するクラスを読み込みます。
s = load("digitsCustom.mat");
dlnet = s.dlnet;
classes = s.classes;
予測用データの読み込み
予測用の数字データを読み込みます。
digitDatasetPath = fullfile(matlabroot,'toolbox','nnet','nndemos', ... 'nndatasets','DigitDataset'); imds = imageDatastore(digitDatasetPath, ... 'IncludeSubfolders',true);
予測の実行
テスト データのミニバッチをループ処理して、カスタム予測ループを使って予測を行います。
minibatchqueue
を使用して、イメージのミニバッチを処理および管理します。ミニバッチ サイズとして 128 を指定します。イメージ データストアの読み取りサイズ プロパティをミニバッチ サイズに設定します。
各ミニバッチで次を行います。
カスタム ミニバッチ前処理関数
preprocessMiniBatch
(この例の最後に定義) を使用して、データをバッチに連結し、イメージを正規化。イメージを次元
'SSCB'
(spatial、spatial、channel、batch) で書式設定。既定では、minibatchqueue
オブジェクトは、基となる型がsingle
のdlarray
オブジェクトにデータを変換します。GPU が利用できる場合、GPU で予測を実行。既定では、
minibatchqueue
オブジェクトは、GPU が利用可能な場合、出力をgpuArray
に変換します。GPU を使用するには、Parallel Computing Toolbox™ とサポートされている GPU デバイスが必要です。サポートされているデバイスについては、GPU 計算の要件 (Parallel Computing Toolbox)を参照してください。
miniBatchSize = 128; imds.ReadSize = miniBatchSize; mbq = minibatchqueue(imds,... "MiniBatchSize",miniBatchSize,... "MiniBatchFcn", @preprocessMiniBatch,... "MiniBatchFormat","SSCB");
データのミニバッチをループ処理し、関数 predict
を使用して予測を行います。関数 onehotdecode
を使用して、クラス ラベルを決定します。予測クラス ラベルを保存します。
numObservations = numel(imds.Files); YPred = strings(1,numObservations); predictions = []; % Loop over mini-batches. while hasdata(mbq) % Read mini-batch of data. dlX = next(mbq); % Make predictions using the predict function. dlYPred = predict(dlnet,dlX); % Determine corresponding classes. predBatch = onehotdecode(dlYPred,classes,1); predictions = [predictions predBatch]; end
予測の一部を可視化します。
idx = randperm(numObservations,9); figure for i = 1:9 subplot(3,3,i) I = imread(imds.Files{idx(i)}); label = predictions(idx(i)); imshow(I) title("Label: " + string(label)) end
ミニバッチ前処理関数
関数 preprocessMiniBatch
は、次の手順でデータを前処理します。
入力 cell 配列からデータを抽出し、数値配列に連結します。4 番目の次元で連結することにより、3 番目の次元が各イメージに追加されます。この次元は、シングルトン チャネル次元として使用されることになります。
0
と1
の間のピクセル値を正規化します。
function X = preprocessMiniBatch(data) % Extract image data from cell and concatenate X = cat(4,data{:}); % Normalize the images. X = X/255; end
入力引数
net
— カスタム学習ループまたはカスタム枝刈りループのためのネットワーク
dlnetwork
オブジェクト | TaylorPrunableNetwork
オブジェクト
この引数は、次のいずれかを表すことができます。
カスタム学習ループのネットワーク。
dlnetwork
オブジェクトとして指定します。カスタム枝刈りループのネットワーク。
TaylorPrunableNetwork
オブジェクトとして指定します。
深層ニューラル ネットワークの枝刈りを行うには、Deep Learning Toolbox™ Model Quantization Library サポート パッケージが必要です。このサポート パッケージは無料のアドオンで、アドオン エクスプローラーを使用してダウンロードできます。または、Deep Learning Toolbox Model Quantization Library を参照してください。
X
— 入力データ
数値配列 | dlarray
オブジェクト
入力データ。次のいずれかの値として指定します。
数値配列 (R2023b 以降)
dlarray
オブジェクト (R2023b 以降)書式化された
dlarray
オブジェクト
ヒント
ニューラル ネットワークには、特定のレイアウトをもつ入力データが必要です。たとえば、ベクトルシーケンス分類ネットワークは、通常、t 行 c 列の数値配列として表されたシーケンスを必要とします。ここで、t および c は、それぞれタイム ステップの数とシーケンスのチャネル数です。ニューラル ネットワークの入力層では、通常、必要とされるデータ レイアウトが指定されています。
ほとんどのデータストアと関数は、ネットワークで必要とされるレイアウトでデータを出力します。データのレイアウトがネットワークで必要とされるレイアウトと異なる場合、InputDataFormats
オプションを使用するか、書式化された dlarray
オブジェクトとして入力データを指定し、データのレイアウトが異なることを示します。通常、入力データを前処理するよりも、InputDataFormats
学習オプションを調整する方が簡単です。
入力層をもたないニューラル ネットワークの場合、InputDataFormats
オプションまたは書式化された dlarray
オブジェクトを使用しなければなりません。
詳細については、Deep Learning Data Formatsを参照してください。
layerNames
— 出力の抽出元の層
string 配列 | 文字ベクトルの cell 配列
出力の抽出元の層。層の名前を含む string 配列、または層の名前を含む文字ベクトルの cell 配列として指定します。
layerNames(i)
が 1 つの出力をもつ層に対応する場合、layerNames(i)
は層の名前です。layerNames(i)
が複数の出力をもつ層に対応する場合、layerNames(i)
はまず層の名前、その後に文字/
、さらに層出力の名前が続きます ("layerName/outputName"
)。
名前と値の引数
オプションの引数のペアを Name1=Value1,...,NameN=ValueN
として指定します。ここで、Name
は引数名で、Value
は対応する値です。名前と値の引数は他の引数の後に指定しなければなりませんが、ペアの順序は重要ではありません。
R2021a より前では、コンマを使用して名前と値をそれぞれ区切り、Name
を引用符で囲みます。
例: Y = predict(net,X,InputDataFormats="CBT")
は、形式が "CBT"
(チャネル、バッチ、時間) であるシーケンス データを使用して予測を行います。
InputDataFormats
— 入力データの次元の説明
"auto"
(既定値) | string 配列 | 文字ベクトルの cell 配列 | 文字ベクトル
R2023b 以降
入力データの次元の説明。string 配列、文字ベクトル、または文字ベクトルの cell 配列として指定します。
InputDataFormats
が "auto"
の場合、ソフトウェアは、ネットワークの入力で必要とされる形式を使用します。そうでない場合、ソフトウェアは、該当するネットワーク入力に対して指定された形式を使用します。
データの形式は文字列で、各文字はデータ内の対応する次元のタイプを表します。
各文字は以下のとおりです。
"S"
— 空間"C"
— チャネル"B"
— バッチ"T"
— 時間"U"
— 指定なし
たとえば、シーケンスのバッチを含み、1 番目、2 番目、および 3 番目の次元がそれぞれチャネル、観測値、およびタイム ステップに対応する配列の場合、"CBT"
の形式で指定できます。
"S"
または "U"
のラベルが付いた次元については、複数回指定できます。"C"
、"B"
、"T"
のラベルについては、1 回のみ使用できます。ソフトウェアは、2 番目の次元の後ろにある大きさ 1 の "U"
次元を無視します。
詳細については、Deep Learning Data Formatsを参照してください。
データ型: char
| string
| cell
OutputDataFormats
— 出力データの次元の説明
"auto"
(既定値) | string 配列 | 文字ベクトルの cell 配列 | 文字ベクトル
R2023b 以降
出力データの次元の説明。次のいずれかの値として指定します。
"auto"
— 出力データと入力データの次元の数が同じ場合、関数predict
はInputDataFormats
で指定された形式を使用します。出力データと入力データの次元の数が異なる場合、関数predict
は、ネットワークの入力層、InputDataFormats
オプション、または関数trainnet
で必要とされるターゲットと一致するように、出力データの次元を自動的に並べ替えます。データ形式 (string 配列、文字ベクトル、または文字ベクトルの cell 配列として指定) — 関数
predict
は指定されたデータ形式を使用します。
データの形式は文字列で、各文字はデータ内の対応する次元のタイプを表します。
各文字は以下のとおりです。
"S"
— 空間"C"
— チャネル"B"
— バッチ"T"
— 時間"U"
— 指定なし
たとえば、シーケンスのバッチを含み、1 番目、2 番目、および 3 番目の次元がそれぞれチャネル、観測値、およびタイム ステップに対応する配列の場合、"CBT"
の形式で指定できます。
"S"
または "U"
のラベルが付いた次元については、複数回指定できます。"C"
、"B"
、"T"
のラベルについては、1 回のみ使用できます。ソフトウェアは、2 番目の次元の後ろにある大きさ 1 の "U"
次元を無視します。
詳細については、Deep Learning Data Formatsを参照してください。
データ型: char
| string
| cell
Acceleration
— パフォーマンスの最適化
"auto"
(既定値) | "mex"
| "none"
パフォーマンスの最適化。次のいずれかの値として指定します。
"auto"
— 入力ネットワークとハードウェア リソースに適した最適化の回数を自動的に適用します。"mex"
— MEX 関数をコンパイルして実行します。このオプションは GPU の使用時にのみ利用できます。入力データ、またはネットワークの学習可能なパラメーターは、gpuArray
オブジェクトとして格納しなければなりません。GPU を使用するには Parallel Computing Toolbox™ とサポートされている GPU デバイスが必要です。サポートされているデバイスについては、GPU 計算の要件 (Parallel Computing Toolbox)を参照してください。Parallel Computing Toolbox または適切な GPU が利用できない場合、エラーが返されます。"none"
— すべての高速化を無効にします。
Acceleration
が "auto"
の場合、ソフトウェアは MEX 関数を生成しません。
"auto"
オプションまたは "mex"
オプションを使用した場合、ソフトウェアはパフォーマンス上のメリットを提供しますが、初期実行時間が長くなります。関数のそれ以降の呼び出しでは、通常、より高速になります。新しい入力データを使用して関数を複数回呼び出す場合は、パフォーマンスの最適化を使用してください。
"mex"
オプションは、関数の呼び出しで指定したモデルとパラメーターに基づいて MEX 関数を生成し、実行します。1 つのモデルに一度に複数の MEX 関数を関連付けることができます。モデルの変数をクリアすると、そのモデルに関連付けられている MEX 関数もクリアされます。
"mex"
オプションは GPU の使用時にのみ利用できます。C/C++ コンパイラがインストールされ、GPU Coder™ Interface for Deep Learning サポート パッケージがなければなりません。MATLAB® でアドオン エクスプローラーを使用してサポート パッケージをインストールします。設定手順については、MEX の設定 (GPU Coder)を参照してください。GPU Coder は不要です。
"mex"
オプションには以下の制限があります。
出力引数
state
はサポートされていません。single
の精度のみがサポートされています。入力データ、またはネットワークの学習可能なパラメーターは、基となる型がsingle
でなければなりません。入力層に接続されていない入力が存在するネットワークはサポートされていません。
トレースされた
dlarray
オブジェクトはサポートされていません。これは、dlfeval
の呼び出し内では'mex'
オプションがサポートされていないことを意味します。一部の層はサポートされていません。サポートされている層の一覧については、サポートされている層 (GPU Coder)を参照してください。
"mex"
オプションを使用した場合、MATLAB Compiler™ を使用してネットワークを展開することはできません。
量子化されたネットワークでは、"mex"
オプションには、Compute Capability 6.1、6.3、またはそれ以上の CUDA® 対応 NVIDIA® GPU が必要です。
出力引数
Y
— 出力データ
数値配列 | dlarray
オブジェクト
出力データ。次のいずれかの値として返されます。
数値配列 (R2023b 以降)
書式化されていない
dlarray
オブジェクト (R2023b 以降)書式化された
dlarray
オブジェクト
データ型は入力データのデータ型と同じです。
state
— 更新されたネットワークの状態
table
更新されたネットワークの状態。table として返されます。
ネットワークの状態は、次の 3 つの列をもつ table です。
Layer
– 層の名前。string スカラーとして指定します。Parameter
– 状態パラメーターの名前。string スカラーとして指定します。Value
– 状態パラメーターの値。dlarray
オブジェクトとして指定します。
層の状態には、層処理中に計算された情報が格納されます。この情報は、層の後続のフォワード パスで使用するために保持されます。たとえば、LSTM 層のセル状態と隠れ状態、またはバッチ正規化層の実行中の統計が格納されます。
LSTM 層などの再帰層の場合、HasStateInputs
プロパティを 1
(true) に設定すると、その層の状態に関するエントリはステート table に格納されません。
アルゴリズム
再現性
最高のパフォーマンスを提供するために、GPU を使用した MATLAB での深層学習は確定的であることを保証しません。ネットワーク アーキテクチャによっては、GPU を使用して 2 つの同一のネットワークに学習させたり、同じネットワークとデータを使用して 2 つの予測を行ったりする場合に、ある条件下で異なる結果が得られることがあります。
拡張機能
C/C++ コード生成
MATLAB® Coder™ を使用して C および C++ コードを生成します。
使用上の注意および制限:
C++ コード生成は、以下の構文をサポートします。
Y = predict(net,X)
Y = predict(net,X1,...,XM)
[Y1,...,YN] = predict(__)
[Y1,...,YK] = predict(__,'Outputs',layerNames)
入力データ
X
のサイズは可変であってはなりません。サイズはコード生成時に固定しなければなりません。predict
メソッドへのdlarray
入力は、single
データ型でなければなりません。
GPU コード生成
GPU Coder™ を使用して NVIDIA® GPU のための CUDA® コードを生成します。
使用上の注意および制限:
GPU コード生成は、以下の構文をサポートします。
Y = predict(net,X)
Y = predict(net,X1,...,XM)
[Y1,...,YN] = predict(__)
[Y1,...,YK] = predict(__,'Outputs',layerNames)
入力データ
X
のサイズは可変であってはなりません。サイズはコード生成時に固定しなければなりません。TensorRT ライブラリ用のコード生成では、構文
[Y1,...,YK] = predict(__,'Outputs',layerNames)
を使用して入力層を出力としてマークすることができません。predict
メソッドへのdlarray
入力は、single
データ型でなければなりません。
GPU 配列
Parallel Computing Toolbox™ を使用してグラフィックス処理装置 (GPU) 上で実行することにより、コードを高速化します。
使用上の注意および制限:
次のいずれかまたは両方の条件が満たされる場合、この関数は GPU で実行されます。
net.Learnables.Value
に含まれるネットワークの学習可能なパラメーター値のいずれかが、gpuArray
型の基になるデータをもつdlarray
オブジェクトである入力引数
X
がgpuArray
型の基になるデータをもつdlarray
である
詳細については、GPU での MATLAB 関数の実行 (Parallel Computing Toolbox)を参照してください。
バージョン履歴
R2019b で導入R2023b: 数値配列および書式化されていない dlarray
オブジェクトの指定
数値配列および書式化されていない dlarray
オブジェクトを使用して予測を実行します。
入力と出力のデータ形式は、それぞれ InputDataFormats
オプションおよび OutputDataFormats
オプションを使用して指定します。
R2021a: predict
は dlarray
オブジェクトとして状態値を返す
dlnetwork
オブジェクトの場合、関数 predict
によって返される出力引数 state
は、ネットワーク内の各層に関する状態パラメーターの名前と値が格納された table になります。
R2021a 以降、状態値は dlarray
オブジェクトになっています。この変更によって、AcceleratedFunction
オブジェクトを使用する際のサポートが強化されています。頻繁に変更される入力値 (ネットワークの状態を含む入力など) をもつ深層学習関数を高速化するには、頻繁に変更される値を dlarray
オブジェクトとして指定しなければなりません。
以前のバージョンでは、状態値が数値配列になっています。
多くの場合、コードを更新する必要はありません。状態値を数値配列にする必要があるコードの場合、以前の動作を再現するには、関数 extractdata
と関数 dlupdate
を併用して状態値からデータを手動で抽出します。
state = dlupdate(@extractdata,net.State);
MATLAB コマンド
次の MATLAB コマンドに対応するリンクがクリックされました。
コマンドを MATLAB コマンド ウィンドウに入力して実行してください。Web ブラウザーは MATLAB コマンドをサポートしていません。
Select a Web Site
Choose a web site to get translated content where available and see local events and offers. Based on your location, we recommend that you select: .
You can also select a web site from the following list:
How to Get Best Site Performance
Select the China site (in Chinese or English) for best site performance. Other MathWorks country sites are not optimized for visits from your location.
Americas
- América Latina (Español)
- Canada (English)
- United States (English)
Europe
- Belgium (English)
- Denmark (English)
- Deutschland (Deutsch)
- España (Español)
- Finland (English)
- France (Français)
- Ireland (English)
- Italia (Italiano)
- Luxembourg (English)
- Netherlands (English)
- Norway (English)
- Österreich (Deutsch)
- Portugal (English)
- Sweden (English)
- Switzerland
- United Kingdom (English)