回帰用の畳み込みニューラル ネットワークの学習
この例では、畳み込みニューラル ネットワークに学習させ、手書きの数字の回転角度を予測する方法を示します。
回帰タスクには、離散クラス ラベルの代わりに連続的な数値の予測が含まれます。この例では、回帰用の畳み込みニューラル ネットワーク アーキテクチャを構築し、ネットワークに学習させ、学習済みネットワークを使用して手書きの数字の回転角度を予測します。
次の図は、回帰ニューラル ネットワークを通るイメージ データの流れを示しています。

データの読み込み
データ セットには、手書きの数字の合成イメージと各イメージに対応する回転角度 (度単位) が含まれています。
MAT ファイル DigitsDataTrain.mat および DigitsDataTest.mat から学習データとテスト データをそれぞれ読み込みます。変数 anglesTrain および anglesTest は回転角度 (度単位) です。学習データ セットとテスト データ セットにはそれぞれ、5000 個のイメージが含まれています。
load DigitsDataTrain load DigitsDataTest
学習イメージをいくつか表示します。
numObservations = size(XTrain,4); idx = randperm(numObservations,49); I = imtile(XTrain(:,:,:,idx)); figure imshow(I);

この例にサポート ファイルとして添付されている関数 trainingPartitions を使用して、XTrain と anglesTrain を学習区画と検証区画に分割します。この関数にアクセスするには、例をライブ スクリプトとして開きます。学習データの 15% を検証用に残しておきます。
[idxTrain,idxValidation] = trainingPartitions(numObservations,[0.85 0.15]); XValidation = XTrain(:,:,:,idxValidation); anglesValidation = anglesTrain(idxValidation); XTrain = XTrain(:,:,:,idxTrain); anglesTrain = anglesTrain(idxTrain);
ニューラル ネットワーク アーキテクチャの定義
ニューラル ネットワーク アーキテクチャを定義します。
イメージ入力用に、イメージ入力層を指定する。
フィルター数を増やして 4 つの convolution-batchnorm-ReLU ブロックを指定する。
各ブロック間に、プーリング領域とサイズ 2 のストライドをもつ平均プーリング層を指定する。
回帰の場合、応答の数と一致する出力サイズをもつ全結合層を含めます。
この例では、学習プロセスで
NormalizeTargets学習オプション (R2026a で導入) を使用し、学習ターゲットを自動的に正規化します。正規化されたターゲットを使用することで、学習が安定し、正規化されたターゲットにほぼ一致する学習予測が得られます。予測時にのみ、正規化されていない値の空間でニューラル ネットワークの出力予測を行うには、逆正規化層 (R2026a で導入) を含めます。R2026a より前: 学習を安定させるために、ニューラル ネットワークに学習させる前にターゲットを手動で正規化します。
numResponses = size(anglesTrain,2);
layers = [
imageInputLayer([28 28 1])
convolution2dLayer(3,8,Padding="same")
batchNormalizationLayer
reluLayer
averagePooling2dLayer(2,Stride=2)
convolution2dLayer(3,16,Padding="same")
batchNormalizationLayer
reluLayer
averagePooling2dLayer(2,Stride=2)
convolution2dLayer(3,32,Padding="same")
batchNormalizationLayer
reluLayer
convolution2dLayer(3,32,Padding="same")
batchNormalizationLayer
reluLayer
fullyConnectedLayer(numResponses)
inverseNormalizationLayer];学習オプションの指定
学習オプションを指定します。オプションの中から選択するには、経験的解析が必要です。実験を実行してさまざまな学習オプションの構成を調べるには、実験マネージャーアプリを使用できます。
ミニバッチ サイズを 128 として学習させます。
NormalizeTargets引数 (R2026a で導入) を使用して、学習ターゲットを自動的に正規化します。R2026a より前: 学習を安定させるために、ニューラル ネットワークに学習させる前にターゲットを手動で正規化します。初期学習率として 0.001 を使用し、20 エポックごとに係数 0.1 を使用して学習率を下げる区分的学習率スケジュールを使用して、学習率を下げます。
エポックごとに、検証データを使用してニューラル ネットワークを検証します。
学習の進行状況をプロットに表示します。
詳細出力を無効にします。
miniBatchSize = 128; schedule = piecewiseLearnRate( ... DropFactor=0.1, ... Period=20); numIterationsPerEpoch = floor(numel(anglesTrain)/miniBatchSize); options = trainingOptions("sgdm", ... NormalizeTargets=true, ... MiniBatchSize=miniBatchSize, ... InitialLearnRate=1e-3, ... LearnRateSchedule=schedule, ... Shuffle="every-epoch", ... ValidationData={XValidation,anglesValidation}, ... ValidationFrequency=numIterationsPerEpoch, ... Plots="training-progress", ... Verbose=false);
ニューラル ネットワークの学習
関数trainnetを使用してニューラル ネットワークに学習させます。回帰の場合は、平均二乗誤差損失を使用します。既定では、関数 trainnet は利用可能な GPU がある場合にそれを使用します。GPU を使用するには、Parallel Computing Toolbox™ ライセンスとサポートされている GPU デバイスが必要です。サポートされているデバイスの詳細については、GPU 計算の要件 (Parallel Computing Toolbox)を参照してください。そうでない場合、関数は CPU を使用します。実行環境を指定するには、ExecutionEnvironment 学習オプションを使用します。
net = trainnet(XTrain,anglesTrain,layers,"mse",options);
ネットワークのテスト
testnet関数を使用してニューラル ネットワークをテストします。回帰の場合、平方根平均二乗誤差 (RMSE) を評価します。既定では、testnet 関数は利用可能な GPU がある場合にそれを使用します。実行環境を手動で選択するには、testnet 関数の ExecutionEnvironment 引数を使用します。
rmse = testnet(net,XTest,anglesTest,"rmse")rmse = 7.6861
テスト データを使用して予測を行い、その予測をターゲットと比較したときの精度をプロットに表示します。関数minibatchpredictを使用して予測を行います。既定では、関数 minibatchpredict は利用可能な GPU がある場合にそれを使用します。
YTest = minibatchpredict(net,XTest);
ターゲットに対する予測値をプロットします。
figure scatter(YTest,anglesTest,"+") xlabel("Prediction") ylabel("Target") hold on plot([-60 60], [-60 60],"r--")

新しいデータでの予測の実行
ニューラル ネットワークを使用して、最初のテスト イメージで予測を行います。単一のイメージを使用して予測を行うには、関数predictを使用します。GPU を使用するには、まずデータを gpuArray に変換します。
X = XTest(:,:,:,1); if canUseGPU X = gpuArray(X); end Y = predict(net,X)
Y = single
33.0647
figure
imshow(X)
title("Angle: " + gather(Y))
参考
trainnet | trainingOptions | dlnetwork