Main Content

このページの翻訳は最新ではありません。ここをクリックして、英語の最新版を参照してください。

機械学習モデルのシャープレイ値

shapley オブジェクトを使用して、機械学習モデルのシャープレイ値を計算できます。この値を使用して、クエリ点についての予測に対するモデル内の個々の特徴量の寄与を解釈します。シャープレイ値を計算するには、次の 2 つの方法があります。

  • 関数 shapley を使用して、クエリ点を指定して機械学習モデル用の shapley オブジェクトを作成します。そのクエリ点についてモデル内のすべての特徴量のシャープレイ値が計算されます。

  • 関数 shapley を使用して機械学習モデル用の shapley オブジェクトを作成してから、関数 fit を使用して指定したクエリ点のシャープレイ値を計算します。

このトピックでは、シャープレイ値について定義し、シャープレイ値の計算に使用できる 2 つのアルゴリズムについて説明して、それぞれの例を示します。また、シャープレイ値の計算コストを削減する方法を示します。

シャープレイ値とは

ゲーム理論におけるプレイヤーのシャープレイ値とは、協力ゲームでのプレイヤーの平均限界貢献度です。つまり、シャープレイ値は、協力ゲームから得られた合計ゲインを個々のプレイヤーに公平に分配した値です。機械学習予測のコンテキストでは、クエリ点の特徴量のシャープレイ値は、指定したクエリ点での予測 (回帰の場合は応答、分類の場合は各クラスのスコア) に対する特徴量の寄与を説明します。シャープレイ値は、クエリ点についての予測に関して特徴量が原因で生じた平均予測からの偏差に対応します。各クエリ点について、すべての特徴に関するシャープレイ値の合計は、予測の平均からの合計偏差に対応します。

クエリ点 x についての i 番目の特徴量のシャープレイ値は、次の価値関数 v によって定義されます。

φi(vx)=1MS\{i}vx(S{i})vx(S)(M1)!|S|!(M|S|1)!(1)
  • M は、すべての特徴量の数です。

  • は、すべての特徴量のセットです。

  • |S| は、セット S のカーディナリティ、つまりセット S の要素数です。

  • vx(S) は、クエリ点 x についてのセット S 内の特徴量の価値関数です。この関数の値は、クエリ点 x についての予測に対する S 内の特徴量の期待される寄与を示します。

シャープレイ値の計算アルゴリズム

shapley には 2 つのアルゴリズムが用意されています。価値関数に介入型分布を使用する kernelSHAP [1]と、価値関数に条件付き分布を使用する kernelSHAP の拡張機能[2]です。使用するアルゴリズムは、関数 shapley または関数 fit の名前と値の引数 'Method' を設定して指定できます。

2 つのアルゴリズムの違いは、価値関数の定義です。どちらのアルゴリズムでも、すべての特徴量にわたるクエリ点のシャープレイ値の合計がクエリ点についての予測に関する平均からの合計偏差に対応するように価値関数が定義されます。

i=1Mφi(vx)=f(x)E[f(x)].

そのため、価値関数 vx(S) は、クエリ点 x についての予測 (f) に対する S 内の特徴量の期待される寄与に対応しなければなりません。2 つのアルゴリズムでは、指定のデータ (X) から作成された人為的な標本を使用して、期待される寄与を計算します。X は、shapley オブジェクトの作成時に機械学習モデルの入力または個別のデータ入力引数によって指定しなければなりません。人為的な標本では、S 内の特徴量の値はクエリ点から得られます。残りの特徴量 (S の補数である Sc 内の特徴量) については、kernelSHAP アルゴリズムが介入型分布を使用して標本を生成するのに対し、kernelSHAP アルゴリズムの拡張機能は条件付き分布を使用して標本を生成します。

KernelSHAP ('Method','interventional-kernel')

shapley は、既定では kernelSHAP アルゴリズムを使用します。

