カスタム学習ループ、損失関数、およびネットワークの定義
ほとんどの深層学習タスクでは、事前学習済みのニューラル ネットワークを使用して独自のデータに適応させることができます。転移学習を使用して、畳み込みニューラル ネットワークの再学習を行い、新しい一連のイメージを分類する方法を示す例については、Retrain Neural Network to Classify New Imagesを参照してください。または、関数 trainnet
と関数 trainingOptions
を使用してニューラル ネットワークを作成し、これにゼロから学習させることができます。
タスクに必要な学習オプションが関数 trainingOptions
に用意されていない場合、自動微分を使用してカスタム学習ループを作成できます。詳細については、カスタム学習ループを使用したネットワークの学習を参照してください。
タスクに必要な損失関数が関数 trainnet
に用意されていない場合、カスタム損失関数を関数ハンドルとして trainnet
に指定できます。損失関数が予測とターゲットよりも多くの入力を必要とする場合 (たとえば、損失関数がニューラル ネットワークまたは追加の入力にアクセスする必要がある場合)、カスタム学習ループを使用してモデルに学習させます。詳細については、カスタム学習ループを使用したネットワークの学習を参照してください。
タスクに必要な層が Deep Learning Toolbox™ に用意されていない場合、カスタム層を作成できます。詳細については、カスタム深層学習層の定義を参照してください。層のネットワークとして指定できないモデルの場合は、モデルを関数として定義できます。詳細については、モデル関数を使用したネットワークの学習を参照してください。
どのタスクでどの学習手法を使用するかについての詳細は、MATLAB による深層学習モデルの学習を参照してください。
カスタム損失関数の定義
関数 trainnet
には、学習に使用できるいくつかの組み込み損失関数が用意されています。たとえば、損失関数の引数として "crossentropy"
と "mse"
をそれぞれ指定することで、分類にはクロスエントロピー損失を指定し、回帰には平均二乗誤差損失を指定することができます。
タスクに必要な損失関数が関数 trainnet
に用意されていない場合、カスタム損失関数を関数ハンドルとして trainnet
に指定できます。関数の構文は loss = f(Y,T)
でなければなりません。ここで、Y
と T
はそれぞれ予測とターゲットです。
カスタム損失関数の作成を支援するために、次の表の深層学習関数を使用できます。これらの関数を関数ハンドルとして直接関数 trainnet
に渡すこともできます。
関数 | 説明 |
---|---|
softmax | ソフトマックス活性化演算は、入力データのチャネルの次元にソフトマックス関数を適用します。 |
sigmoid | シグモイド活性化演算は、入力データにシグモイド関数を適用します。 |
crossentropy | クロスエントロピー演算は、単一ラベルおよび複数ラベルの分類タスクについて、ネットワーク予測とバイナリのターゲット値または one-hot 符号化されたターゲット値との間のクロスエントロピー損失を計算します。 |
indexcrossentropy | インデックス クロスエントロピー演算は、単一ラベル分類タスクについて、ネットワーク予測と整数クラス インデックスとして指定されたターゲットとの間のクロスエントロピー損失を計算します。 |
l1loss | L1 損失演算は、ネットワーク予測とターゲット値を指定して L1 損失を計算します。Reduction オプションが "sum" で、NormalizationFactor オプションが "batch-size" のときの計算値は平均絶対誤差 (MAE) と呼ばれます。 |
l2loss | L2 損失演算は、ネットワーク予測とターゲット値を指定して L2 損失を (L2 ノルムの 2 乗に基づいて) 計算します。Reduction オプションが "sum" で、NormalizationFactor オプションが "batch-size" のときの計算値は平均二乗誤差 (MSE) と呼ばれます。 |
huber | Huber 演算は、回帰タスクのネットワーク予測とターゲット値の間の Huber 損失を計算します。'TransitionPoint' オプションが 1 の場合、これは "滑らかな L1 損失" とも呼ばれます。 |
ctc | CTC 演算は、非整列シーケンス間のコネクショニスト時間分類 (CTC) 損失を計算します。 |
mse | 半平均二乗誤差演算は、回帰タスクのネットワーク予測とターゲット値の間の半平均二乗誤差損失を計算します。 |
損失関数が予測とターゲットよりも多くの入力を必要とする場合 (たとえば、損失関数がニューラル ネットワークまたは追加の入力にアクセスする必要がある場合)、カスタム学習ループを使用してモデルに学習させます。詳細については、カスタム学習ループの損失関数の定義を参照してください。例については、カスタム学習ループを使用したネットワークの学習を参照してください。
カスタム学習ループ用の深層学習モデルの定義
ほとんどのタスクでは、関数 trainingOptions
と関数 trainnet
を使用して学習アルゴリズムの詳細を制御できます。trainingOptions
関数がタスクに必要なオプション (カスタム ソルバーなど) を提供しない場合は、独自のカスタム学習ループを定義できます。
ニューラル ネットワークとしてのモデルの定義
配列または層のニューラル ネットワークとして指定できるモデルの場合は、モデルを dlnetwork
オブジェクトとして指定します。たとえば、カスタム学習ループ用のシンプルな LSTM ニューラル ネットワークを定義するには、次を使用します。
layers = [
sequenceInputLayer(3)
lstmLayer(100,OutputMode="last")
fullyConnectedLayer(4)
softmaxLayer];
net = dlnetwork(layers);
カスタム学習ループを使用してニューラル ネットワークに学習させるには、ネットワークを初期化しなければなりません。ニューラル ネットワークを初期化するには、関数 initialize
を使用します。
net = initialize(net);
カスタム学習ループを使用してニューラル ネットワークに学習させる方法の例については、カスタム学習ループを使用したネットワークの学習を参照してください。
関数としてのモデルの定義
配列または層のネットワークを使用して作成できないアーキテクチャの場合は、[Y1,...,YM] = model(parameters,X1,...,XN)
という形式の関数としてモデルを定義できます。ここで、parameters
にはネットワーク パラメーターが含まれ、X1,...,XN
は N
個のモデル入力の入力データに対応し、Y1,...,YM
は M
個のモデル出力に対応します。関数として定義される深層学習モデルに学習させるには、カスタム学習ループを使用します。例については、モデル関数を使用したネットワークの学習を参照してください。
関数として深層学習モデルを定義する場合は、学習可能なパラメーターを手動で初期化しなければなりません。詳細については、モデル関数の学習可能パラメーターの初期化を参照してください。
カスタム ネットワークを関数として定義するには、モデル関数が自動微分をサポートしていなければなりません。次の表に示す深層学習演算を使用できます。ここには一部の関数のみを示します。dlarray
入力をサポートしている関数の詳細な一覧については、dlarray をサポートする関数の一覧を参照してください。
関数 | 説明 |
---|---|
attention | attention 演算は、重み付き乗算を使用して入力の一部に焦点を当てます。 |
avgpool | 平均プーリング演算は、入力をプーリング領域に分割し、各領域の平均値を計算することによって、ダウンサンプリングを実行します。 |
batchnorm | バッチ正規化演算は、観測値全体における入力データの正規化を、各チャネルについて個別に行います。畳み込みニューラル ネットワークの学習速度を上げ、ネットワークの初期化に対する感度を下げるには、relu など、畳み込み演算と非線形演算の間のバッチ正規化を使用します。 |
crossentropy | クロスエントロピー演算は、単一ラベルおよび複数ラベルの分類タスクについて、ネットワーク予測とバイナリのターゲット値または one-hot 符号化されたターゲット値との間のクロスエントロピー損失を計算します。 |
indexcrossentropy (R2024b 以降) | インデックス クロスエントロピー演算は、単一ラベル分類タスクについて、ネットワーク予測と整数クラス インデックスとして指定されたターゲットとの間のクロスエントロピー損失を計算します。 |
crosschannelnorm | クロスチャネル正規化演算は、異なるチャネルの局所応答を使用して各活性化を正規化します。通常、クロスチャネル正規化は relu 演算に続きます。クロスチャネル正規化は局所応答正規化とも呼ばれます。 |
ctc | CTC 演算は、非整列シーケンス間のコネクショニスト時間分類 (CTC) 損失を計算します。 |
dlconv | 畳み込み演算は、入力データにスライディング フィルターを適用します。関数 dlconv は、深層学習畳み込み、グループ化された畳み込み、チャネル方向に分離可能な畳み込みで使用します。 |
dlode45 | ニューラル常微分方程式 (ODE) 演算は、指定された ODE の解を返します。 |
dltranspconv | 転置畳み込み演算は、特徴マップをアップサンプリングします。 |
embed | 組み込み演算は、数値インデックスを数値ベクトルに変換します。ここで、インデックスは離散データに対応します。埋め込みを使用して、categorical 値や単語などの離散データを数値ベクトルにマッピングします。 |
fullyconnect | 全結合演算は、入力に重み行列を乗算してから、バイアス ベクトルを加算します。 |
gelu | ガウス誤差線形単位 (GELU) 活性化演算は、ガウス確率分布に従って入力を重み付けします。 |
groupnorm | グループ正規化演算は、グループ化されたチャネル サブセット全体における入力データの正規化を、各観測値について個別に行います。畳み込みニューラル ネットワークの学習速度を上げ、ネットワークの初期化に対する感度を下げるには、relu など、畳み込み演算と非線形演算の間のグループ正規化を使用します。 |
gru | ゲート付き回帰型ユニット (GRU) 演算では、時系列データとシーケンス データのタイム ステップ間の依存関係をネットワークに学習させることができます。 |
huber | Huber 演算は、回帰タスクのネットワーク予測とターゲット値の間の Huber 損失を計算します。'TransitionPoint' オプションが 1 の場合、これは "滑らかな L1 損失" とも呼ばれます。 |
instancenorm | インスタンス正規化演算は、各チャネルにおける入力データの正規化を、各観測値について個別に行います。畳み込みニューラル ネットワークの学習の収束性能を上げ、ネットワークのハイパーパラメーターに対する感度を下げるには、relu など、畳み込み演算と非線形演算の間のインスタンス正規化を使用します。 |
l1loss | L1 損失演算は、ネットワーク予測とターゲット値を指定して L1 損失を計算します。Reduction オプションが "sum" で、NormalizationFactor オプションが "batch-size" のときの計算値は平均絶対誤差 (MAE) と呼ばれます。 |
l2loss | L2 損失演算は、ネットワーク予測とターゲット値を指定して L2 損失を (L2 ノルムの 2 乗に基づいて) 計算します。Reduction オプションが "sum" で、NormalizationFactor オプションが "batch-size" のときの計算値は平均二乗誤差 (MSE) と呼ばれます。 |
layernorm | レイヤー正規化演算は、チャネル全体における入力データの正規化を、各観測値について個別に行います。再帰型多層パーセプトロン ニューラル ネットワークの学習速度を上げ、ネットワークの初期化に対する感度を下げるには、LSTM 演算や全結合演算などの学習可能な演算を行った後に、レイヤー正規化を使用します。 |
leakyrelu | 漏洩 (leaky) 正規化線形ユニット (ReLU) 活性化演算は、非線形のしきい値処理を実行し、ゼロよりも小さい入力値を固定スケール係数で乗算します。 |
lstm | 長短期記憶 (LSTM) 演算では、時系列データおよびシーケンス データのタイム ステップ間の長期的な依存関係をネットワークに学習させることができます。 |
maxpool | 最大プーリング演算は、入力をプーリング領域に分割し、各領域の最大値を計算することによって、ダウンサンプリングを実行します。 |
maxunpool | 最大逆プーリング演算は、アップサンプリングとゼロを使ったパディングによって、最大プーリング演算の出力を逆プーリングします。 |
mse | 半平均二乗誤差演算は、回帰タスクのネットワーク予測とターゲット値の間の半平均二乗誤差損失を計算します。 |
onehotdecode | one-hot 復号化演算は、分類ネットワークの出力などの確率ベクトルを分類ラベルに復号化します。 入力 |
relu | 正規化線形ユニット (ReLU) 活性化演算は、非線形のしきい値処理を実行し、ゼロよりも小さい入力値をゼロに設定します。 |
sigmoid | シグモイド活性化演算は、入力データにシグモイド関数を適用します。 |
softmax | ソフトマックス活性化演算は、入力データのチャネルの次元にソフトマックス関数を適用します。 |
カスタム学習ループの損失関数の定義
深層ニューラル モデルの学習は最適化タスクです。深層学習モデルを関数 f(X;θ) と見なすことにより (ここで、X はモデルの入力、θ は学習可能なパラメーターのセット)、θ を最適化して学習データに基づく損失値を最小化できます。たとえば、与えらえた入力 X と対応するターゲット T に対して、予測 Y=f(X;θ) と T の間の誤差が最小になるように、学習可能なパラメーター θ を最適化します。
カスタム学習ループを使用して深層学習モデルに学習させるには、勾配降下法ベースの方法を使用して損失を最小化できます。たとえば、損失を最小化するように、モデルの学習可能なパラメーターを反復的に更新できます。たとえば、関数 lbfgsupdate
、adamupdate
、rmspropupdate
、および sgdmupdate
を使用して学習可能パラメーターを更新できます。これには、損失に対する学習可能パラメーターの勾配が必要です。これらの勾配を計算するには、自動微分を使用できます。モデルと学習データを取得し、学習可能パラメーターに対する損失と損失の勾配を返すカスタム損失関数を作成します。
dlnetwork
オブジェクトとして指定されるモデルでは、[loss,gradients] = modelLoss(net,X,T)
という形式の関数を作成します。ここで、net
はネットワーク、X
はネットワークの入力で、T
にはターゲットが格納され、loss
と gradients
にはそれぞれ損失と勾配が返されます。オプションで、損失関数が必要とする追加情報などを追加引数として勾配関数に渡すことや、更新されたネットワークの状態などを追加引数として返すことができます。
関数として指定されるモデルでは、[loss,gradients] = modelLoss(parameters,X,T)
という形式の関数を作成します。ここで、parameters
には学習可能なパラメーターが格納され、X
はモデルの入力で、T
にはターゲットが格納され、loss
と gradients
にはそれぞれ損失と勾配が返されます。オプションで、損失関数が必要とする追加情報などを追加引数として勾配関数に渡すことや、更新されたモデルの状態などを追加引数として返すことができます。
関数 modelLoss
で勾配を計算するには、関数 dlgradient
を使用します。
カスタム学習ループに対するモデル損失関数の定義の詳細については、カスタム学習ループのモデル損失関数の定義を参照してください。
カスタム損失関数を使用してイメージを生成する敵対的生成ネットワーク (GAN) に学習させる方法を示す例については、敵対的生成ネットワーク (GAN) の学習を参照してください。
自動微分を使用した学習可能なパラメーターの更新
自動微分を使用してモデル損失関数を評価するには、自動微分を有効にして関数を評価する関数 dlfeval
を使用します。dlfeval
の最初の入力では、関数ハンドルとして指定されるモデル損失関数を渡します。続く入力では、モデル損失関数に必要な変数を渡します。関数 dlfeval
の出力では、モデル損失関数と同じ出力を指定します。
学習可能なパラメーターを更新するには、次の関数を使用できます。
関数 | 説明 |
---|---|
adamupdate | 適応モーメント推定 (Adam) を使用してパラメーターを更新する |
rmspropupdate | 平方根平均二乗伝播 (RMSProp) を使用してパラメーターを更新する |
sgdmupdate | モーメンタム項付き確率的勾配降下法 (SGDM) を使用してパラメーターを更新する |
lbfgsupdate | メモリ制限 BFGS (L-BFGS) を使用してパラメーターを更新する |
dlupdate | カスタム関数を使用してパラメーターを更新する |
たとえば、SGDM を使用して学習可能なパラメーターを更新するには、カスタム学習ループの各反復で次を使用します。
[loss,gradients] = dlfeval(@modelLoss,net,X,T); [net,velocity] = sgdmupdate(net,gradients,velocity,learnRate,momentum);
参考
dlarray
| dlgradient
| dlfeval
| dlnetwork
| dljacobian
| dldivergence
| dllaplacian
関連するトピック
- 敵対的生成ネットワーク (GAN) の学習
- カスタム学習ループを使用したネットワークの学習
- カスタム学習ループでの学習オプションの指定
- カスタム学習ループのモデル損失関数の定義
- カスタム学習ループでのバッチ正規化統計量の更新
- モデル関数を使用したバッチ正規化統計量の更新
- dlnetwork オブジェクトを使用した予測の実行
- モデル関数を使用した予測の実行
- モデル関数を使用したネットワークの学習
- モデル関数の学習可能パラメーターの初期化
- MATLAB による深層学習モデルの学習
- カスタム深層学習層の定義
- dlarray をサポートする関数の一覧
- 自動微分の背景
- Deep Learning Toolbox での自動微分の使用