Main Content

機械学習モデルのシャープレイ値

このトピックでは、シャープレイ値について定義し、シャープレイ値を計算する Statistics and Machine Learning Toolbox™ の機能で使用できる 2 つのアルゴリズムについて説明して、それぞれの例を示します。また、計算コストを削減する方法を示します。

シャープレイ値とは

ゲーム理論におけるプレーヤーのシャープレイ値とは、協力ゲームでのプレーヤーの平均限界貢献度です。つまり、シャープレイ値は、協力ゲームから得られた合計ゲインを個々のプレーヤーに公平に分配した値です。機械学習予測のコンテキストでは、クエリ点の特徴量のシャープレイ値は、指定したクエリ点での予測 (回帰の場合は応答、分類の場合は各クラスのスコア) に対する特徴量の寄与を説明します。シャープレイ値は、クエリ点についての予測に関して特徴量が原因で生じた平均予測からの偏差に対応します。各クエリ点について、すべての特徴に関するシャープレイ値の合計は、予測の平均からの合計偏差に対応します。

クエリ点 x についての i 番目の特徴量のシャープレイ値は、次の価値関数 v によって定義されます。

φi(vx)=1MS\{i}vx(S{i})vx(S)(M1)!|S|!(M|S|1)!(1)
  • M は、すべての特徴量の数です。

  • は、すべての特徴量のセットです。

  • |S| は、セット S のカーディナリティ、つまりセット S の要素数です。

  • vx(S) は、クエリ点 x についてのセット S 内の特徴量の価値関数です。この関数の値は、クエリ点 x についての予測に対する S 内の特徴量の期待される寄与を示します。

Statistics and Machine Learning Toolbox でのシャープレイ値

shapley オブジェクトを使用して、機械学習モデルのシャープレイ値を計算できます。この値を使用して、クエリ点についての予測に対するモデル内の個々の特徴量の寄与を解釈します。シャープレイ値は 2 つの方法で計算できます。

  • 関数 shapley を使用して、クエリ点を指定して機械学習モデル用の shapley オブジェクトを作成します。そのクエリ点についてモデル内のすべての特徴量のシャープレイ値が計算されます。

  • 関数 shapley を使用して機械学習モデル用の shapley オブジェクトを作成してから、関数 fit を使用して指定したクエリ点のシャープレイ値を計算します。

複数のクエリ点のシャープレイ値を計算できることに注意してください。 (R2024a 以降)

アルゴリズム

shapley には 2 つのタイプのアルゴリズムが用意されています。価値関数に介入型分布を使用する介入型と、価値関数に条件付き分布を使用する条件付きです。使用するアルゴリズムのタイプは、関数 shapley または関数 fit の名前と値の引数 Method を設定して指定できます。

2 つのタイプのアルゴリズムの違いは、価値関数の定義です。どちらのタイプでも、すべての特徴量にわたるクエリ点のシャープレイ値の合計がクエリ点についての予測に関する平均からの合計偏差に対応するように価値関数が定義されます。

i=1Mφi(vx)=f(x)E[f(x)].

そのため、価値関数 vx(S) は、クエリ点 x についての予測 (f) に対する S 内の特徴量の期待される寄与に対応しなければなりません。アルゴリズムでは、指定のデータ (X) から作成された人為的な標本を使用して、期待される寄与を計算します。X は、shapley オブジェクトの作成時に機械学習モデルの入力または個別のデータ入力引数によって指定しなければなりません。人為的な標本では、S 内の特徴量の値はクエリ点から得られます。残りの特徴量 (S の補数である Sc 内の特徴量) については、介入型アルゴリズムが介入型分布を使用して標本を生成するのに対し、条件付きアルゴリズムは条件付き分布を使用して標本を生成します。

介入型アルゴリズム

既定では、shapley は次の介入型アルゴリズムのいずれかを使用します。Kernel SHAP [1]、Linear SHAP [1]、または Tree SHAP [2]です。

