このページの内容は最新ではありません。最新版の英語を参照するには、ここをクリックします。
fit
説明
は、指定されたクエリ点 (newExplainer
= fit(explainer
,queryPoints
)queryPoints
) のシャープレイ値を計算し、計算したシャープレイ値を newExplainer
の ShapleyValues
プロパティに格納します。shapley
オブジェクト explainer
には、機械学習モデルとシャープレイ値の計算オプションが格納されています。
fit
は、いつ explainer
を作成するかを指定する、シャープレイ値計算オプションを使用します。このオプションは、関数 fit
の名前と値の引数を使用して変更できます。この関数は、新しく計算したシャープレイ値を含む shapley
オブジェクト newExplainer
を返します。
では、1 つ以上の名前と値の引数を使用して追加オプションを指定します。たとえば、newExplainer
= fit(explainer
,queryPoints
,Name=Value
)UseParallel=true
と指定してシャープレイ値を並列計算します。
例
shapley
オブジェクトの作成と fit
を使用したシャープレイ値の計算
回帰モデルの学習を行い、shapley
オブジェクトを作成します。shapley
オブジェクトを作成するときに、クエリ点を指定しなかった場合、シャープレイ値は計算されません。オブジェクト関数 fit
を使用して、指定したクエリ点のシャープレイ値を計算します。次に、オブジェクト関数 plot
を使用して、シャープレイ値の棒グラフを作成します。
carbig
データ セットを読み込みます。このデータ セットには、1970 年代と 1980 年代初期に製造された自動車の測定値が格納されています。
load carbig
Acceleration
、Cylinders
などの予測子変数と応答変数 MPG
が格納された table を作成します。
tbl = table(Acceleration,Cylinders,Displacement, ...
Horsepower,Model_Year,Weight,MPG);
学習セットの欠損値を削除すると、メモリ消費量を減らして関数 fitrkernel
の学習速度を向上させることができます。tbl
の欠損値を削除します。
tbl = rmmissing(tbl);
関数fitrkernel
を使用して MPG
の blackbox モデルの学習を行います。変数 Cylinders
と Model_Year
をカテゴリカル予測子として指定します。残りの予測子を標準化します。
rng("default") % For reproducibility mdl = fitrkernel(tbl,"MPG",CategoricalPredictors=[2 5], ... Standardize=true);
shapley
オブジェクトを作成します。mdl
には学習データが含まれないため、データ セット tbl
を指定します。
explainer = shapley(mdl,tbl)
explainer = BlackboxModel: [1x1 RegressionKernel] QueryPoints: [] BlackboxFitted: [] ShapleyValues: [] X: [392x7 table] CategoricalPredictors: [2 5] Method: "interventional-kernel" Intercept: 22.7326 NumSubsets: 64
explainer
は、学習データ tbl
を X
プロパティに格納します。
tbl
の最初の観測値についてすべての予測子変数のシャープレイ値を計算します。
queryPoint = tbl(1,:)
queryPoint=1×7 table
Acceleration Cylinders Displacement Horsepower Model_Year Weight MPG
____________ _________ ____________ __________ __________ ______ ___
12 8 307 130 70 3504 18
explainer = fit(explainer,queryPoint);
回帰モデルの場合、shapley
は予測応答を使用してシャープレイ値を計算し、ShapleyValues
プロパティに格納します。ShapleyValues
プロパティの値を表示します。
explainer.ShapleyValues
ans=6×2 table
Predictor ShapleyValue
______________ ____________
"Acceleration" -0.23731
"Cylinders" -0.87423
"Displacement" -1.0224
"Horsepower" -0.56975
"Model_Year" -0.055414
"Weight" -0.86088
関数 plot
を使用して、クエリ点のシャープレイ値をプロットします。
plot(explainer)
横棒グラフは、絶対値で並べ替えられた、すべての変数のシャープレイ値を示します。各シャープレイ値は、クエリ点についての予測に関して対応する変数が原因で生じた平均からの偏差を説明します。
2 つのクエリ点のシャープレイ値の計算
分類モデルの学習を行い、shapley
オブジェクトを作成します。その後、2 つのクエリ点のシャープレイ値を計算します。
CreditRating_Historical
データ セットを読み込みます。データ セットには、顧客 ID、顧客の財務比率、業種ラベル、および信用格付けが格納されています。
tbl = readtable("CreditRating_Historical.dat");
関数fitcecoc
を使用して、信用格付けの blackbox モデルに学習させます。tbl
内の 2 ~ 7 列目の変数を予測子変数として使用します。
blackbox = fitcecoc(tbl,"Rating", ... PredictorNames=tbl.Properties.VariableNames(2:7), ... CategoricalPredictors="Industry");
blackbox
モデルを使用して、shapley
オブジェクトを作成します。計算速度を向上するには、tbl
の観測値の 25% を階層的にサブサンプリングし、その標本を使用してシャープレイ値を計算します。Kernel SHAP アルゴリズムの拡張機能の使用を指定します。
rng("default") % For reproducibility c = cvpartition(tbl.Rating,"Holdout",0.25); sampleTbl = tbl(test(c),:); explainer = shapley(blackbox,sampleTbl,Method="conditional");
真の信用格付け値がそれぞれ AAA
および B
となる 2 つのクエリ点を見つけます。
queryPoints(1,:) = sampleTbl(find(strcmp(sampleTbl.Rating,"AAA"),1),:); queryPoints(2,:) = sampleTbl(find(strcmp(sampleTbl.Rating,"B"),1),:)
queryPoints=2×8 table
ID WC_TA RE_TA EBIT_TA MVE_BVTD S_TA Industry Rating
_____ ______ ______ _______ ________ _____ ________ _______
58258 0.511 0.869 0.106 8.538 0.732 2 {'AAA'}
82367 -0.078 -0.042 0.011 0.262 0.167 7 {'B' }
最初のクエリ点についてのシャープレイ値を計算してプロットします。
explainer1 = fit(explainer,queryPoints(1,:)); plot(explainer1)
2 番目のクエリ点についてのシャープレイ値を計算してプロットします。
explainer2 = fit(explainer,queryPoints(2,:)); plot(explainer2)
2 番目のクエリ点に関する実際の格付けは B
ですが、予測された格付けは BB
です。プロットには、予測された格付けのシャープレイ値が表示されます。
explainer1
と explainer2
にはそれぞれ、最初のクエリ点と 2 番目のクエリ点についてのシャープレイ値が含まれています。
回帰モデルのシャープレイ値の粒子群チャート
回帰モデルの学習を行い、shapley
オブジェクトを作成します。オブジェクト関数 fit
を使用して、指定したクエリ点のシャープレイ値を計算します。その後、オブジェクト関数 swarmchart
を使用して複数のクエリ点のシャープレイ値をプロットします。
carbig
データ セットを読み込みます。このデータ セットには、1970 年代と 1980 年代初期に製造された自動車の測定値が格納されています。
load carbig
Acceleration
、Cylinders
などの予測子変数と応答変数 MPG
が格納された table を作成します。
tbl = table(Acceleration,Cylinders,Displacement, ...
Horsepower,Model_Year,Weight,MPG);
学習セットの欠損値を削除すると、メモリ消費量が減り、関数 fitrkernel
の学習速度の向上に役立ちます。tbl
の欠損値を削除します。
tbl = rmmissing(tbl);
関数 fitrkernel
を使用して MPG
の blackbox モデルに学習させます。変数 Cylinders
と Model_Year
をカテゴリカル予測子として指定します。残りの予測子を標準化します。
rng("default") % For reproducibility mdl = fitrkernel(tbl,"MPG",CategoricalPredictors=[2 5], ... Standardize=true);
shapley
オブジェクトを作成します。mdl
には学習データが含まれていないため、データ セット tbl
を指定します。
explainer = shapley(mdl,tbl)
explainer = BlackboxModel: [1×1 RegressionKernel] QueryPoints: [] BlackboxFitted: [] ShapleyValues: [] X: [392×7 table] CategoricalPredictors: [2 5] Method: "interventional-kernel" Intercept: 22.7326 NumSubsets: 64
explainer
は、学習データ tbl
を X
プロパティに格納します。
tbl
のすべての観測値のシャープレイ値を計算します。Parallel Computing Toolbox™ のライセンスがある場合は、名前と値の引数 UseParallel
を使用して計算を高速化します。
explainer = fit(explainer,tbl,UseParallel=true);
Starting parallel pool (parpool) using the 'Processes' profile ... 10-Jan-2024 14:09:35: Job Queued. Waiting for parallel pool job with ID 5 to start ... Connected to parallel pool with 6 workers.
回帰モデルの場合、shapley
は予測応答を使用してシャープレイ値を計算し、ShapleyValues
プロパティに格納します。explainer
に複数のクエリ点のシャープレイ値が格納されているため、代わりに平均絶対シャープレイ値を表示します。
explainer.MeanAbsoluteShapley
ans=6×2 table
Predictor ShapleyValue
______________ ____________
"Acceleration" 0.52233
"Cylinders" 1.0412
"Displacement" 0.80485
"Horsepower" 0.7589
"Model_Year" 0.82285
"Weight" 0.98453
それぞれの予測子について、すべてのクエリ点で平均化したシャープレイ値の絶対値が平均絶対シャープレイ値になります。予測子 Cylinders
の平均絶対シャープレイ値が最も大きく、予測子 Acceleration
の平均絶対シャープレイ値が最も小さくなっています。
オブジェクト関数 swarmchart
を使用してシャープレイ値を可視化します。"copper"
カラーマップを使用するように指定します。
swarmchart(explainer,ColorMap="copper")
それぞれの予測子について、クエリ点のシャープレイ値が関数によって表示されます。対応する粒子群チャートにシャープレイ値の分布が表示されます。予測子の順序は、平均絶対シャープレイ値を使用して関数で決定されます。
Weight
の値が小さいクエリ点は、シャープレイ値が大きい正の値になっているように見えます。つまり、それらのクエリ点については、予測子 Weight
は MPG
の予測される値の平均からの差が大きくなるのに寄与しています。同様に、Weight
の値が大きいクエリ点は、シャープレイ値が大きい負の値になっているように見えます。つまり、それらのクエリ点については、予測子 Weight
は MPG
の予測される値の平均からの差が小さくなるのに寄与しています。これらの結果は、自動車の重量は MPG の値と逆の相関があるという考え方に一致しています。
入力引数
explainer
— blackbox モデルを説明するオブジェクト
shapley
オブジェクト
blackbox モデルを説明するオブジェクト。shapley
オブジェクトとして指定します。
queryPoints
— クエリ点
数値行列 | テーブル
fit
が予測を説明するクエリ点。数値行列または table として指定します。queryPoints
の各行が 1 つのクエリ点に対応します。
数値行列の場合
テーブルの場合
予測子データ
explainer.X
が table の場合、queryPoints
内のすべての予測子変数は変数名およびデータ型がexplainer.X
内の変数と同じでなければなりません。ただし、queryPoints
の列の順序がexplainer.X
の列の順序に対応する必要はありません。予測子データ
explainer.X
が数値行列の場合、explainer.BlackboxModel.PredictorNames
内の予測子名とqueryPoints
内の対応する予測子変数名が同じでなければなりません。学習時に予測子の名前を指定するには、名前と値の引数PredictorNames
を使用します。queryPoints
内の予測子変数はすべて数値ベクトルでなければなりません。queryPoints
に追加の変数 (応答変数や観測値の重みなど) を含めることができますが、fit
はこれらを無視します。fit
は、文字ベクトルの cell 配列ではない cell 配列や複数列の変数をサポートしません。
連続予測子について queryPoints
に NaN
が含まれており、Method
が "conditional"
である場合、返されたオブジェクトに含まれているシャープレイ値 (ShapleyValues
) は NaN
になります。ガウス過程回帰 (GPR)、カーネル、線形、ニューラル ネットワーク、またはサポート ベクター マシン (SVM) のモデルである回帰モデルを使用する場合、欠損値がある予測子や学習時にはなかったカテゴリを含むクエリ点については、fit
はシャープレイ値として NaN
を返します。それ以外のすべてのモデルについては、fit
は欠損値を explainer.BlackboxModel
(explainer.BlackboxModel
のオブジェクト関数 predict
または blackbox
で指定された関数ハンドル) と同じ方法で処理します。
R2024a より前: クエリ点が 1 つだけの場合は、数値の行ベクトルまたは単一行 table を使用して指定できます。
例: explainer.X(1,:)
は、explainer
の予測子データ X
の最初の観測値としてクエリ点を指定します。
データ型: single
| double
| table
名前と値の引数
オプションの引数のペアを Name1=Value1,...,NameN=ValueN
として指定します。ここで Name
は引数名、Value
は対応する値です。名前と値の引数は他の引数の後ろにする必要がありますが、ペアの順序は関係ありません。
例: fit(explainer,q,Method="conditional",UseParallel=true)
は、Kernel SHAP アルゴリズムの拡張機能を使用してクエリ点 q
のシャープレイ値を計算し、その計算を並列実行します。
MaxNumSubsets
— 予測子サブセットの最大数
explainer.NumSubsets
(既定値) | 正の整数
シャープレイ値の計算に使用する予測子サブセットの最大数。正の整数を指定します。
fit
が使用するサブセットを選択する方法の詳細については、計算コストを参照してください。
この引数は、関数 fit
で Kernel SHAP アルゴリズムまたは Kernel SHAP アルゴリズムの拡張機能を使用する場合に有効です。Method
が "interventional"
の場合に引数 MaxNumSubsets
を設定すると、Kernel SHAP アルゴリズムが使用されます。詳細については、アルゴリズムを参照してください。
例: MaxNumSubsets=100
データ型: single
| double
Method
— シャープレイ値の計算アルゴリズム
"interventional"
| "conditional"
シャープレイ値の計算アルゴリズム。"interventional"
または "conditional"
として指定します。
選択されたアルゴリズムの名前が newExplainer
の Method
プロパティに格納されます。詳細については、アルゴリズムを参照してください。
既定では、関数 fit
は、explainer
の Method
プロパティで指定されたアルゴリズムを使用します。
R2023a より前: この引数は "interventional-kernel"
または "conditional-kernel"
として指定できます。fit
は、Kernel SHAP アルゴリズムと Kernel SHAP アルゴリズムの拡張機能をサポートしています。
例: Method="conditional"
データ型: char
| string
OutputFcn
— 各クエリ点の評価後に呼び出される関数
[]
(既定値) | 関数ハンドル
R2024a 以降
各クエリ点の評価後に呼び出される関数。関数ハンドルとして指定します。シャープレイ値の計算の停止、変数の作成、結果のプロットなど、さまざまなタスクを出力関数で実行できます。独自の出力関数を記述する方法の詳細と例については、Shapley Output Functionsを参照してください。
この引数は、関数 fit
で複数のクエリ点のシャープレイ値を計算する場合で、UseParallel
の値が false
の場合のみ有効です。
データ型: function_handle
UseParallel
— 並列実行のフラグ
false
(既定値) | true
並列実行のフラグ。数値または logical の 1
(true
) または 0
(false
) として指定します。UseParallel=true
を指定した場合、関数 fit
は parfor
を使用して for
ループの反復を実行します。Parallel Computing Toolbox™ がある場合、ループが並列に実行されます。
この引数は、関数 fit
で複数のクエリ点のシャープレイ値を計算する場合、または木のアンサンブル用の Tree SHAP アルゴリズム、Kernel SHAP アルゴリズム、または Kernel SHAP アルゴリズムの拡張機能を使用して 1 つのクエリ点のシャープレイ値を計算する場合のみ有効です。
例: UseParallel=true
データ型: logical
出力引数
newExplainer
— blackbox モデルを説明するオブジェクト
shapley
オブジェクト
blackbox モデルを説明するオブジェクト。shapley
オブジェクトとして返されます。オブジェクトの ShapleyValues
プロパティには、計算されたシャープレイ値が含まれています。
入力引数 explainer
を上書きするには、fit
の出力を explainer
に代入します。
explainer = fit(explainer,queryPoints);
詳細
シャープレイ値
ゲーム理論におけるプレーヤーのシャープレイ値とは、協力ゲームでのプレーヤーの平均限界貢献度です。機械学習予測のコンテキストでは、クエリ点の特徴量のシャープレイ値は、指定したクエリ点での予測 (回帰の場合は応答、分類の場合は各クラスのスコア) に対する特徴量の寄与を説明します。
クエリ点の特徴量のシャープレイ値は、平均予測からの偏差に対する特徴量の寄与です。クエリ点について、すべての特徴量に関するシャープレイ値の合計は、予測の平均からの合計偏差に対応します。つまり、平均予測とすべての特徴量に関するシャープレイ値の合計は、クエリ点についての予測に対応します。
詳細については、機械学習モデルのシャープレイ値を参照してください。
参照
[1] Lundberg, Scott M., and S. Lee. "A Unified Approach to Interpreting Model Predictions." Advances in Neural Information Processing Systems 30 (2017): 4765–774.
[2] Lundberg, Scott M., G. Erion, H. Chen, et al. "From Local Explanations to Global Understanding with Explainable AI for Trees." Nature Machine Intelligence 2 (January 2020): 56–67.
[3] Aas, Kjersti, Martin Jullum, and Anders Løland. "Explaining Individual Predictions When Features Are Dependent: More Accurate Approximations to Shapley Values." Artificial Intelligence 298 (September 2021).
拡張機能
自動並列サポート
Parallel Computing Toolbox™ を使用して自動的に並列計算を実行することで、コードを高速化します。
並列実行するには、この関数を呼び出すときに名前と値の引数 UseParallel
を true
に設定します。
並列計算の全般的な情報については、自動並列サポートを使用した MATLAB 関数の実行 (Parallel Computing Toolbox)を参照してください。
バージョン履歴
R2021a で導入R2024a: 複数のクエリ点のシャープレイ値の計算
引数 queryPoints
を使用して複数のクエリ点のシャープレイ値を計算できるようになりました。複数のクエリ点を扱うときは、シャープレイ値の計算の停止、変数の作成、結果のプロットなど、さまざまなタスクを出力関数を使用して実行できます。これを行うには、名前と値の引数 OutputFcn
を指定します。
R2023b: 介入型の Tree SHAP アルゴリズムで予測子に欠損値があるデータをサポート
入力予測子データ (
) の観測値またはクエリ点 (explainer
.XqueryPoint
) の値に欠損値があり、Method
の値が "interventional"
の場合、関数 fit
では、木モデルおよび木学習器のアンサンブル モデル用の Tree SHAP アルゴリズムを使用できます。以前のリリースでは、このような条件の場合、関数 fit
では常に木ベースのモデル用の Kernel SHAP アルゴリズムが使用されていました。木ベースのモデル用に Tree SHAP ではなく Kernel SHAP が引き続き使用される場合を含む詳細については、介入型アルゴリズムを参照してください。
R2023a: fit
で Linear SHAP アルゴリズムと Tree SHAP アルゴリズムをサポート
R2023a: 名前と値の引数 Method
の値の変更
名前と値の引数 Method
のサポートされる値が、'interventional-kernel'
と 'conditional-kernel'
から 'interventional'
と 'conditional'
にそれぞれ変更されました。
MATLAB コマンド
次の MATLAB コマンドに対応するリンクがクリックされました。
コマンドを MATLAB コマンド ウィンドウに入力して実行してください。Web ブラウザーは MATLAB コマンドをサポートしていません。
Select a Web Site
Choose a web site to get translated content where available and see local events and offers. Based on your location, we recommend that you select: .
You can also select a web site from the following list:
How to Get Best Site Performance
Select the China site (in Chinese or English) for best site performance. Other MathWorks country sites are not optimized for visits from your location.
Americas
- América Latina (Español)
- Canada (English)
- United States (English)
Europe
- Belgium (English)
- Denmark (English)
- Deutschland (Deutsch)
- España (Español)
- Finland (English)
- France (Français)
- Ireland (English)
- Italia (Italiano)
- Luxembourg (English)
- Netherlands (English)
- Norway (English)
- Österreich (Deutsch)
- Portugal (English)
- Sweden (English)
- Switzerland
- United Kingdom (English)