Main Content

グラフ畳み込みネットワークを使用したノード分類

この例では、グラフ畳み込みネットワーク (GCN) を使用してグラフ内のノードを分類する方法を説明します。

グラフ内のノードのカテゴリカル ラベルを予測するには、GCN [1] を使用します。たとえば、GCN を使用すると、分子構造 (グラフとして表された化学結合) が与えられたときに、分子に含まれる原子のタイプ (炭素や酸素など) を予測できます。

GCN は、畳み込みニューラル ネットワークのバリアントで、次の 2 つの入力を取ります。

  1. NC 列の特徴行列 X。ここで、N はグラフ内のノードの数、C はノードごとのチャネル数です。

  2. グラフ内のノード間の接続を表す NN 列の隣接行列 A

この図は、グラフのノード分類の例を示しています。

このグラフのデータはスパースであるため、GCN の学習にはカスタム学習ループが最適です。この例では、カスタム学習ループを使用し、QM7 データ セット [2] [3] で GCN に学習させる方法を説明します。この分子データ セットには、最大 23 個の原子で構成された分子が 7165 個含まれます。つまり、原子の数が最も多い分子には 23 個の原子が含まれています。

QM7 データのダウンロードと読み込み

http://quantum-machine.org/data/qm7.mat から QM7 データ セットをダウンロードします。このデータ セットには、5 種類の原子 (炭素、水素、窒素、酸素、硫黄) が含まれています。

dataURL = "http://quantum-machine.org/data/qm7.mat";
outputFolder = fullfile(tempdir,"qm7Data");
dataFile = fullfile(outputFolder,"qm7.mat");

if ~exist(dataFile,"file")
    mkdir(outputFolder);
    disp("Downloading QM7 data...");
    websave(dataFile, dataURL);
    disp("Done.")
end

MAT ファイルには、異なる 5 つの配列が含まれています。この例では、配列 X および Z を使用します。これらはそれぞれ、各分子のクーロン行列 [3] 表現、および分子に含まれる各原子の原子番号を表しています。データに含まれる分子のうち、原子の数が 23 個未満の分子はゼロでパディングされます。

MAT ファイルから QM7 データを読み込みます。

data = load(dataFile)
data = struct with fields:
    X: [7165×23×23 single]
    R: [7165×23×3 single]
    Z: [7165×23 single]
    T: [-417.9600 -712.4200 -564.2100 -404.8800 -808.8700 -677.1600 -796.9800 -860.3300 -1.0085e+03 -861.7300 -708.3700 -725.9300 -879.3800 -618.7200 -871.1900 -653.4400 -1.0109e+03 -1.1594e+03 -1.0039e+03 -1.0184e+03 -1.0250e+03 -1.1750e+03 … ]
    P: [5×1433 int64]

読み込んだ構造体から、クーロン データと原子番号を抽出します。3 番目の次元が観測値に対応するように、クーロン データを並べ替えます。原子番号を降順で並べ替えます。

coulombData = double(permute(data.X, [2 3 1]));
atomData = sort(data.Z,2,'descend');

最初の観測値の原子を表示します。非ゼロの要素の数は、分子に含まれる原子の種類の数を表します。それぞれの非ゼロの要素は、分子に含まれる特定の元素の原子番号に対応します。

atomData(1,:)
ans = 1×23 single row vector

     6     1     1     1     1     0     0     0     0     0     0     0     0     0     0     0     0     0     0     0     0     0     0

グラフ データの前処理

この例で使用する GCN は、特徴入力としてクーロン行列を必要とし、さらに対応する隣接行列を必要とします。

学習データに含まれるクーロン行列を隣接行列に変換するには、この例にサポート ファイルとして添付されている関数 coulomb2Adjacency を使用します。このファイルにアクセスするには、例をライブ スクリプトとして開きます。関数 coloumb2Adjacency は、学習用、検証用、および推論用にデータを各分子に簡単に分割できるよう、パディングされたゼロをデータから除去しません。

adjacencyData = coulomb2Adjacency(coulombData,atomData);

最初のいくつかの分子をプロットに可視化します。各分子について、パディングされていない隣接行列を抽出し、ラベル付けされたノードとともにグラフをプロットします。原子番号を元素記号に変換するには、この例にサポート ファイルとして添付されている関数 atomicSymbol を使用します。この関数にアクセスするには、例をライブ スクリプトとして開きます。