shapley で使用可能なすべてのサブセット S を使用すると、正確なシャープレイ値の計算にかかる計算コストが高くなる可能性があります。そのため、shapley では、Kernel SHAP アルゴリズムに使用するサブセットの最大数を制限してシャープレイ値を推定します。詳細については、計算コストを参照してください。

線形モデルと木ベースのモデル向けに、shapley には Linear SHAP アルゴリズムと Tree SHAP アルゴリズムがそれぞれ用意されています。これらのアルゴリズムは、少ない計算コストで正確なシャープレイ値を計算します。これらのアルゴリズムでは、Kernel SHAP アルゴリズムで使用可能なすべてのサブセットを使用した場合に返されるのと同じシャープレイ値が返されます。Linear SHAP アルゴリズム、Tree SHAP アルゴリズム、および Kernel SHAP アルゴリズムは次の点が異なります。

  • Linear SHAP アルゴリズムと Tree SHAP アルゴリズムでは、機械学習モデルの ResponseTransform プロパティ (回帰の場合) と ScoreTransform プロパティ (分類の場合) を無視します。つまり、それぞれのアルゴリズムで、応答変換またはスコア変換を適用せずに、生の応答または生のスコアに基づいてシャープレイ値が計算されます。Kernel SHAP アルゴリズムでは、モデルの ResponseTransform プロパティまたは ScoreTransform プロパティで変換が指定されている場合、変換された値が使用されます。

  • Kernel SHAP アルゴリズムと Tree SHAP アルゴリズムでは、欠損値を含む観測値を使用できます。Linear SHAP アルゴリズムでは、いずれのモデルについても、欠損値を含む観測値は処理できません。

shapley では、機械学習モデルのタイプとその他のオプションの指定に基づいてアルゴリズムが選択されます。

  • 次の線形モデルの場合は Linear SHAP アルゴリズム:

  • 次の木モデルおよび木学習器を含むアンサンブル モデルの場合は Tree SHAP アルゴリズム:

  • その他のすべてのモデル タイプの場合と次の場合は Kernel SHAP アルゴリズム:

    • 木モデルおよび上記の木のアンサンブルで、モデルでの予測に代理分岐 (Surrogate) が使用されていて、入力予測子データの観測値またはクエリ点の値に欠損値がある場合、Tree SHAP の代わりに Kernel SHAP が使用されることがあります。

      R2023b より前: 木モデルおよび木学習器を含むアンサンブル モデルで、入力予測子データの観測値またはクエリ点の値に欠損値がある場合、Tree SHAP の代わりに Kernel SHAP が常に使用されます。

    • shapley または fit の名前と値の引数 MaxNumSubsets (シャープレイ値の計算に使用する予測子サブセットの最大数) を指定した場合、Kernel SHAP が使用されます。

    • 場合によっては、Kernel SHAP の方が Tree SHAP よりも計算コストが低くなることがあります。たとえば、低次元データ向けの深い木を含むモデルの場合、Kernel SHAP の方が効率的なことがあります。効率的なアルゴリズムが経験則的に選択されます。

介入型アルゴリズムでは、クエリ点 x における S 内の特徴量の価値関数を介入型分布 D (Sc 内の特徴量の同時分布) に対する起こり得る予測として定義します。

vx(S)=ED[f(xS,XSc)].

xS は S 内の特徴量のクエリ点値、XSc は Sc 内の特徴量です。

特徴量間の相関性が高くないと仮定して、クエリ点 x における価値関数 vx(S) を評価するには、shapley は、データ X の値を Sc 内の特徴量に関する介入型分布 D の標本として使用します。

vx(S)=ED[f(xS,XSc)]1Nj=1Nf(xS,(XSc)j).

N は観測値の数であり、(XSc)j には j 番目の観測値に関する Sc 内の特徴量の値が格納されています。

