Main Content

最新のリリースでは、このページがまだ翻訳されていません。 このページの最新版は英語でご覧になれます。

カスタムの学習ループ、損失関数、およびネットワークの定義

ほとんどの深層学習タスクでは、事前学習済みのネットワークを使用して独自のデータに適応させることができます。転移学習を使用して、畳み込みニューラル ネットワークの再学習を行い、新しい一連のイメージを分類する方法を示す例については、新しいイメージを分類するための深層学習ネットワークの学習を参照してください。または、layerGraph オブジェクトを関数 trainNetwork および関数 trainingOptions と共に使用して、ネットワークを作成してゼロから学習させることができます。

タスクに必要な学習オプションが関数 trainingOptions に用意されていない場合、自動微分を使用してカスタム学習ループを作成できます。詳細については、カスタム学習ループの定義を参照してください。

タスクに必要な層 (損失関数を指定する出力層を含む) が Deep Learning Toolbox™ に用意されていない場合、カスタム層を作成できます。詳細については、カスタム深層学習層の定義を参照してください。出力層を使用して指定できない損失関数の場合、カスタム学習ループで損失を指定できます。詳細については、損失関数の指定を参照してください。層グラフを使用して作成できないネットワークの場合、カスタム ネットワークを関数として定義できます。詳細については、カスタム ネットワークの定義を参照してください。

カスタムの学習ループ、損失関数、およびネットワークでは、自動微分を使用してモデル勾配を自動的に計算します。詳細については、自動微分の背景を参照してください。

カスタム学習ループの定義

ほとんどのタスクでは、関数 trainingOptions および trainNetwork を使用して学習アルゴリズムの詳細を制御できます。タスク (たとえば、カスタム学習率スケジュール) に必要なオプションが関数 trainingOptions に用意されていない場合、自動微分を使用して独自のカスタム学習ループを定義できます。

カスタム学習率スケジュールでネットワークに学習させる方法を示す例については、カスタム学習ループを使用したネットワークの学習を参照してください。

自動微分を使用した学習可能なパラメーターの更新

学習可能なパラメーターを更新するには、まず学習可能なパラメーターについての損失の勾配を計算しなければなりません。

gradients = modelGradients(dlnet,dlX,dlT) という形式の関数を作成します。ここで、dlnet はネットワーク、dlX は入力予測子、dlT はターゲット、gradients は返される勾配です。オプションで、(損失関数に追加情報が必要な場合などに) 勾配関数に追加引数を渡したり、(学習の進行状況をプロットするためのメトリクスなどの) 追加引数を返すことができます。モデルが関数として定義されている場合、ネットワークを入力引数として渡す必要はありません。

自動微分を使用するには、dlgradient を呼び出して関数の勾配を計算し、dlfeval を呼び出して計算グラフを設定または更新します。これらの関数は dlarray を使用してデータ構造を管理し、評価のトレースを有効にします。

ネットワークの重みを更新するには、次の関数を使用できます。

関数説明
adamupdate 適応モーメント推定 (Adam) を使用してパラメーターを更新する
rmspropupdate 平方根平均二乗伝播 (RMSProp) を使用してパラメーターを更新する
sgdmupdate モーメンタム項付き確率的勾配降下法 (SGDM) を使用してパラメーターを更新する
dlupdate カスタム関数を使用してパラメーターを更新する

モデル勾配関数を作成し、イメージを生成する敵対的生成ネットワーク (GAN) に学習させる方法を示す例については、敵対的生成ネットワーク (GAN) の学習を参照してください。

損失関数の指定

dlnetwork オブジェクトを使用する場合、出力層を使用する代わりに、モデル勾配関数で損失を手動で計算しなければなりません。次の関数を使用して損失を計算できます。

関数説明
softmax ソフトマックス活性化演算は、入力データのチャネルの次元にソフトマックス関数を適用します。
sigmoid シグモイド活性化演算は、入力データにシグモイド関数を適用します。
crossentropy 交差エントロピー演算は、単一ラベルおよび複数ラベルの分類タスクにおけるネットワーク予測とターゲット値の間の交差エントロピー損失を計算します。
mse 半平均二乗誤差演算は、回帰タスクのネットワーク予測とターゲット値の間の半平均二乗誤差損失を計算します。

