このページの翻訳は最新ではありません。ここをクリックして、英語の最新版を参照してください。
predict
構文
説明
深層学習層によっては学習時と推論時 (予測時) の動作が異なる場合があります。たとえば、学習時には過適合を防ぐためにドロップアウト層によって入力要素がランダムに 0 に設定されますが、推論時にはドロップアウト層によって入力が変更されることはありません。
推論用のネットワーク出力を計算するには、関数 predict
を使用します。学習用のネットワーク出力を計算するには、関数 forward
を使用します。SeriesNetwork
オブジェクトおよび DAGNetwork
オブジェクトを使用した予測については、predict
を参照してください。
ヒント
SeriesNetwork
オブジェクトおよび DAGNetwork
オブジェクトを使用した予測については、predict
を参照してください。
[Y1,...,YN] = predict(___)
は、前述のいずれかの構文を使用して、N
個の出力をもつネットワークについて、N
個の推論時の出力 Y1
, …, YN
を返します。
[Y1,...,YK] = predict(___,'Outputs',
は、前述のいずれかの構文を使用して、指定された層について、推論時の出力 layerNames
)Y1
, …, YK
を返します。
[___] = predict(___,'Acceleration',
は、前述の構文の入力引数に加え、推論時に使用するパフォーマンスの最適化も指定します。 acceleration
)
[___,
は、更新されたネットワークの状態も返します。state
] = predict(___)
例
dlnetwork
オブジェクトを使用した予測の実行
この例では、データをミニバッチに分割することにより、dlnetwork
オブジェクトを使用して予測を行う方法を示します。
データセットが大きい場合、またはメモリが限られたハードウェアで予測を行う場合、データをミニバッチに分割して予測を行います。SeriesNetwork
または DAGNetwork
オブジェクトで予測を行う場合、関数 predict
は入力データをミニバッチに自動的に分割します。dlnetwork
オブジェクトでは、データを手動でミニバッチに分割しなければなりません。
dlnetwork
オブジェクトの読み込み
学習済み dlnetwork
オブジェクトと対応するクラスを読み込みます。
s = load("digitsCustom.mat");
dlnet = s.dlnet;
classes = s.classes;
予測用データの読み込み
予測用の数字データを読み込みます。
digitDatasetPath = fullfile(matlabroot,'toolbox','nnet','nndemos', ... 'nndatasets','DigitDataset'); imds = imageDatastore(digitDatasetPath, ... 'IncludeSubfolders',true);
予測の実行
テスト データのミニバッチをループ処理して、カスタム予測ループを使って予測を行います。
minibatchqueue
を使用して、イメージのミニバッチを処理および管理します。ミニバッチ サイズとして 128 を指定します。イメージ データストアの読み取りサイズ プロパティをミニバッチ サイズに設定します。
各ミニバッチで次を行います。
カスタム ミニバッチ前処理関数
preprocessMiniBatch
(この例の最後に定義) を使用して、データをバッチに連結し、イメージを正規化。イメージを次元
'SSCB'
(spatial、spatial、channel、batch) で書式設定。既定では、minibatchqueue
オブジェクトは、基となる型がsingle
のdlarray
オブジェクトにデータを変換します。GPU が利用できる場合、GPU で予測を実行。既定では、
minibatchqueue
オブジェクトは、GPU が利用可能な場合、出力をgpuArray
に変換します。GPU を使用するには、Parallel Computing Toolbox™ とサポートされている GPU デバイスが必要です。サポートされているデバイスについては、GPU 計算の要件 (Parallel Computing Toolbox)を参照してください。
miniBatchSize = 128; imds.ReadSize = miniBatchSize; mbq = minibatchqueue(imds,... "MiniBatchSize",miniBatchSize,... "MiniBatchFcn", @preprocessMiniBatch,... "MiniBatchFormat","SSCB");
データのミニバッチをループ処理し、関数 predict
を使用して予測を行います。関数 onehotdecode
を使用して、クラス ラベルを決定します。予測クラス ラベルを保存します。
numObservations = numel(imds.Files); YPred = strings(1,numObservations); predictions = []; % Loop over mini-batches. while hasdata(mbq) % Read mini-batch of data. dlX = next(mbq); % Make predictions using the predict function. dlYPred = predict(dlnet,dlX); % Determine corresponding classes. predBatch = onehotdecode(dlYPred,classes,1); predictions = [predictions predBatch]; end
予測の一部を可視化します。
idx = randperm(numObservations,9); figure for i = 1:9 subplot(3,3,i) I = imread(imds.Files{idx(i)}); label = predictions(idx(i)); imshow(I) title("Label: " + string(label)) end
ミニバッチ前処理関数
関数 preprocessMiniBatch
は、次の手順でデータを前処理します。
入力 cell 配列からデータを抽出し、数値配列に連結します。4 番目の次元で連結することにより、3 番目の次元が各イメージに追加されます。この次元は、シングルトン チャネル次元として使用されることになります。
0
と1
の間のピクセル値を正規化します。
function X = preprocessMiniBatch(data) % Extract image data from cell and concatenate X = cat(4,data{:}); % Normalize the images. X = X/255; end
入力引数
net
— カスタム学習ループまたはカスタム枝刈りループのためのネットワーク
dlnetwork
オブジェクト | TaylorPrunableNetwork
オブジェクト
この引数は、次のいずれかを表すことができます。
カスタム学習ループのネットワーク。
dlnetwork
オブジェクトとして指定します。カスタム枝刈りループのネットワーク。
TaylorPrunableNetwork
オブジェクトとして指定します。
深層ニューラル ネットワークの枝刈りを行うには、Deep Learning Toolbox™ Model Quantization Library サポート パッケージが必要です。このサポート パッケージは無料のアドオンで、アドオン エクスプローラーを使用してダウンロードできます。または、Deep Learning Toolbox Model Quantization Library を参照してください。
layerNames
— 出力の抽出元の層
string 配列 | 文字ベクトルの cell 配列
出力の抽出元の層。層の名前を含む string 配列、または層の名前を含む文字ベクトルの cell 配列として指定します。
layerNames(i)
が 1 つの出力をもつ層に対応する場合、layerNames(i)
は層の名前です。layerNames(i)
が複数の出力をもつ層に対応する場合、layerNames(i)
はまず層の名前、その後に文字 "/
"、さらに層出力の名前が続きます ('layerName/outputName'
)。
acceleration
— パフォーマンスの最適化
'auto'
(既定値) | 'mex'
| 'none'
パフォーマンスの最適化。'Acceleration'
と次のいずれかで構成されるコンマ区切りのペアとして指定します。
'auto'
— 入力ネットワークとハードウェア リソースに適した最適化の回数を自動的に適用します。'mex'
— MEX 関数をコンパイルして実行します。このオプションは GPU の使用時にのみ利用できます。入力データ、またはネットワークの学習可能なパラメーターは、gpuArray
オブジェクトとして格納しなければなりません。GPU を使用するには Parallel Computing Toolbox™ とサポートされている GPU デバイスが必要です。サポートされているデバイスについては、GPU 計算の要件 (Parallel Computing Toolbox)を参照してください。Parallel Computing Toolbox または適切な GPU が利用できない場合、エラーが返されます。'none'
— すべての高速化を無効にします。
既定のオプションは 'auto'
です。'auto'
が指定されている場合、MATLAB® は互換性のある最適化を複数適用します。'auto'
オプションを使用すると、MATLAB は MEX 関数を生成しません。
'Acceleration'
オプション 'auto'
および 'mex'
を使用すると、パフォーマンス上のメリットが得られますが、初期実行時間が長くなります。互換性のあるパラメーターを使用した後続の呼び出しは、より高速になります。新しい入力データを使用して関数を複数回呼び出す場合は、パフォーマンスの最適化を使用してください。
'mex'
オプションは、関数の呼び出しに使用されたネットワークとパラメーターに基づいて MEX 関数を生成し、実行します。複数の MEX 関数を一度に 1 つのネットワークに関連付けることができます。ネットワークの変数をクリアすると、そのネットワークに関連付けられている MEX 関数もクリアされます。
'mex'
オプションは、GPU の使用時にのみ利用できます。C/C++ コンパイラがインストールされ、GPU Coder™ Interface for Deep Learning Libraries サポート パッケージがなければなりません。MATLAB でアドオン エクスプローラーを使用してサポート パッケージをインストールします。設定手順については、MEX の設定 (GPU Coder)を参照してください。GPU Coder は不要です。
'mex'
オプションには以下の制限があります。
出力引数
state
はサポートされていません。single
の精度のみがサポートされています。入力データ、またはネットワークの学習可能なパラメーターは、基となる型がsingle
でなければなりません。入力層に接続されていない入力が存在するネットワークはサポートされていません。
トレースされた
dlarray
オブジェクトはサポートされていません。これは、dlfeval
の呼び出し内では'mex'
オプションがサポートされていないことを意味します。一部の層はサポートされていません。サポートされている層の一覧については、サポートされている層 (GPU Coder)を参照してください。
'mex'
オプションを使用した場合、MATLAB Compiler™ を使用してネットワークを展開することはできません。
例: 'Acceleration','mex'
出力引数
state
— 更新されたネットワークの状態
table
更新されたネットワークの状態。table として返されます。
ネットワークの状態は、次の 3 つの列をもつ table です。
Layer
– 層の名前。string スカラーとして指定します。Parameter
– 状態パラメーターの名前。string スカラーとして指定します。Value
– 状態パラメーターの値。dlarray
オブジェクトとして指定します。
層の状態には、層処理中に計算された情報が格納されます。この情報は、層の後続のフォワード パスで使用するために保持されます。たとえば、LSTM 層のセル状態と隠れ状態、またはバッチ正規化層の実行中の統計が格納されます。
LSTM 層などの再帰層の場合、HasStateInputs
プロパティを 1
(true) に設定すると、その層の状態に関するエントリはステート table に格納されません。
拡張機能
C/C++ コード生成
MATLAB® Coder™ を使用して C および C++ コードを生成します。
使用上の注意および制限:
C++ コード生成は、以下の構文をサポートします。
Y = predict(net,X)
Y = predict(net,X1,...,XM)
[Y1,...,YN] = predict(__)
[Y1,...,YK] = predict(__,'Outputs',layerNames)
入力データ
X
のサイズは可変であってはなりません。サイズはコード生成時に固定しなければなりません。predict
メソッドへのdlarray
入力はsingle
データ型でなければなりません。
GPU コード生成
GPU Coder™ を使用して NVIDIA® GPU のための CUDA® コードを生成します。
使用上の注意および制限:
GPU コード生成は、以下の構文をサポートします。
Y = predict(net,X)
Y = predict(net,X1,...,XM)
[Y1,...,YN] = predict(__)
[Y1,...,YK] = predict(__,'Outputs',layerNames)
入力データ
X
のサイズは可変であってはなりません。サイズはコード生成時に固定しなければなりません。TensorRT ライブラリ用のコード生成では、構文
[Y1,...,YK] = predict(__,'Outputs',layerNames)
を使用して入力層を出力としてマークすることができません。predict
メソッドへのdlarray
入力はsingle
データ型でなければなりません。
GPU 配列
Parallel Computing Toolbox™ を使用してグラフィックス処理装置 (GPU) 上で実行することにより、コードを高速化します。
使用上の注意および制限:
次のいずれかまたは両方の条件が満たされる場合、この関数は GPU で実行されます。
net.Learnables.Value
に含まれるネットワークの学習可能なパラメーター値のいずれかが、gpuArray
型の基になるデータをもつdlarray
オブジェクトである入力引数
X
がgpuArray
型の基になるデータをもつdlarray
である
詳細については、GPU での MATLAB 関数の実行 (Parallel Computing Toolbox)を参照してください。
バージョン履歴
R2019b で導入R2021a: predict
は dlarray
オブジェクトとして状態値を返す
dlnetwork
オブジェクトの場合、関数 predict
によって返される出力引数 state
は、ネットワーク内の各層に関する状態パラメーターの名前と値が格納された table になります。
R2021a 以降、状態値は dlarray
オブジェクトになっています。この変更によって、AcceleratedFunction
オブジェクトを使用する際のサポートが強化されています。頻繁に変更される入力値 (ネットワークの状態を含む入力など) をもつ深層学習関数を高速化するには、頻繁に変更される値を dlarray
オブジェクトとして指定しなければなりません。
以前のバージョンでは、状態値が数値配列になっています。
多くの場合、コードを更新する必要はありません。状態値を数値配列にする必要があるコードの場合、以前の動作を再現するには、関数 extractdata
と関数 dlupdate
を併用して状態値からデータを手動で抽出します。
state = dlupdate(@extractdata,net.State);
MATLAB コマンド
次の MATLAB コマンドに対応するリンクがクリックされました。
コマンドを MATLAB コマンド ウィンドウに入力して実行してください。Web ブラウザーは MATLAB コマンドをサポートしていません。
Select a Web Site
Choose a web site to get translated content where available and see local events and offers. Based on your location, we recommend that you select: .
You can also select a web site from the following list:
How to Get Best Site Performance
Select the China site (in Chinese or English) for best site performance. Other MathWorks country sites are not optimized for visits from your location.
Americas
- América Latina (Español)
- Canada (English)
- United States (English)
Europe
- Belgium (English)
- Denmark (English)
- Deutschland (Deutsch)
- España (Español)
- Finland (English)
- France (Français)
- Ireland (English)
- Italia (Italiano)
- Luxembourg (English)
- Netherlands (English)
- Norway (English)
- Österreich (Deutsch)
- Portugal (English)
- Sweden (English)
- Switzerland
- United Kingdom (English)