過学習

過学習とは?

過学習とは、機械学習の振る舞いの 1 つで、モデルが学習データにあまりに適合しすぎて、新しいデータにどのように対応したらよいかわからなくなることです。過学習は、以下の理由で発生する可能性があります。

  • 機械学習モデルが複雑すぎる場合。学習データの非常に微妙なパターンを記憶してしまい、うまく一般化できなくなった。
  • モデルの複雑度に対して学習データサイズが小さすぎる、または関係がない情報が大量に含まれている。

モデルの複雑度を管理し、学習データセットを改善することで、過学習を回避できます。

過学習と未学習

未学習は、過学習の逆の概念で、モデルが学習データとうまく適合しないか、新しいデータに対してうまく一般化できないことを指します。過学習と未学習は、分類モデルと回帰モデルの両方に存在する可能性があります。次の図では、分類決定境界と回帰の線が、過学習モデルでは学習データに近づきすぎ、未学習モデルでは十分に近づいていない様子が示されています。

分類モデルと回帰モデルについて、過学習、正しい学習、および未学習を示すデータのプロット。

過学習の分類モデルと回帰モデルは、正しく学習したモデルと比較して、学習データを過剰に記憶してしまいます。

学習データに対する機械学習モデルの計算上の誤差だけを見ると、過学習は未学習より検出するのが困難です。そのため、過学習を回避するには、機械学習モデルを使用する前にテストデータで検証することが重要です。

誤差

過学習

正しい学習

未学習

学習

テスト

学習データに対する過学習モデルの計算上の誤差は小さいものの、テストデータに対する誤差は大きくなっています。

MATLAB®Statistics and Machine Learning Toolbox™ および Deep Learning Toolbox™ を使用すると、機械学習モデルおよびディープラーニング モデルの過学習を回避できます。MATLAB には、モデルの過学習を回避するために特別に設計された関数と手法が用意されています。これらのツールをモデルの学習や調整を行う際に使用することで、過学習からモデルを保護することができます。

モデルの複雑度を軽減して過学習を回避する方法

MATLAB を使用すると、機械学習モデルやディープラーニング モデル (CNN など) をゼロから学習したり、事前学習済みのディープラーニング モデルを利用したりすることができます。過学習を回避するには、モデル検証を行ってデータに適した複雑度のモデルを選択するか、正則化を用いてモデルの複雑度を軽減させます。

モデル検証

過学習されたモデルの誤差は、学習データに対して計算された場合、低くなります。新しいデータを導入する前に、別のデータセット (検証用データセット) でモデルを検証することが推奨されます。MATLAB の機械学習モデルの場合、関数 cvpartition を使用して、データセットを学習セットと検証セットにランダムに分割できます。ディープラーニング モデルの場合、学習中に検証精度を監視することができます。モデルの選択とハイパーパラメーターの調整により、モデルの検証精度をより適切に測定することで、モデルが新しいデータに対応する際の精度向上が期待されます。

交差検証は、機械学習アルゴリズムが学習していないデータセットに対して予測を行う場合の性能評価に用いられるモデル評価手法です。交差検証は、過学習を引き起こすほど複雑ではないアルゴリズムを選択するのに役立ちます。関数 crossval を使用して、k 分割 (データをほぼ同じ大きさのランダムに選んだ k 個の部分集合に分割する) やホールドアウト (データを指定した比率のちょうど 2 個の部分集合にランダムに分割する) などの一般的な交差検証手法を用いて、機械学習モデルの交差検証誤差の推定値を計算します。

正則化

正則化は、機械学習モデルにおける統計的な過学習を回避するために用いられる手法です。正則化アルゴリズムは通常、複雑さまたは粗さのいずれかに対してペナルティを適用することで機能します。正則化アルゴリズムでは、モデルに追加情報を導入し、モデルをより簡素で正確なものとすることにより、多重共線性や冗長な予測子を処理することができます。

機械学習では、一般的な正則化手法である LASSO (L1 ノルム)、Ridge (L2 ノルム)、Elastic Net の 3 種類と、数種類の線形機械学習モデルから選択できます。ディープラーニングの場合は、指定された学習オプションの L2 正則化係数を大きくするか、ネットワークにドロップアウト層を使用して過学習を回避することができます。

学習データセットを強化して過学習を回避する方法

交差検証と正則化ではモデルの複雑度を管理することで、過学習を回避します。もう 1 つの手法は、データセットの改良です。ディープラーニング モデルは特に、過学習を回避するために大量のデータが必要です。

データ拡張

データの可用性が限られている場合に、ランダム化した既存データをデータセットに追加することで、学習データセットのデータ点を人為的に拡張する手法をデータ拡張といいます。MATLAB を使用することで、画像音声などのデータを拡張できます。たとえば、既存の画像の縮尺や回転をランダム化することで、画像データを拡張します。

データ生成

合成データの生成も、データセットを拡張する手法の 1 つです。MATLAB では、敵対的生成ネットワーク (GAN)デジタルツイン (シミュレーションによるデータ生成) を用いて、合成データを生成できます。

データクリーンアップ

データのノイズは過学習の原因になります。望ましくないデータ点を削減するための一般的な手法には、関数 rmoutliers を用いてデータから外れ値を除去する方法があります。