kernelSHAP アルゴリズムでは、クエリ点 x における S 内の特徴量の価値関数を介入型分布 D (Sc 内の特徴量の同時分布) に対する起こり得る予測として定義します。

vx(S)=ED[f(xS,XSc)],

ここで、xS は S 内の特徴量のクエリ点値、XSc は Sc 内の特徴量です。

特徴量間の相関性が高くないと仮定して、クエリ点 x における価値関数 vx(S) を評価するには、shapley は、データ X の値を Sc 内の特徴量に関する介入型分布 D の標本として使用します。

vx(S)=ED[f(xS,XSc)]1Nj=1Nf(xS,(XSc)j),

ここで、N は観測値の数であり、(XSc)j には j 番目の観測値に関する Sc 内の特徴量の値が格納されています。

たとえば、X に 3 つの特徴量があり、4 つの観測値 (x11,x12,x13)、(x21,x22,x23)、(x31,x32,x33) および (x41,x42,x43) があるとします。S に最初の特徴量が含まれており、Sc に残りの特徴量が含まれていると仮定します。この場合、クエリ点 (x41,x42,x43) で評価された最初の特徴量の価値関数は次のようになります。

vx(S)=14[f(x41,x12,x13)+f(x41,x22,x23)+f(x41,x32,x33)+f(x41,x42,x43)].

kernelSHAP アルゴリズムは kernelSHAP アルゴリズムの拡張機能よりも計算コストが低く、順序付きのカテゴリカル予測子をサポートするほか、X の欠損値を処理できます。ただし、このアルゴリズムでは特徴量が独立していると仮定しなければならず、分布外の標本を使用します[3]。クエリ点とデータ X を組み合わせて作成した人為的な標本には、非現実的な観測値が含まれる可能性があります。たとえば、(x41,x12,x13) は、3 つの特徴量の完全な同時分布では出現しない標本である可能性があります。

KernelSHAP の拡張機能 ('Method','conditional-kernel')

kernelSHAP アルゴリズムの拡張機能を使用するには、'Method','conditional-kernel' を指定します。

kernelSHAP アルゴリズムの拡張機能では、XS にクエリ点値が含まれていることを前提に、XSc の条件付き分布を使用してクエリ点 x における S 内の特徴量の価値関数を定義します。

vx(S)=EXSc|XS=xS[f(xS,XSc)].

クエリ点 x における価値関数 vx(S) を評価するには、shapley は、データ X の観測値の 10% に対応するクエリ点の最近傍を使用します。この方法では、kernelSHAP アルゴリズムよりも現実的な標本が使用され、特徴量が独立していると仮定する必要はありません。ただし、このアルゴリズムは計算コストが高く、順序付きのカテゴリカル予測子をサポートしないほか、連続的特徴量の NaN を処理できません。また、このアルゴリズムでは、予測に寄与しないダミー特徴量が重要な特徴量と相関している場合に、そのダミー特徴量に非ゼロのシャープレイ値が割り当てられる可能性があります[3]

シャープレイ値の計算アルゴリズムの指定

この例では、線形分類モデルに学習をさせ、kernelSHAP アルゴリズム ('Method','interventional-kernel') と kernelSHAP アルゴリズムの拡張機能 ('Method','conditional-kernel') の両方を使用してシャープレイ値を計算します。

線形分類モデルの学習

ionosphere データセットを読み込みます。このデータセットには、レーダー反射についての 34 個の予測子と、不良 ('b') または良好 ('g') という 351 個の二項反応が含まれています。

load ionosphere

線形分類モデルに学習させます。線形係数の精度を向上させるため、目的関数の最小化手法 (名前と値の引数 'Solver') としてメモリ制限 Broyden-Fletcher-Goldfarb-Shanno 準ニュートン アルゴリズム ('lbfgs') を指定します。