figure
tiledlayout("flow")

for i = 1:9
    % Extract unpadded adjacency matrix.
    atomicNumbers = nonzeros(atomData(i,:));
    numNodes = numel(atomicNumbers);
    A = adjacencyData(1:numNodes,1:numNodes,i);

    % Convert adjacency matrix to graph.
    G = graph(A);

    % Convert atomic numbers to symbols.
    symbols = atomicSymbol(atomicNumbers);

    % Plot graph.
    nexttile
    plot(G,NodeLabel=symbols,Layout="force")
    title("Molecule " + i)
end

Figure contains 9 axes objects. Axes object 1 with title Molecule 1 contains an object of type graphplot. Axes object 2 with title Molecule 2 contains an object of type graphplot. Axes object 3 with title Molecule 3 contains an object of type graphplot. Axes object 4 with title Molecule 4 contains an object of type graphplot. Axes object 5 with title Molecule 5 contains an object of type graphplot. Axes object 6 with title Molecule 6 contains an object of type graphplot. Axes object 7 with title Molecule 7 contains an object of type graphplot. Axes object 8 with title Molecule 8 contains an object of type graphplot. Axes object 9 with title Molecule 9 contains an object of type graphplot.

ヒストグラムを使用して、各ラベル カテゴリの出現頻度を可視化します。

figure
histogram(categorical(atomicSymbol(atomData)))
xlabel("Node Label")
ylabel("Frequency")
title("Label Counts")

Figure contains an axes object. The axes object with title Label Counts contains an object of type categoricalhistogram.

学習用、検証用、テスト用に、それぞれ 80%、10%、10% の比率でデータを分割します。データをランダムに分割するため、この例にサポート ファイルとして添付されている関数 trainingPartitions を使用します。このファイルにアクセスするには、例をライブ スクリプトとして開きます。

numObservations = size(adjacencyData,3);
[idxTrain,idxValidation,idxTest] = trainingPartitions(numObservations,[0.8 0.1 0.1]);

adjacencyDataTrain = adjacencyData(:,:,idxTrain);
adjacencyDataValidation = adjacencyData(:,:,idxValidation);
adjacencyDataTest = adjacencyData(:,:,idxTest);

coulombDataTrain = coulombData(:,:,idxTrain);
coulombDataValidation = coulombData(:,:,idxValidation);
coulombDataTest = coulombData(:,:,idxTest);

atomDataTrain = atomData(idxTrain,:);
atomDataValidation = atomData(idxValidation,:);
atomDataTest = atomData(idxTest,:);

この例のデータ前処理関数のセクションで定義された関数 preprocessData を使用して、学習データと検証データを前処理します。関数 preprocessData は、さまざまなグラフ インスタンスの隣接行列から成るスパース ブロック対角行列を構築します。この行列の各ブロックは、1 つのグラフ インスタンスの隣接行列に対応します。GCN は 1 つの隣接行列を入力として受け取りますが、この例では複数のグラフ インスタンスを扱うため、この前処理が必要になります。この関数は、クーロン行列の非ゼロの対角要素を受け取り、それらを特徴として割り当てます。そのため、この例ではノードごとの入力特徴の数は 1 になります。

[ATrain,XTrain,labelsTrain] = preprocessData(adjacencyDataTrain,coulombDataTrain,atomDataTrain);
size(XTrain)
ans = 1×2

       88424           1

size(labelsTrain)
ans = 1×2

       88424           1

[AValidation,XValidation,labelsValidation] = preprocessData(adjacencyDataValidation,coulombDataValidation,atomDataValidation);

学習特徴の平均と分散を使用して、特徴を正規化します。同じ統計量を使用して、検証特徴を正規化します。

muX = mean(XTrain);
sigsqX = var(XTrain,1);

XTrain = (XTrain - muX)./sqrt(sigsqX);
XValidation = (XValidation - muX)./sqrt(sigsqX);

深層学習モデルの定義

以下の深層学習モデルを定義します。これは、隣接行列 A と特徴行列 X を入力として受け取り、カテゴリカル予測を出力します。

