このページの内容は最新ではありません。最新版の英語を参照するには、ここをクリックします。
fit
説明
は、指定されたクエリ点 (newExplainer
= fit(explainer
,queryPoint
)queryPoint
) のシャープレイ値を計算し、計算したシャープレイ値を newExplainer
の ShapleyValues
プロパティに格納します。shapley
オブジェクト explainer
には、機械学習モデルとシャープレイ値の計算オプションが格納されています。
fit
は、いつ explainer
を作成するかを指定する、シャープレイ値計算オプションを使用します。このオプションは、関数 fit
の名前と値の引数を使用して変更できます。この関数は、新しく計算したシャープレイ値を含む shapley
オブジェクト newExplainer
を返します。
では、1 つ以上の名前と値の引数を使用して追加オプションを指定します。たとえば、newExplainer
= fit(explainer
,queryPoint
,Name,Value
)'UseParallel',true
と指定してシャープレイ値を並列計算します。
例
shapley
オブジェクトの作成と fit
を使用したシャープレイ値の計算
回帰モデルの学習を行い、shapley
オブジェクトを作成します。shapley
オブジェクトを作成するときに、クエリ点を指定しなかった場合、シャープレイ値は計算されません。オブジェクト関数 fit
を使用して、指定したクエリ点のシャープレイ値を計算します。次に、オブジェクト関数 plot
を使用して、シャープレイ値の棒グラフを作成します。
carbig
データセットを読み込みます。このデータセットには、1970 年代と 1980 年代初期に製造された自動車の測定値が格納されています。
load carbig
Acceleration
、Cylinders
などの予測子変数と応答変数 MPG
が格納された table を作成します。
tbl = table(Acceleration,Cylinders,Displacement,Horsepower,Model_Year,Weight,MPG);
学習セットの欠損値を削除すると、メモリ消費量を減らして関数 fitrkernel
の学習速度を向上させることができます。tbl
の欠損値を削除します。
tbl = rmmissing(tbl);
関数fitrkernel
を使用して MPG
の blackbox モデルの学習を行います。
rng('default') % For reproducibility mdl = fitrkernel(tbl,'MPG','CategoricalPredictors',[2 5]);
shapley
オブジェクトを作成します。mdl
には学習データが含まれないため、データ セット tbl
を指定します。
explainer = shapley(mdl,tbl)
explainer = shapley with properties: BlackboxModel: [1x1 RegressionKernel] QueryPoint: [] BlackboxFitted: [] ShapleyValues: [] X: [392x7 table] CategoricalPredictors: [2 5] Method: 'interventional-kernel' Intercept: 22.6202 NumSubsets: 64
explainer
は、学習データ tbl
を X
プロパティに格納します。
tbl
の最初の観測値についてすべての予測子変数のシャープレイ値を計算します。
queryPoint = tbl(1,:)
queryPoint=1×7 table
Acceleration Cylinders Displacement Horsepower Model_Year Weight MPG
____________ _________ ____________ __________ __________ ______ ___
12 8 307 130 70 3504 18
explainer = fit(explainer,queryPoint);
回帰モデルの場合、shapley
は予測応答を使用してシャープレイ値を計算し、ShapleyValues
プロパティに格納します。ShapleyValues
プロパティの値を表示します。
explainer.ShapleyValues
ans=6×2 table
Predictor ShapleyValue
______________ ____________
"Acceleration" -0.1561
"Cylinders" -0.18306
"Displacement" -0.34203
"Horsepower" -0.27291
"Model_Year" -0.2926
"Weight" -0.32402
関数 plot
を使用して、クエリ点のシャープレイ値をプロットします。
plot(explainer)
横棒グラフは、絶対値で並べ替えられた、すべての変数のシャープレイ値を示します。各シャープレイ値は、クエリ点についての予測に関して対応する変数が原因で生じた平均からの偏差を説明します。
複数のクエリ点のシャープレイ値の計算
分類モデルの学習を行い、shapley
オブジェクトを作成します。次に、複数のクエリ点のシャープレイ値を計算します。
CreditRating_Historical
データセットを読み込みます。データ セットには、顧客 ID、顧客の財務比率、業種ラベル、および信用格付けが格納されています。
tbl = readtable('CreditRating_Historical.dat');
関数fitcecoc
を使用して、信用格付けの blackbox モデルに学習させます。tbl
内の 2 ~ 7 列目の変数を予測子変数として使用します。
blackbox = fitcecoc(tbl,'Rating', ... 'PredictorNames',tbl.Properties.VariableNames(2:7), ... 'CategoricalPredictors','Industry');
blackbox
モデルを使用して、shapley
オブジェクトを作成します。計算速度を向上するには、tbl
の観測値の 25% を階層的にサブサンプリングし、その標本を使用してシャープレイ値を計算します。Kernel SHAP アルゴリズムの拡張機能の使用を指定します。
rng('default') % For reproducibility c = cvpartition(tbl.Rating,'Holdout',0.25); tbl_s = tbl(test(c),:); explainer = shapley(blackbox,tbl_s,'Method','conditional');
真の信用格付け値がそれぞれ AAA
および B
となる 2 つのクエリ点を見つけます。
queryPoint(1,:) = tbl_s(find(strcmp(tbl_s.Rating,'AAA'),1),:); queryPoint(2,:) = tbl_s(find(strcmp(tbl_s.Rating,'B'),1),:)
queryPoint=2×8 table
ID WC_TA RE_TA EBIT_TA MVE_BVTD S_TA Industry Rating
_____ ______ ______ _______ ________ _____ ________ _______
58258 0.511 0.869 0.106 8.538 0.732 2 {'AAA'}
82367 -0.078 -0.042 0.011 0.262 0.167 7 {'B' }
最初のクエリ点についてのシャープレイ値を計算してプロットします。
explainer1 = fit(explainer,queryPoint(1,:)); plot(explainer1)
2 番目のクエリ点についてのシャープレイ値を計算してプロットします。
explainer2 = fit(explainer,queryPoint(2,:)); plot(explainer2)
2 番目のクエリ点に関する実際の格付けは B
ですが、予測された格付けは BB
です。プロットには、予測された格付けのシャープレイ値が表示されます。
explainer1
と explainer2
にはそれぞれ、最初のクエリ点と 2 番目のクエリ点についてのシャープレイ値が含まれています。
入力引数
explainer
— blackbox モデルを説明するオブジェクト
shapley
オブジェクト
blackbox モデルを説明するオブジェクト。shapley
オブジェクトとして指定します。
queryPoint
— クエリ点
数値の行ベクトル | 単一行テーブル
fit
が予測を説明するクエリ点。数値の行ベクトルまたは単一行 table として指定します。
数値の行ベクトルの場合:
単一行 table の場合:
予測子データ
explainer.X
が table の場合、queryPoint
内のすべての予測子変数は変数名およびデータ型がexplainer.X
内の変数と同じでなければなりません。ただし、queryPoint
の列の順序がexplainer.X
の列の順序に対応する必要はありません。予測子データ
explainer.X
が数値行列の場合、explainer.BlackboxModel.PredictorNames
内の予測子名とqueryPoint
内の対応する予測子変数名が同じでなければなりません。学習時に予測子の名前を指定するには、名前と値の引数'PredictorNames'
を使用します。queryPoint
内の予測子変数はすべて数値ベクトルでなければなりません。queryPoint
に追加の変数 (応答変数や観測値の重みなど) を含めることができますが、fit
はこれらを無視します。fit
は、文字ベクトルの cell 配列ではない cell 配列や複数列の変数をサポートしません。
連続予測子について queryPoint
に NaN
が含まれており、'Method'
が 'conditional'
である場合、返されたオブジェクトに含まれているシャープレイ値 (ShapleyValues
) は NaN
になります。それ以外の場合、fit
は NaN
値を explainer.BlackboxModel
(explainer.BlackboxModel
のオブジェクト関数 predict
または blackbox
で指定された関数ハンドル) と同じ方法で処理します。
例: explainer.X(1,:)
は、explainer
の予測子データ X
の最初の観測値としてクエリ点を指定します。
データ型: single
| double
| table
名前と値の引数
オプションの引数のペアを Name1=Value1,...,NameN=ValueN
として指定します。ここで Name
は引数名、Value
は対応する値です。名前と値の引数は他の引数の後ろにする必要がありますが、ペアの順序は関係ありません。
R2021a より前では、名前と値をそれぞれコンマを使って区切り、Name
を引用符で囲みます。
例: fit(explainer,q,'Method','conditional','UseParallel',true)
は、Kernel SHAP アルゴリズムの拡張機能を使用してクエリ点 q
のシャープレイ値を計算し、その計算を並列実行します。
MaxNumSubsets
— 予測子サブセットの最大数
explainer.NumSubsets
(既定値) | 正の整数
シャープレイ値の計算に使用する予測子サブセットの最大数。正の整数を指定します。
fit
が使用するサブセットを選択する方法の詳細については、計算コストを参照してください。
この引数は、関数 fit
で Kernel SHAP アルゴリズムまたは Kernel SHAP アルゴリズムの拡張機能を使用する場合に有効です。Method
が 'interventional'
の場合に引数 MaxNumSubsets
を設定すると、Kernel SHAP アルゴリズムが使用されます。詳細については、アルゴリズムを参照してください。
例: 'MaxNumSubsets',100
データ型: single
| double
Method
— シャープレイ値の計算アルゴリズム
'interventional'
| 'conditional'
R2023a 以降
シャープレイ値の計算アルゴリズム。'interventional'
または 'conditional'
を指定します。
選択されたアルゴリズムの名前が newExplainer
の Method
プロパティに格納されます。詳細については、アルゴリズムを参照してください。
既定では、関数 fit
は、explainer
の Method
プロパティで指定されたアルゴリズムを使用します。
R2023a より前: この引数は 'interventional-kernel'
または 'conditional-kernel'
として指定できます。fit
は、Kernel SHAP アルゴリズムと Kernel SHAP アルゴリズムの拡張機能をサポートしています。
例: 'Method','conditional'
データ型: char
| string
UseParallel
— 並列実行のフラグ
false
(既定値) | true
並列実行のフラグ。数値または logical の 1
(true
) または 0
(false
) として指定します。UseParallel=true
を指定した場合、関数 fit
は parfor
を使用して for
ループの反復を実行します。Parallel Computing Toolbox™ がある場合、ループが並列に実行されます。
この引数は、関数 fit
で木のアンサンブル用の Tree SHAP アルゴリズム、Kernel SHAP アルゴリズム、または Kernel SHAP アルゴリズムの拡張機能を使用する場合に有効です。
例: 'UseParallel',true
データ型: logical
出力引数
newExplainer
— blackbox モデルを説明するオブジェクト
shapley
オブジェクト
blackbox モデルを説明するオブジェクト。shapley
オブジェクトとして返されます。オブジェクトの ShapleyValues
プロパティには、計算されたシャープレイ値が含まれています。
入力引数 explainer
を上書きするには、fit
の出力を explainer
に代入します。
explainer = fit(explainer,queryPoint);
詳細
シャープレイ値
ゲーム理論におけるプレーヤーのシャープレイ値とは、協力ゲームでのプレーヤーの平均限界貢献度です。機械学習予測のコンテキストでは、クエリ点の特徴量のシャープレイ値は、指定したクエリ点での予測 (回帰の場合は応答、分類の場合は各クラスのスコア) に対する特徴量の寄与を説明します。
クエリ点の特徴量のシャープレイ値は、平均予測からの偏差に対する特徴量の寄与です。クエリ点について、すべての特徴量に関するシャープレイ値の合計は、予測の平均からの合計偏差に対応します。つまり、平均予測とすべての特徴量に関するシャープレイ値の合計は、クエリ点についての予測に対応します。
詳細については、機械学習モデルのシャープレイ値を参照してください。
参照
[1] Lundberg, Scott M., and S. Lee. "A Unified Approach to Interpreting Model Predictions." Advances in Neural Information Processing Systems 30 (2017): 4765–774.
[2] Lundberg, Scott M., G. Erion, H. Chen, et al. "From Local Explanations to Global Understanding with Explainable AI for Trees." Nature Machine Intelligence 2 (January 2020): 56–67.
[3] Aas, Kjersti, Martin Jullum, and Anders Løland. "Explaining Individual Predictions When Features Are Dependent: More Accurate Approximations to Shapley Values." Artificial Intelligence 298 (September 2021).
拡張機能
自動並列サポート
Parallel Computing Toolbox™ を使用して自動的に並列計算を実行することで、コードを高速化します。
並列実行するには、この関数を呼び出すときに名前と値の引数 UseParallel
を true
に設定します。
並列計算の全般的な情報については、自動並列サポートを使用した MATLAB 関数の実行 (Parallel Computing Toolbox)を参照してください。
バージョン履歴
R2021a で導入R2023b: 介入型の Tree SHAP アルゴリズムで予測子に欠損値があるデータをサポート
入力予測子データ (
) の観測値またはクエリ点 (explainer
.XqueryPoint
) の値に欠損値があり、Method
の値が "interventional"
の場合、関数 fit
では、木モデルおよび木学習器のアンサンブル モデル用の Tree SHAP アルゴリズムを使用できます。以前のリリースでは、このような条件の場合、関数 fit
では常に木ベースのモデル用の Kernel SHAP アルゴリズムが使用されていました。木ベースのモデル用に Tree SHAP ではなく Kernel SHAP が引き続き使用される場合を含む詳細については、介入型アルゴリズムを参照してください。
R2023a: fit
で Linear SHAP アルゴリズムと Tree SHAP アルゴリズムをサポート
R2023a: 名前と値の引数 Method
の値の変更
名前と値の引数 Method
のサポートされる値が、'interventional-kernel'
と 'conditional-kernel'
から 'interventional'
と 'conditional'
にそれぞれ変更されました。
MATLAB コマンド
次の MATLAB コマンドに対応するリンクがクリックされました。
コマンドを MATLAB コマンド ウィンドウに入力して実行してください。Web ブラウザーは MATLAB コマンドをサポートしていません。
Select a Web Site
Choose a web site to get translated content where available and see local events and offers. Based on your location, we recommend that you select: .
You can also select a web site from the following list:
How to Get Best Site Performance
Select the China site (in Chinese or English) for best site performance. Other MathWorks country sites are not optimized for visits from your location.
Americas
- América Latina (Español)
- Canada (English)
- United States (English)
Europe
- Belgium (English)
- Denmark (English)
- Deutschland (Deutsch)
- España (Español)
- Finland (English)
- France (Français)
- Ireland (English)
- Italia (Italiano)
- Luxembourg (English)
- Netherlands (English)
- Norway (English)
- Österreich (Deutsch)
- Portugal (English)
- Sweden (English)
- Switzerland
- United Kingdom (English)