resubLoss
回帰木モデルの再代入損失
説明
L = resubLoss(
は、1 つ以上の tree
,Name=Value
)name-value
の引数を使用して追加のオプションを指定します。たとえば、resubLoss
が損失の計算に使用する損失関数、枝刈りレベル、木のサイズを指定できます。
例
carsmall
データ セットを読み込みます。Displacement
、Horsepower
および Weight
が応答 MPG
の予測子であると考えます。
load carsmall
X = [Displacement Horsepower Weight];
すべての観測値を使用して回帰木を成長させます。
Mdl = fitrtree(X,MPG);
再代入の MSE を計算します。
resubLoss(Mdl)
ans = 4.8952
枝刈りをしていない決定木は、過適合になる傾向があります。モデルの複雑さと標本外性能のバランスをとる方法の 1 つとして、標本内性能と標本外性能が十分高くなるように木の枝刈りを行います (つまり木の成長を制限します)。
carsmall
データ セットを読み込みます。Displacement
、Horsepower
および Weight
が応答 MPG
の予測子であると考えます。
load carsmall
X = [Displacement Horsepower Weight];
Y = MPG;
データを学習セット (50%) と検証セット (50%) に分割します。
n = size(X,1); rng(1) % For reproducibility idxTrn = false(n,1); idxTrn(randsample(n,round(0.5*n))) = true; % Training set logical indices idxVal = idxTrn == false; % Validation set logical indices
学習セットを使用して回帰木を成長させます。
Mdl = fitrtree(X(idxTrn,:),Y(idxTrn));
回帰木を表示します。
view(Mdl,Mode="graph");
この回帰木には 7 つの枝刈りレベルがあります。レベル 0 は、(表示のように) 枝刈りされていない完全な木です。レベル 7 はルート ノードのみ (分割なし) です。
最上位レベルを除く各部分木 (枝刈りレベル) について、学習標本の MSE を確認します。
m = max(Mdl.PruneList) - 1; trnLoss = resubLoss(Mdl,SubTrees=0:m)
trnLoss = 7×1
5.9789
6.2768
6.8316
7.5209
8.3951
10.7452
14.8445
枝刈りされていない完全な木の MSE は約 6 単位です。
レベル 1 まで枝刈りされた木の MSE は約 6.3 単位です。
レベル 6 (切り株) まで枝刈りされた木の MSE は約 14.8 単位です。
最上位を除く各レベルで検証標本の MSE を確認します。
valLoss = loss(Mdl,X(idxVal,:),Y(idxVal),Subtrees=0:m)
valLoss = 7×1
32.1205
31.5035
32.0541
30.8183
26.3535
30.0137
38.4695
枝刈りされていない完全な木 (レベル 0) の MSE は約 32.1 単位です。
レベル 4 まで枝刈りされた木の MSE は約 26.4 単位です。
レベル 5 まで枝刈りされた木の MSE は約 30.0 単位です。
レベル 6 (切り株) まで枝刈りされた木の MSE は約 38.5 単位です。
モデルの複雑さと標本外性能のバランスをとるには、Mdl
をレベル 4 まで枝刈りすることを検討します。
pruneMdl = prune(Mdl,Level=4);
view(pruneMdl,Mode="graph")
入力引数
回帰木モデル。fitrtree
で学習させた RegressionTree
モデル オブジェクトとして指定します。
名前と値の引数
オプションの引数のペアを Name1=Value1,...,NameN=ValueN
として指定します。ここで、Name
は引数名で、Value
は対応する値です。名前と値の引数は他の引数の後に指定しなければなりませんが、ペアの順序は重要ではありません。
R2021a より前では、名前と値をそれぞれコンマを使って区切り、Name
を引用符で囲みます。
例: L = resubloss(tree,Subtrees="all")
は、すべての部分木を枝刈りします。
損失関数。"mse"
(平均二乗誤差) または関数ハンドルとして指定します。関数ハンドル fun
を渡す場合、resubLoss
でこの関数を次のように呼び出します。
fun(Y,Yfit,W)
ここで、Y
、Yfit
、W
は、すべて同じ長さの数値ベクトルです。
Y
は、観測された応答です。Yfit
は予測された応答です。W
は観測の重みです。
fun(Y,Yfit,W)
の戻り値はスカラーでなければなりません。
例: LossFun="mse"
例: LossFun=@
Lossfun
データ型: char
| string
| function_handle
枝刈りレベル。昇順の非負の整数のベクトルまたは "all"
として指定します。
ベクトルを指定する場合、すべての要素が 0
から max(tree.PruneList)
の範囲になければなりません。0
は枝刈りしない完全な木を、max(tree.PruneList)
は完全に枝刈りした木 (つまり、ルート ノードのみ) を表します。
"all"
を指定した場合、resubLoss
はすべての部分木、つまり枝刈り順序全体に作用します。これは、0:max(tree.PruneList)
を指定することと同じです。
resubLoss
では、Subtrees
で指定された各レベルまで tree
の枝刈りを行ってから、対応する出力引数を推定します。Subtrees
のサイズにより、一部の出力引数のサイズが決まります。
関数で Subtrees
を呼び出すために、tree
の PruneList
プロパティと PruneAlpha
プロパティは空以外でなければなりません。言い換えると、fitrtree
を使用するときに Prune="on"
を設定して tree
を成長させるか、prune
を使用して tree
を枝刈りすることで成長させます。
例: Subtrees="all"
データ型: single
| double
| char
| string
木のサイズ。次の値のいずれかとして指定します。
"se"
—resubLoss
は、損失が最小値 (L
+se
、ここでL
とse
はSubtrees
における最小値) の 1 標準偏差以内である最も高い枝刈りレベルを最適な枝刈りレベル (BestLevel
) として返します。"min"
—resubLoss
は、損失が最も小さいSubtrees
の要素を最適な枝刈りレベルとして返します。通常、この要素はSubtrees
の最小要素です。
例: TreeSize="min"
データ型: char
| string
出力引数
詳細
組み込み損失関数は、平均二乗誤差を意味する "mse"
です。
ユーザー独自の損失関数を作成するには、次の形式の関数ファイルを作成します。
function loss = lossfun(Y,Yfit,W)
N
は、tree
.X
の行数です。Y
は、観測された応答を表す、N
要素のベクトルです。Yfit
は、予測された応答を表す、N
要素のベクトルです。W
は、観測の重みを表す、N
要素のベクトルです。出力
loss
はスカラーでなければなりません。
関数ハンドル @
を名前と値の引数 lossfun
LossFun
の値として渡します。
拡張機能
この関数は、GPU 配列を完全にサポートします。詳細は、GPU での MATLAB 関数の実行 (Parallel Computing Toolbox)を参照してください。
バージョン履歴
R2011a で導入
参考
MATLAB Command
You clicked a link that corresponds to this MATLAB command:
Run the command by entering it in the MATLAB Command Window. Web browsers do not support MATLAB commands.
Web サイトの選択
Web サイトを選択すると、翻訳されたコンテンツにアクセスし、地域のイベントやサービスを確認できます。現在の位置情報に基づき、次のサイトの選択を推奨します:
また、以下のリストから Web サイトを選択することもできます。
最適なサイトパフォーマンスの取得方法
中国のサイト (中国語または英語) を選択することで、最適なサイトパフォーマンスが得られます。その他の国の MathWorks のサイトは、お客様の地域からのアクセスが最適化されていません。
南北アメリカ
- América Latina (Español)
- Canada (English)
- United States (English)
ヨーロッパ
- Belgium (English)
- Denmark (English)
- Deutschland (Deutsch)
- España (Español)
- Finland (English)
- France (Français)
- Ireland (English)
- Italia (Italiano)
- Luxembourg (English)
- Netherlands (English)
- Norway (English)
- Österreich (Deutsch)
- Portugal (English)
- Sweden (English)
- Switzerland
- United Kingdom (English)