乗算処理では、学習可能な重みによる重み付き乗算が行われます。

詳しく見ると、このモデルは Zl+1=σl(Dˆ-1/2AˆDˆ-1/2ZlWl)+Zl という形式の一連の演算となっています (最後の演算に加算ステップは含まれません)。この式は以下のようになっています。

  • σl は活性化関数。

  • Z1=X.

  • Wl は乗算用の重み行列。

  • Aˆ=A+IN は、グラフ G の隣接行列に自己結合を追加したもの。IN は単位行列。

  • DˆAˆ の次数行列。

Dˆ-1/2AˆDˆ-1/2 は、グラフの "正規化" 隣接行列とも呼ばれます。

モデル パラメーターの初期化

各演算のパラメーターを定義して構造体に含めます。parameters.OperationName.ParameterName の形式を使用します。ここで、parameters は構造体、OperationName は演算名 (multiply1 など)、ParameterName はパラメーター名 (Weights など) です。

モデル パラメーターを含む構造体 parameters を作成します。

parameters = struct;

この例にサポート ファイルとして添付されている関数 initializeGlorot を使用して、学習可能な重みを初期化します。この関数にアクセスするには、例をライブ スクリプトとして開きます。

最初の乗算の重みを初期化します。出力サイズが 32 となるように重みを初期化します。入力サイズは、入力特徴データのチャネルの数です。

numHiddenFeatureMaps = 32;
numInputFeatures = size(XTrain,2);

sz = [numInputFeatures numHiddenFeatureMaps];
numOut = numHiddenFeatureMaps;
numIn = numInputFeatures;
parameters.mult1.Weights = initializeGlorot(sz,numOut,numIn,"double");

2 番目の乗算の重みを初期化します。出力サイズが前の乗算と同じになるように重みを初期化します。入力サイズは、前の乗算の出力サイズです。

sz = [numHiddenFeatureMaps numHiddenFeatureMaps];
numOut = numHiddenFeatureMaps;
numIn = numHiddenFeatureMaps;
parameters.mult2.Weights = initializeGlorot(sz,numOut,numIn,"double");

3 番目の乗算の重みを初期化します。出力サイズがクラスの数と同じになるように重みを初期化します。入力サイズは、前の乗算の出力サイズです。

classes = categories(labelsTrain);
numClasses = numel(classes);

sz = [numHiddenFeatureMaps numClasses];
numOut = numClasses;
numIn = numHiddenFeatureMaps;
parameters.mult3.Weights = initializeGlorot(sz,numOut,numIn,"double");

パラメーターの構造を表示します。

parameters
parameters = struct with fields:
    mult1: [1×1 struct]
    mult2: [1×1 struct]
    mult3: [1×1 struct]

最初の乗算のパラメーターを表示します。

parameters.mult1
ans = struct with fields:
    Weights: [1×32 dlarray]

モデルの関数の定義

この例のモデル関数のセクションで定義されている関数 model を作成します。この関数は、モデル パラメーター、特徴データ、および隣接行列を入力として受け取り、予測を返します。

モデル損失関数の定義

この例のモデル損失関数のセクションで定義されている関数 modelLoss を作成します。この関数は、モデル パラメーター、特徴データ、隣接行列、および one-hot 符号化されたターゲットを入力として受け取り、パラメーターについての損失とその損失の勾配、およびネットワーク予測を返します。

学習オプションの指定

1500 エポック学習させ、Adam ソルバーの学習率を 0.01 に設定します。

numEpochs = 1500;
learnRate = 0.01;

300 エポックごとにネットワークを検証します。

validationFrequency = 300;

モデルの学習

Adam 用にパラメーターを初期化します。

trailingAvg = [];
trailingAvgSq = [];

学習用と検証用の特徴データを dlarray オブジェクトに変換します。

XTrain = dlarray(XTrain);
XValidation = dlarray(XValidation);

GPU が利用できる場合、GPU で学習を行います。GPU を使用するには、Parallel Computing Toolbox™ とサポートされている GPU デバイスが必要です。サポートされているデバイスについては、GPU 計算の要件 (Parallel Computing Toolbox) (Parallel Computing Toolbox) を参照してください。GPU で学習を行うには、データを gpuArray オブジェクトに変換します。

