plot
説明
plot(
は、explainer
)shapley
オブジェクト explainer
のシャープレイ値を使用して横棒グラフを作成します。
explainer
にクエリ点が 1 つだけ格納されている場合は、棒グラフにシャープレイ値が表示されます。これらの値はオブジェクトのShapleyValues
プロパティに格納されています。各バーは、クエリ点 (explainer.
) についての blackbox モデル (QueryPoints
explainer.
) 内の各特徴量 (予測子) のシャープレイ値を示します。BlackboxModel
explainer
にクエリ点が複数格納されている場合は、棒グラフに平均絶対シャープレイ値が表示されます。これらの値はオブジェクトのMeanAbsoluteShapley
プロパティに格納されています。それぞれの予測子 (explainer.BlackboxModel
が分類モデルの場合はそれぞれのクラス) について、explainer.QueryPoints
のすべてのクエリ点で平均化したシャープレイ値の絶対値が平均絶対シャープレイ値になります。 (R2024a 以降)
plot(
では、1 つ以上の名前と値の引数を使用して追加オプションを指定します。たとえば、explainer
,Name=Value
)NumImportantPredictors=5
と指定すると、絶対シャープレイ値 (クエリ点が 1 つの場合) または平均絶対シャープレイ値 (クエリ点が複数の場合) が大きい上位 5 つの特徴量のシャープレイ値がプロットされます。
plot(
は、ターゲットの座標軸 ax
,___)ax
にプロットを表示します。ax
は、前の任意の構文で最初の引数として指定します。 (R2023b 以降)
は、前の構文におけるいずれかの入力引数の組み合わせを使用して、b
= plot(___)Bar
オブジェクトまたは Bar
オブジェクトの配列を返します。b
は、オブジェクトの作成後にそのプロパティ (Bar のプロパティ) をクエリまたは変更するのに使用します。
例
1 つのクエリ点についてのすべてのクラスのシャープレイ値のプロット
分類モデルの学習を行い、shapley
オブジェクトを作成します。次に、オブジェクト関数 plot
を使用して、シャープレイ値をプロットします。
CreditRating_Historical
データ セットを読み込みます。データ セットには、顧客 ID、顧客の財務比率、業種ラベル、および信用格付けが格納されています。
tbl = readtable("CreditRating_Historical.dat");
テーブルの最初の 3 行を表示します。
head(tbl,3)
ID WC_TA RE_TA EBIT_TA MVE_BVTD S_TA Industry Rating _____ _____ _____ _______ ________ _____ ________ ______ 62394 0.013 0.104 0.036 0.447 0.142 3 {'BB'} 48608 0.232 0.335 0.062 1.969 0.281 8 {'A' } 42444 0.311 0.367 0.074 1.935 0.366 1 {'A' }
関数fitcecoc
を使用して、信用格付けの blackbox モデルに学習させます。tbl
内の 2 ~ 7 列目の変数を予測子変数として使用します。クラス名を指定してクラスの順序を設定することが推奨されます。
blackbox = fitcecoc(tbl,"Rating", ... PredictorNames=tbl.Properties.VariableNames(2:7), ... CategoricalPredictors="Industry", ... ClassNames={'AAA','AA','A','BBB','BB','B','CCC'});
最後の観測値の予測を説明する shapley
オブジェクトを作成します。計算速度を向上するには、tbl
の観測値の 25% を階層的にサブサンプリングし、その標本を使用してシャープレイ値を計算します。
R2024a より前: QueryPoints
の代わりに、名前と値の引数 QueryPoint
を使用してクエリ点を指定します。
queryPoint = tbl(end,:)
queryPoint=1×8 table
ID WC_TA RE_TA EBIT_TA MVE_BVTD S_TA Industry Rating
_____ _____ _____ _______ ________ ____ ________ ______
73104 0.239 0.463 0.065 2.924 0.34 2 {'AA'}
rng("default") % For reproducibility c = cvpartition(tbl.Rating,"Holdout",0.25); sampleTbl = tbl(test(c),:); explainer = shapley(blackbox,sampleTbl,QueryPoints=queryPoint);
分類モデルの場合、shapley
は各クラスの予測クラス スコアを使用してシャープレイ値を計算します。ShapleyValues
プロパティの値を表示します。
explainer.ShapleyValues
ans=6×8 table
Predictor AAA AA A BBB BB B CCC
__________ _________ __________ ___________ __________ ___________ __________ __________
"WC_TA" 0.051045 0.022644 0.0096138 0.0015954 -0.027857 -0.04134 -0.039476
"RE_TA" 0.16729 0.09479 0.05308 -0.011178 -0.087689 -0.20847 -0.29204
"EBIT_TA" 0.0012015 0.00053338 0.00043344 0.00012321 -0.00066994 -0.0013388 -0.0011793
"MVE_BVTD" 1.3377 1.338 0.67839 -0.027654 -0.55142 -0.75327 -0.59578
"S_TA" -0.012484 -0.009098 -0.00074119 -0.0035582 -7.3462e-05 0.0014495 -0.0020609
"Industry" -0.099117 -0.046867 0.0031376 0.080071 0.089726 0.099699 0.15691
ShapleyValues
プロパティには、クラスごとにすべての特徴量のシャープレイ値が格納されています。
関数 plot
を使用して予測クラスのシャープレイ値をプロットします。
plot(explainer)
横棒グラフは、絶対値で並べ替えられた、すべての変数のシャープレイ値を示します。各シャープレイ値は、クエリ点についてのスコアに関して対応する変数が原因で生じた予測クラスの平均スコアからの偏差を説明します。
explainer.BlackboxModel
ですべてのクラス名を指定して、すべてのクラスのシャープレイ値をプロットします。
plot(explainer,ClassNames=explainer.BlackboxModel.ClassNames)
1 つのクエリ点についてのプロットする重要な予測子の数の指定
回帰モデルの学習を行い、shapley
オブジェクトを作成します。オブジェクト関数 fit
を使用して、指定したクエリ点のシャープレイ値を計算します。次に、オブジェクト関数 plot
を使用して、予測子のシャープレイ値をプロットします。関数 plot
を呼び出すときにプロットする重要な予測子の数を指定します。
carbig
データ セットを読み込みます。このデータ セットには、1970 年代と 1980 年代初期に製造された自動車の測定値が格納されています。
load carbig
Acceleration
、Cylinders
などの予測子変数と応答変数 MPG
が格納された table を作成します。
tbl = table(Acceleration,Cylinders,Displacement, ...
Horsepower,Model_Year,Weight,MPG);
学習セットの欠損値を削除すると、メモリ消費量を減らして関数 fitrkernel
の学習速度を向上させることができます。tbl
の欠損値を削除します。
tbl = rmmissing(tbl);
関数fitrkernel
を使用して MPG
の blackbox モデルの学習を行います。変数 Cylinders
と Model_Year
をカテゴリカル予測子として指定します。残りの予測子を標準化します。
rng("default") % For reproducibility mdl = fitrkernel(tbl,"MPG",CategoricalPredictors=[2 5], ... Standardize=true);
shapley
オブジェクトを作成します。mdl
には学習データが含まれていないため、データ セット tbl
を指定します。
explainer = shapley(mdl,tbl)
explainer = BlackboxModel: [1x1 RegressionKernel] QueryPoints: [] BlackboxFitted: [] ShapleyValues: [] X: [392x7 table] CategoricalPredictors: [2 5] Method: "interventional-kernel" Intercept: 22.7326 NumSubsets: 64
explainer
は、学習データ tbl
を X
プロパティに格納します。
tbl
の最初の観測値についてすべての予測子変数のシャープレイ値を計算します。
queryPoint = tbl(1,:)
queryPoint=1×7 table
Acceleration Cylinders Displacement Horsepower Model_Year Weight MPG
____________ _________ ____________ __________ __________ ______ ___
12 8 307 130 70 3504 18
explainer = fit(explainer,queryPoint);
回帰モデルの場合、shapley
は予測応答を使用してシャープレイ値を計算し、ShapleyValues
プロパティに格納します。ShapleyValues
プロパティの値を表示します。
explainer.ShapleyValues
ans=6×2 table
Predictor ShapleyValue
______________ ____________
"Acceleration" -0.23731
"Cylinders" -0.87423
"Displacement" -1.0224
"Horsepower" -0.56975
"Model_Year" -0.055414
"Weight" -0.86088
関数 plot
を使用して、クエリ点のシャープレイ値をプロットします。予測応答について上位 5 つの重要な予測子のみをプロットするように指定します。
plot(explainer,NumImportantPredictors=5)
横棒グラフは、絶対値で並べ替えられた、5 つの最も重要な予測子のシャープレイ値を示します。各シャープレイ値は、クエリ点についての予測に関して対応する変数が原因で生じた平均からの偏差を説明します。
複数のクエリ点のシャープレイ値のプロット
分類モデルの学習を行い、shapley
オブジェクトを作成します。オブジェクト関数 plot
を使用して、複数のクエリ点の平均絶対シャープレイ値をプロットします。その後、いずれかのクエリ点のシャープレイ値をプロットします。
CreditRating_Historical
データ セットを読み込みます。データ セットには、顧客 ID、顧客の財務比率、業種ラベル、および信用格付けが格納されています。
tbl = readtable("CreditRating_Historical.dat");
テーブルの最初の 3 行を表示します。
head(tbl,3)
ID WC_TA RE_TA EBIT_TA MVE_BVTD S_TA Industry Rating _____ _____ _____ _______ ________ _____ ________ ______ 62394 0.013 0.104 0.036 0.447 0.142 3 {'BB'} 48608 0.232 0.335 0.062 1.969 0.281 8 {'A' } 42444 0.311 0.367 0.074 1.935 0.366 1 {'A' }
関数 fitcecoc
を使用して、信用格付けの blackbox モデルに学習させます。tbl
内の 2 ~ 7 列目の変数を予測子変数として使用します。クラス名を指定してクラスの順序を設定することが推奨されます。
blackbox = fitcecoc(tbl,"Rating", ... PredictorNames=tbl.Properties.VariableNames(2:7), ... CategoricalPredictors="Industry", ... ClassNames={'AAA','AA','A','BBB','BB','B','CCC'});
複数のクエリ点の予測を説明する shapley
オブジェクトを作成します。計算速度を向上するには、tbl
の観測値の 10% を階層的にサブサンプリングし、その標本を使用してシャープレイ値を計算します。抽出された観測値をクエリ点として指定します。
rng("default") % For reproducibility c = cvpartition(tbl.Rating,"Holdout",0.10); sampleTbl = tbl(test(c),:); explainer = shapley(blackbox,sampleTbl, ... queryPoints=sampleTbl);
分類モデルの場合、shapley
は各クラスの予測クラス スコアを使用してシャープレイ値を計算します。複数のクエリ点を指定すると、各予測子と各クラスについての平均絶対シャープレイ値が関数で計算されます。
explainer.MeanAbsoluteShapley
ans=6×8 table
Predictor AAA AA A BBB BB B CCC
__________ _________ __________ _________ __________ _________ _________ _________
"WC_TA" 0.056246 0.034016 0.027208 0.02194 0.041348 0.060144 0.056189
"RE_TA" 0.1202 0.097136 0.099341 0.094155 0.10629 0.1799 0.25493
"EBIT_TA" 0.0014694 0.00086978 0.0010461 0.00088111 0.0011695 0.0020823 0.0018035
"MVE_BVTD" 0.81198 0.79496 1.0804 1.5952 2.0768 2.2893 1.7551
"S_TA" 0.025692 0.0098722 0.011002 0.01535 0.0015691 0.0075802 0.012961
"Industry" 0.073842 0.084015 0.066049 0.039714 0.062301 0.12082 0.11111
たとえば、explainer.MeanAbsoluteShapley.AAA(1)
の値は、予測子 WC_TA
とクラス AAA
の絶対シャープレイ値の sampleTbl
のすべての観測値での平均になります。
explainer.MeanAbsoluteShapley.AAA(1)
ans = 0.0562
オブジェクト関数 plot
を使用して、平均絶対シャープレイ値をプロットします。
plot(explainer)
それぞれのクラスについて、予測子 MVE_BVTD
の平均絶対シャープレイ値が最も大きくなっています。
最初のクエリ点を選択し、そのクエリ点についてのクラス予測を調べます。
queryPoint = explainer.QueryPoints(1,:)
queryPoint=1×8 table
ID WC_TA RE_TA EBIT_TA MVE_BVTD S_TA Industry Rating
_____ _____ _____ _______ ________ _____ ________ ______
48608 0.232 0.335 0.062 1.969 0.281 8 {'A'}
queryPointPrediction = explainer.BlackboxFitted(1)
queryPointPrediction = 1×1 cell array
{'A'}
名前と値の引数 QueryPointIndices
を使用して、クエリ点のシャープレイ値をプロットします。クエリ点の予測クラス (A
) の色と一致するようにバーの色を変更します。
b = plot(explainer,QueryPointIndices=1); b.FaceColor = [0.9290 0.6940 0.1250];
このクエリ点について、予測子 MVE_BVTD
は、クラス A
の予測スコアの平均からの最大偏差を説明しています。
入力引数
explainer
— blackbox モデルを説明するオブジェクト
shapley
オブジェクト
blackbox モデルを説明するオブジェクト。shapley
オブジェクトとして指定します。explainer
にシャープレイ値が格納されていなければならず、つまり explainer.ShapleyValues
が空であってはなりません。
ax
— プロットの座標軸
Axes
オブジェクト
R2023b 以降
プロットの座標軸。Axes
オブジェクトとして指定します。ax
を指定しない場合、plot
は、現在の座標軸を使用してプロットを作成します。Axes
オブジェクトを作成する方法の詳細については、axes
を参照してください。
名前と値の引数
オプションの引数のペアを Name1=Value1,...,NameN=ValueN
として指定します。ここで Name
は引数名、Value
は対応する値です。名前と値の引数は他の引数の後ろにする必要がありますが、ペアの順序は関係ありません。
例: plot(explainer,NumImportantPredictors=5,ClassNames=["AAA","AA","A"])
は、クラス AAA
、AA
、および A
について、上位 5 つの重要な予測子のシャープレイ値または平均絶対シャープレイ値を示す棒グラフを作成します。
NumImportantPredictors
— プロットする重要な予測子の数
min(M,10)
(M
は予測子の数) (既定値) | 正の整数
プロットする重要な予測子の数。正の整数を指定します。関数 plot
は、絶対シャープレイ値 (クエリ点が 1 つの場合) または平均絶対シャープレイ値 (クエリ点が複数の場合) が大きいものから、指定された数の上位の予測子の値をプロットします。
例: NumImportantPredictors=5
は、上位 5 つの重要な予測子をプロットするように指定します。関数 plot
は、絶対シャープレイ値 (クエリ点が 1 つの場合) または平均絶対シャープレイ値 (クエリ点が複数の場合) を使用して重要度の順序を決定します。
データ型: single
| double
ClassNames
— プロットするクラス ラベル
explainer.BlackboxFitted
(クエリ点が 1 つの場合) または explainer.BlackboxModel.ClassNames(1)
(クエリ点が複数の場合) (既定値) | 数値ベクトル | logical ベクトル | 文字配列 | string 配列 | 文字ベクトルの cell 配列 | categorical 配列
プロットするクラス ラベル。数値ベクトル、logical ベクトル、文字配列、string 配列、または文字ベクトルの cell 配列として指定します。ClassNames
値の値とデータ型は、explainer
における機械学習モデルの ClassNames
プロパティ (explainer.BlackboxModel.ClassNames
) のクラス名のものと一致しなければなりません。string 配列、文字ベクトルの cell 配列、および categorical 配列は相互交換可能なものとして受け入れられることに注意してください。
1 つ以上のラベルを指定できます。複数のクラス ラベルを指定すると、関数で色を使用してクラスが区別されます。
ClassNames
の既定値はクエリ点の数によって異なります。
explainer
にクエリ点が 1 つ格納されている場合、そのクエリ点についての予測クラス (explainer
のBlackboxFitted
プロパティ) が既定値になります。explainer
にクエリ点が複数格納されている場合、explainer
の機械学習モデルのClassNames
プロパティに含まれる最初のクラスが既定値になります。
この引数は、explainer
の機械学習モデル (BlackboxModel
) が分類モデルである場合のみ有効です。
例: ClassNames={'red','blue'}
例: ClassNames=explainer.BlackboxModel.ClassNames
は、ClassNames
として BlackboxModel
内のすべてのクラスを指定します。
データ型: single
| double
| logical
| char
| string
| cell
| categorical
QueryPointIndices
— プロットに使用するクエリ点のインデックス
1:N
(N
はクエリ点の数) (既定値) | 正の整数ベクトル
R2024a 以降
プロットに使用するクエリ点のインデックス。正の整数ベクトルとして指定します。
QueryPointIndices
の値がベクトルidx
の場合、関数plot
は、指定されたすべてのクエリ点 (explainer.QueryPoints(idx)
) で平均化した平均絶対シャープレイ値の棒グラフを返します。QueryPointIndices
の値がスカラーの場合、関数plot
は、指定されたクエリ点のシャープレイ値の棒グラフを返します。
この引数は、explainer
にクエリ点が複数格納されている場合のみ有効です。
例: QueryPointIndices=1:100
例: QueryPointIndices=50
データ型: single
| double
詳細
シャープレイ値
ゲーム理論におけるプレーヤーのシャープレイ値とは、協力ゲームでのプレーヤーの平均限界貢献度です。機械学習予測のコンテキストでは、クエリ点の特徴量のシャープレイ値は、指定したクエリ点での予測 (回帰の場合は応答、分類の場合は各クラスのスコア) に対する特徴量の寄与を説明します。
クエリ点の特徴量のシャープレイ値は、平均予測からの偏差に対する特徴量の寄与です。クエリ点について、すべての特徴量に関するシャープレイ値の合計は、予測の平均からの合計偏差に対応します。つまり、平均予測とすべての特徴量に関するシャープレイ値の合計は、クエリ点についての予測に対応します。
詳細については、機械学習モデルのシャープレイ値を参照してください。
参照
[1] Lundberg, Scott M., and S. Lee. "A Unified Approach to Interpreting Model Predictions." Advances in Neural Information Processing Systems 30 (2017): 4765–774.
[2] Aas, Kjersti, Martin Jullum, and Anders Løland. "Explaining Individual Predictions When Features Are Dependent: More Accurate Approximations to Shapley Values." Artificial Intelligence 298 (September 2021).
[3] Lundberg, Scott M., G. Erion, H. Chen, et al. "From Local Explanations to Global Understanding with Explainable AI for Trees." Nature Machine Intelligence 2 (January 2020): 56–67.
バージョン履歴
R2021a で導入R2024a: 複数のクエリ点の平均絶対シャープレイ値のプロット
R2023b: plot
で指定したターゲットの座標軸を使用
オブジェクト関数 plot
でターゲットの座標軸を指定できるようになりました。関数の最初の入力引数として Axes
オブジェクトを指定します。
R2021b: 目盛りラベル インタープリターの既定値が 'none'
シャープレイ値を figure オブジェクト b
で返す場合、関数 plot
は座標軸の TickLabelInterpreter
の値を既定で 'none'
に設定します。つまり、b.CurrentAxes.TickLabelInterpreter
は 'none'
になります。以前のリリースでは、座標軸の TickLabelInterpreter
の値は既定では 'tex'
でした。'none'
値と 'tex'
値の違いの詳細については、TickLabelInterpreter
を参照してください。
MATLAB コマンド
次の MATLAB コマンドに対応するリンクがクリックされました。
コマンドを MATLAB コマンド ウィンドウに入力して実行してください。Web ブラウザーは MATLAB コマンドをサポートしていません。
Select a Web Site
Choose a web site to get translated content where available and see local events and offers. Based on your location, we recommend that you select: .
You can also select a web site from the following list:
How to Get Best Site Performance
Select the China site (in Chinese or English) for best site performance. Other MathWorks country sites are not optimized for visits from your location.
Americas
- América Latina (Español)
- Canada (English)
- United States (English)
Europe
- Belgium (English)
- Denmark (English)
- Deutschland (Deutsch)
- España (Español)
- Finland (English)
- France (Français)
- Ireland (English)
- Italia (Italiano)
- Luxembourg (English)
- Netherlands (English)
- Norway (English)
- Österreich (Deutsch)
- Portugal (English)
- Sweden (English)
- Switzerland
- United Kingdom (English)