Main Content

このページの翻訳は最新ではありません。ここをクリックして、英語の最新版を参照してください。

fit

クエリ点のシャープレイ値の計算

R2021a 以降

    説明

    newExplainer = fit(explainer,queryPoint) は、指定されたクエリ点 (queryPoint) のシャープレイ値を計算し、計算したシャープレイ値を newExplainerShapleyValues プロパティに格納します。shapley オブジェクト explainer には、機械学習モデルとシャープレイ値の計算オプションが格納されています。

    fit は、いつ explainer を作成するかを指定する、シャープレイ値計算オプションを使用します。このオプションは、関数 fit の名前と値の引数を使用して変更できます。この関数は、新しく計算したシャープレイ値を含む shapley オブジェクト newExplainer を返します。

    newExplainer = fit(explainer,queryPoint,Name,Value) では、1 つ以上の名前と値の引数を使用して追加オプションを指定します。たとえば、'UseParallel',true と指定してシャープレイ値を並列計算します。

    すべて折りたたむ

    回帰モデルの学習を行い、shapley オブジェクトを作成します。shapley オブジェクトを作成するときに、クエリ点を指定しなかった場合、シャープレイ値は計算されません。オブジェクト関数 fit を使用して、指定したクエリ点のシャープレイ値を計算します。次に、オブジェクト関数 plot を使用して、シャープレイ値の棒グラフを作成します。

    carbig データセットを読み込みます。このデータセットには、1970 年代と 1980 年代初期に製造された自動車の測定値が格納されています。

    load carbig

    AccelerationCylinders などの予測子変数と応答変数 MPG が格納された table を作成します。

    tbl = table(Acceleration,Cylinders,Displacement,Horsepower,Model_Year,Weight,MPG);

    学習セットの欠損値を削除すると、メモリ消費量を減らして関数 fitrkernel の学習速度を向上させることができます。tbl の欠損値を削除します。

    tbl = rmmissing(tbl);

    関数fitrkernelを使用して MPG の blackbox モデルの学習を行います。

    rng('default') % For reproducibility
    mdl = fitrkernel(tbl,'MPG','CategoricalPredictors',[2 5]);

    shapley オブジェクトを作成します。mdl には学習データが含まれないため、データ セット tbl を指定します。

    explainer = shapley(mdl,tbl)
    explainer = 
      shapley with properties:
    
                BlackboxModel: [1x1 RegressionKernel]
                   QueryPoint: []
               BlackboxFitted: []
                ShapleyValues: []
                   NumSubsets: 64
                            X: [392x7 table]
        CategoricalPredictors: [2 5]
                       Method: 'interventional-kernel'
                    Intercept: 22.6202
    
    

    explainer は、学習データ tblX プロパティに格納します。

    tbl の最初の観測値についてすべての予測子変数のシャープレイ値を計算します。

    queryPoint = tbl(1,:)
    queryPoint=1×7 table
        Acceleration    Cylinders    Displacement    Horsepower    Model_Year    Weight    MPG
        ____________    _________    ____________    __________    __________    ______    ___
    
             12             8            307            130            70         3504     18 
    
    
    explainer = fit(explainer,queryPoint);

    回帰モデルの場合、shapley は予測応答を使用してシャープレイ値を計算し、ShapleyValues プロパティに格納します。ShapleyValues プロパティの値を表示します。

    explainer.ShapleyValues
    ans=6×2 table
          Predictor       ShapleyValue
        ______________    ____________
    
        "Acceleration"       -0.1561  
        "Cylinders"         -0.18306  
        "Displacement"      -0.34203  
        "Horsepower"        -0.27291  
        "Model_Year"         -0.2926  
        "Weight"            -0.32402  
    
    

    関数 plot を使用して、クエリ点のシャープレイ値をプロットします。

    plot(explainer)

    Figure contains an axes object. The axes object with title Shapley Explanation contains an object of type bar.

    横棒グラフは、絶対値で並べ替えられた、すべての変数のシャープレイ値を示します。各シャープレイ値は、クエリ点についての予測に関して対応する変数が原因で生じた平均からの偏差を説明します。

    分類モデルの学習を行い、shapley オブジェクトを作成します。次に、複数のクエリ点のシャープレイ値を計算します。

    CreditRating_Historical データセットを読み込みます。データ セットには、顧客 ID、顧客の財務比率、業種ラベル、および信用格付けが格納されています。

    tbl = readtable('CreditRating_Historical.dat');

    関数fitcecocを使用して、信用格付けの blackbox モデルに学習させます。tbl 内の 2 ~ 7 列目の変数を予測子変数として使用します。

    blackbox = fitcecoc(tbl,'Rating', ...
        'PredictorNames',tbl.Properties.VariableNames(2:7), ...
        'CategoricalPredictors','Industry');

    blackbox モデルを使用して、shapley オブジェクトを作成します。計算速度を向上するには、tbl の観測値の 25% を階層的にサブサンプリングし、その標本を使用してシャープレイ値を計算します。kernelSHAP アルゴリズムの拡張機能の使用を指定します。

    rng('default') % For reproducibility
    c = cvpartition(tbl.Rating,'Holdout',0.25);
    tbl_s = tbl(test(c),:);
    explainer = shapley(blackbox,tbl_s,'Method','conditional-kernel');

    真の信用格付け値がそれぞれ AAA および B となる 2 つのクエリ点を見つけます。

    queryPoint(1,:) = tbl_s(find(strcmp(tbl_s.Rating,'AAA'),1),:);
    queryPoint(2,:) = tbl_s(find(strcmp(tbl_s.Rating,'B'),1),:)
    queryPoint=2×8 table
         ID      WC_TA     RE_TA     EBIT_TA    MVE_BVTD    S_TA     Industry    Rating 
        _____    ______    ______    _______    ________    _____    ________    _______
    
        58258     0.511     0.869     0.106      8.538      0.732       2        {'AAA'}
        82367    -0.078    -0.042     0.011      0.262      0.167       7        {'B'  }
    
    

    最初のクエリ点についてのシャープレイ値を計算してプロットします。

    explainer1 = fit(explainer,queryPoint(1,:));
    plot(explainer1)

    Figure contains an axes object. The axes object with title Shapley Explanation contains an object of type bar.

    2 番目のクエリ点についてのシャープレイ値を計算してプロットします。

    explainer2 = fit(explainer,queryPoint(2,:));
    plot(explainer2)

    Figure contains an axes object. The axes object with title Shapley Explanation contains an object of type bar.

    2 番目のクエリ点に関する実際の格付けは B ですが、予測された格付けは BB です。プロットには、予測された格付けのシャープレイ値が表示されます。

    explainer1explainer2 にはそれぞれ、最初のクエリ点と 2 番目のクエリ点についてのシャープレイ値が含まれています。

    入力引数

    すべて折りたたむ

    blackbox モデルを説明するオブジェクト。shapley オブジェクトとして指定します。

    fit が予測を説明するクエリ点。数値の行ベクトルまたは単一行 table として指定します。

    • 数値の行ベクトルの場合:

      • queryPoint の列を構成する変数の順序は、explainer の予測子データ X の順序と同じでなければなりません。

      • 予測子データ explainer.X が table の場合、table に含まれている変数がすべて数値変数であれば、queryPoint を数値ベクトルにすることができます。

    • 単一行 table の場合:

      • 予測子データ explainer.X が table の場合、queryPoint 内のすべての予測子変数は変数名およびデータ型が explainer.X 内の変数と同じでなければなりません。ただし、queryPoint の列の順序が explainer.X の列の順序に対応する必要はありません。

      • 予測子データ explainer.X が数値行列の場合、explainer.BlackboxModel.PredictorNames 内の予測子名と queryPoint 内の対応する予測子変数名が同じでなければなりません。学習時に予測子の名前を指定するには、名前と値の引数 'PredictorNames' を使用します。queryPoint 内の予測子変数はすべて数値ベクトルでなければなりません。

      • queryPoint に追加の変数 (応答変数や観測値の重みなど) を含めることができますが、fit はこれらを無視します。

      • fit は、文字ベクトルの cell 配列ではない cell 配列や複数列の変数をサポートしません。

    連続予測子について queryPointNaN が含まれており、'Method''conditional-kernel' である場合、返されたオブジェクトに含まれているシャープレイ値 (ShapleyValues) は NaN になります。それ以外の場合、fitNaN 値を explainer.BlackboxModel (explainer.BlackboxModel のオブジェクト関数 predict または blackbox で指定された関数ハンドル) と同じ方法で処理します。

    例: explainer.X(1,:) は、explainer の予測子データ X の最初の観測値としてクエリ点を指定します。

    データ型: single | double | table

    名前と値の引数

    オプションの引数のペアを Name1=Value1,...,NameN=ValueN として指定します。ここで Name は引数名、Value は対応する値です。名前と値の引数は他の引数の後ろにする必要がありますが、ペアの順序は関係ありません。

    R2021a より前では、名前と値をそれぞれコンマを使って区切り、Name を引用符で囲みます。

    例: fit(explainer,q,'Method','conditional-kernel','UseParallel',true) は、kernelSHAP アルゴリズムの拡張機能を使用してクエリ点 q のシャープレイ値を計算し、その計算を並列実行します。

    シャープレイ値の計算に使用する予測子サブセットの最大数。正の整数を指定します。

    fit が使用するサブセットを選択する方法の詳細については、計算コストを参照してください。

    例: 'MaxNumSubsets',100

    データ型: single | double

    シャープレイ値の計算アルゴリズム。'interventional-kernel' または 'conditional-kernel' を指定します。

    • 'interventional-kernel'fit は、介入型の価値関数を伴う kernelSHAP アルゴリズム[1]を使用します。

    • 'conditional-kernel'fit は、条件付きの価値関数を伴う kernelSHAP アルゴリズムの拡張機能[2]を使用します。

    これらのアルゴリズムの詳細については、アルゴリズムを参照してください。

    例: 'Method','conditional-kernel'

    データ型: char | string

    並列実行のフラグ。true または false として指定します。'UseParallel',true を指定した場合、関数 fitparfor を使用して for ループの反復を並列実行します。このオプションには Parallel Computing Toolbox™ が必要です。

    例: 'UseParallel',true

    データ型: logical

    出力引数

    すべて折りたたむ

    blackbox モデルを説明するオブジェクト。shapley オブジェクトとして返されます。オブジェクトの ShapleyValues プロパティには、計算されたシャープレイ値が含まれています。

    入力引数 explainer を上書きするには、fit の出力を explainer に代入します。

    explainer = fit(explainer,queryPoint);

    詳細

    すべて折りたたむ

    シャープレイ値

    ゲーム理論におけるプレーヤーのシャープレイ値とは、協力ゲームでのプレーヤーの平均限界貢献度です。機械学習予測のコンテキストでは、クエリ点の特徴量のシャープレイ値は、指定したクエリ点での予測 (回帰の場合は応答、分類の場合は各クラスのスコア) に対する特徴量の寄与を説明します。

    クエリ点の特徴量のシャープレイ値は、平均予測からの偏差に対する特徴量の寄与です。クエリ点について、すべての特徴量に関するシャープレイ値の合計は、予測の平均からの合計偏差に対応します。つまり、平均予測とすべての特徴量に関するシャープレイ値の合計は、クエリ点についての予測に対応します。

    詳細については、機械学習モデルのシャープレイ値を参照してください。

    参照

    [1] Lundberg, Scott M., and S. Lee. "A Unified Approach to Interpreting Model Predictions." Advances in Neural Information Processing Systems 30 (2017): 4765–774.

    [2] Aas, Kjersti, Martin. Jullum, and Anders Løland. "Explaining Individual Predictions When Features Are Dependent: More Accurate Approximations to Shapley Values." arXiv:1903.10464 (2019).

    拡張機能

    バージョン履歴

    R2021a で導入