Main Content

このページの翻訳は最新ではありません。ここをクリックして、英語の最新版を参照してください。

LIME を使用した表形式データに関する深層ネットワークの予測の解釈

この例では、Locally Interpretable Model-agnostic Explanations (LIME) 技術を使用し、表形式データの分類を行う深層ニューラル ネットワークの予測を解釈する方法を説明します。LIME 技術を使用して、ネットワークによる分類の判定に最も重要な予測子を把握できます。

この例では、LIME を使用して、特徴データ分類用のネットワークを解釈します。指定されたクエリ観測値に対して、LIME は、各特徴の統計量が実データ セットと一致する合成データ セットを生成します。この合成データ セットは、分類を取得するために深層ニューラル ネットワークに渡され、単純かつ解釈しやすいモデルが当てはめられます。この単純なモデルは、ネットワークの分類判定に対する上位のいくつかの特徴の重要度を理解するために使用できます。この解釈可能モデルの学習において、合成観測値はクエリ観測値からの距離で重み付けされます。そのため、説明はクエリ観測値に対して "局所的" になります。

この例では、lime (Statistics and Machine Learning Toolbox)fit (Statistics and Machine Learning Toolbox)を使用して合成データ セットを生成し、単純かつ解釈可能なモデルを合成データ セットに当てはめます。学習済みのイメージ分類ニューラル ネットワークによる予測を理解するには、imageLIMEを使用してください。詳細については、LIME を使用したネットワーク予測の理解を参照してください。

データの読み込み

フィッシャーのアヤメのデータ セットを読み込みます。このデータには 150 個の観測値のほかに、植物のパラメーターを表す 4 つの入力特徴と、植物の種類を表す 1 つのカテゴリカル応答が含まれています。各観測値は、3 種類 (setosa、versicolor、virginica) のいずれかに分類されます。各観測値には、4 つの測定値 (がく片の幅、がく片の長さ、花弁の幅、花弁の長さ) があります。

filename = fullfile(toolboxdir('stats'),'statsdemos','fisheriris.mat');
load(filename)

数値データを table に変換します。

features = ["Sepal length","Sepal width","Petal length","Petal width"];

predictors = array2table(meas,"VariableNames",features);
trueLabels = array2table(categorical(species),"VariableNames","Response");

最後の列を応答とする学習データの table を作成します。

data = [predictors trueLabels];

観測値の数、特徴の数、およびクラスの数を計算します。

numObservations = size(predictors,1);
numFeatures = size(predictors,2);
numClasses = length(categories(data{:,5}));

学習セット、検証セット、およびテスト セットへのデータの分割

データ セットを学習セット、検証セット、およびテスト セットに分割します。データの 15% を検証用に、15% をテスト用に残しておきます。

各区画の観測数を求めます。乱数シードを設定してデータを分割し、CPU での学習が再現可能になるようにします。

rng('default');
numObservationsTrain = floor(0.7*numObservations);
numObservationsValidation = floor(0.15*numObservations);

観測値に対応するランダムなインデックスの配列を作成し、区画サイズを使用して分割します。

idx = randperm(numObservations);
idxTrain = idx(1:numObservationsTrain);
idxValidation = idx(numObservationsTrain + 1:numObservationsTrain + numObservationsValidation);
idxTest = idx(numObservationsTrain + numObservationsValidation + 1:end);

インデックスを使用して、データの table を学習用、検証用、およびテスト用の区画に分割します。

dataTrain = data(idxTrain,:);
dataVal = data(idxValidation,:);
dataTest = data(idxTest,:);

ネットワーク アーキテクチャの定義

5 個のニューロンと ReLU 活性化がある 1 つの隠れ層をもつ単純な多層パーセプトロンを作成します。特徴入力層は、フィッシャーのアヤメのデータ セットのような、特徴量を表す数値スカラーを含むデータを受け入れます。

numHiddenUnits = 5;
layers = [
    featureInputLayer(numFeatures)
    fullyConnectedLayer(numHiddenUnits)
    reluLayer
    fullyConnectedLayer(numClasses)
    softmaxLayer
    classificationLayer];

学習オプションの定義とネットワークの学習

モーメンタム項付き確率的勾配降下 (SGDM) を使用してネットワークに学習させます。学習データに含まれている観測値の数は多くないため、エポックの最大回数を 30 に設定し、ミニバッチのサイズ 15 を使用します。

opts = trainingOptions("sgdm", ...
    "MaxEpochs",30, ...
    "MiniBatchSize",15, ...
    "Shuffle","every-epoch", ...
    "ValidationData",dataVal, ...
    "ExecutionEnvironment","cpu");

ネットワークに学習をさせます。

