freezeParameters
ONNXParameters
の学習可能なネットワーク パラメーターを学習不能に変換する
説明
params = freezeParameters(
は、params
,names
)ONNXParameters
オブジェクト params
の names
で指定されたネットワーク パラメーターを凍結します。関数は、指定されたパラメーターを入力引数 params
の params.Learnables
から出力引数 params
の params.Nonlearnables
に移動します。
例
SqueezeNet 畳み込みニューラル ネットワークを関数としてインポートし、事前学習済みのネットワークを転移学習で微調整し、新しいイメージ コレクションで分類を実行します。
この例では、いくつかの補助関数を使用しています。これらの関数のコードを見るには、補助関数を参照してください。
新しいイメージを解凍してイメージ データストアとして読み込みます。imageDatastore
は、フォルダー名に基づいてイメージに自動的にラベルを付け、データを ImageDatastore
オブジェクトとして格納します。イメージ データストアを使用すると、メモリに収まらないデータなどの大きなイメージ データを格納し、畳み込みニューラル ネットワークの学習中にイメージをバッチ単位で効率的に読み取ることができます。ミニバッチのサイズを指定します。
unzip("MerchData.zip"); miniBatchSize = 8; imds = imageDatastore("MerchData", ... IncludeSubfolders=true, ... LabelSource="foldernames", ... ReadSize=miniBatchSize);
このデータ セットは小さく、75 個の学習イメージが含まれています。いくつかのサンプル イメージを表示します。
numImages = numel(imds.Labels); idx = randperm(numImages,16); figure for i = 1:16 subplot(4,4,i) I = readimage(imds,idx(i)); imshow(I) end
学習セットを抽出し、カテゴリカル分類ラベルを one-hot 符号化します。
XTrain = readall(imds); XTrain = single(cat(4,XTrain{:})); YTrain_categ = categorical(imds.Labels); YTrain = onehotencode(YTrain_categ,2)';
データ内のクラスの数を判定します。
classes = categories(YTrain_categ); numClasses = numel(classes)
numClasses = 5
SqueezeNet は、ImageNet データベースの 100 万個を超えるイメージで学習を行った畳み込みニューラル ネットワークです。結果として、このネットワークは広範囲のイメージに対する豊富な特徴表現を学習しています。このネットワークは、イメージを 1000 個のオブジェクト カテゴリ (キーボード、マウス、鉛筆、多くの動物など) に分類できます。
事前学習済みの SqueezeNet ネットワークを関数としてインポートします。
squeezenetONNX() params = importONNXFunction("squeezenet.onnx","squeezenetFcn")
Function containing the imported ONNX network architecture was saved to the file squeezenetFcn.m. To learn how to use this function, type: help squeezenetFcn.
params = ONNXParameters with properties: Learnables: [1×1 struct] Nonlearnables: [1×1 struct] State: [1×1 struct] NumDimensions: [1×1 struct] NetworkFunctionName: 'squeezenetFcn'
params
は、ネットワーク パラメーターを含む ONNXParameters
オブジェクトです。squeezenetFcn
は、ネットワーク アーキテクチャを含むモデル関数です。importONNXFunction
は、squeezenetFcn
を現在のフォルダーに保存します。
新しい学習セットに対する事前学習済みネットワークの分類精度を計算します。
accuracyBeforeTraining = getNetworkAccuracy(XTrain,YTrain,params);
fprintf("%.2f accuracy before transfer learning\n",accuracyBeforeTraining);
0.01 accuracy before transfer learning
非常に低い精度です。
params.Learnables
と入力して、ネットワークの学習可能なパラメーターを表示します。重み (W
) やバイアス (B
) など、畳み込み層や全結合層のこれらのパラメーターは、学習時にネットワークによって更新されます。学習不能なパラメーターは、学習時に一定のままとなります。
事前学習済みネットワークの最後の 2 つの学習可能なパラメーターは、1000 個のクラスに対して構成されています。
conv10_W: [1×1×512×1000 dlarray]
conv10_B: [1000×1 dlarray]
パラメーター conv10_W
と conv10_B
は、新しい分類問題に対して微調整しなければなりません。パラメーターを初期化して、5 つのクラスを分類するようにパラメーターを移動します。
params.Learnables.conv10_W = rand(1,1,512,5); params.Learnables.conv10_B = rand(5,1);
ネットワークのすべてのパラメーターを凍結して、それらを学習不能なパラメーターに変換します。多くの初期層の重みを凍結することで、凍結層の勾配を計算する必要をなくし、ネットワークの学習を大幅に高速化できます。
params = freezeParameters(params,"all");
ネットワークの最後の 2 つのパラメーターを凍結解除して、それらを学習可能なパラメーターに変換します。
params = unfreezeParameters(params,"conv10_W"); params = unfreezeParameters(params,"conv10_B");
ネットワークの学習準備が整いました。学習オプションを指定します。
velocity = []; numEpochs = 5; miniBatchSize = 16; initialLearnRate = 0.01; momentum = 0.9; decay = 0.01;
学習の進行状況モニター用に合計反復回数を計算します。
numObservations = size(YTrain,2); numIterationsPerEpoch = floor(numObservations./miniBatchSize); numIterations = numEpochs*numIterationsPerEpoch;
TrainingProgressMonitor
オブジェクトを初期化します。監視オブジェクトを作成するとタイマーが開始されるため、学習ループの直後でオブジェクトを作成するようにしてください。
monitor = trainingProgressMonitor(Metrics="Loss"); monitor.Info = ["LearningRate","Epoch","Iteration"]; monitor.XLabel = "Iteration";
ネットワークに学習をさせます。
epoch = 0; iteration = 0; executionEnvironment = "cpu"; % Change to "gpu" to train on a GPU. % Loop over epochs. while epoch < numEpochs && ~monitor.Stop epoch = epoch + 1; % Shuffle data. idx = randperm(numObservations); XTrain = XTrain(:,:,:,idx); YTrain = YTrain(:,idx); % Loop over mini-batches. i = 0; while i < numIterationsPerEpoch && ~monitor.Stop i = i + 1; iteration = iteration + 1; % Read mini-batch of data. idx = (i-1)*miniBatchSize+1:i*miniBatchSize; X = XTrain(:,:,:,idx); Y = YTrain(:,idx); % If training on a GPU, then convert data to gpuArray. if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu" X = gpuArray(X); end % Evaluate the model gradients and loss using dlfeval and the % modelGradients function. [gradients,loss,state] = dlfeval(@modelGradients,X,Y,params); params.State = state; % Determine the learning rate for the time-based decay learning rate schedule. learnRate = initialLearnRate/(1 + decay*iteration); % Update the network parameters using the SGDM optimizer. [params.Learnables,velocity] = sgdmupdate(params.Learnables,gradients,velocity,learnRate); % Update the training progress monitor. recordMetrics(monitor,iteration,Loss=loss); updateInfo(monitor,Epoch=epoch,LearningRate=learnRate); monitor.Progress = 100 * iteration/numIterations; end end
微調整後のネットワークの分類精度を計算します。
accuracyAfterTraining = getNetworkAccuracy(XTrain,YTrain,params);
fprintf("%.2f accuracy after transfer learning\n",accuracyAfterTraining);
0.99 accuracy after transfer learning
補助関数
このセクションでは、この例で使用されている補助関数のコードを示します。
関数 getNetworkAccuracy
は、分類精度を計算することによってネットワーク性能を評価します。
function accuracy = getNetworkAccuracy(X,Y,onnxParams) N = size(X,4); Ypred = squeezenetFcn(X,onnxParams,Training=false); [~,YIdx] = max(Y,[],1); [~,YpredIdx] = max(Ypred,[],1); numIncorrect = sum(abs(YIdx-YpredIdx) > 0); accuracy = 1 - numIncorrect/N; end
関数 modelGradients
は、損失と勾配を計算します。
function [grad, loss, state] = modelGradients(X,Y,onnxParams) [y,state] = squeezenetFcn(X,onnxParams,Training=true); loss = crossentropy(y,Y,DataFormat="CB"); grad = dlgradient(loss,onnxParams.Learnables); end
関数 squeezenetONNX
は、SqueezeNet ネットワークの ONNX モデルを生成します。
function squeezenetONNX() exportONNXNetwork(squeezenet,"squeezenet.onnx"); end
入力引数
ネットワーク パラメーター。ONNXParameters
オブジェクトとして指定します。params
には、インポートされた ONNX™ モデルのネットワーク パラメーターが格納されます。
凍結するパラメーターの名前。'all'
または string 配列として指定します。names
を 'all'
に設定すると、すべての学習可能なパラメーターを凍結します。1 行 k
列の string 配列 names
でパラメーター名を定義することにより、k
個の学習可能なパラメーターを凍結します。
例: 'all'
例: ["gpu_0_sl_pred_b_0", "gpu_0_sl_pred_w_0"]
データ型: char
| string
出力引数
ネットワーク パラメーター。ONNXParameters
オブジェクトとして返されます。params
には、freezeParameters
によって更新されたネットワーク パラメーターが格納されます。
バージョン履歴
R2020b で導入
MATLAB Command
You clicked a link that corresponds to this MATLAB command:
Run the command by entering it in the MATLAB Command Window. Web browsers do not support MATLAB commands.
Web サイトの選択
Web サイトを選択すると、翻訳されたコンテンツにアクセスし、地域のイベントやサービスを確認できます。現在の位置情報に基づき、次のサイトの選択を推奨します:
また、以下のリストから Web サイトを選択することもできます。
最適なサイトパフォーマンスの取得方法
中国のサイト (中国語または英語) を選択することで、最適なサイトパフォーマンスが得られます。その他の国の MathWorks のサイトは、お客様の地域からのアクセスが最適化されていません。
南北アメリカ
- América Latina (Español)
- Canada (English)
- United States (English)
ヨーロッパ
- Belgium (English)
- Denmark (English)
- Deutschland (Deutsch)
- España (Español)
- Finland (English)
- France (Français)
- Ireland (English)
- Italia (Italiano)
- Luxembourg (English)
- Netherlands (English)
- Norway (English)
- Österreich (Deutsch)
- Portugal (English)
- Sweden (English)
- Switzerland
- United Kingdom (English)