Main Content

このページの翻訳は最新ではありません。ここをクリックして、英語の最新版を参照してください。

loss

回帰ニューラル ネットワークの損失

    説明

    L = loss(Mdl,Tbl,ResponseVarName) は、table Tbl 内の予測子データと table 変数 ResponseVarName 内の応答値を使用して、学習させた回帰ニューラル ネットワーク Mdl の回帰損失を返します。

    L は、既定で平均二乗誤差 (MSE) を表すスカラー値として返されます。

    L = loss(Mdl,Tbl,Y) は、table Tbl 内の予測子データとベクトル Y 内の応答値を使用して、モデル Mdl の回帰損失を返します。

    L = loss(Mdl,X,Y) は、予測子データ XY 内の対応する応答値を使用して、学習させた回帰ニューラル ネットワーク Mdl の回帰損失を返します。

    L = loss(___,Name,Value) では、前の構文におけるいずれかの入力引数の組み合わせに加えて、1 つ以上の名前と値の引数を使用してオプションを指定します。たとえば、予測子データの列が観測値に対応するように指定したり、損失関数を指定したり、観測値の重みを与えることができます。

    すべて折りたたむ

    回帰ニューラル ネットワーク モデルの検定セットの平均二乗誤差 (MSE) を計算します。

    patients データセットを読み込みます。データ セットから table を作成します。各行が 1 人の患者に対応し、各列が診断の変数に対応します。変数 Systolic を応答変数として使用し、残りの変数を予測子として使用します。

    load patients
    tbl = table(Age,Diastolic,Gender,Height,Smoker,Weight,Systolic);

    非層化ホールドアウト分割を使用して、データを学習セット tblTrain と検定セット tblTest に分割します。観測値の約 30% が検定データ用に予約され、残りの観測値が学習データ セットに使用されます。

    rng("default") % For reproducibility of the partition
    c = cvpartition(size(tbl,1),"Holdout",0.30);
    trainingIndices = training(c);
    testIndices = test(c);
    tblTrain = tbl(trainingIndices,:);
    tblTest = tbl(testIndices,:);

    学習セットを使用して回帰ニューラル ネットワーク モデルに学習させます。tblTrain の列 Systolic を応答変数として指定します。数値予測子を標準化するための指定を行います。

    Mdl = fitrnet(tblTrain,"Systolic", ...
        "Standardize",true);

    検定セットの MSE を計算します。MSE の値が小さいほど、パフォーマンスが優れていることを示します。

    testMSE = loss(Mdl,tblTest,"Systolic")
    testMSE = 49.9595
    

    検定セットの損失と予測を比較することにより、特徴選択を実行します。すべての予測子を使用して学習させた回帰ニューラル ネットワーク モデルの検定セット メトリクスを予測子のサブセットのみを使用して学習させたモデルの検定セット メトリクスと比較します。

    標本ファイル fisheriris.csv を読み込みます。これには、アヤメについてのがく片の長さ、がく片の幅、花弁の長さ、花弁の幅、種の種類などのデータが格納されています。ファイルを table に読み込みます。

    fishertable = readtable('fisheriris.csv');

    非層化ホールドアウト分割を使用して、データを学習セット trainTbl と検定セット testTbl に分割します。観測値の約 30% が検定データ用に予約され、残りの観測値が学習データ セットに使用されます。

    rng("default")
    c = cvpartition(size(fishertable,1),"Holdout",0.3);
    trainTbl = fishertable(training(c),:);
    testTbl = fishertable(test(c),:);

    学習セット内のすべての予測子を使用して 1 つの回帰ニューラル ネットワーク モデルに学習させ、PetalWidth を除くすべての予測子を使用してもう 1 つのモデルに学習させます。両方のモデルについて、PetalLength を応答変数として指定し、予測子を標準化します。

    allMdl = fitrnet(trainTbl,"PetalLength","Standardize",true);
    subsetMdl = fitrnet(trainTbl,"PetalLength ~ SepalLength + SepalWidth + Species", ...
        "Standardize",true);

    2 つのモデルの検定セットの平均二乗誤差 (MSE) を比較します。MSE の値が小さいほど、パフォーマンスが優れていることを示します。

    allMSE = loss(allMdl,testTbl)
    allMSE = 0.0834
    
    subsetMSE = loss(subsetMdl,testTbl)
    subsetMSE = 0.0887
    

    各モデルについて、検定セットの予測される花弁の長さと実際の花弁の長さを比較します。予測される花弁の長さを縦軸に、実際の花弁の長さを横軸に沿ってプロットします。基準線上にある点は予測が正しいことを示します。

    tiledlayout(2,1)
    
    % Top axes
    ax1 = nexttile;
    allPredictedY = predict(allMdl,testTbl);
    plot(ax1,testTbl.PetalLength,allPredictedY,".")
    hold on
    plot(ax1,testTbl.PetalLength,testTbl.PetalLength)
    hold off
    xlabel(ax1,"True Petal Length")
    ylabel(ax1,"Predicted Petal Length")
    title(ax1,"All Predictors")
    
    % Bottom axes
    ax2 = nexttile;
    subsetPredictedY = predict(subsetMdl,testTbl);
    plot(ax2,testTbl.PetalLength,subsetPredictedY,".")
    hold on
    plot(ax2,testTbl.PetalLength,testTbl.PetalLength)
    hold off
    xlabel(ax2,"True Petal Length")
    ylabel(ax2,"Predicted Petal Length")
    title(ax2,"Subset of Predictors")

    予測が基準線の近くに分布しており、両方のモデルが適切に機能しているようなので、PetalWidth を除くすべての予測子を使用して学習させたモデルを使用することを検討します。

    入力引数

    すべて折りたたむ

    学習させた回帰ニューラル ネットワーク。fitrnet によって返される RegressionNeuralNetwork モデル オブジェクト、または compact によって返される CompactRegressionNeuralNetwork モデル オブジェクトとして指定します。

    標本データ。テーブルとして指定します。Tbl の各行は 1 つの観測値に、各列は 1 つの予測子変数に対応します。オプションとして、応答変数用の追加列を Tbl に含めることができます。Tbl には、Mdl を学習させるために使用したすべての予測子が含まれていなければなりません。文字ベクトルの cell 配列ではない cell 配列と複数列の変数は使用できません。

    • Mdl を学習させるために使用した応答変数が Tbl に含まれている場合、ResponseVarName または Y を指定する必要はありません。

    • table に格納されている標本データを使用して Mdl に学習をさせた場合、loss の入力データも table に含まれていなければなりません。

    • Mdl に学習させるときに fitrnet'Standardize',true を設定した場合、予測子データの数値列が対応する平均および標準偏差を使用して標準化されます。

    データ型: table

    応答変数の名前。Tbl 内の変数の名前で指定します。応答変数は、数値ベクトルでなければなりません。

    ResponseVarName を指定する場合は、文字ベクトルまたは string スカラーとして指定しなければなりません。たとえば、応答変数が Tbl.Y として格納されている場合、ResponseVarName として 'Y' を指定します。それ以外の場合、Tbl の列は Tbl.Y を含めてすべて予測子として扱われます。

    データ型: char | string

    応答データ。数値ベクトルとして指定します。Y の長さは X または Tbl の観測値の数と等しくなければなりません。

    データ型: single | double

    予測子データ。数値行列として指定します。既定では、loss は、X の各行が 1 つの観測値に、各列が 1 つの予測子変数に対応すると見なします。

    メモ

    観測値が列に対応するように予測子行列を配置して 'ObservationsIn','columns' を指定すると、計算時間が大幅に短縮される可能性があります。

    Y の長さと X の観測値数は同じでなければなりません。

    Mdl に学習させるときに fitrnet'Standardize',true を設定した場合、予測子データの数値列が対応する平均および標準偏差を使用して標準化されます。

    データ型: single | double

    名前と値のペアの引数

    オプションの Name,Value 引数のコンマ区切りペアを指定します。Name は引数名で、Value は対応する値です。Name は引用符で囲まなければなりません。Name1,Value1,...,NameN,ValueN のように、複数の名前と値のペアの引数を、任意の順番で指定できます。

    例: loss(Mdl,Tbl,"Response","Weights","W") は、table Tbl 内の変数 Response および W をそれぞれ応答値および観測値の重みとして使用するように指定します。

    損失関数。'mse' または関数ハンドルとして指定します。

    • 'mse' — 重み付けされた平均二乗誤差。

    • 関数ハンドル — カスタム損失関数を指定するには、関数ハンドルを使用します。関数は次の形式でなければなりません。

      lossval = lossfun(Y,YFit,W)

      • 出力引数 lossval は浮動小数点スカラーです。

      • 関数名 (lossfun) を指定します。

      • Y は、観測応答の長さ n の数値ベクトルです。ここで、n は Tbl または X に含まれている観測値の個数です。

      • YFit は、対応する予測応答の長さ n の数値ベクトルです。

      • W は、観測値の重みの n 行 1 列の数値ベクトルです。

    例: 'LossFun',@lossfun

    データ型: char | string | function_handle

    予測子データにおける観測値の次元。'rows' または 'columns' として指定します。

    メモ

    観測値が列に対応するように予測子行列を配置して 'ObservationsIn','columns' を指定すると、計算時間が大幅に短縮される可能性があります。table の予測子データに対して 'ObservationsIn','columns' を指定することはできません。

    データ型: char | string

    観測値の重み。非負の数値ベクトルまたは Tbl 内の変数の名前を指定します。ソフトウェアは、X または Tbl の各観測値に、Weights の対応する値で重みを付けます。Weights の長さは、X または Tbl の観測値の数と等しくなければなりません。

    入力データをテーブル Tbl として指定した場合、Weights は数値ベクトルが含まれている Tbl 内の変数の名前にすることができます。この場合、Weights には文字ベクトルまたは string スカラーを指定しなければなりません。たとえば、重みベクトル WTbl.W として格納されている場合、'W' として指定します。

    既定の設定では、Weightsones(n,1) です。nX または Tbl の観測値数です。

    重みを指定した場合、loss は加重回帰損失を計算し、合計が 1 になるように重みを正規化します。

    データ型: single | double | char | string

    R2021a で導入