たとえば、X に 3 つの特徴量があり、4 つの観測値 (x11,x12,x13)、(x21,x22,x23)、(x31,x32,x33) および (x41,x42,x43) があるとします。S に最初の特徴量が含まれており、Sc に残りの特徴量が含まれていると仮定します。この場合、クエリ点 (x41,x42,x43) で評価された最初の特徴量の価値関数は次のようになります。

vx(S)=14[f(x41,x12,x13)+f(x41,x22,x23)+f(x41,x32,x33)+f(x41,x42,x43)].

介入型アルゴリズムは条件付きアルゴリズムよりも計算コストが低く、順序付きのカテゴリカル予測子をサポートします。ただし、介入型アルゴリズムでは特徴量が独立していると仮定しなければならず、分布外の標本を使用します[4]。クエリ点とデータ X を組み合わせて作成した人為的な標本には、非現実的な観測値が含まれる可能性があります。たとえば、(x41,x12,x13) は、3 つの特徴量の完全な同時分布では出現しない標本である可能性があります。

条件付きアルゴリズム

条件付きアルゴリズムである Kernel SHAP アルゴリズムの拡張機能[3]を使用するには、名前と値の引数 Method"conditional" として指定します。

条件付きアルゴリズムでは、XS にクエリ点値が含まれていることを前提に、XSc の条件付き分布を使用してクエリ点 x における S 内の特徴量の価値関数を定義します。

vx(S)=EXSc|XS=xS[f(xS,XSc)].

クエリ点 x における価値関数 vx(S) を評価するには、shapley は、データ X の観測値の 10% に対応するクエリ点の最近傍を使用します。この方法では、介入型アルゴリズムよりも現実的な標本が使用され、特徴量が独立していると仮定する必要はありません。ただし、条件付きアルゴリズムは計算コストが高く、順序付きのカテゴリカル予測子をサポートしないほか、連続的特徴量の NaN を処理できません。また、このアルゴリズムでは、予測に寄与しないダミー特徴量が重要な特徴量と相関している場合に、そのダミー特徴量に非ゼロのシャープレイ値が割り当てられる可能性があります[4]

計算アルゴリズムの指定

この例では、線形分類モデルに学習させ、介入型アルゴリズム (Method="interventional") と条件付きアルゴリズム (Method="conditional") を順番に使用してシャープレイ値を計算します。

線形分類モデルの学習

ionosphere データ セットを読み込みます。このデータ セットには、レーダー反射についての 34 個の予測子と、不良 ('b') または良好 ('g') という 351 個の二項反応が含まれています。

load ionosphere

線形分類モデルに学習させます。線形係数の精度を向上させるため、目的関数の最小化手法 (名前と値の引数 Solver) としてメモリ制限 Broyden-Fletcher-Goldfarb-Shanno 準ニュートン アルゴリズム ("lbfgs") を指定します。

Mdl = fitclinear(X,Y,Solver="lbfgs")
Mdl = 
  ClassificationLinear
      ResponseName: 'Y'
        ClassNames: {'b'  'g'}
    ScoreTransform: 'none'
              Beta: [34x1 double]
              Bias: -3.7100
            Lambda: 0.0028
           Learner: 'svm'


介入型アルゴリズムを使用したシャープレイ値の計算

介入型アルゴリズムである Linear SHAP アルゴリズムを使用して、最初の観測値のシャープレイ値を計算します。"interventional" が既定であるため、名前と値の引数 Method の値を指定する必要はありません。

R2024a より前: QueryPoints の代わりに、名前と値の引数 QueryPoint を使用してクエリ点を指定します。

queryPoint = X(1,:);
explainer1 = shapley(Mdl,X,QueryPoints=queryPoint);

分類モデルの場合、shapley は各クラスの予測クラス スコアを使用してシャープレイ値を計算します。関数 plot を使用して予測クラスのシャープレイ値をプロットします。

plot(explainer1)