net = trainNetwork(dataTrain,layers,opts);
|======================================================================================================================|
|  Epoch  |  Iteration  |  Time Elapsed  |  Mini-batch  |  Validation  |  Mini-batch  |  Validation  |  Base Learning  |
|         |             |   (hh:mm:ss)   |   Accuracy   |   Accuracy   |     Loss     |     Loss     |      Rate       |
|======================================================================================================================|
|       1 |           1 |       00:00:00 |       40.00% |       31.82% |       1.3060 |       1.2897 |          0.0100 |
|       8 |          50 |       00:00:00 |       86.67% |       90.91% |       0.4223 |       0.3656 |          0.0100 |
|      15 |         100 |       00:00:00 |       93.33% |       86.36% |       0.2947 |       0.2927 |          0.0100 |
|      22 |         150 |       00:00:00 |       86.67% |       81.82% |       0.2804 |       0.3707 |          0.0100 |
|      29 |         200 |       00:00:01 |       86.67% |       90.91% |       0.2268 |       0.2129 |          0.0100 |
|      30 |         210 |       00:00:01 |       93.33% |       95.45% |       0.2782 |       0.1666 |          0.0100 |
|======================================================================================================================|

ネットワーク性能の評価

学習済みネットワークを使用してテスト セットの観測値を分類します。

predictedLabels = net.classify(dataTest);
trueLabels = dataTest{:,end};

混同行列を使用して結果を可視化します。

figure
confusionchart(trueLabels,predictedLabels)

ネットワークは、テスト観測値について、植物の 4 つの特徴を使用して種類をうまく予測できています。

クラスごとに異なる予測子を使用することの重要性について理解する

LIME を使用して、ネットワークの分類判定に対する各予測子の重要度を理解します。

各観測値の最も重要な 2 つの予測子について調査します。

numImportantPredictors = 2;

lime を使用して、各特徴の統計量が実データ セットと一致する合成データ セットを作成します。深層学習モデル blackbox と、predictors に含まれている予測子データを使用して lime オブジェクトを作成します。'KernelWidth' に小さな値を使用して、lime が、クエリ点に近接するサンプルに焦点が当たるような重みを使用するようにします。

blackbox = @(x)classify(net,x);
explainer = lime(blackbox,predictors,'Type','classification','KernelWidth',0.1);

この LIME explainer を使用して、深層ニューラル ネットワークに対して最も重要な特徴を理解できます。関数は、クエリ観測値の近傍にあるニューラル ネットワークを近似する単純な線形モデルを使用して、特徴の重要度を推定します。

setosa クラスに対応するテスト データ内の、最初の 2 つの観測値のインデックスを検索します。

trueLabelsTest = dataTest{:,end};

label = "setosa";
idxSetosa = find(trueLabelsTest == label,2);

関数 fit を使用して、指定されたクラスの最初の 2 つの観測値に単純な線形モデルを当てはめます。

explainerObs1 = fit(explainer,dataTest(idxSetosa(1),1:4),numImportantPredictors);
explainerObs2 = fit(explainer,dataTest(idxSetosa(2),1:4),numImportantPredictors);

結果をプロットします。

figure
subplot(2,1,1)
plot(explainerObs1);
subplot(2,1,2)
plot(explainerObs2);

setosa クラスで最も重要な予測子は、小さな値の花弁の長さ (Petal length) と大きな値のがく片の幅 (Sepal width) です。

versicolor クラスに対して同じ解析を実行します。

label = "versicolor";
idxVersicolor = find(trueLabelsTest == label,2);

explainerObs1 = fit(explainer,dataTest(idxVersicolor(1),1:4),numImportantPredictors);
explainerObs2 = fit(explainer,dataTest(idxVersicolor(2),1:4),numImportantPredictors);

figure
subplot(2,1,1)
plot(explainerObs1);
subplot(2,1,2)
plot(explainerObs2);

versicolor クラスでは、大きな値の花弁の長さが重要です。

最後に、virginica クラスについて考慮します。

label = "virginica";
idxVirginica = find(trueLabelsTest == label,2);

explainerObs1 = fit(explainer,dataTest(idxVirginica(1),1:4),numImportantPredictors);
explainerObs2 = fit(explainer,dataTest(idxVirginica(2),1:4),numImportantPredictors);

figure
subplot(2,1,1)
plot(explainerObs1);
subplot(2,1,2)
plot(explainerObs2);

virginica クラスでは、大きな値の花弁の長さと小さな値のがく片の幅が重要です。

LIME 仮説の検証

LIME によるプロットは、花弁の長さの値が大きいものは versicolor と virginica のクラスと関連があり、花弁の長さの値が小さいものは setosa クラスと関連があることを示しています。データの詳細を調べて、結果についてさらに調査できます。

データ セットの各イメージの花弁の長さをプロットします。

setosaIdx = ismember(data{:,end},"setosa");
versicolorIdx = ismember(data{:,end},"versicolor");
virginicaIdx = ismember(data{:,end},"virginica");

figure
hold on
plot(data{setosaIdx,"Petal length"},'.')
plot(data{versicolorIdx,"Petal length"},'.')
plot(data{virginicaIdx,"Petal length"},'.')
hold off

xlabel("Observation number")
ylabel("Petal length")
legend(["setosa","versicolor","virginica"])

setosa クラスの花弁の長さの値は他のクラスの値をはるかに下回っており、lime モデルから得られた結果と一致しています。

参考

(Statistics and Machine Learning Toolbox) | (Statistics and Machine Learning Toolbox) | | | |

関連するトピック