if canUseGPU
    XTrain = gpuArray(XTrain);
end

関数 onehotencode を使用して、学習ラベルと検証ラベルを one-hot 符号化ベクトルに変換します。

TTrain = onehotencode(labelsTrain,2,ClassNames=classes);
TValidation = onehotencode(labelsValidation,2,ClassNames=classes);

TrainingProgressMonitor オブジェクトを初期化します。

monitor = trainingProgressMonitor( ...
    Metrics=["TrainingLoss","ValidationLoss"], ...
    Info="Epoch", ...
    XLabel="Epoch");

groupSubPlot(monitor,"Loss",["TrainingLoss","ValidationLoss"])

カスタム学習ループを使用してモデルに学習させます。この学習では、全バッチによる勾配降下法を使用します。

各エポックで次を行います。

  • 関数 dlfeval および modelLoss を使用してモデルの損失と勾配を評価します。

  • adamupdate を使用してネットワーク パラメーターを更新します。

  • 学習プロットを更新します。

  • 必要に応じて、関数 model で予測を行ってネットワークを検証し、検証損失をプロットします。

epoch = 0;

while epoch < numEpochs && ~monitor.Stop
    epoch = epoch + 1;

    % Evaluate the model loss and gradients.
    [loss,gradients] = dlfeval(@modelLoss,parameters,XTrain,ATrain,TTrain);

    % Update the network parameters using the Adam optimizer.
    [parameters,trailingAvg,trailingAvgSq] = adamupdate(parameters,gradients, ...
        trailingAvg,trailingAvgSq,epoch,learnRate);

    % Record the training loss and epoch.
    recordMetrics(monitor,epoch,TrainingLoss=loss);
    updateInfo(monitor,Epoch=(epoch+" of "+numEpochs));

    % Display the validation metrics.
    if epoch == 1 || mod(epoch,validationFrequency) == 0
        YValidation = model(parameters,XValidation,AValidation);
        lossValidation = crossentropy(YValidation,TValidation,DataFormat="BC");

        % Record the validation loss.
        recordMetrics(monitor,epoch,ValidationLoss=lossValidation);
    end

    monitor.Progress = 100*(epoch/numEpochs);
end

モデルのテスト

テスト データを使用して、モデルをテストします。

学習データおよび検証データの場合と同じ手順に従って、テスト データを前処理します。

[ATest,XTest,labelsTest] = preprocessData(adjacencyDataTest,coulombDataTest,atomDataTest);
XTest = (XTest - muX)./sqrt(sigsqX);

テスト用の特徴データを dlarray オブジェクトに変換します。

XTest = dlarray(XTest);

このデータで予測を行い、関数 onehotdecode を使用して確率をカテゴリカル ラベルに変換します。

YTest = model(parameters,XTest,ATest);
YTest = onehotdecode(YTest,classes,2);

精度を計算します。

accuracy = mean(YTest == labelsTest)
accuracy = 0.8930

モデルがどのようにして間違った予測を行うかを可視化し、クラスごとの適合率とクラスごとの再現率に基づいてモデルを評価するには、関数 confusionchart を使用して混同行列を計算します。

クラスごとの適合率は、クラスに関する陽性予測の総数に対する真陽性の割合です。陽性予測の総数には、真陽性と偽陽性が含まれます。偽陽性は、あるクラスが観測値に含まれているとモデルが間違って予測した結果です。クラスごとの再現率 (これは真陽性率とも呼ばれます) は、クラスに関する陽性観測値の総数に対する真陽性の割合です。陽性観測値の総数には、真陽性と偽陰性が含まれます。偽陰性は、あるクラスが観測値に含まれていないとモデルが間違って予測した結果です。

figure
cm = confusionchart(labelsTest,YTest, ...
    ColumnSummary="column-normalized", ...
    RowSummary="row-normalized");
title("GCN QM7 Confusion Chart");

Figure contains an object of type ConfusionMatrixChart. The chart of type ConfusionMatrixChart has title GCN QM7 Confusion Chart.

クラスごとの適合率スコアは、チャートの列の要約における最初の行に示され、クラスごとの再現率スコアは、チャートの行の要約における最初の列に示されます。

新しいデータを使用した予測

