freezeParameters
説明
params = freezeParameters(
は、params
,names
)ONNXParameters
オブジェクト params
の names
で指定されたネットワーク パラメーターを凍結します。関数は、指定されたパラメーターを入力引数 params
の params.Learnables
から出力引数 params
の params.Nonlearnables
に移動します。
例
カスタム学習ループを使用したインポート済み ONNX 関数の学習
squeezenet
畳み込みニューラル ネットワークを関数としてインポートし、事前学習済みのネットワークを転移学習で微調整し、新しいイメージ コレクションで分類を実行します。
この例では、いくつかの補助関数を使用しています。これらの関数のコードを見るには、補助関数を参照してください。
新しいイメージを解凍してイメージ データストアとして読み込みます。imageDatastore
は、フォルダー名に基づいてイメージに自動的にラベルを付け、データを ImageDatastore
オブジェクトとして格納します。イメージ データストアを使用すると、メモリに収まらないデータなどの大きなイメージ データを格納し、畳み込みニューラル ネットワークの学習中にイメージをバッチ単位で効率的に読み取ることができます。ミニバッチのサイズを指定します。
unzip("MerchData.zip"); miniBatchSize = 8; imds = imageDatastore("MerchData", ... IncludeSubfolders=true, ... LabelSource="foldernames", ... ReadSize=miniBatchSize);
このデータ セットは小さく、75 個の学習イメージが含まれています。いくつかのサンプル イメージを表示します。
numImages = numel(imds.Labels); idx = randperm(numImages,16); figure for i = 1:16 subplot(4,4,i) I = readimage(imds,idx(i)); imshow(I) end
学習セットを抽出し、カテゴリカル分類ラベルを one-hot 符号化します。
XTrain = readall(imds); XTrain = single(cat(4,XTrain{:})); YTrain_categ = categorical(imds.Labels); YTrain = onehotencode(YTrain_categ,2)';
データ内のクラスの数を判定します。
classes = categories(YTrain_categ); numClasses = numel(classes)
numClasses = 5
squeezenet
は、ImageNet データベースの 100 万個を超えるイメージで学習済みの畳み込みニューラル ネットワークです。結果として、このネットワークは広範囲のイメージに対する豊富な特徴表現を学習しています。このネットワークは、イメージを 1000 個のオブジェクト カテゴリ (キーボード、マウス、鉛筆、多くの動物など) に分類できます。
事前学習済みの squeezenet
ネットワークを関数としてインポートします。
squeezenetONNX() params = importONNXFunction("squeezenet.onnx","squeezenetFcn")
Function containing the imported ONNX network architecture was saved to the file squeezenetFcn.m. To learn how to use this function, type: help squeezenetFcn.
params = ONNXParameters with properties: Learnables: [1×1 struct] Nonlearnables: [1×1 struct] State: [1×1 struct] NumDimensions: [1×1 struct] NetworkFunctionName: 'squeezenetFcn'
params
は、ネットワーク パラメーターを含む ONNXParameters
オブジェクトです。squeezenetFcn
は、ネットワーク アーキテクチャを含むモデル関数です。importONNXFunction
は、squeezenetFcn
を現在のフォルダーに保存します。
新しい学習セットに対する事前学習済みネットワークの分類精度を計算します。
accuracyBeforeTraining = getNetworkAccuracy(XTrain,YTrain,params);
fprintf("%.2f accuracy before transfer learning\n",accuracyBeforeTraining);
0.01 accuracy before transfer learning
非常に低い精度です。
params.Learnables
と入力して、ネットワークの学習可能なパラメーターを表示します。重み (W
) やバイアス (B
) など、畳み込み層や全結合層のこれらのパラメーターは、学習時にネットワークによって更新されます。学習不能パラメーターは、学習時に一定のままとなります。
事前学習済みネットワークの最後の 2 つの学習可能パラメーターは、1000 個のクラスに対して構成されています。
conv10_W: [1×1×512×1000 dlarray]
conv10_B: [1000×1 dlarray]
パラメーター conv10_W
と conv10_B
は、新しい分類問題に対して微調整しなければなりません。パラメーターを初期化して、5 つのクラスを分類するようにパラメーターを移動します。
params.Learnables.conv10_W = rand(1,1,512,5); params.Learnables.conv10_B = rand(5,1);
ネットワークのすべてのパラメーターを凍結して、それらを学習不能なパラメーターに変換します。多くの初期層の重みを凍結することで、凍結層の勾配を計算する必要をなくし、ネットワークの学習を大幅に高速化できます。
params = freezeParameters(params,"all");
ネットワークの最後の 2 つのパラメーターを凍結解除して、それらを学習可能なパラメーターに変換します。
params = unfreezeParameters(params,"conv10_W"); params = unfreezeParameters(params,"conv10_B");
ネットワークの学習準備が整いました。学習オプションを指定します。
velocity = []; numEpochs = 5; miniBatchSize = 16; initialLearnRate = 0.01; momentum = 0.9; decay = 0.01;
学習の進行状況モニター用に合計反復回数を計算します。
numObservations = size(YTrain,2); numIterationsPerEpoch = floor(numObservations./miniBatchSize); numIterations = numEpochs*numIterationsPerEpoch;
TrainingProgressMonitor
オブジェクトを初期化します。監視オブジェクトを作成するとタイマーが開始されるため、学習ループの直後でオブジェクトを作成するようにしてください。
monitor = trainingProgressMonitor(Metrics="Loss",Info="Epoch",XLabel="Iteration");
ネットワークに学習をさせます。
epoch = 0; iteration = 0; executionEnvironment = "cpu"; % Change to "gpu" to train on a GPU. % Loop over epochs. while epoch < numEpochs && ~monitor.Stop epoch = epoch + 1; % Shuffle data. idx = randperm(numObservations); XTrain = XTrain(:,:,:,idx); YTrain = YTrain(:,idx); % Loop over mini-batches. i = 0; while i < numIterationsPerEpoch && ~monitor.Stop i = i + 1; iteration = iteration + 1; % Read mini-batch of data. idx = (i-1)*miniBatchSize+1:i*miniBatchSize; X = XTrain(:,:,:,idx); Y = YTrain(:,idx); % If training on a GPU, then convert data to gpuArray. if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu" X = gpuArray(X); end % Evaluate the model gradients and loss using dlfeval and the % modelGradients function. [gradients,loss,state] = dlfeval(@modelGradients,X,Y,params); params.State = state; % Determine the learning rate for the time-based decay learning rate schedule. learnRate = initialLearnRate/(1 + decay*iteration); % Update the network parameters using the SGDM optimizer. [params.Learnables,velocity] = sgdmupdate(params.Learnables,gradients,velocity); % Update the training progress monitor. recordMetrics(monitor,iteration,Loss=loss); updateInfo(monitor,Epoch=epoch); monitor.Progress = 100 * iteration/numIterations; end end
微調整後のネットワークの分類精度を計算します。
accuracyAfterTraining = getNetworkAccuracy(XTrain,YTrain,params);
fprintf("%.2f accuracy after transfer learning\n",accuracyAfterTraining);
1.00 accuracy after transfer learning
補助関数
このセクションでは、この例で使用されている補助関数のコードを示します。
関数 getNetworkAccuracy
は、分類精度を計算することによってネットワーク性能を評価します。
function accuracy = getNetworkAccuracy(X,Y,onnxParams) N = size(X,4); Ypred = squeezenetFcn(X,onnxParams,Training=false); [~,YIdx] = max(Y,[],1); [~,YpredIdx] = max(Ypred,[],1); numIncorrect = sum(abs(YIdx-YpredIdx) > 0); accuracy = 1 - numIncorrect/N; end
関数 modelGradients
は、損失と勾配を計算します。
function [grad, loss, state] = modelGradients(X,Y,onnxParams) [y,state] = squeezenetFcn(X,onnxParams,Training=true); loss = crossentropy(y,Y,DataFormat="CB"); grad = dlgradient(loss,onnxParams.Learnables); end
関数 squeezenetONNX
は、squeezenet
ネットワークの ONNX モデルを生成します。
function squeezenetONNX() exportONNXNetwork(squeezenet,"squeezenet.onnx"); end
入力引数
params
— ネットワーク パラメーター
ONNXParameters
オブジェクト
ネットワーク パラメーター。ONNXParameters
オブジェクトとして指定します。params
には、インポートされた ONNX™ モデルのネットワーク パラメーターが格納されます。
names
— 凍結するパラメーターの名前
'all'
| string 配列
凍結するパラメーターの名前。'all'
または string 配列として指定します。names
を 'all'
に設定すると、すべての学習可能パラメーターを凍結します。1 行 k
列の string 配列 names
でパラメーター名を定義することにより、k
個の学習可能なパラメーターを凍結します。
例: 'all'
例: ["gpu_0_sl_pred_b_0", "gpu_0_sl_pred_w_0"]
データ型: char
| string
出力引数
params
— ネットワーク パラメーター
ONNXParameters
オブジェクト
ネットワーク パラメーター。ONNXParameters
オブジェクトとして返されます。params
には、freezeParameters
によって更新されたネットワーク パラメーターが格納されます。
バージョン履歴
R2020b で導入
MATLAB コマンド
次の MATLAB コマンドに対応するリンクがクリックされました。
コマンドを MATLAB コマンド ウィンドウに入力して実行してください。Web ブラウザーは MATLAB コマンドをサポートしていません。
Select a Web Site
Choose a web site to get translated content where available and see local events and offers. Based on your location, we recommend that you select: .
You can also select a web site from the following list:
How to Get Best Site Performance
Select the China site (in Chinese or English) for best site performance. Other MathWorks country sites are not optimized for visits from your location.
Americas
- América Latina (Español)
- Canada (English)
- United States (English)
Europe
- Belgium (English)
- Denmark (English)
- Deutschland (Deutsch)
- España (Español)
- Finland (English)
- France (Français)
- Ireland (English)
- Italia (Italiano)
- Luxembourg (English)
- Netherlands (English)
- Norway (English)
- Österreich (Deutsch)
- Portugal (English)
- Sweden (English)
- Switzerland
- United Kingdom (English)