Main Content

loss

説明

L = loss(tree,Tbl,ResponseVarName) は、真の応答 Tbl.ResponseVarName に対する、tree の予測間の平均二乗誤差を Tbl 内のデータに返します。

L = loss(tree,Tbl,Y) は、真の応答 Y に対する、tree の予測間の平均二乗誤差を Tbl 内のデータに返します。

L = loss(tree,X,Y) は、真の応答 Y に対する、tree の予測間の平均二乗誤差を X 内のデータに返します。

L = loss(___,Name,Value) は、前の構文のいずれかを使用し、1 つ以上の Name,Value ペア引数で指定されたオプションを追加して、予測の誤差を計算します。

[L,se,NLeaf,bestlevel] = loss(___) は、損失の標準誤差 (se)、木における葉 (終端ノード) の数 (NLeaf)、および tree の最適な枝刈りレベル (bestlevel) も返します。

入力引数

すべて展開する

学習済みの回帰木。fitrtree で構築した RegressionTree オブジェクトまたは compact で構築した CompactRegressionTree オブジェクトとして指定します。

標本データ。テーブルとして指定します。Tbl の各行は 1 つの観測値に、各列は 1 つの予測子変数に対応します。Tbl には、tree を学習させるために使用したすべての予測子が含まれていなければなりません。必要に応じて、Tbl に応答変数用および観測値の重み用の追加列を含めることができます。文字ベクトルの cell 配列ではない cell 配列と複数列の変数は使用できません。

tree を学習させるために使用した応答変数が Tbl に含まれている場合、ResponseVarName または Y を指定する必要はありません。

table に含まれている標本データを使用して tree を学習させた場合、このメソッドの入力データも table に格納されていなければなりません。

データ型: table

予測子の値。数値行列として指定します。X の各列が 1 つの変数を表し、各行が 1 つの観測値を表します。

X の列数は、tree を学習させるために使用したデータ数と同じでなければなりません。X の行数は、Y の要素数と同じでなければなりません。

データ型: single | double

応答変数の名前。Tbl 内の変数の名前で指定します。tree を学習させるために使用した応答変数が Tbl に含まれている場合、ResponseVarName を指定する必要はありません。

ResponseVarName を指定する場合は、文字ベクトルまたは string スカラーとして指定しなければなりません。たとえば、応答変数が Tbl.Response として格納されている場合、'Response' として指定します。それ以外の場合、Tbl の列は Tbl.ResponseVarName を含めてすべて予測子として扱われます。

データ型: char | string

応答データ。X と同じ行数の数値列ベクトルとして指定します。Y の各エントリは X の対応する行のデータに対する応答です。

データ型: single | double

名前と値の引数

オプションの引数のペアを Name1=Value1,...,NameN=ValueN として指定します。ここで Name は引数名、Value は対応する値です。名前と値の引数は他の引数の後ろにする必要がありますが、ペアの順序は関係ありません。

R2021a より前では、名前と値をそれぞれコンマを使って区切り、Name を引用符で囲みます。

損失関数。'LossFun' と、損失関数のハンドルまたは平均二乗誤差を表す 'mse' から構成されるコンマ区切りのペアとして指定します。関数ハンドル fun を渡す場合、lossfun を次のように呼び出します。

fun(Y,Yfit,W)
  • Y は観測された応答のベクトルです。

  • Yfit は予測された応答のベクトルです。

  • W は観測の重みです。W を渡す場合、要素を正規化して、合計が 1 になるようにします。

すべてのベクトルには、Y と同じ行数が含まれます。

例: 'LossFun','mse'

データ型: function_handle | char | string

枝刈りレベル。'Subtrees' と昇順の非負の整数のベクトルまたは 'all' から構成されるコンマ区切りのペアとして指定します。

ベクトルを指定する場合、すべての要素が 0 から max(tree.PruneList) の範囲になければなりません。0 は枝刈りしない完全な木を、max(tree.PruneList) は完全に枝刈りした木 (つまり、ルート ノードのみ) を表します。

'all' を指定した場合、loss はすべての部分木 (枝刈り順序全体) に作用します。これは、0:max(tree.PruneList) を指定することと同じです。

loss では、Subtrees で指定された各レベルまで tree の枝刈りを行ってから、対応する出力引数を推定します。Subtrees のサイズにより、一部の出力引数のサイズが決まります。

Subtrees を呼び出すために、treePruneList プロパティまたは PruneAlpha プロパティを空にすることはできません。言い換えると、'Prune','on' を設定して tree を成長させるか、prune を使用して tree の枝刈りを行います。

例: 'Subtrees','all'

データ型: single | double | char | string

木のサイズ。'TreeSize' と次のいずれかから構成されるコンマ区切りのペアとして指定します。

  • 'se'loss は、平均二乗誤差 (MSE) が、最小 MSE の 1 標準誤差内にある最小ツリーに対応する bestlevel を返します。

  • 'min'loss は、最小 MSE ツリーに対応する bestlevel を返します。

例: 'TreeSize','min'

観測値の重み。'Weights' とスカラー値のベクトルで構成されるコンマ区切りのペアとして指定します。X または Tbl の各行に含まれている観測値には、Weights の対応する値で重みが付けられます。Weights のサイズは、X または Tbl の行数と同じでなければなりません。

