Main Content

このページの内容は最新ではありません。最新版の英語を参照するには、ここをクリックします。

fit

クエリ点のシャープレイ値の計算

R2021a 以降

    説明

    newExplainer = fit(explainer,queryPoints) は、指定されたクエリ点 (queryPoints) のシャープレイ値を計算し、計算したシャープレイ値を newExplainerShapleyValues プロパティに格納します。shapley オブジェクト explainer には、機械学習モデルとシャープレイ値の計算オプションが格納されています。

    fit は、いつ explainer を作成するかを指定する、シャープレイ値計算オプションを使用します。このオプションは、関数 fit の名前と値の引数を使用して変更できます。この関数は、新しく計算したシャープレイ値を含む shapley オブジェクト newExplainer を返します。

    newExplainer = fit(explainer,queryPoints,Name=Value) では、1 つ以上の名前と値の引数を使用して追加オプションを指定します。たとえば、UseParallel=true と指定してシャープレイ値を並列計算します。

    すべて折りたたむ

    回帰モデルの学習を行い、shapley オブジェクトを作成します。shapley オブジェクトを作成するときに、クエリ点を指定しなかった場合、シャープレイ値は計算されません。オブジェクト関数 fit を使用して、指定したクエリ点のシャープレイ値を計算します。次に、オブジェクト関数 plot を使用して、シャープレイ値の棒グラフを作成します。

    carbig データ セットを読み込みます。このデータ セットには、1970 年代と 1980 年代初期に製造された自動車の測定値が格納されています。

    load carbig

    AccelerationCylinders などの予測子変数と応答変数 MPG が格納された table を作成します。

    tbl = table(Acceleration,Cylinders,Displacement, ...
        Horsepower,Model_Year,Weight,MPG);

    学習セットの欠損値を削除すると、メモリ消費量を減らして関数 fitrkernel の学習速度を向上させることができます。tbl の欠損値を削除します。

    tbl = rmmissing(tbl);

    関数fitrkernelを使用して MPG の blackbox モデルの学習を行います。変数 CylindersModel_Year をカテゴリカル予測子として指定します。残りの予測子を標準化します。

    rng("default") % For reproducibility
    mdl = fitrkernel(tbl,"MPG",CategoricalPredictors=[2 5], ...
        Standardize=true);

    shapley オブジェクトを作成します。mdl には学習データが含まれないため、データ セット tbl を指定します。

    explainer = shapley(mdl,tbl)
    explainer = 
                BlackboxModel: [1x1 RegressionKernel]
                  QueryPoints: []
               BlackboxFitted: []
                ShapleyValues: []
                            X: [392x7 table]
        CategoricalPredictors: [2 5]
                       Method: "interventional-kernel"
                    Intercept: 22.7326
                   NumSubsets: 64
    
    

    explainer は、学習データ tblX プロパティに格納します。

    tbl の最初の観測値についてすべての予測子変数のシャープレイ値を計算します。

    queryPoint = tbl(1,:)
    queryPoint=1×7 table
        Acceleration    Cylinders    Displacement    Horsepower    Model_Year    Weight    MPG
        ____________    _________    ____________    __________    __________    ______    ___
    
             12             8            307            130            70         3504     18 
    
    
    explainer = fit(explainer,queryPoint);

    回帰モデルの場合、shapley は予測応答を使用してシャープレイ値を計算し、ShapleyValues プロパティに格納します。ShapleyValues プロパティの値を表示します。

    explainer.ShapleyValues
    ans=6×2 table
          Predictor       ShapleyValue
        ______________    ____________
    
        "Acceleration"      -0.23731  
        "Cylinders"         -0.87423  
        "Displacement"       -1.0224  
        "Horsepower"        -0.56975  
        "Model_Year"       -0.055414  
        "Weight"            -0.86088  
    
    

    関数 plot を使用して、クエリ点のシャープレイ値をプロットします。

    plot(explainer)

    横棒グラフは、絶対値で並べ替えられた、すべての変数のシャープレイ値を示します。各シャープレイ値は、クエリ点についての予測に関して対応する変数が原因で生じた平均からの偏差を説明します。

    分類モデルの学習を行い、shapley オブジェクトを作成します。その後、2 つのクエリ点のシャープレイ値を計算します。

    CreditRating_Historical データ セットを読み込みます。データ セットには、顧客 ID、顧客の財務比率、業種ラベル、および信用格付けが格納されています。

    tbl = readtable("CreditRating_Historical.dat");

    関数fitcecocを使用して、信用格付けの blackbox モデルに学習させます。tbl 内の 2 ~ 7 列目の変数を予測子変数として使用します。

    blackbox = fitcecoc(tbl,"Rating", ...
        PredictorNames=tbl.Properties.VariableNames(2:7), ...
        CategoricalPredictors="Industry");

    blackbox モデルを使用して、shapley オブジェクトを作成します。計算速度を向上するには、tbl の観測値の 25% を階層的にサブサンプリングし、その標本を使用してシャープレイ値を計算します。Kernel SHAP アルゴリズムの拡張機能の使用を指定します。

    rng("default") % For reproducibility
    c = cvpartition(tbl.Rating,"Holdout",0.25);
    sampleTbl = tbl(test(c),:);
    explainer = shapley(blackbox,sampleTbl,Method="conditional");

    真の信用格付け値がそれぞれ AAA および B となる 2 つのクエリ点を見つけます。

    queryPoints(1,:) = sampleTbl(find(strcmp(sampleTbl.Rating,"AAA"),1),:);
    queryPoints(2,:) = sampleTbl(find(strcmp(sampleTbl.Rating,"B"),1),:)
    queryPoints=2×8 table
         ID      WC_TA     RE_TA     EBIT_TA    MVE_BVTD    S_TA     Industry    Rating 
        _____    ______    ______    _______    ________    _____    ________    _______
    
        58258     0.511     0.869     0.106      8.538      0.732       2        {'AAA'}
        82367    -0.078    -0.042     0.011      0.262      0.167       7        {'B'  }
    
    

    最初のクエリ点についてのシャープレイ値を計算してプロットします。

    explainer1 = fit(explainer,queryPoints(1,:));
    plot(explainer1)

    2 番目のクエリ点についてのシャープレイ値を計算してプロットします。

    explainer2 = fit(explainer,queryPoints(2,:));
    plot(explainer2)

    2 番目のクエリ点に関する実際の格付けは B ですが、予測された格付けは BB です。プロットには、予測された格付けのシャープレイ値が表示されます。

    explainer1explainer2 にはそれぞれ、最初のクエリ点と 2 番目のクエリ点についてのシャープレイ値が含まれています。

    回帰モデルの学習を行い、shapley オブジェクトを作成します。オブジェクト関数 fit を使用して、指定したクエリ点のシャープレイ値を計算します。その後、オブジェクト関数 swarmchart を使用して複数のクエリ点のシャープレイ値をプロットします。

    carbig データ セットを読み込みます。このデータ セットには、1970 年代と 1980 年代初期に製造された自動車の測定値が格納されています。

    load carbig

    AccelerationCylinders などの予測子変数と応答変数 MPG が格納された table を作成します。

    tbl = table(Acceleration,Cylinders,Displacement, ...
        Horsepower,Model_Year,Weight,MPG);

    学習セットの欠損値を削除すると、メモリ消費量が減り、関数 fitrkernel の学習速度の向上に役立ちます。tbl の欠損値を削除します。

    tbl = rmmissing(tbl);

    関数 fitrkernel を使用して MPG の blackbox モデルに学習させます。変数 CylindersModel_Year をカテゴリカル予測子として指定します。残りの予測子を標準化します。

    rng("default") % For reproducibility
    mdl = fitrkernel(tbl,"MPG",CategoricalPredictors=[2 5], ...
        Standardize=true);

    shapley オブジェクトを作成します。mdl には学習データが含まれていないため、データ セット tbl を指定します。

    explainer = shapley(mdl,tbl)
    explainer = 
                BlackboxModel: [1×1 RegressionKernel]
                  QueryPoints: []
               BlackboxFitted: []
                ShapleyValues: []
                            X: [392×7 table]
        CategoricalPredictors: [2 5]
                       Method: "interventional-kernel"
                    Intercept: 22.7326
                   NumSubsets: 64
    
    

    explainer は、学習データ tblX プロパティに格納します。

    tbl のすべての観測値のシャープレイ値を計算します。Parallel Computing Toolbox™ のライセンスがある場合は、名前と値の引数 UseParallel を使用して計算を高速化します。

    explainer = fit(explainer,tbl,UseParallel=true);
    Starting parallel pool (parpool) using the 'Processes' profile ...
    10-Jan-2024 14:09:35: Job Queued. Waiting for parallel pool job with ID 5 to start ...
    Connected to parallel pool with 6 workers.
    

    回帰モデルの場合、shapley は予測応答を使用してシャープレイ値を計算し、ShapleyValues プロパティに格納します。explainer に複数のクエリ点のシャープレイ値が格納されているため、代わりに平均絶対シャープレイ値を表示します。

    explainer.MeanAbsoluteShapley
    ans=6×2 table
          Predictor       ShapleyValue
        ______________    ____________
    
        "Acceleration"      0.52233   
        "Cylinders"          1.0412   
        "Displacement"      0.80485   
        "Horsepower"         0.7589   
        "Model_Year"        0.82285   
        "Weight"            0.98453   
    
    

    それぞれの予測子について、すべてのクエリ点で平均化したシャープレイ値の絶対値が平均絶対シャープレイ値になります。予測子 Cylinders の平均絶対シャープレイ値が最も大きく、予測子 Acceleration の平均絶対シャープレイ値が最も小さくなっています。

    オブジェクト関数 swarmchart を使用してシャープレイ値を可視化します。"copper" カラーマップを使用するように指定します。

    swarmchart(explainer,ColorMap="copper")

    Figure contains an axes object. The axes object with title Shapley Summary Plot, xlabel Shapley Value, ylabel Predictor contains 7 objects of type constantline, scatter.

    それぞれの予測子について、クエリ点のシャープレイ値が関数によって表示されます。対応する粒子群チャートにシャープレイ値の分布が表示されます。予測子の順序は、平均絶対シャープレイ値を使用して関数で決定されます。

    Weight の値が小さいクエリ点は、シャープレイ値が大きい正の値になっているように見えます。つまり、それらのクエリ点については、予測子 WeightMPG の予測される値の平均からの差が大きくなるのに寄与しています。同様に、Weight の値が大きいクエリ点は、シャープレイ値が大きい負の値になっているように見えます。つまり、それらのクエリ点については、予測子 WeightMPG の予測される値の平均からの差が小さくなるのに寄与しています。これらの結果は、自動車の重量は MPG の値と逆の相関があるという考え方に一致しています。

    入力引数

    すべて折りたたむ

    blackbox モデルを説明するオブジェクト。shapley オブジェクトとして指定します。

    fit が予測を説明するクエリ点。数値行列または table として指定します。queryPoints の各行が 1 つのクエリ点に対応します。

    • 数値行列の場合

      • queryPoints の列を構成する変数の順序は、explainer の予測子データ X の順序と同じでなければなりません。

      • 予測子データ explainer.X が table の場合、table に含まれている変数がすべて数値変数であれば、queryPoints を数値行列にすることができます。

    • テーブルの場合

      • 予測子データ explainer.X が table の場合、queryPoints 内のすべての予測子変数は変数名およびデータ型が explainer.X 内の変数と同じでなければなりません。ただし、queryPoints の列の順序が explainer.X の列の順序に対応する必要はありません。

      • 予測子データ explainer.X が数値行列の場合、explainer.BlackboxModel.PredictorNames 内の予測子名と queryPoints 内の対応する予測子変数名が同じでなければなりません。学習時に予測子の名前を指定するには、名前と値の引数 PredictorNames を使用します。queryPoints 内の予測子変数はすべて数値ベクトルでなければなりません。

      • queryPoints に追加の変数 (応答変数や観測値の重みなど) を含めることができますが、fit はこれらを無視します。

      • fit は、文字ベクトルの cell 配列ではない cell 配列や複数列の変数をサポートしません。

    連続予測子について queryPointsNaN が含まれており、Method"conditional" である場合、返されたオブジェクトに含まれているシャープレイ値 (ShapleyValues) は NaN になります。ガウス過程回帰 (GPR)、カーネル、線形、ニューラル ネットワーク、またはサポート ベクター マシン (SVM) のモデルである回帰モデルを使用する場合、欠損値がある予測子や学習時にはなかったカテゴリを含むクエリ点については、fit はシャープレイ値として NaN を返します。それ以外のすべてのモデルについては、fit は欠損値を explainer.BlackboxModel (explainer.BlackboxModel のオブジェクト関数 predict または blackbox で指定された関数ハンドル) と同じ方法で処理します。

    R2024a より前: クエリ点が 1 つだけの場合は、数値の行ベクトルまたは単一行 table を使用して指定できます。

    例: explainer.X(1,:) は、explainer の予測子データ X の最初の観測値としてクエリ点を指定します。

    データ型: single | double | table

    名前と値の引数

    オプションの引数のペアを Name1=Value1,...,NameN=ValueN として指定します。ここで Name は引数名、Value は対応する値です。名前と値の引数は他の引数の後ろにする必要がありますが、ペアの順序は関係ありません。

    例: fit(explainer,q,Method="conditional",UseParallel=true) は、Kernel SHAP アルゴリズムの拡張機能を使用してクエリ点 q のシャープレイ値を計算し、その計算を並列実行します。

    シャープレイ値の計算に使用する予測子サブセットの最大数。正の整数を指定します。

    fit が使用するサブセットを選択する方法の詳細については、計算コストを参照してください。

    この引数は、関数 fit で Kernel SHAP アルゴリズムまたは Kernel SHAP アルゴリズムの拡張機能を使用する場合に有効です。Method"interventional" の場合に引数 MaxNumSubsets を設定すると、Kernel SHAP アルゴリズムが使用されます。詳細については、アルゴリズムを参照してください。

    例: MaxNumSubsets=100

    データ型: single | double

    シャープレイ値の計算アルゴリズム。"interventional" または "conditional" として指定します。

    • "interventional"fit は、介入型の価値関数でシャープレイ値を計算します。

      fit には介入型のアルゴリズムが 3 つあります。Kernel SHAP [1]、Linear SHAP [1]、および Tree SHAP [2]です。クエリ点ごとに、機械学習モデル explainer.BlackboxModel とその他のオプションの指定に基づいてアルゴリズムが選択されます。詳細については、介入型アルゴリズムを参照してください。

    • "conditional"fit は、条件付きの価値関数を伴う Kernel SHAP アルゴリズムの拡張機能[3]を使用します。

    選択されたアルゴリズムの名前が newExplainerMethod プロパティに格納されます。詳細については、アルゴリズムを参照してください。

    既定では、関数 fit は、explainerMethod プロパティで指定されたアルゴリズムを使用します。

    R2023a より前: この引数は "interventional-kernel" または "conditional-kernel" として指定できます。fit は、Kernel SHAP アルゴリズムと Kernel SHAP アルゴリズムの拡張機能をサポートしています。

    例: Method="conditional"

    データ型: char | string

    R2024a 以降

    各クエリ点の評価後に呼び出される関数。関数ハンドルとして指定します。シャープレイ値の計算の停止、変数の作成、結果のプロットなど、さまざまなタスクを出力関数で実行できます。独自の出力関数を記述する方法の詳細と例については、Shapley Output Functionsを参照してください。

    この引数は、関数 fit で複数のクエリ点のシャープレイ値を計算する場合で、UseParallel の値が false の場合のみ有効です。

    データ型: function_handle

    並列実行のフラグ。数値または logical の 1 (true) または 0 (false) として指定します。UseParallel=true を指定した場合、関数 fitparfor を使用して for ループの反復を実行します。Parallel Computing Toolbox™ がある場合、ループが並列に実行されます。

    この引数は、関数 fit で複数のクエリ点のシャープレイ値を計算する場合、または木のアンサンブル用の Tree SHAP アルゴリズム、Kernel SHAP アルゴリズム、または Kernel SHAP アルゴリズムの拡張機能を使用して 1 つのクエリ点のシャープレイ値を計算する場合のみ有効です。

    例: UseParallel=true

    データ型: logical

    出力引数

    すべて折りたたむ

    blackbox モデルを説明するオブジェクト。shapley オブジェクトとして返されます。オブジェクトの ShapleyValues プロパティには、計算されたシャープレイ値が含まれています。

    入力引数 explainer を上書きするには、fit の出力を explainer に代入します。

    explainer = fit(explainer,queryPoints);

    詳細

    すべて折りたたむ

    シャープレイ値

    ゲーム理論におけるプレーヤーのシャープレイ値とは、協力ゲームでのプレーヤーの平均限界貢献度です。機械学習予測のコンテキストでは、クエリ点の特徴量のシャープレイ値は、指定したクエリ点での予測 (回帰の場合は応答、分類の場合は各クラスのスコア) に対する特徴量の寄与を説明します。

    クエリ点の特徴量のシャープレイ値は、平均予測からの偏差に対する特徴量の寄与です。クエリ点について、すべての特徴量に関するシャープレイ値の合計は、予測の平均からの合計偏差に対応します。つまり、平均予測とすべての特徴量に関するシャープレイ値の合計は、クエリ点についての予測に対応します。

    詳細については、機械学習モデルのシャープレイ値を参照してください。

    参照

    [1] Lundberg, Scott M., and S. Lee. "A Unified Approach to Interpreting Model Predictions." Advances in Neural Information Processing Systems 30 (2017): 4765–774.

    [2] Lundberg, Scott M., G. Erion, H. Chen, et al. "From Local Explanations to Global Understanding with Explainable AI for Trees." Nature Machine Intelligence 2 (January 2020): 56–67.

    [3] Aas, Kjersti, Martin Jullum, and Anders Løland. "Explaining Individual Predictions When Features Are Dependent: More Accurate Approximations to Shapley Values." Artificial Intelligence 298 (September 2021).

    拡張機能

    バージョン履歴

    R2021a で導入

    すべて展開する