Main Content

このページの翻訳は最新ではありません。ここをクリックして、英語の最新版を参照してください。

カスタム学習ループ、損失関数、およびネットワークの定義

ほとんどの深層学習タスクでは、事前学習済みのネットワークを使用して独自のデータに適応させることができます。転移学習を使用して、畳み込みニューラル ネットワークの再学習を行い、新しい一連のイメージを分類する方法を示す例については、新しいイメージを分類するための深層学習ネットワークの学習を参照してください。または、layerGraph オブジェクトを関数 trainNetwork および関数 trainingOptions と共に使用して、ネットワークを作成してゼロから学習させることができます。

タスクに必要な学習オプションが関数 trainingOptions に用意されていない場合、自動微分を使用してカスタム学習ループを作成できます。詳細については、カスタム学習ループ向けの深層学習ネットワークの定義を参照してください。

タスクに必要な層 (損失関数を指定する出力層を含む) が Deep Learning Toolbox™ に用意されていない場合、カスタム層を作成できます。詳細については、カスタム深層学習層の定義を参照してください。出力層を使用して指定できない損失関数の場合、カスタム学習ループで損失を指定できます。詳細については、損失関数の指定を参照してください。層グラフを使用して作成できないネットワークの場合、カスタム ネットワークを関数として定義できます。詳細については、モデル関数としてのネットワークの定義を参照してください。

どのタスクでどの学習手法を使用するかについての詳細は、MATLAB による深層学習モデルの学習を参照してください。

カスタム学習ループ向けの深層学習ネットワークの定義

dlnetwork オブジェクトとしてのネットワークの定義

ほとんどのタスクでは、関数 trainingOptions および trainNetwork を使用して学習アルゴリズムの詳細を制御できます。タスク (たとえば、カスタム学習率スケジュール) に必要なオプションが関数 trainingOptions に用意されていない場合、dlnetwork オブジェクトを使用して自分でカスタム学習ループを定義できます。dlnetwork オブジェクトを使用すると、自動微分を使用して、層グラフとして指定したネットワークに学習させることができます。

層グラフとして指定するネットワークの場合、関数 dlnetwork を直接使用して、層グラフから dlnetwork オブジェクトを作成できます。dlnetwork オブジェクトによってサポートされている層の一覧については、dlnetwork ページのSupported Layersの節を参照してください。

dlnet = dlnetwork(lgraph);

カスタム学習率スケジュールでネットワークに学習させる方法を示す例については、カスタム学習ループを使用したネットワークの学習を参照してください。

モデル関数としてのネットワークの定義

層グラフを使用して作成できないアーキテクチャ (重みの共有が必要なシャム ネットワークなど) の場合、[dlY1,...,dlYM] = model(parameters,dlX1,...,dlXN) という形式の関数としてモデルを定義できます。ここで、parameters にはネットワーク パラメーターが含まれ、dlX1,...,dlXNN モデル入力の入力データに対応し、dlY1,...,dlYMM モデル出力に対応します。関数として定義される深層学習モデルに学習させるには、カスタム学習ループを使用します。例については、モデル関数を使用したネットワークの学習を参照してください。

関数として深層学習モデルを定義する場合は、層の重みを手動で初期化しなければなりません。詳細については、モデル関数の学習可能なパラメーターの初期化を参照してください。

カスタム ネットワークを関数として定義するには、モデル関数が自動微分をサポートしていなければなりません。使用できる深層学習演算は次のとおりです。ここには一部の関数のみを示します。dlarray 入力をサポートしている関数の詳細な一覧については、dlarray をサポートする関数の一覧を参照してください。

関数説明
avgpool平均プーリング演算は、入力をプーリング領域に分割し、各領域の平均値を計算することによって、ダウンサンプリングを実行します。
batchnormバッチ正規化演算は、ミニバッチ全体で各入力チャネルを正規化します。畳み込みニューラル ネットワークの学習速度を上げ、ネットワークの初期化に対する感度を下げるには、relu など、畳み込み演算と非線形演算の間のバッチ正規化を使用します。
crossentropy交差エントロピー演算は、単一ラベルおよび複数ラベルの分類タスクにおけるネットワーク予測とターゲット値の間の交差エントロピー損失を計算します。
crosschannelnormクロスチャネル正規化演算は、異なるチャネルの局所応答を使用して各活性化を正規化します。通常、クロスチャネル正規化は relu 演算に続きます。クロスチャネル正規化は局所応答正規化とも呼ばれます。
dlconv畳み込み演算は、入力データにスライディング フィルターを適用します。1 次元および 2 次元フィルターはグループ化されていない畳み込みまたはグループ化された畳み込みに使用し、3 次元フィルターはグループ化されていない畳み込みに使用します。
dltranspconv転置畳み込み演算は、特徴マップをアップサンプリングします。
embed組み込み演算は、数値インデックスを数値ベクトルに変換します。ここで、インデックスは離散データに対応します。埋め込みを使用して、categorical 値や単語などの離散データを数値ベクトルにマッピングします。
fullyconnect全結合演算は、入力に重み行列を乗算してから、バイアス ベクトルを加算します。
groupnormグループ正規化演算は、入力データのチャネルをグループに分割し、各グループのアクティベーションを正規化します。畳み込みニューラル ネットワークの学習速度を上げ、ネットワークの初期化に対する感度を下げるには、relu など、畳み込み演算と非線形演算の間のグループ正規化を使用します。グループの数を適切に設定することで、インスタンスの正規化や層の正規化を実行できます。
gruゲート付き回帰型ユニット (GRU) 演算では、時系列データとシーケンス データのタイム ステップの間の依存関係をネットワークに学習させることができます。
leakyrelu漏洩 (leaky) 正規化線形ユニット (ReLU) 活性化演算は、非線形のしきい値処理を実行し、入力値がゼロよりも小さい場合は固定スケール係数で乗算します。
lstm長短期記憶 (LSTM) 演算では、時系列データおよびシーケンス データのタイム ステップ間の長期的な依存関係をネットワークに学習させることができます。
maxpool最大プーリング演算は、入力をプーリング領域に分割し、各領域の最大値を計算することによって、ダウンサンプリングを実行します。
maxunpool最大逆プーリング演算は、ゼロでアップサンプリングとパディングを行うことによって、最大プーリング演算の出力を逆プーリングします。
mse半平均二乗誤差演算は、回帰タスクのネットワーク予測とターゲット値の間の半平均二乗誤差損失を計算します。
relu正規化線形ユニット (ReLU) 活性化演算は、非線形のしきい値処理を実行し、入力値がゼロよりも小さい場合はゼロに設定します。
onehotdecode