Mdl = fitclinear(X,Y,'Solver','lbfgs')
Mdl = 
  ClassificationLinear
      ResponseName: 'Y'
        ClassNames: {'b'  'g'}
    ScoreTransform: 'none'
              Beta: [34x1 double]
              Bias: -3.7100
            Lambda: 0.0028
           Learner: 'svm'


  Properties, Methods

介入型分布を使用したシャープレイ値

価値関数の評価に介入型分布を使用する kernelSHAP アルゴリズムを使用して、最初の観測値のシャープレイ値を計算します。'interventional-kernel' が既定であるため、'Method' の値を指定する必要はありません。

queryPoint = X(1,:);
explainer1 = shapley(Mdl,X,'QueryPoint',queryPoint); 

分類モデルの場合、shapley は各クラスの予測クラス スコアを使用してシャープレイ値を計算します。関数 plot を使用して予測クラスのシャープレイ値をプロットします。

plot(explainer1)

Figure contains an axes. The axes contains an object of type bar. This object represents g.

横棒グラフは、絶対値で並べ替えられた、上位 10 個の重要な変数のシャープレイ値を示します。各シャープレイ値は、クエリ点についてのスコアに関して対応する変数が原因で生じた予測クラスの平均スコアからの偏差を説明します。

特徴量が相互に独立していると仮定される線形モデルの場合、推定係数 (Mdl.Beta) から陽性クラス (Mdl.ClassNames の 2 番目のクラス 'g') の介入型のシャープレイ値を計算できます[1]