横棒グラフは、絶対値で並べ替えられた、上位 10 個の重要な変数のシャープレイ値を示します。各値は、クエリ点についてのスコアに関して対応する変数が原因で生じた予測クラスの平均スコアからの偏差を説明します。

線形モデルの場合、shapley は特徴量が相互に独立していると仮定し、推定係数 (Mdl.Beta) からシャープレイ値を計算します[1]。陽性クラス (Mdl.ClassNames の 2 番目のクラス 'g') のシャープレイ値を推定係数から直接計算します。

linearSHAPValues = (Mdl.Beta'.*(queryPoint-mean(X)))';

shapley から計算されたシャープレイ値と係数から計算された値を格納する table を作成します。

t = table(explainer1.ShapleyValues.Predictor, ...
    explainer1.ShapleyValues.g,linearSHAPValues, ...
    VariableNames=["Predictor","Values from shapley", ...
    "Values from Coefficients"])
t=34×3 table
    Predictor    Values from shapley    Values from Coefficients
    _________    ___________________    ________________________

      "x1"              0.28789                  0.28789        
      "x2"                    0                        0        
      "x3"              0.20822                  0.20822        
      "x4"             -0.01998                 -0.01998        
      "x5"              0.20872                  0.20872        
      "x6"            -0.076991                -0.076991        
      "x7"              0.19188                  0.19188        
      "x8"             -0.64386                 -0.64386        
      "x9"              0.42348                  0.42348        
      "x10"           -0.030049                -0.030049        
      "x11"            -0.23132                 -0.23132        
      "x12"              0.1422                   0.1422        
      "x13"           -0.045973                -0.045973        
      "x14"            -0.29022                 -0.29022        
      "x15"             0.21051                  0.21051        
      "x16"             0.13382                  0.13382        
      ⋮

条件付きアルゴリズムを使用したシャープレイ値の計算

条件付きアルゴリズムである Kernel SHAP アルゴリズムの拡張機能を使用して、最初の観測値のシャープレイ値を計算します。

explainer2 = shapley(Mdl,X,QueryPoints=queryPoint, ...
    Method="conditional");

シャープレイ値をプロットします。

plot(explainer2)

2 つのアルゴリズムでは、10 個の最も重要な変数には異なるセットが特定されます。2 つの変数 x8 および x22 のみが両方のセットに含まれています。

計算コスト

観測値または特徴量の数が多い場合、クエリ点のシャープレイ値の計算コストが高くなります。

観測値の数が多い場合

観測値の数が多い場合 (1000 を超える場合など)、価値関数 (v) の計算コストが高くなる可能性があります。計算速度を向上するには、shapley オブジェクトを作成するときに観測値の標本を小さくするか、関数 shapley または fit を使用して値を計算するときに UseParalleltrue を指定して並列実行してください。UseParallel オプションは、関数で木のアンサンブル用の Tree SHAP アルゴリズム、Kernel SHAP アルゴリズム、または Kernel SHAP アルゴリズムの拡張機能を使用して 1 つのクエリ点のシャープレイ値を計算する場合に使用できます。(関数で複数のクエリ点のシャープレイ値を計算する場合にも UseParallel オプションを使用できます。)並列計算には Parallel Computing Toolbox™ が必要です。

特徴量の数が多い場合

Kernel SHAP アルゴリズムまたは Kernel SHAP アルゴリズムの拡張機能では、M (特徴量の数) が多い場合、使用可能なすべてのサブセット S の式 1における被加数の計算コストが高くなる可能性があります。考慮するサブセットの総数は 2M です。すべてのサブセットの被加数を計算する代わりに、名前と値の引数 MaxNumSubsets を使用してサブセットの最大数を指定できます。shapley は、重み値に基づいて使用するサブセットを選択します。サブセットの重みは 1/(被加数の分母) に比例し、これは二項係数分の 1、つまり 1/(M1|S|) に対応します。そのため、カーディナリティの値が高いまたは低いサブセットは重み値が大きくなります。shapley には、まず重みが最も大きいサブセットが含まれ、以降はその他のサブセットが重み値に基づいて降順で含まれます。

計算コストの削減

この例では、観測値と特徴量の両方の数が多い場合にシャープレイ値の計算コストを削減する方法を示します。

標本データ セット NYCHousing2015 を読み込みます。

load NYCHousing2015

データ セットには、2015 年のニューヨーク市における不動産の売上に関する情報を持つ 10 の変数の観測値が 91,446 個含まれます。この例では、これらの変数を使用して売価 (SALEPRICE) を解析します。

データ セットを前処理します。datetime 配列 (SALEDATE) を月番号に変換します。

NYCHousing2015.SALEDATE = month(NYCHousing2015.SALEDATE);

ニューラル ネットワーク回帰モデルに学習させます。

Mdl = fitrnet(NYCHousing2015,"SALEPRICE",Standardize=true);

最初の観測値についてすべての予測子変数のシャープレイ値を計算します。tictoc を使用して、計算に必要な時間を測定します。

R2024a より前: QueryPoints の代わりに、名前と値の引数 QueryPoint を使用してクエリ点を指定します。

tic
explainer1 = shapley(Mdl,QueryPoints=NYCHousing2015(1,:));
Warning: Computation can be slow because the predictor data has over 1000 observations. Use a smaller sample of the training set or specify 'UseParallel' as true for faster computation.
toc
Elapsed time is 207.796421 seconds.

警告メッセージが示すように、予測子データに 1000 個を超える観測値があるため、計算が遅くなる場合があります。

shapley には、観測値または特徴量の数が多い場合に計算コストを削減するためのオプションがいくつか用意されています。

  • 観測値の数が多い場合 — 学習データの標本を小さくし、UseParalleltrue を指定して並列実行します。

  • 特徴量の数が多い場合 — 名前と値の引数 MaxNumSubsets を指定して、計算に含めるサブセットの数を制限します。

並列プールを起動します。

parpool;
Starting parallel pool (parpool) using the 'Processes' profile ...
16-Nov-2023 12:24:28: Job Queued. Waiting for parallel pool job with ID 2 to start ...
16-Nov-2023 12:25:28: Job Queued. Waiting for parallel pool job with ID 2 to start ...
Connected to parallel pool with 6 workers.

学習データの標本を小さくし、並列計算オプションを使用して、シャープレイ値を再度計算します。また、サブセットの最大数を 2^5 に指定します。

NumSamples = 5e2;
Tbl = datasample(NYCHousing2015,NumSamples,Replace=false);
tic
explainer2 = shapley(Mdl,Tbl,QueryPoints=NYCHousing2015(1,:), ...
    UseParallel=true,MaxNumSubsets=2^5);
toc
Elapsed time is 1.252758 seconds.

追加オプションを指定することで、計算時間が短縮されています。

参照

[1] Lundberg, Scott M., and S. Lee. "A Unified Approach to Interpreting Model Predictions." Advances in Neural Information Processing Systems 30 (2017): 4765–774.

[2] Lundberg, Scott M., G. Erion, H. Chen, et al. "From Local Explanations to Global Understanding with Explainable AI for Trees." Nature Machine Intelligence 2 (January 2020): 56–67.

[3] Aas, Kjersti, Martin Jullum, and Anders Løland. "Explaining Individual Predictions When Features Are Dependent: More Accurate Approximations to Shapley Values." Artificial Intelligence 298 (September 2021).

[4] Kumar, I. Elizabeth, Suresh Venkatasubramanian, Carlos Scheidegger, and Sorelle Friedler. "Problems with Shapley-Value-Based Explanations as Feature Importance Measures." Proceedings of the 37th International Conference on Machine Learning 119 (July 2020): 5491–500.

参考

| |

関連するトピック