または、loss = myLoss(Y,T) という形式の関数を作成して、カスタム損失関数を使用できます。ここで、Y はネットワーク予測、T はターゲット、loss は返される損失です。

ネットワークの重みを更新するための勾配を計算する際には、損失値を使用します。

モデル勾配関数を作成し、カスタム損失関数を使用してイメージを生成する敵対的生成ネットワーク (GAN) に学習させる方法を示す例については、敵対的生成ネットワーク (GAN) の学習を参照してください。

カスタム ネットワークの定義

ほとんどのタスクでは、事前学習済みのネットワークを使用するか、または独自のネットワークを層グラフとして定義できます。事前学習済みのネットワークの詳細は、事前学習済みの深層ニューラル ネットワークを参照してください。dlnetwork オブジェクトによってサポートされている層の一覧については、Supported Layersを参照してください。

層グラフを使用して作成できないアーキテクチャの場合、[dlY1,...,dlYM] = model(dlX1,...,dlXN,parameters) という形式の関数としてカスタム モデルを定義できます。ここで、dlX1,...,dlXNN モデル入力の入力データに対応し、parameters にはネットワーク パラメーターが含まれ、dlY1,...,dlYMM モデル出力に対応します。カスタム ネットワークに学習させるには、カスタム学習ループを使用します。

カスタム ネットワークを関数として定義するには、モデル関数が自動微分をサポートしていなければなりません。使用できる深層学習演算は次のとおりです。ここには一部の関数のみを示します。dlarray 入力をサポートしている関数の詳細な一覧については、dlarray をサポートする関数の一覧を参照してください。

関数説明
avgpool平均プーリング演算は、入力をプーリング領域に分割し、各領域の平均値を計算することによって、ダウンサンプリングを実行します。
batchnormバッチ正規化演算は、ミニバッチ全体で各入力チャネルを正規化します。畳み込みニューラル ネットワークの学習速度を上げ、ネットワークの初期化に対する感度を下げるには、relu など、畳み込み演算と非線形演算の間のバッチ正規化を使用します。
crossentropy交差エントロピー演算は、単一ラベルおよび複数ラベルの分類タスクにおけるネットワーク予測とターゲット値の間の交差エントロピー損失を計算します。
crosschannelnormクロスチャネル正規化演算は、異なるチャネルの局所応答を使用して各活性化を正規化します。通常、クロスチャネル正規化は relu 演算に続きます。クロスチャネル正規化は局所応答正規化とも呼ばれます。
dlconv畳み込み演算は、入力データにスライディング フィルターを適用します。1 次元および 2 次元フィルターはグループ化されていない畳み込みまたはグループ化された畳み込みに使用し、3 次元フィルターはグループ化されていない畳み込みに使用します。
dltranspconv転置畳み込み演算は、特徴マップをアップサンプリングします。
fullyconnect全結合演算は、入力に重み行列を乗算してから、バイアス ベクトルを加算します。
gruゲート付き回帰型ユニット (GRU) 演算では、時系列データとシーケンス データのタイム ステップの間の依存関係をネットワークに学習させることができます。
leakyrelu漏洩正規化線形ユニット (ReLU) 活性化演算は、非線形のしきい値処理を実行し、入力値がゼロよりも小さい場合は固定スケール係数で乗算します。
lstm長短期記憶 (LSTM) 演算では、時系列データおよびシーケンス データのタイム ステップ間の長期的な依存関係をネットワークに学習させることができます。
maxpool最大プーリング演算は、入力をプーリング領域に分割し、各領域の最大値を計算することによって、ダウンサンプリングを実行します。
maxunpool最大逆プーリング演算は、ゼロでアップサンプリングとパディングを行うことによって、最大プーリング演算の出力を逆プーリングします。
mse半平均二乗誤差演算は、回帰タスクのネットワーク予測とターゲット値の間の半平均二乗誤差損失を計算します。
relu正規化線形ユニット (ReLU) 活性化演算は、非線形のしきい値処理を実行し、入力値がゼロよりも小さい場合はゼロに設定します。
sigmoidシグモイド活性化演算は、入力データにシグモイド関数を適用します。
softmaxソフトマックス活性化演算は、入力データのチャネルの次元にソフトマックス関数を適用します。

参考

| | |

関連するトピック