この例のモデル予測関数のセクションにリストされている関数 modelPredictions を使用して、ラベル付けされていないデータで予測を行います。簡単にするために、テスト データに含まれる最初のいくつかの観測値を使用します。

numObservationsNew = 4;
adjacencyDataNew = adjacencyDataTest(:,:,1:numObservationsNew);
coulombDataNew = coulombDataTest(:,:,1:numObservationsNew);

predictions = modelPredictions(parameters,coulombDataNew,adjacencyDataNew,muX,sigsqX,classes);

プロットで予測を可視化します。各分子について、隣接行列を使用してグラフ表現を作成し、予測に基づいてノードをラベル付けします。

figure
tiledlayout("flow")

for i = 1:numObservationsNew
    % Extract unpadded adjacency data.
    numNodes = find(any(adjacencyDataTest(:,:,i)),1,"last");

    A = adjacencyDataTest(1:numNodes,1:numNodes,i);

    % Create and plot graph representation.
    nexttile
    G = graph(A);
    plot(G,NodeLabel=string(predictions{i}),Layout="force")
    title("Observation " + i + " Prediction")
end

Figure contains 4 axes objects. Axes object 1 with title Observation 1 Prediction contains an object of type graphplot. Axes object 2 with title Observation 2 Prediction contains an object of type graphplot. Axes object 3 with title Observation 3 Prediction contains an object of type graphplot. Axes object 4 with title Observation 4 Prediction contains an object of type graphplot.

サポート関数

データ前処理関数

関数 preprocessData は、次の手順を使用して、隣接、クーロン、原子のデータを前処理します。

  • この例の予測子前処理関数のセクションにリストされている関数 preprocessPredictors を使用して、隣接行列とクーロン行列を前処理。

  • 原子のデータをカテゴリカル ラベルから成るフラット化された配列に変換。

function [adjacency,features,labels] = preprocessData(adjacencyData,coulombData,atomData)

[adjacency, features] = preprocessPredictors(adjacencyData,coulombData);
labels = [];

% Convert labels to categorical.
for i = 1:size(adjacencyData,3)
    % Extract and append unpadded data.
    T = nonzeros(atomData(i,:));
    labels = [labels; T];
end

labels2 = nonzeros(atomData);
assert(isequal(labels2,labels2))

atomicNumbers = unique(labels);
atomNames =  atomicSymbol(atomicNumbers);
labels = categorical(labels, atomicNumbers, atomNames);

end

予測子前処理関数

関数 preprocessPredictors は、次の手順を使用して、隣接行列とクーロン行列を前処理します。

各分子について次を行います。

  • パディングされていないデータを抽出。

  • パディングされていないクーロン行列の対角要素から特徴ベクトルを抽出。

  • 抽出したデータを出力配列に追加。

GCN で隣接行列を入力するには、各隣接行列を含む単一のスパース ブロック対角行列が必要となります。ここで、各ブロックは、1 つのグラフ インスタンスの隣接行列に対応します。この関数は、ブロック対角行列にデータを追加するために、関数 blkdiag を使用します。

function [adjacency,features] = preprocessPredictors(adjacencyData,coulombData)

adjacency = sparse([]);
features = [];

for i = 1:size(adjacencyData, 3)
    % Extract unpadded data.
    numNodes = find(any(adjacencyData(:,:,i)),1,"last");

    A = adjacencyData(1:numNodes,1:numNodes,i);
    X = coulombData(1:numNodes,1:numNodes,i);

    % Extract feature vector from diagonal of Coulomb matrix.
    X = diag(X);

    % Append extracted data.
    adjacency = blkdiag(adjacency,A);
    features = [features; X];
end

end

モデル関数

関数 model は、モデル パラメーター parameters、特徴行列 X、および隣接行列 A を入力として受け取り、ネットワーク予測を返します。関数 model は、前処理のステップにおいて、この例の隣接正規化関数のセクションにリストされている関数 normalizeAdjacency を使用し、"正規化された" 隣接行列を計算します。正規化された隣接行列は、下に示す式の Dˆ-1/2AˆDˆ-1/2 に対応します。

この深層学習モデルは、隣接行列 A と特徴行列 X を入力として受け取り、カテゴリカル予測を出力します。

