Main Content

このページの内容は最新ではありません。最新版の英語を参照するには、ここをクリックします。

学習後の浅層ニューラル ネットワークの性能分析

このトピックでは、典型的な浅層ニューラル ネットワークのワークフローの一部について説明します。詳細とその他のステップについては、浅層の多層ニューラル ネットワークと逆伝播学習を参照してください。深層学習の進行状況を監視する方法については、深層学習における学習の進行状況の監視を参照してください。

浅層の多層ニューラル ネットワークの学習と適用の学習が完了すると、ネットワーク性能を確認し、学習プロセス、ネットワーク アーキテクチャ、またはデータセットに対して何らかの変更が必要かどうかを判断できます。まず、学習関数から 2 番目の引数として返される学習記録 tr を確認します。

tr
tr = struct with fields:
        trainFcn: 'trainlm'
      trainParam: [1x1 struct]
      performFcn: 'mse'
    performParam: [1x1 struct]
        derivFcn: 'defaultderiv'
       divideFcn: 'dividerand'
      divideMode: 'sample'
     divideParam: [1x1 struct]
        trainInd: [2 3 5 6 9 10 11 13 14 15 18 19 20 22 23 24 25 29 30 31 33 35 36 38 39 40 41 44 45 46 47 48 49 50 51 52 54 55 56 57 58 59 62 64 65 66 68 70 73 76 77 79 80 81 84 85 86 88 90 91 92 93 94 95 96 97 98 99 100 101 102 103 ... ] (1x176 double)
          valInd: [1 8 17 21 27 28 34 43 63 71 72 74 75 83 106 124 125 134 140 155 157 158 162 165 166 175 177 181 187 191 196 201 205 212 233 243 245 250]
         testInd: [4 7 12 16 26 32 37 42 53 60 61 67 69 78 82 87 89 104 105 110 111 112 133 135 149 151 153 163 170 189 203 216 217 222 226 235 246 247]
            stop: 'Training finished: Met validation criterion'
      num_epochs: 9
       trainMask: {[NaN 1 1 NaN 1 1 NaN NaN 1 1 1 NaN 1 1 1 NaN NaN 1 1 1 NaN 1 1 1 1 NaN NaN NaN 1 1 1 NaN 1 NaN 1 1 NaN 1 1 1 1 NaN NaN 1 1 1 1 1 1 1 1 1 NaN 1 1 1 1 1 1 NaN NaN 1 NaN 1 1 1 NaN 1 NaN 1 NaN NaN 1 NaN NaN 1 1 NaN 1 1 ... ] (1x252 double)}
         valMask: {[1 NaN NaN NaN NaN NaN NaN 1 NaN NaN NaN NaN NaN NaN NaN NaN 1 NaN NaN NaN 1 NaN NaN NaN NaN NaN 1 1 NaN NaN NaN NaN NaN 1 NaN NaN NaN NaN NaN NaN NaN NaN 1 NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN ... ] (1x252 double)}
        testMask: {[NaN NaN NaN 1 NaN NaN 1 NaN NaN NaN NaN 1 NaN NaN NaN 1 NaN NaN NaN NaN NaN NaN NaN NaN NaN 1 NaN NaN NaN NaN NaN 1 NaN NaN NaN NaN 1 NaN NaN NaN NaN 1 NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN 1 NaN NaN NaN NaN NaN ... ] (1x252 double)}
      best_epoch: 3
            goal: 0
          states: {'epoch'  'time'  'perf'  'vperf'  'tperf'  'mu'  'gradient'  'val_fail'}
           epoch: [0 1 2 3 4 5 6 7 8 9]
            time: [2.1601 2.2525 2.2583 2.2840 2.2894 2.3046 2.3254 2.3310 2.3365 2.3429]
            perf: [672.2031 94.8128 43.7489 12.3078 9.7063 8.9212 8.0412 7.3500 6.7890 6.3064]
           vperf: [675.3788 76.9621 74.0752 16.6857 19.9424 23.4096 26.6791 29.1562 31.1592 32.9227]
           tperf: [599.2224 97.7009 79.1240 24.1796 31.6290 38.4484 42.7637 44.4194 44.8848 44.3171]
              mu: [1.0000e-03 0.0100 0.0100 0.1000 0.1000 0.1000 0.1000 0.1000 0.1000 0.1000]
        gradient: [2.4114e+03 867.8889 301.7333 142.1049 12.4011 85.0504 49.4147 17.4011 15.7749 14.6346]
        val_fail: [0 0 0 0 1 2 3 4 5 6]
       best_perf: 12.3078
      best_vperf: 16.6857
      best_tperf: 24.1796

この構造体には、ネットワークの学習に関するすべての情報が含まれます。たとえば、tr.trainIndtr.valInd、および tr.testInd には、学習セット、検証セット、およびテスト セットにそれぞれ使用されたデータ点のインデックスが含まれています。同じデータ分割でネットワークの再学習を行う場合は、net.divideFcn'divideInd' に、net.divideParam.trainIndtr.trainInd に、net.divideParam.valIndtr.valInd に、net.divideParam.testIndtr.testInd に設定します。