入力データをテーブル Tbl として指定した場合、Weights は数値ベクトルが含まれている Tbl 内の変数の名前にすることができます。この場合、Weights には変数名を指定しなければなりません。たとえば、重みのベクトル WTbl.W として格納されている場合、Weights として 'W' を指定します。それ以外の場合、モデルを学習させるときに、Tbl の列は W を含めてすべて予測子として扱われます。

データ型: single | double | char | string

出力引数

すべて展開する

分類誤差。長さが Subtrees のベクトルとして返されます。各ツリーの誤差は、Weights で重み付けした平均二乗誤差です。LossFun を含める場合、L には、LossFun を使用して計算した損失が反映されます。

損失の標準誤差。長さが Subtrees のベクトルとして返されます。

枝刈りされた部分木における葉 (終端ノード) の数。長さが Subtrees のベクトルとして返されます。

名前と値のペア TreeSize で定義した最適な枝刈りレベル。スカラー値として返されます。値は、TreeSize の設定に応じて次のようになります。

  • TreeSize = 'se'loss は、最小の 1 標準偏差内の損失 (L+se、このとき L および se は、Subtrees での最小値に相関します) をもつ、最も高い枝刈りレベルを返します。

  • TreeSize = 'min'loss は、最も損失が少ない Subtrees の要素を返します。通常、これは Subtrees の最小要素です。

すべて展開する

carsmall データセットを読み込みます。DisplacementHorsepower および Weight が応答 MPG の予測子であると考えます。

load carsmall
X = [Displacement Horsepower Weight];

すべての観測値を使用して回帰木を成長させます。

tree = fitrtree(X,MPG);

標本内 MSE を推定します。

L = loss(tree,X,MPG)
L = 4.8952

carsmall データセットを読み込みます。DisplacementHorsepower および Weight が応答 MPG の予測子であると考えます。

load carsmall
X = [Displacement Horsepower Weight];

すべての観測値を使用して回帰木を成長させます。

Mdl = fitrtree(X,MPG);

回帰木を表示します。

view(Mdl,'Mode','graph');

Figure Regression tree viewer contains an axes object and other objects of type uimenu, uicontrol. The axes object contains 60 objects of type line, text.

標本内損失が最適になる枝刈りレベルを探索します。

[L,se,NLeaf,bestLevel] = loss(Mdl,X,MPG,'Subtrees','all');
bestLevel
bestLevel = 1

最適な枝刈りレベルはレベル 1 です。

木をレベル 1 まで枝刈りします。

pruneMdl = prune(Mdl,'Level',bestLevel);
view(pruneMdl,'Mode','graph');

Figure Regression tree viewer contains an axes object and other objects of type uimenu, uicontrol. The axes object contains 57 objects of type line, text.

枝刈りをしていない決定木は、過適合になる傾向があります。モデルの複雑さと標本外性能のバランスをとる方法の 1 つとして、標本内性能と標本外性能が十分高くなるように木の枝刈りを行います (つまり木の成長を制限します)。

carsmall データセットを読み込みます。DisplacementHorsepower および Weight が応答 MPG の予測子であると考えます。

load carsmall
X = [Displacement Horsepower Weight];
Y = MPG;

データを学習セット (50%) と検証セット (50%) に分割します。

n = size(X,1);
rng(1) % For reproducibility
idxTrn = false(n,1);
idxTrn(randsample(n,round(0.5*n))) = true; % Training set logical indices 
idxVal = idxTrn == false;                  % Validation set logical indices

学習セットを使用して回帰木を成長させます。

Mdl = fitrtree(X(idxTrn,:),Y(idxTrn));

回帰木を表示します。

view(Mdl,'Mode','graph');

Figure Regression tree viewer contains an axes object and other objects of type uimenu, uicontrol. The axes object contains 27 objects of type line, text.

この回帰木には 7 つの枝刈りレベルがあります。レベル 0 は、(表示のように) 枝刈りされていない完全な木です。レベル 7 はルート ノードのみ (分割なし) です。

最上位レベルを除く各部分木 (枝刈りレベル) について、学習標本の MSE を確認します。

m = max(Mdl.PruneList) - 1;
trnLoss = resubLoss(Mdl,'SubTrees',0:m)
trnLoss = 7×1

    5.9789
    6.2768
    6.8316
    7.5209
    8.3951
   10.7452
   14.8445

  • 枝刈りされていない完全な木の MSE は約 6 単位です。

  • レベル 1 まで枝刈りされた木の MSE は約 6.3 単位です。

  • レベル 6 (切り株) まで枝刈りされた木の MSE は約 14.8 単位です。

最上位を除く各レベルで検証標本の MSE を確認します。

valLoss = loss(Mdl,X(idxVal,:),Y(idxVal),'SubTrees',0:m)
valLoss = 7×1

   32.1205
   31.5035
   32.0541
   30.8183
   26.3535
   30.0137
   38.4695

  • 枝刈りされていない完全な木 (レベル 0) の MSE は約 32.1 単位です。

  • レベル 4 まで枝刈りされた木の MSE は約 26.4 単位です。

  • レベル 5 まで枝刈りされた木の MSE は約 30.0 単位です。

  • レベル 6 (切り株) まで枝刈りされた木の MSE は約 38.5 単位です。

モデルの複雑さと標本外性能のバランスをとるには、Mdl をレベル 4 まで枝刈りすることを検討します。

pruneMdl = prune(Mdl,'Level',4);
view(pruneMdl,'Mode','graph')

Figure Regression tree viewer contains an axes object and other objects of type uimenu, uicontrol. The axes object contains 15 objects of type line, text.

詳細

すべて展開する

拡張機能