Main Content

このページの内容は最新ではありません。最新版の英語を参照するには、ここをクリックします。

カスタム学習ループ、損失関数、およびネットワークの定義

ほとんどの深層学習タスクでは、事前学習済みのニューラル ネットワークを使用して独自のデータに適応させることができます。転移学習を使用して、畳み込みニューラル ネットワークの再学習を行い、新しい一連のイメージを分類する方法を示す例については、新しいイメージを分類するための深層学習ネットワークの学習を参照してください。または、関数 trainnet、関数 trainNetwork、および関数 trainingOptions を使用して、ニューラル ネットワークを作成し、これにゼロから学習させることができます。

タスクに必要な学習オプションが関数 trainingOptions に用意されていない場合、自動微分を使用してカスタム学習ループを作成できます。詳細については、カスタム学習ループ向けの深層学習ネットワークの定義を参照してください。

タスクに必要な層 (損失関数を指定する出力層を含む) が Deep Learning Toolbox™ に用意されていない場合、カスタム層を作成できます。詳細については、カスタム深層学習層の定義を参照してください。出力層を使用して指定できない損失関数の場合、カスタム学習ループで損失を指定できます。詳細については、損失関数の指定を参照してください。層グラフを使用して作成できないネットワークの場合、カスタム ネットワークを関数として定義できます。詳細については、モデル関数としてのネットワークの定義を参照してください。

どのタスクでどの学習手法を使用するかについての詳細は、MATLAB による深層学習モデルの学習を参照してください。

カスタム学習ループ向けの深層学習ネットワークの定義

dlnetwork オブジェクトとしてのネットワークの定義

ほとんどのタスクでは、関数 trainingOptionstrainnet、および trainNetwork を使用して学習アルゴリズムの詳細を制御できます。タスク (たとえば、カスタム学習率スケジュール) に必要なオプションが関数 trainingOptions に用意されていない場合、dlnetwork オブジェクトを使用して自分でカスタム学習ループを定義できます。dlnetwork オブジェクトを使用すると、自動微分を使用して、層グラフとして指定したネットワークに学習させることができます。

層グラフとして指定するネットワークの場合、関数 dlnetwork を直接使用して、層グラフから dlnetwork オブジェクトを作成できます。

net = dlnetwork(lgraph);

カスタム学習率スケジュールでネットワークに学習させる方法を示す例については、カスタム学習ループを使用したネットワークの学習を参照してください。

モデル関数としてのネットワークの定義

層グラフを使用して作成できないアーキテクチャ (重みの共有が必要なツイン ニューラル ネットワークなど) の場合、[Y1,...,YM] = model(parameters,X1,...,XN) という形式の関数としてモデルを定義できます。ここで、parameters にはネットワーク パラメーターが含まれ、X1,...,XNN 個のモデル入力の入力データに対応し、Y1,...,YMM 個のモデル出力に対応します。関数として定義される深層学習モデルに学習させるには、カスタム学習ループを使用します。例については、モデル関数を使用したネットワークの学習を参照してください。

関数として深層学習モデルを定義する場合は、層の重みを手動で初期化しなければなりません。詳細については、モデル関数の学習可能パラメーターの初期化を参照してください。

カスタム ネットワークを関数として定義するには、モデル関数が自動微分をサポートしていなければなりません。使用できる深層学習演算は次のとおりです。ここには一部の関数のみを示します。dlarray 入力をサポートしている関数の詳細な一覧については、dlarray をサポートする関数の一覧を参照してください。