tr 構造体には学習中に、性能関数の値、勾配の大きさなど、いくつかの変数も記録されます。plotperf コマンドを使用して、この学習記録を使って性能の進行状況をプロットできます。

plotperf(tr)

Figure Training Record contains an axes object. The axes object with title Performance is 6.3064, xlabel 9 Epochs, ylabel Performance contains 4 objects of type line. These objects represent Test, Validation, Train.

tr.best_epoch プロパティは検証性能が最小値に到達したときの反復を示します。この学習は、さらに 6 回の反復を続けてから、学習を停止しました。

この図では、学習中に大きな問題が発生したかどうかはわかりません。検証とテストの曲線は非常に似ています。テストの曲線は検証の曲線が上昇する前に大幅に上昇しました。これは、何らかの過適合が発生した可能性を示しています。

ネットワークの検証の次のステップは、ネットワークの出力とターゲットとの関係を示す回帰プロットを作成することです。学習が完璧であれば、ネットワーク出力とターゲットは厳密に等しくなりますが、実際には関係がぴったり一致することはほとんどありません。体脂肪の例では、以下のコマンドを使って回帰プロットを作成できます。最初のコマンドでは、データセットのすべての入力に対する学習済みネットワークの応答が計算されます。続く 6 つのコマンドでは、学習、検証、およびテストの各サブセットに属する出力とターゲットが抽出されます。最後のコマンドでは、学習、テスト、検証に対する 3 つの回帰プロットが作成されます。

bodyfatOutputs = net(bodyfatInputs);
trOut = bodyfatOutputs(tr.trainInd);
vOut = bodyfatOutputs(tr.valInd);
tsOut = bodyfatOutputs(tr.testInd);
trTarg = bodyfatTargets(tr.trainInd);
vTarg = bodyfatTargets(tr.valInd);
tsTarg = bodyfatTargets(tr.testInd);
plotregression(trTarg, trOut, 'Train', vTarg, vOut, 'Validation', tsTarg, tsOut, 'Testing')

Figure Regression (plotregression) contains 3 axes objects. Axes object 1 with title Train: R=0.91107, xlabel Target, ylabel Output ~= 0.82*Target + 2.7 contains 3 objects of type line. One or more of the lines displays its values using only markers These objects represent Y = T, Fit, Data. Axes object 2 with title Validation: R=0.8456, xlabel Target, ylabel Output ~= 0.82*Target + 3.8 contains 3 objects of type line. One or more of the lines displays its values using only markers These objects represent Y = T, Fit, Data. Axes object 3 with title Testing: R=0.87068, xlabel Target, ylabel Output ~= 0.93*Target + 1.8 contains 3 objects of type line. One or more of the lines displays its values using only markers These objects represent Y = T, Fit, Data.

3 つのプロットは、学習データ、検証データ、テスト データを表します。各プロットの破線は、出力 = ターゲットとなる完璧な結果を表します。実線は、出力とターゲットとの間に最も当てはまる線形回帰直線を表します。R 値は、出力とターゲットとの関係を示します。R = 1 の場合、出力とターゲットの間に厳密な線形関係があることを示します。R がゼロに近づくと、出力とターゲットの間に線形関係はありません。

この例では、学習データがよく当てはめられていることがわかります。検証とテストの結果でも、大きな R 値を示しています。散布図は、特定のデータ点の適合度が低いことを示すのに便利です。たとえば、ネットワーク出力が 35 に近いデータ点がテスト セットにあり、その対応するターゲット値は約 12 であるとします。次の手順として、このデータ点を調査して外挿を表しているか (つまり、学習データセットの外側であるか) どうかを判定します。その場合は、学習セットに含める必要があり、テスト セットで使用する追加のデータを収集する必要があります。

結果の改良

ネットワークの精度が十分ではない場合、ネットワークを初期化して、再学習させることができます。フィードフォワード ネットワークを初期化するたびに、ネットワーク パラメーターが変わり、異なる解が生成される可能性があります。

net = init(net);
net = train(net, bodyfatInputs, bodyfatTargets);

2 つ目の方法として、隠れニューロンの数を 20 より多くすることができます。隠れ層のニューロンの数が多いほど、ネットワークで最適化できるパラメーターの数が増えるため、ネットワークの柔軟性が高くなります (層のサイズは徐々に大きくします。作成する隠れ層が大きすぎると、問題の特徴付けが不十分になる可能性があり、ネットワークはパラメーターを制約するデータ ベクトルよりも多い数のパラメーターを最適化しなければなりません。

3 つ目の方法は、別の学習関数を試すことです。trainbr によるベイズ正則化学習では、早期停止を使用する場合よりも優れた汎化機能が得られることがあります。

最後は、追加の学習データを使用してみることです。ネットワークに追加データを与えると、新しいデータに対して適切に汎化を行うネットワークを生成できる可能性が高くなります。