グラフ畳み込みネットワークを使用したノード分類
この例では、グラフ畳み込みネットワーク (GCN) を使用してグラフ内のノードを分類する方法を説明します。
グラフ内のノードのカテゴリカル ラベルを予測するには、GCN [1] を使用します。たとえば、GCN を使用すると、分子構造 (グラフとして表された化学結合) が与えられたときに、分子に含まれる原子のタイプ (炭素や酸素など) を予測できます。
GCN は、畳み込みニューラル ネットワークのバリアントで、次の 2 つの入力を取ります。
行 列の特徴行列 。ここで、 はグラフ内のノードの数、 はノードごとのチャネル数です。
グラフ内のノード間の接続を表す 行 列の隣接行列 。
この図は、グラフのノード分類の例を示しています。
このグラフのデータはスパースであるため、GCN の学習にはカスタム学習ループが最適です。この例では、カスタム学習ループを使用し、QM7 データ セット [2] [3] で GCN に学習させる方法を説明します。この分子データ セットには、最大 23 個の原子で構成された分子が 7165 個含まれます。つまり、原子の数が最も多い分子には 23 個の原子が含まれています。
QM7 データのダウンロードと読み込み
http://quantum-machine.org/data/qm7.mat から QM7 データ セットをダウンロードします。このデータ セットには、5 種類の原子 (炭素、水素、窒素、酸素、硫黄) が含まれています。
dataURL = "http://quantum-machine.org/data/qm7.mat"; outputFolder = fullfile(tempdir,"qm7Data"); dataFile = fullfile(outputFolder,"qm7.mat"); if ~exist(dataFile,"file") mkdir(outputFolder); disp("Downloading QM7 data..."); websave(dataFile, dataURL); disp("Done.") end
MAT ファイルには、異なる 5 つの配列が含まれています。この例では、配列 X
および Z
を使用します。これらはそれぞれ、各分子のクーロン行列 [3] 表現、および分子に含まれる各原子の原子番号を表しています。データに含まれる分子のうち、原子の数が 23 個未満の分子はゼロでパディングされます。
MAT ファイルから QM7 データを読み込みます。
data = load(dataFile)
data = struct with fields:
X: [7165×23×23 single]
R: [7165×23×3 single]
Z: [7165×23 single]
T: [-417.9600 -712.4200 -564.2100 -404.8800 -808.8700 -677.1600 -796.9800 -860.3300 -1.0085e+03 -861.7300 -708.3700 -725.9300 -879.3800 -618.7200 -871.1900 -653.4400 -1.0109e+03 -1.1594e+03 -1.0039e+03 -1.0184e+03 -1.0250e+03 -1.1750e+03 … ]
P: [5×1433 int64]
読み込んだ構造体から、クーロン データと原子番号を抽出します。3 番目の次元が観測値に対応するように、クーロン データを並べ替えます。原子番号を降順で並べ替えます。
coulombData = double(permute(data.X, [2 3 1]));
atomData = sort(data.Z,2,'descend');
最初の観測値の原子を表示します。非ゼロの要素の数は、分子に含まれる原子の種類の数を表します。それぞれの非ゼロの要素は、分子に含まれる特定の元素の原子番号に対応します。
atomData(1,:)
ans = 1×23 single row vector
6 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
グラフ データの前処理
この例で使用する GCN は、特徴入力としてクーロン行列を必要とし、さらに対応する隣接行列を必要とします。
学習データに含まれるクーロン行列を隣接行列に変換するには、この例にサポート ファイルとして添付されている関数 coulomb2Adjacency
を使用します。このファイルにアクセスするには、例をライブ スクリプトとして開きます。関数 coloumb2Adjacency
は、学習用、検証用、および推論用にデータを各分子に簡単に分割できるよう、パディングされたゼロをデータから除去しません。
adjacencyData = coulomb2Adjacency(coulombData,atomData);
最初のいくつかの分子をプロットに可視化します。各分子について、パディングされていない隣接行列を抽出し、ラベル付けされたノードとともにグラフをプロットします。原子番号を元素記号に変換するには、この例にサポート ファイルとして添付されている関数 atomicSymbol
を使用します。この関数にアクセスするには、例をライブ スクリプトとして開きます。
figure tiledlayout("flow") for i = 1:9 % Extract unpadded adjacency matrix. atomicNumbers = nonzeros(atomData(i,:)); numNodes = numel(atomicNumbers); A = adjacencyData(1:numNodes,1:numNodes,i); % Convert adjacency matrix to graph. G = graph(A); % Convert atomic numbers to symbols. symbols = atomicSymbol(atomicNumbers); % Plot graph. nexttile plot(G,NodeLabel=symbols,Layout="force") title("Molecule " + i) end
ヒストグラムを使用して、各ラベル カテゴリの出現頻度を可視化します。
figure histogram(categorical(atomicSymbol(atomData))) xlabel("Node Label") ylabel("Frequency") title("Label Counts")
学習用、検証用、テスト用に、それぞれ 80%、10%、10% の比率でデータを分割します。データをランダムに分割するため、この例にサポート ファイルとして添付されている関数 trainingPartitions
を使用します。このファイルにアクセスするには、例をライブ スクリプトとして開きます。
numObservations = size(adjacencyData,3); [idxTrain,idxValidation,idxTest] = trainingPartitions(numObservations,[0.8 0.1 0.1]); adjacencyDataTrain = adjacencyData(:,:,idxTrain); adjacencyDataValidation = adjacencyData(:,:,idxValidation); adjacencyDataTest = adjacencyData(:,:,idxTest); coulombDataTrain = coulombData(:,:,idxTrain); coulombDataValidation = coulombData(:,:,idxValidation); coulombDataTest = coulombData(:,:,idxTest); atomDataTrain = atomData(idxTrain,:); atomDataValidation = atomData(idxValidation,:); atomDataTest = atomData(idxTest,:);
この例のデータ前処理関数のセクションで定義された関数 preprocessData
を使用して、学習データと検証データを前処理します。関数 preprocessData
は、さまざまなグラフ インスタンスの隣接行列から成るスパース ブロック対角行列を構築します。この行列の各ブロックは、1 つのグラフ インスタンスの隣接行列に対応します。GCN は 1 つの隣接行列を入力として受け取りますが、この例では複数のグラフ インスタンスを扱うため、この前処理が必要になります。この関数は、クーロン行列の非ゼロの対角要素を受け取り、それらを特徴として割り当てます。そのため、この例ではノードごとの入力特徴の数は 1 になります。
[ATrain,XTrain,labelsTrain] = preprocessData(adjacencyDataTrain,coulombDataTrain,atomDataTrain); size(XTrain)
ans = 1×2
88424 1
size(labelsTrain)
ans = 1×2
88424 1
[AValidation,XValidation,labelsValidation] = preprocessData(adjacencyDataValidation,coulombDataValidation,atomDataValidation);
学習特徴の平均と分散を使用して、特徴を正規化します。同じ統計量を使用して、検証特徴を正規化します。
muX = mean(XTrain); sigsqX = var(XTrain,1); XTrain = (XTrain - muX)./sqrt(sigsqX); XValidation = (XValidation - muX)./sqrt(sigsqX);
深層学習モデルの定義
以下の深層学習モデルを定義します。これは、隣接行列 と特徴行列 を入力として受け取り、カテゴリカル予測を出力します。
乗算処理では、学習可能な重みによる重み付き乗算が行われます。
詳しく見ると、このモデルは という形式の一連の演算となっています (最後の演算に加算ステップは含まれません)。この式は以下のようになっています。
は活性化関数。
.
は乗算用の重み行列。
は、グラフ の隣接行列に自己結合を追加したもの。 は単位行列。
は の次数行列。
式 は、グラフの "正規化" 隣接行列とも呼ばれます。
モデル パラメーターの初期化
各演算のパラメーターを定義して構造体に含めます。parameters.OperationName.ParameterName
の形式を使用します。ここで、parameters
は構造体、OperationName
は演算名 (multiply1
など)、ParameterName
はパラメーター名 (Weights
など) です。
モデル パラメーターを含む構造体 parameters
を作成します。
parameters = struct;
この例にサポート ファイルとして添付されている関数 initializeGlorot
を使用して、学習可能な重みを初期化します。この関数にアクセスするには、例をライブ スクリプトとして開きます。
最初の乗算の重みを初期化します。出力サイズが 32 となるように重みを初期化します。入力サイズは、入力特徴データのチャネルの数です。
numHiddenFeatureMaps = 32;
numInputFeatures = size(XTrain,2);
sz = [numInputFeatures numHiddenFeatureMaps];
numOut = numHiddenFeatureMaps;
numIn = numInputFeatures;
parameters.mult1.Weights = initializeGlorot(sz,numOut,numIn,"double");
2 番目の乗算の重みを初期化します。出力サイズが前の乗算と同じになるように重みを初期化します。入力サイズは、前の乗算の出力サイズです。
sz = [numHiddenFeatureMaps numHiddenFeatureMaps];
numOut = numHiddenFeatureMaps;
numIn = numHiddenFeatureMaps;
parameters.mult2.Weights = initializeGlorot(sz,numOut,numIn,"double");
3 番目の乗算の重みを初期化します。出力サイズがクラスの数と同じになるように重みを初期化します。入力サイズは、前の乗算の出力サイズです。
classes = categories(labelsTrain);
numClasses = numel(classes);
sz = [numHiddenFeatureMaps numClasses];
numOut = numClasses;
numIn = numHiddenFeatureMaps;
parameters.mult3.Weights = initializeGlorot(sz,numOut,numIn,"double");
パラメーターの構造を表示します。
parameters
parameters = struct with fields:
mult1: [1×1 struct]
mult2: [1×1 struct]
mult3: [1×1 struct]
最初の乗算のパラメーターを表示します。
parameters.mult1
ans = struct with fields:
Weights: [1×32 dlarray]
モデルの関数の定義
この例のモデル関数のセクションで定義されている関数 model
を作成します。この関数は、モデル パラメーター、特徴データ、および隣接行列を入力として受け取り、予測を返します。
モデル損失関数の定義
この例のモデル損失関数のセクションで定義されている関数 modelLoss
を作成します。この関数は、モデル パラメーター、特徴データ、隣接行列、および one-hot 符号化されたターゲットを入力として受け取り、パラメーターについての損失とその損失の勾配、およびネットワーク予測を返します。
学習オプションの指定
1500 エポック学習させ、Adam ソルバーの学習率を 0.01 に設定します。
numEpochs = 1500; learnRate = 0.01;
300 エポックごとにネットワークを検証します。
validationFrequency = 300;
モデルの学習
Adam 用にパラメーターを初期化します。
trailingAvg = []; trailingAvgSq = [];
学習用と検証用の特徴データを dlarray
オブジェクトに変換します。
XTrain = dlarray(XTrain); XValidation = dlarray(XValidation);
GPU が利用できる場合、GPU で学習を行います。GPU を使用するには、Parallel Computing Toolbox™ とサポートされている GPU デバイスが必要です。サポートされているデバイスの詳細については、GPU 計算の要件 (Parallel Computing Toolbox) (Parallel Computing Toolbox) を参照してください。GPU で学習を行うには、データを gpuArray
オブジェクトに変換します。
if canUseGPU XTrain = gpuArray(XTrain); end
関数 onehotencode
を使用して、学習ラベルと検証ラベルを one-hot 符号化ベクトルに変換します。
TTrain = onehotencode(labelsTrain,2,ClassNames=classes); TValidation = onehotencode(labelsValidation,2,ClassNames=classes);
TrainingProgressMonitor
オブジェクトを初期化します。
monitor = trainingProgressMonitor( ... Metrics=["TrainingLoss","ValidationLoss"], ... Info="Epoch", ... XLabel="Epoch"); groupSubPlot(monitor,"Loss",["TrainingLoss","ValidationLoss"])
カスタム学習ループを使用してモデルに学習させます。この学習では、全バッチによる勾配降下法を使用します。
各エポックで次を行います。
関数
dlfeval
およびmodelLoss
を使用してモデルの損失と勾配を評価します。adamupdate
を使用してネットワーク パラメーターを更新します。学習プロットを更新します。
必要に応じて、関数
model
で予測を行ってネットワークを検証し、検証損失をプロットします。
epoch = 0; while epoch < numEpochs && ~monitor.Stop epoch = epoch + 1; % Evaluate the model loss and gradients. [loss,gradients] = dlfeval(@modelLoss,parameters,XTrain,ATrain,TTrain); % Update the network parameters using the Adam optimizer. [parameters,trailingAvg,trailingAvgSq] = adamupdate(parameters,gradients, ... trailingAvg,trailingAvgSq,epoch,learnRate); % Record the training loss and epoch. recordMetrics(monitor,epoch,TrainingLoss=loss); updateInfo(monitor,Epoch=(epoch+" of "+numEpochs)); % Display the validation metrics. if epoch == 1 || mod(epoch,validationFrequency) == 0 YValidation = model(parameters,XValidation,AValidation); lossValidation = crossentropy(YValidation,TValidation,DataFormat="BC"); % Record the validation loss. recordMetrics(monitor,epoch,ValidationLoss=lossValidation); end monitor.Progress = 100*(epoch/numEpochs); end
モデルのテスト
テスト データを使用して、モデルをテストします。
学習データおよび検証データの場合と同じ手順に従って、テスト データを前処理します。
[ATest,XTest,labelsTest] = preprocessData(adjacencyDataTest,coulombDataTest,atomDataTest); XTest = (XTest - muX)./sqrt(sigsqX);
テスト用の特徴データを dlarray
オブジェクトに変換します。
XTest = dlarray(XTest);
このデータで予測を行い、関数 onehotdecode
を使用して確率をカテゴリカル ラベルに変換します。
YTest = model(parameters,XTest,ATest); YTest = onehotdecode(YTest,classes,2);
精度を計算します。
accuracy = mean(YTest == labelsTest)
accuracy = 0.8930
モデルがどのようにして間違った予測を行うかを可視化し、クラスごとの適合率とクラスごとの再現率に基づいてモデルを評価するには、関数 confusionchart
を使用して混同行列を計算します。
クラスごとの適合率は、クラスに関する陽性予測の総数に対する真陽性の割合です。陽性予測の総数には、真陽性と偽陽性が含まれます。偽陽性は、あるクラスが観測値に含まれているとモデルが間違って予測した結果です。クラスごとの再現率 (これは真陽性率とも呼ばれます) は、クラスに関する陽性観測値の総数に対する真陽性の割合です。陽性観測値の総数には、真陽性と偽陰性が含まれます。偽陰性は、あるクラスが観測値に含まれていないとモデルが間違って予測した結果です。
figure cm = confusionchart(labelsTest,YTest, ... ColumnSummary="column-normalized", ... RowSummary="row-normalized"); title("GCN QM7 Confusion Chart");
クラスごとの適合率スコアは、チャートの列の要約における最初の行に示され、クラスごとの再現率スコアは、チャートの行の要約における最初の列に示されます。
新しいデータを使用した予測
この例のモデル予測関数のセクションにリストされている関数 modelPredictions
を使用して、ラベル付けされていないデータで予測を行います。簡単にするために、テスト データに含まれる最初のいくつかの観測値を使用します。
numObservationsNew = 4; adjacencyDataNew = adjacencyDataTest(:,:,1:numObservationsNew); coulombDataNew = coulombDataTest(:,:,1:numObservationsNew); predictions = modelPredictions(parameters,coulombDataNew,adjacencyDataNew,muX,sigsqX,classes);
プロットで予測を可視化します。各分子について、隣接行列を使用してグラフ表現を作成し、予測に基づいてノードをラベル付けします。
figure tiledlayout("flow") for i = 1:numObservationsNew % Extract unpadded adjacency data. numNodes = find(any(adjacencyDataTest(:,:,i)),1,"last"); A = adjacencyDataTest(1:numNodes,1:numNodes,i); % Create and plot graph representation. nexttile G = graph(A); plot(G,NodeLabel=string(predictions{i}),Layout="force") title("Observation " + i + " Prediction") end
サポート関数
データ前処理関数
関数 preprocessData
は、次の手順を使用して、隣接、クーロン、原子のデータを前処理します。
この例の予測子前処理関数のセクションにリストされている関数
preprocessPredictors
を使用して、隣接行列とクーロン行列を前処理。原子のデータをカテゴリカル ラベルから成るフラット化された配列に変換。
function [adjacency,features,labels] = preprocessData(adjacencyData,coulombData,atomData) [adjacency, features] = preprocessPredictors(adjacencyData,coulombData); labels = []; % Convert labels to categorical. for i = 1:size(adjacencyData,3) % Extract and append unpadded data. T = nonzeros(atomData(i,:)); labels = [labels; T]; end labels2 = nonzeros(atomData); assert(isequal(labels2,labels2)) atomicNumbers = unique(labels); atomNames = atomicSymbol(atomicNumbers); labels = categorical(labels, atomicNumbers, atomNames); end
予測子前処理関数
関数 preprocessPredictors
は、次の手順を使用して、隣接行列とクーロン行列を前処理します。
各分子について次を行います。
パディングされていないデータを抽出。
パディングされていないクーロン行列の対角要素から特徴ベクトルを抽出。
抽出したデータを出力配列に追加。
GCN で隣接行列を入力するには、各隣接行列を含む単一のスパース ブロック対角行列が必要となります。ここで、各ブロックは、1 つのグラフ インスタンスの隣接行列に対応します。この関数は、ブロック対角行列にデータを追加するために、関数 blkdiag
を使用します。
function [adjacency,features] = preprocessPredictors(adjacencyData,coulombData) adjacency = sparse([]); features = []; for i = 1:size(adjacencyData, 3) % Extract unpadded data. numNodes = find(any(adjacencyData(:,:,i)),1,"last"); A = adjacencyData(1:numNodes,1:numNodes,i); X = coulombData(1:numNodes,1:numNodes,i); % Extract feature vector from diagonal of Coulomb matrix. X = diag(X); % Append extracted data. adjacency = blkdiag(adjacency,A); features = [features; X]; end end
モデル関数
関数 model
は、モデル パラメーター parameters
、特徴行列 X
、および隣接行列 A
を入力として受け取り、ネットワーク予測を返します。関数 model
は、前処理のステップにおいて、この例の隣接正規化関数のセクションにリストされている関数 normalizeAdjacency
を使用し、"正規化された" 隣接行列を計算します。正規化された隣接行列は、下に示す式の に対応します。
この深層学習モデルは、隣接行列 と特徴行列 を入力として受け取り、カテゴリカル予測を出力します。
乗算処理では、学習可能な重みによる重み付き乗算が行われます。
詳しく見ると、このモデルは という形式の一連の演算となっています (最後の演算に加算ステップは含まれません)。この式は以下のようになっています。
は活性化関数。
.
は乗算用の重み行列。
は、グラフ の隣接行列に自己結合を追加したもの。 は単位行列。
は の次数行列。
function Y = model(parameters,X,A) ANorm = normalizeAdjacency(A); Z1 = X; Z2 = ANorm * Z1 * parameters.mult1.Weights; Z2 = relu(Z2) + Z1; Z3 = ANorm * Z2 * parameters.mult2.Weights; Z3 = relu(Z3) + Z2; Z4 = ANorm * Z3 * parameters.mult3.Weights; Y = softmax(Z4,DataFormat="BC"); end
モデル損失関数
関数 modelLoss
は、モデル パラメーター parameters
、特徴行列 X
、および隣接行列 A
を入力として受け取り、one-hot 符号化されたターゲット データ T
、およびモデル パラメーターについての損失とその損失の勾配を返します。
function [loss,gradients] = modelLoss(parameters,X,A,T) Y = model(parameters,X,A); loss = crossentropy(Y,T,DataFormat="BC"); gradients = dlgradient(loss, parameters); end
モデル予測関数
関数 modelPredictions
は、モデル パラメーター、クーロンと隣接の入力データ、正規化統計量 (mu
、sigsq
)、およびクラス名のリストを入力として受け取り、入力データについて予測されたノード ラベルの cell 配列を返します。この関数は、入力グラフを一度に 1 回ずつループ処理することで予測を行います。
function predictions = modelPredictions(parameters,coulombData,adjacencyData,mu,sigsq,classes) predictions = {}; numObservations = size(coulombData,3); for i = 1:numObservations % Extract unpadded data. numNodes = find(any(adjacencyData(:,:,i)),1,"last"); A = adjacencyData(1:numNodes,1:numNodes,i); X = coulombData(1:numNodes,1:numNodes,i); % Preprocess data. [A,X] = preprocessPredictors(A,X); X = (X - mu)./sqrt(sigsq); X = dlarray(X); % Make predictions. Y = model(parameters,X,A); Y = onehotdecode(Y,classes,2); predictions{end+1} = Y; end end
隣接正規化関数
関数 normalizeAdjacency
は、隣接行列 を入力として受け取り、"正規化された" 隣接行列 を返します。ここで、 は、グラフの隣接行列に自己結合を追加したもの、 は単位行列、 は の次数行列です。
function ANorm = normalizeAdjacency(A) % Add self connections to adjacency matrix. A = A + speye(size(A)); % Compute inverse square root of degree. degree = sum(A, 2); degreeInvSqrt = sparse(sqrt(1./degree)); % Normalize adjacency matrix. ANorm = diag(degreeInvSqrt) * A * diag(degreeInvSqrt); end
参考文献
Kipf, Thomas N., and Max Welling. “Semi-Supervised Classification with Graph Convolutional Networks.” Paper presented at ICLR 2017, Toulon, France, April 2017.
Blum, Lorenz C., and Jean-Louis Reymond. “970 Million Druglike Small Molecules for Virtual Screening in the Chemical Universe Database GDB-13.” Journal of the American Chemical Society 131, no. 25 (July 1, 2009): 8732–33. https://doi.org/10.1021/ja902302h.
Rupp, Matthias, Alexandre Tkatchenko, Klaus-Robert Müller, and O. Anatole von Lilienfeld. “Fast and Accurate Modeling of Molecular Atomization Energies with Machine Learning.” Physical Review Letters 108, no. 5 (January 31, 2012): 058301. https://doi.org/10.1103/PhysRevLett.108.058301.
Copyright 2021, The MathWorks, Inc.
参考
dlarray
| dlfeval
| dlgradient
| minibatchqueue