関数説明
attentionattention 演算は、重み付き乗算を使用して入力の一部に焦点を当てます。
avgpool平均プーリング演算は、入力をプーリング領域に分割し、各領域の平均値を計算することによって、ダウンサンプリングを実行します。
batchnormバッチ正規化演算は、観測値全体における入力データの正規化を、各チャネルについて個別に行います。畳み込みニューラル ネットワークの学習速度を上げ、ネットワークの初期化に対する感度を下げるには、relu など、畳み込み演算と非線形演算の間のバッチ正規化を使用します。
crossentropy交差エントロピー演算は、単一ラベルおよび複数ラベルの分類タスクについて、ネットワーク予測とターゲット値の間の交差エントロピー損失を計算します。
crosschannelnormクロスチャネル正規化演算は、異なるチャネルの局所応答を使用して各活性化を正規化します。通常、クロスチャネル正規化は relu 演算に続きます。クロスチャネル正規化は局所応答正規化とも呼ばれます。
ctcCTC 演算は、非整列シーケンス間のコネクショニスト時間分類 (CTC) 損失を計算します。
dlconv畳み込み演算は、入力データにスライディング フィルターを適用します。関数 dlconv は、深層学習畳み込み、グループ化された畳み込み、チャネル方向に分離可能な畳み込みで使用します。
dlode45ニューラル常微分方程式 (ODE) 演算は、指定された ODE の解を返します。
dltranspconv転置畳み込み演算は、特徴マップをアップサンプリングします。
embed組み込み演算は、数値インデックスを数値ベクトルに変換します。ここで、インデックスは離散データに対応します。埋め込みを使用して、categorical 値や単語などの離散データを数値ベクトルにマッピングします。
fullyconnect全結合演算は、入力に重み行列を乗算してから、バイアス ベクトルを加算します。
geluガウス誤差線形単位 (GELU) 活性化演算は、ガウス確率分布に従って入力を重み付けします。
groupnormグループ正規化演算は、グループ化されたチャネル サブセット全体における入力データの正規化を、各観測値について個別に行います。畳み込みニューラル ネットワークの学習速度を上げ、ネットワークの初期化に対する感度を下げるには、relu など、畳み込み演算と非線形演算の間のグループ正規化を使用します。
gruゲート付き回帰型ユニット (GRU) 演算では、時系列データとシーケンス データのタイム ステップ間の依存関係をネットワークに学習させることができます。
huberHuber 演算は、回帰タスクのネットワーク予測とターゲット値の間の Huber 損失を計算します。'TransitionPoint' オプションが 1 の場合、これは "滑らかな L1 損失" とも呼ばれます。
instancenormインスタンス正規化演算は、各チャネルにおける入力データの正規化を、各観測値について個別に行います。畳み込みニューラル ネットワークの学習の収束性能を上げ、ネットワークのハイパーパラメーターに対する感度を下げるには、relu など、畳み込み演算と非線形演算の間のインスタンス正規化を使用します。
l1lossL1 損失演算は、ネットワーク予測とターゲット値を指定して L1 損失を計算します。Reduction オプションが "sum" で、NormalizationFactor オプションが "batch-size" のときの計算値は平均絶対誤差 (MAE) と呼ばれます。
l2lossL2 損失演算は、ネットワーク予測とターゲット値を指定して L2 損失を (L2 ノルムの 2 乗に基づいて) 計算します。Reduction オプションが "sum" で、NormalizationFactor オプションが "batch-size" のときの計算値は平均二乗誤差 (MSE) と呼ばれます。
layernormレイヤー正規化演算は、チャネル全体における入力データの正規化を、各観測値について個別に行います。再帰型多層パーセプトロン ニューラル ネットワークの学習速度を上げ、ネットワークの初期化に対する感度を下げるには、LSTM 演算や全結合演算などの学習可能な演算を行った後に、レイヤー正規化を使用します。
leakyrelu漏洩 (leaky) 正規化線形ユニット (ReLU) 活性化演算は、非線形のしきい値処理を実行し、ゼロよりも小さい入力値を固定スケール係数で乗算します。
lstm長短期記憶 (LSTM) 演算では、時系列データおよびシーケンス データのタイム ステップ間の長期的な依存関係をネットワークに学習させることができます。
maxpool最大プーリング演算は、入力をプーリング領域に分割し、各領域の最大値を計算することによって、ダウンサンプリングを実行します。
maxunpool最大逆プーリング演算は、アップサンプリングとゼロを使ったパディングによって、最大プーリング演算の出力を逆プーリングします。
mse半平均二乗誤差演算は、回帰タスクのネットワーク予測とターゲット値の間の半平均二乗誤差損失を計算します。
onehotdecode

one-hot 復号化演算は、分類ネットワークの出力などの確率ベクトルを分類ラベルに復号化します。

入力 Adlarray にすることができます。A が書式化されている場合、関数はデータ形式を無視します。

relu正規化線形ユニット (ReLU) 活性化演算は、非線形のしきい値処理を実行し、ゼロよりも小さい入力値をゼロに設定します。
sigmoidシグモイド活性化演算は、入力データにシグモイド関数を適用します。
softmaxソフトマックス活性化演算は、入力データのチャネルの次元にソフトマックス関数を適用します。

損失関数の指定