linearSHAPValues = (Mdl.Beta'.*(queryPoint-mean(X)))';

kernelSHAP アルゴリズムから計算されたシャープレイ値と係数から計算された値を格納する table を作成します。

t = table(explainer1.ShapleyValues.Predictor,explainer1.ShapleyValues.g,linearSHAPValues, ...
    'VariableNames',{'Predictor','KernelSHAP Value','LinearSHAP Value'})
t=34×3 table
    Predictor    KernelSHAP Value    LinearSHAP Value
    _________    ________________    ________________

      "x1"             0.28789            0.28789    
      "x2"         -2.6619e-15                  0    
      "x3"             0.20822            0.20822    
      "x4"            -0.01998           -0.01998    
      "x5"             0.20872            0.20872    
      "x6"           -0.076991          -0.076991    
      "x7"             0.19188            0.19188    
      "x8"            -0.64386           -0.64386    
      "x9"             0.42348            0.42348    
      "x10"          -0.030049          -0.030049    
      "x11"           -0.23132           -0.23132    
      "x12"             0.1422             0.1422    
      "x13"          -0.045973          -0.045973    
      "x14"           -0.29022           -0.29022    
      "x15"            0.21051            0.21051    
      "x16"            0.13382            0.13382    
      ⋮

条件付き分布を使用したシャープレイ値

価値関数の評価に条件付き分布を使用する kernelSHAP アルゴリズムの拡張機能を使用して、最初の観測値のシャープレイ値を計算します。

explainer2 = shapley(Mdl,X,'QueryPoint',queryPoint,'Method','conditional-kernel');

シャープレイ値をプロットします。

plot(explainer2)

Figure contains an axes. The axes contains an object of type bar. This object represents g.

2 つのアルゴリズムでは、10 個の最も重要な変数には異なるセットが特定されます。2 つの変数 x8 および x22 のみが両方のセットに共通しています。

シャープレイ値の計算の複雑度

観測値または特徴量の数が多い場合、シャープレイ値の計算コストは高くなります。

観測値の数が多い場合

観測値の数が多い場合 (1000 を超える場合など)、価値関数 (v) の計算コストが高くなる可能性があります。計算速度を向上するには、shapley オブジェクトを作成するときに観測値の標本を小さくするか、関数 shapley または fit を使用して値を計算するときに 'UseParallel'true を指定してシャープレイ値を並列計算してください。並列計算には Parallel Computing Toolbox™ が必要です。

特徴量の数が多い場合

M (特徴量の数) が多い場合、使用可能なすべてのサブセット S の式 1における被加数の計算コストが高くなる可能性があります。考慮するサブセットの総数は 2M です。すべてのサブセットの被加数を計算する代わりに、名前と値の引数 'MaxNumSubsets' を使用して、シャープレイ値の計算に使用するサブセットの最大数を指定できます。shapley は、重み値に基づいて使用するサブセットを選択します。サブセットの重みは 1/(被加数の分母) に比例し、これは二項係数分の 1、つまり 1/(M1|S|) に対応します。そのため、カーディナリティの値が高いまたは低いサブセットは重み値が大きくなります。shapley には、まず重みが最も大きいサブセットが含まれ、以降はその他のサブセットが重み値に基づいて降順で含まれます。

計算コストの削減

この例では、観測値と特徴量の両方の数が多い場合にシャープレイ値の計算コストを削減する方法を示します。

アンサンブル回帰に学習をさせる

標本データセット NYCHousing2015 を読み込みます。

load NYCHousing2015

データ セットには、2015 年のニューヨーク市における不動産の売上に関する情報を持つ 10 の変数の観測値が 55,246 個含まれます。この例では、これらの変数を使用して売価 (SALEPRICE) を解析します。

データ セットを前処理します。datetime 配列 (SALEDATE) を月番号に変換します。

NYCHousing2015.SALEDATE = month(NYCHousing2015.SALEDATE);

アンサンブル回帰に学習をさせます。

Mdl = fitrensemble(NYCHousing2015,'SALEPRICE');

既定のオプションを使用したシャープレイ値の計算

最初の観測値についてすべての予測子変数のシャープレイ値を計算します。tictoc を使用して、シャープレイ値の計算に必要な時間を測定します。

tic
explainer1 = shapley(Mdl,'QueryPoint',NYCHousing2015(1,:));
Warning: Computation can be slow because the predictor data has over 1000 observations. Use a smaller sample of the training set or specify 'UseParallel' as true for faster computation.
toc
Elapsed time is 492.365307 seconds.

警告メッセージが示すように、予測子データに 1000 個を超える観測値があるため、計算が遅くなる場合があります。

計算コストを削減するためのオプションの指定

shapley には、観測値または特徴量の数が多い場合に計算コストを削減するためのオプションがいくつか用意されています。

  • 観測値の数が多い場合 — 学習データの標本を小さくし、'UseParallel'true を指定してシャープレイ値を並列計算します。

  • 特徴量の数が多い場合 — 名前と値の引数 'MaxNumSubsets' を指定して、計算に含めるサブセットの数を制限します。

学習データの標本を小さくし、並列計算オプションを使用して、シャープレイ値を再度計算します。また、サブセットの最大数を 2^5 に指定します。

NumSamples = 5e2;
Tbl = datasample(NYCHousing2015,NumSamples,'Replace',false);
tic
explainer2 = shapley(Mdl,Tbl,'QueryPoint',NYCHousing2015(1,:), ...
    'UseParallel',true,'MaxNumSubsets',2^5);
Starting parallel pool (parpool) using the 'local' profile ...
Connected to the parallel pool (number of workers: 6).
toc
Elapsed time is 52.183287 seconds.

追加オプションを指定することで、シャープレイ値の計算に必要な時間が短縮されています。

参照

[1] Lundberg, Scott M., and S. Lee. "A Unified Approach to Interpreting Model Predictions." Advances in Neural Information Processing Systems 30 (2017): 4765–774.

[2] Aas, Kjersti, Martin. Jullum, and Anders Løland. "Explaining Individual Predictions When Features Are Dependent: More Accurate Approximations to Shapley Values." arXiv:1903.10464 (2019).

[3] Kumar, I. Elizabeth, Suresh Venkatasubramanian, Carlos Scheidegger, and Sorelle Friedler. "Problems with Shapley-Value-Based Explanations as Feature Importance Measures." arXiv:2002.11097 (2020).

参考

| |