one-hot 復号化演算は、分類ネットワークの出力などの確率ベクトルを分類ラベルに復号化します。

入力 Adlarray にすることができます。A が書式化されている場合、関数はデータ形式を無視します。

sigmoidシグモイド活性化演算は、入力データにシグモイド関数を適用します。
softmaxソフトマックス活性化演算は、入力データのチャネルの次元にソフトマックス関数を適用します。

損失関数の指定

カスタム学習ループを使用する場合、モデル勾配関数で損失を計算しなければなりません。ネットワークの重みを更新するための勾配を計算する際には、損失値を使用します。損失を計算するには、次の関数を使用できます。

関数説明
softmaxソフトマックス活性化演算は、入力データのチャネルの次元にソフトマックス関数を適用します。
sigmoidシグモイド活性化演算は、入力データにシグモイド関数を適用します。
crossentropy交差エントロピー演算は、単一ラベルおよび複数ラベルの分類タスクにおけるネットワーク予測とターゲット値の間の交差エントロピー損失を計算します。
mse半平均二乗誤差演算は、回帰タスクのネットワーク予測とターゲット値の間の半平均二乗誤差損失を計算します。

または、loss = myLoss(Y,T) という形式の関数を作成して、カスタム損失関数を使用できます。ここで、Y はネットワーク予測、T はターゲット、loss は返される損失です。

カスタム損失関数を使用してイメージを生成する敵対的生成ネットワーク (GAN) に学習させる方法を示す例については、敵対的生成ネットワーク (GAN) の学習を参照してください。

自動微分を使用した学習可能なパラメーターの更新

カスタム学習ループを使用して深層学習モデルに学習させる場合、ソフトウェアは、学習可能なパラメーターについての損失を最小化します。損失を最小化するために、ソフトウェアは、学習可能なパラメーターについての損失の勾配を使用します。自動微分を使用してこれらの勾配を計算するには、モデル勾配関数を定義しなければなりません。

モデル勾配関数の定義

dlnetwork オブジェクトとして指定されるモデルでは、gradients = modelGradients(dlnet,dlX,T) という形式の関数を作成します。ここで、dlnet はネットワークであり、dlX には入力予測子、T にはターゲット、gradients には返される勾配が含まれます。オプションで、(損失関数に追加情報が必要な場合などに) 勾配関数に追加引数を渡したり、(学習の進行状況をプロットするためのメトリクスなどの) 追加引数を返すことができます。

関数として指定されるモデルでは、gradients = modelGradients(parameters,dlX,T) という形式の関数を作成します。ここで、parameters には学習可能なパラメーター、dlX には入力予測子、T にはターゲット、gradients には返される勾配が含まれます。オプションで、(損失関数に追加情報が必要な場合などに) 勾配関数に追加引数を渡したり、(学習の進行状況をプロットするためのメトリクスなどの) 追加引数を返すことができます。

カスタム学習ループに対するモデル勾配関数の定義の詳細については、カスタム学習ループのモデル勾配関数の定義を参照してください。

学習可能なパラメーターの更新

自動微分を使用してモデル勾配を評価するには、自動微分を有効にして関数を評価する関数 dlfeval を使用します。dlfeval の最初の入力では、関数ハンドルとして指定されるモデル勾配関数を渡し、次の入力では、モデル勾配関数に必要な変数を渡します。関数 dlfeval の出力では、モデル勾配関数と同じ出力を指定します。

勾配を使用して学習可能なパラメーターを更新するには、次の関数を使用できます。

関数説明
adamupdate適応モーメント推定 (Adam) を使用してパラメーターを更新する
rmspropupdate平方根平均二乗伝播 (RMSProp) を使用してパラメーターを更新する
sgdmupdateモーメンタム項付き確率的勾配降下法 (SGDM) を使用してパラメーターを更新する
dlupdateカスタム関数を使用してパラメーターを更新する

参考

| | |

関連するトピック