カスタム学習ループを使用する場合、モデル勾配関数で損失を計算しなければなりません。ネットワークの重みを更新するための勾配を計算する際には、損失値を使用します。損失を計算するには、次の関数を使用できます。

関数説明
softmaxソフトマックス活性化演算は、入力データのチャネルの次元にソフトマックス関数を適用します。
sigmoidシグモイド活性化演算は、入力データにシグモイド関数を適用します。
crossentropy交差エントロピー演算は、単一ラベルおよび複数ラベルの分類タスクについて、ネットワーク予測とターゲット値の間の交差エントロピー損失を計算します。
l1lossL1 損失演算は、ネットワーク予測とターゲット値を指定して L1 損失を計算します。Reduction オプションが "sum" で、NormalizationFactor オプションが "batch-size" のときの計算値は平均絶対誤差 (MAE) と呼ばれます。
l2lossL2 損失演算は、ネットワーク予測とターゲット値を指定して L2 損失を (L2 ノルムの 2 乗に基づいて) 計算します。Reduction オプションが "sum" で、NormalizationFactor オプションが "batch-size" のときの計算値は平均二乗誤差 (MSE) と呼ばれます。
huberHuber 演算は、回帰タスクのネットワーク予測とターゲット値の間の Huber 損失を計算します。'TransitionPoint' オプションが 1 の場合、これは "滑らかな L1 損失" とも呼ばれます。
mse半平均二乗誤差演算は、回帰タスクのネットワーク予測とターゲット値の間の半平均二乗誤差損失を計算します。
ctcCTC 演算は、非整列シーケンス間のコネクショニスト時間分類 (CTC) 損失を計算します。

または、loss = myLoss(Y,T) という形式の関数を作成して、カスタム損失関数を使用できます。ここで、YT はそれぞれネットワーク予測とターゲットに対応し、loss は返される損失です。

カスタム損失関数を使用してイメージを生成する敵対的生成ネットワーク (GAN) に学習させる方法を示す例については、敵対的生成ネットワーク (GAN) の学習を参照してください。

自動微分を使用した学習可能なパラメーターの更新

カスタム学習ループを使用して深層学習モデルに学習させる場合、ソフトウェアは、学習可能なパラメーターについての損失を最小化します。損失を最小化するために、ソフトウェアは、学習可能なパラメーターについての損失の勾配を使用します。自動微分を使用してこれらの勾配を計算するには、モデル勾配関数を定義しなければなりません。

モデル損失関数の定義

dlnetwork オブジェクトとして指定されるモデルでは、[loss,gradients] = modelLoss(net,X,T) という形式の関数を作成します。ここで、net はネットワーク、X はネットワークの入力で、T にはターゲットが格納され、lossgradients にはそれぞれ損失と勾配が返されます。オプションで、損失関数が必要とする追加情報などを追加引数として勾配関数に渡すことや、更新されたネットワークの状態などを追加引数として返すことができます。

関数として指定されるモデルでは、[loss,gradients] = modelLoss(parameters,X,T) という形式の関数を作成します。ここで、parameters には学習可能なパラメーターが格納され、X はモデルの入力で、T にはターゲットが格納され、lossgradients にはそれぞれ損失と勾配が返されます。オプションで、損失関数が必要とする追加情報などを追加引数として勾配関数に渡すことや、更新されたモデルの状態などを追加引数として返すことができます。

カスタム学習ループに対するモデル損失関数の定義の詳細については、カスタム学習ループのモデル損失関数の定義を参照してください。

学習可能なパラメーターの更新

自動微分を使用してモデル損失関数を評価するには、自動微分を有効にして関数を評価する関数 dlfeval を使用します。dlfeval の最初の入力では、関数ハンドルとして指定されるモデル損失関数を渡します。続く入力では、モデル損失関数に必要な変数を渡します。関数 dlfeval の出力では、モデル損失関数と同じ出力を指定します。

勾配を使用して学習可能なパラメーターを更新するには、次の関数を使用できます。

関数説明
adamupdate適応モーメント推定 (Adam) を使用してパラメーターを更新する
rmspropupdate平方根平均二乗伝播 (RMSProp) を使用してパラメーターを更新する
sgdmupdateモーメンタム項付き確率的勾配降下法 (SGDM) を使用してパラメーターを更新する
lbfgsupdate記憶制限 BFGS (L-BFGS) を使用してパラメーターを更新する
dlupdateカスタム関数を使用してパラメーターを更新する

参考

| | |

関連するトピック