乗算処理では、学習可能な重みによる重み付き乗算が行われます。

詳しく見ると、このモデルは Zl+1=σl(Dˆ-1/2AˆDˆ-1/2ZlWl)+Zl という形式の一連の演算となっています (最後の演算に加算ステップは含まれません)。この式は以下のようになっています。

  • σl は活性化関数。

  • Z1=X.

  • Wl は乗算用の重み行列。

  • Aˆ=A+IN は、グラフ G の隣接行列に自己結合を追加したもの。IN は単位行列。

  • DˆAˆ の次数行列。

function Y = model(parameters,X,A)

ANorm = normalizeAdjacency(A);

Z1 = X;

Z2 = ANorm * Z1 * parameters.mult1.Weights;
Z2 = relu(Z2) + Z1;

Z3 = ANorm * Z2 * parameters.mult2.Weights;
Z3 = relu(Z3) + Z2;

Z4 = ANorm * Z3 * parameters.mult3.Weights;
Y = softmax(Z4,DataFormat="BC");

end

モデル損失関数

関数 modelLoss は、モデル パラメーター parameters、特徴行列 X、および隣接行列 A を入力として受け取り、one-hot 符号化されたターゲット データ T、およびモデル パラメーターについての損失とその損失の勾配を返します。

function [loss,gradients] = modelLoss(parameters,X,A,T)

Y = model(parameters,X,A);
loss = crossentropy(Y,T,DataFormat="BC");
gradients = dlgradient(loss, parameters);

end

モデル予測関数

関数 modelPredictions は、モデル パラメーター、クーロンと隣接の入力データ、正規化統計量 (musigsq)、およびクラス名のリストを入力として受け取り、入力データについて予測されたノード ラベルの cell 配列を返します。この関数は、入力グラフを一度に 1 回ずつループ処理することで予測を行います。

function predictions = modelPredictions(parameters,coulombData,adjacencyData,mu,sigsq,classes)

predictions = {};
numObservations = size(coulombData,3);

for i = 1:numObservations
    % Extract unpadded data.
    numNodes = find(any(adjacencyData(:,:,i)),1,"last");
    A = adjacencyData(1:numNodes,1:numNodes,i);
    X = coulombData(1:numNodes,1:numNodes,i);

    % Preprocess data.
    [A,X] = preprocessPredictors(A,X);
    X = (X - mu)./sqrt(sigsq);
    X = dlarray(X);

    % Make predictions.
    Y = model(parameters,X,A);
    Y = onehotdecode(Y,classes,2);
    predictions{end+1} = Y;
end

end

隣接正規化関数

関数 normalizeAdjacency は、隣接行列 A を入力として受け取り、"正規化された" 隣接行列 Dˆ-1/2AˆDˆ-1/2 を返します。ここで、Aˆ=A+IN は、グラフの隣接行列に自己結合を追加したもの、IN は単位行列、DˆAˆ の次数行列です。

function ANorm = normalizeAdjacency(A)

% Add self connections to adjacency matrix.
A = A + speye(size(A));

% Compute inverse square root of degree.
degree = sum(A, 2);
degreeInvSqrt = sparse(sqrt(1./degree));

% Normalize adjacency matrix.
ANorm = diag(degreeInvSqrt) * A * diag(degreeInvSqrt);

end

参考文献

  1. Kipf, Thomas N., and Max Welling. “Semi-Supervised Classification with Graph Convolutional Networks.” Paper presented at ICLR 2017, Toulon, France, April 2017.

  2. Blum, Lorenz C., and Jean-Louis Reymond. “970 Million Druglike Small Molecules for Virtual Screening in the Chemical Universe Database GDB-13.” Journal of the American Chemical Society 131, no. 25 (July 1, 2009): 8732–33. https://doi.org/10.1021/ja902302h.

  3. Rupp, Matthias, Alexandre Tkatchenko, Klaus-Robert Müller, and O. Anatole von Lilienfeld. “Fast and Accurate Modeling of Molecular Atomization Energies with Machine Learning.” Physical Review Letters 108, no. 5 (January 31, 2012): 058301. https://doi.org/10.1103/PhysRevLett.108.058301.

Copyright 2021, The MathWorks, Inc.

参考

| | |

関連するトピック