このページの翻訳は最新ではありません。ここをクリックして、英語の最新版を参照してください。
fit
説明
は、指定されたクエリ点 (newExplainer
= fit(explainer
,queryPoint
)queryPoint
) のシャープレイ値を計算し、計算したシャープレイ値を newExplainer
の ShapleyValues
プロパティに格納します。shapley
オブジェクト explainer
には、機械学習モデルとシャープレイ値の計算オプションが格納されています。
fit
は、いつ explainer
を作成するかを指定する、シャープレイ値計算オプションを使用します。このオプションは、関数 fit
の名前と値の引数を使用して変更できます。この関数は、新しく計算したシャープレイ値を含む shapley
オブジェクト newExplainer
を返します。
では、1 つ以上の名前と値の引数を使用して追加オプションを指定します。たとえば、newExplainer
= fit(explainer
,queryPoint
,Name,Value
)'UseParallel',true
と指定してシャープレイ値を並列計算します。
例
shapley
オブジェクトの作成と fit
を使用したシャープレイ値の計算
回帰モデルの学習を行い、shapley
オブジェクトを作成します。shapley
オブジェクトを作成するときに、クエリ点を指定しなかった場合、シャープレイ値は計算されません。オブジェクト関数 fit
を使用して、指定したクエリ点のシャープレイ値を計算します。次に、オブジェクト関数 plot
を使用して、シャープレイ値の棒グラフを作成します。
carbig
データセットを読み込みます。このデータセットには、1970 年代と 1980 年代初期に製造された自動車の測定値が格納されています。
load carbig
Acceleration
、Cylinders
などの予測子変数と応答変数 MPG
が格納された table を作成します。
tbl = table(Acceleration,Cylinders,Displacement,Horsepower,Model_Year,Weight,MPG);
学習セットの欠損値を削除すると、メモリ消費量を減らして関数 fitrkernel
の学習速度を向上させることができます。tbl
の欠損値を削除します。
tbl = rmmissing(tbl);
関数fitrkernel
を使用して MPG
の blackbox モデルの学習を行います。
rng('default') % For reproducibility mdl = fitrkernel(tbl,'MPG','CategoricalPredictors',[2 5]);
shapley
オブジェクトを作成します。mdl
には学習データが含まれないため、データ セット tbl
を指定します。
explainer = shapley(mdl,tbl)
explainer = shapley with properties: BlackboxModel: [1x1 RegressionKernel] QueryPoint: [] BlackboxFitted: [] ShapleyValues: [] NumSubsets: 64 X: [392x7 table] CategoricalPredictors: [2 5] Method: 'interventional-kernel' Intercept: 22.6202
explainer
は、学習データ tbl
を X
プロパティに格納します。
tbl
の最初の観測値についてすべての予測子変数のシャープレイ値を計算します。
queryPoint = tbl(1,:)
queryPoint=1×7 table
Acceleration Cylinders Displacement Horsepower Model_Year Weight MPG
____________ _________ ____________ __________ __________ ______ ___
12 8 307 130 70 3504 18
explainer = fit(explainer,queryPoint);
回帰モデルの場合、shapley
は予測応答を使用してシャープレイ値を計算し、ShapleyValues
プロパティに格納します。ShapleyValues
プロパティの値を表示します。
explainer.ShapleyValues
ans=6×2 table
Predictor ShapleyValue
______________ ____________
"Acceleration" -0.1561
"Cylinders" -0.18306
"Displacement" -0.34203
"Horsepower" -0.27291
"Model_Year" -0.2926
"Weight" -0.32402
関数 plot
を使用して、クエリ点のシャープレイ値をプロットします。
plot(explainer)
横棒グラフは、絶対値で並べ替えられた、すべての変数のシャープレイ値を示します。各シャープレイ値は、クエリ点についての予測に関して対応する変数が原因で生じた平均からの偏差を説明します。
複数のクエリ点のシャープレイ値の計算
分類モデルの学習を行い、shapley
オブジェクトを作成します。次に、複数のクエリ点のシャープレイ値を計算します。
CreditRating_Historical
データセットを読み込みます。データ セットには、顧客 ID、顧客の財務比率、業種ラベル、および信用格付けが格納されています。
tbl = readtable('CreditRating_Historical.dat');
関数fitcecoc
を使用して、信用格付けの blackbox モデルに学習させます。tbl
内の 2 ~ 7 列目の変数を予測子変数として使用します。
blackbox = fitcecoc(tbl,'Rating', ... 'PredictorNames',tbl.Properties.VariableNames(2:7), ... 'CategoricalPredictors','Industry');
blackbox
モデルを使用して、shapley
オブジェクトを作成します。計算速度を向上するには、tbl
の観測値の 25% を階層的にサブサンプリングし、その標本を使用してシャープレイ値を計算します。kernelSHAP アルゴリズムの拡張機能の使用を指定します。
rng('default') % For reproducibility c = cvpartition(tbl.Rating,'Holdout',0.25); tbl_s = tbl(test(c),:); explainer = shapley(blackbox,tbl_s,'Method','conditional-kernel');
真の信用格付け値がそれぞれ AAA
および B
となる 2 つのクエリ点を見つけます。
queryPoint(1,:) = tbl_s(find(strcmp(tbl_s.Rating,'AAA'),1),:); queryPoint(2,:) = tbl_s(find(strcmp(tbl_s.Rating,'B'),1),:)
queryPoint=2×8 table
ID WC_TA RE_TA EBIT_TA MVE_BVTD S_TA Industry Rating
_____ ______ ______ _______ ________ _____ ________ _______
58258 0.511 0.869 0.106 8.538 0.732 2 {'AAA'}
82367 -0.078 -0.042 0.011 0.262 0.167 7 {'B' }
最初のクエリ点についてのシャープレイ値を計算してプロットします。
explainer1 = fit(explainer,queryPoint(1,:)); plot(explainer1)
2 番目のクエリ点についてのシャープレイ値を計算してプロットします。
explainer2 = fit(explainer,queryPoint(2,:)); plot(explainer2)
2 番目のクエリ点に関する実際の格付けは B
ですが、予測された格付けは BB
です。プロットには、予測された格付けのシャープレイ値が表示されます。
explainer1
と explainer2
にはそれぞれ、最初のクエリ点と 2 番目のクエリ点についてのシャープレイ値が含まれています。
入力引数
explainer
— blackbox モデルを説明するオブジェクト
shapley
オブジェクト
blackbox モデルを説明するオブジェクト。shapley
オブジェクトとして指定します。
queryPoint
— クエリ点
数値の行ベクトル | 単一行テーブル
fit
が予測を説明するクエリ点。数値の行ベクトルまたは単一行 table として指定します。
数値の行ベクトルの場合:
単一行 table の場合:
予測子データ
explainer.X
が table の場合、queryPoint
内のすべての予測子変数は変数名およびデータ型がexplainer.X
内の変数と同じでなければなりません。ただし、queryPoint
の列の順序がexplainer.X
の列の順序に対応する必要はありません。予測子データ
explainer.X
が数値行列の場合、explainer.BlackboxModel.PredictorNames
内の予測子名とqueryPoint
内の対応する予測子変数名が同じでなければなりません。学習時に予測子の名前を指定するには、名前と値の引数'PredictorNames'
を使用します。queryPoint
内の予測子変数はすべて数値ベクトルでなければなりません。queryPoint
に追加の変数 (応答変数や観測値の重みなど) を含めることができますが、fit
はこれらを無視します。fit
は、文字ベクトルの cell 配列ではない cell 配列や複数列の変数をサポートしません。
連続予測子について queryPoint
に NaN
が含まれており、'Method'
が 'conditional-kernel'
である場合、返されたオブジェクトに含まれているシャープレイ値 (ShapleyValues
) は NaN
になります。それ以外の場合、fit
は NaN
値を explainer.BlackboxModel
(explainer.BlackboxModel
のオブジェクト関数 predict
または blackbox
で指定された関数ハンドル) と同じ方法で処理します。
例: explainer.X(1,:)
は、explainer
の予測子データ X
の最初の観測値としてクエリ点を指定します。
データ型: single
| double
| table
名前と値の引数
オプションの引数のペアを Name1=Value1,...,NameN=ValueN
として指定します。ここで Name
は引数名、Value
は対応する値です。名前と値の引数は他の引数の後ろにする必要がありますが、ペアの順序は関係ありません。
R2021a より前では、名前と値をそれぞれコンマを使って区切り、Name
を引用符で囲みます。
例: fit(explainer,q,'Method','conditional-kernel','UseParallel',true)
は、kernelSHAP アルゴリズムの拡張機能を使用してクエリ点 q
のシャープレイ値を計算し、その計算を並列実行します。
MaxNumSubsets
— 予測子サブセットの最大数
explainer.NumSubsets
(既定値) | 正の整数
シャープレイ値の計算に使用する予測子サブセットの最大数。正の整数を指定します。
fit
が使用するサブセットを選択する方法の詳細については、計算コストを参照してください。
例: 'MaxNumSubsets',100
データ型: single
| double
Method
— シャープレイ値の計算アルゴリズム
explainer.Method
(既定値) | 'interventional-kernel'
| 'conditional-kernel'
シャープレイ値の計算アルゴリズム。'interventional-kernel'
または 'conditional-kernel'
を指定します。
これらのアルゴリズムの詳細については、アルゴリズムを参照してください。
例: 'Method','conditional-kernel'
データ型: char
| string
UseParallel
— 並列実行のフラグ
false
(既定値) | true
並列実行のフラグ。true
または false
として指定します。'UseParallel',true
を指定した場合、関数 fit
は parfor
を使用して for ループの反復を並列実行します。このオプションには Parallel Computing Toolbox™ が必要です。
例: 'UseParallel',true
データ型: logical
出力引数
newExplainer
— blackbox モデルを説明するオブジェクト
shapley
オブジェクト
blackbox モデルを説明するオブジェクト。shapley
オブジェクトとして返されます。オブジェクトの ShapleyValues
プロパティには、計算されたシャープレイ値が含まれています。
入力引数 explainer
を上書きするには、fit
の出力を explainer
に代入します。
explainer = fit(explainer,queryPoint);
詳細
シャープレイ値
ゲーム理論におけるプレーヤーのシャープレイ値とは、協力ゲームでのプレーヤーの平均限界貢献度です。機械学習予測のコンテキストでは、クエリ点の特徴量のシャープレイ値は、指定したクエリ点での予測 (回帰の場合は応答、分類の場合は各クラスのスコア) に対する特徴量の寄与を説明します。
クエリ点の特徴量のシャープレイ値は、平均予測からの偏差に対する特徴量の寄与です。クエリ点について、すべての特徴量に関するシャープレイ値の合計は、予測の平均からの合計偏差に対応します。つまり、平均予測とすべての特徴量に関するシャープレイ値の合計は、クエリ点についての予測に対応します。
詳細については、機械学習モデルのシャープレイ値を参照してください。
参照
[1] Lundberg, Scott M., and S. Lee. "A Unified Approach to Interpreting Model Predictions." Advances in Neural Information Processing Systems 30 (2017): 4765–774.
[2] Aas, Kjersti, Martin. Jullum, and Anders Løland. "Explaining Individual Predictions When Features Are Dependent: More Accurate Approximations to Shapley Values." arXiv:1903.10464 (2019).
拡張機能
自動並列サポート
Parallel Computing Toolbox™ を使用して自動的に並列計算を実行することで、コードを高速化します。
並列実行するには、この関数を呼び出すときに名前と値の引数 UseParallel
を true
に設定します。
並列計算の全般的な情報については、自動並列サポートを使用した MATLAB 関数の実行 (Parallel Computing Toolbox)を参照してください。
バージョン履歴
R2021a で導入
MATLAB コマンド
次の MATLAB コマンドに対応するリンクがクリックされました。
コマンドを MATLAB コマンド ウィンドウに入力して実行してください。Web ブラウザーは MATLAB コマンドをサポートしていません。
Select a Web Site
Choose a web site to get translated content where available and see local events and offers. Based on your location, we recommend that you select: .
You can also select a web site from the following list:
How to Get Best Site Performance
Select the China site (in Chinese or English) for best site performance. Other MathWorks country sites are not optimized for visits from your location.
Americas
- América Latina (Español)
- Canada (English)
- United States (English)
Europe
- Belgium (English)
- Denmark (English)
- Deutschland (Deutsch)
- España (Español)
- Finland (English)
- France (Français)
- Ireland (English)
- Italia (Italiano)
- Luxembourg (English)
- Netherlands (English)
- Norway (English)
- Österreich (Deutsch)
- Portugal (English)
- Sweden (English)
- Switzerland
- United Kingdom (English)