Main Content

このページの翻訳は最新ではありません。ここをクリックして、英語の最新版を参照してください。

深層学習を使用した脳腫瘍の 3 次元セグメンテーション

この例では、3 次元 U-Net ニューラル ネットワークに学習させて、3 次元医用画像から脳腫瘍のセマンティック セグメンテーションを実行する方法を説明します。この例では、3 次元 U-Net ネットワークに学習させる方法を示し、さらに事前学習済みのネットワークも示します。3 次元セマンティック セグメンテーションには、Compute Capability 3.0 以上の CUDA 対応 NVIDIA™ GPU の使用が強く推奨されます (Parallel Computing Toolbox™ が必要)。

はじめに

セマンティック セグメンテーションでは、イメージの各ピクセルまたは 3 次元ボリュームのボクセルにクラスでラベル付けします。この例では、深層学習の各種方法を使用して、MRI (Magnetic Resonance Imaging) スキャンで脳腫瘍のバイナリ セマンティック セグメンテーションを実行する方法を説明します。このバイナリ セグメンテーションでは、各ピクセルを腫瘍または背景としてラベル付けします。

この例では、3 次元 U-Net アーキテクチャ [1] を使用して、脳腫瘍のセグメンテーションを実行します。U-Net は、セマンティック セグメンテーションの分野で一般的になった高速かつ高効率でシンプルなネットワークです。

医用画像のセグメンテーションの課題の 1 つとして、3 次元ボリュームの格納と処理に必要なメモリの量があります。GPU リソースの制約があるため、全入力ボリュームでネットワークに学習させることは現実的ではありません。この例では、イメージ パッチでネットワークに学習させることによって問題を解決します。この例で使用するオーバーラップタイル手法では、テスト パッチをつなぎ合わせてセグメント化された完全なテスト ボリュームにします。この例では、ニューラル ネットワークでの畳み込みの有効な部分を使用して、境界アーティファクトを回避します [5]。

医用画像のセグメンテーションのもう 1 つの課題として、データにクラスの不均衡があり、従来の交差エントロピー損失の使用時に学習の妨げになることがあります。この例では、重み付きマルチクラス Dice 損失関数 [4] を使用して、問題を解決します。クラスへの重み付けは、Dice スコアに対する大きな領域の影響を無効にするのに役立ち、ネットワークがより小さい領域をセグメント化する方法を学習するのを容易にします。

学習データ、検証データ、テスト データのダウンロード

この例では、BraTS データセット [2] を使用します。BraTS データセットには、脳腫瘍、すなわち最も一般的な原発性悪性脳腫瘍である神経膠腫の MRI スキャンが格納されています。データ ファイルのサイズは~ 7 GB です。BraTS データセットをダウンロードしない場合、この例の事前学習済みのネットワークとサンプル テスト セットのダウンロードの節に進みます。

BraTS データセットを格納するディレクトリを作成します。

imageDir = fullfile(tempdir,'BraTS');
if ~exist(imageDir,'dir')
    mkdir(imageDir);
end

BraTS データをダウンロードするには、Medical Segmentation Decathlon の Web サイトに移動し、[Download Data] リンクをクリックします。"Task01_BrainTumour.tar" ファイル [3] をダウンロードします。変数 imageDir で指定されたディレクトリに TAR ファイルを解凍します。正常に解凍されると、imageDir には imagesTrimagesTs、および labelsTr という 3 つのサブカテゴリがある Task01_BrainTumour という名前のディレクトリが含まれます。

データセットには 750 個の 4 次元ボリュームが格納されており、それぞれが 3 次元イメージのスタックを表します。各 4 次元ボリュームのサイズは 240 x 240 x 155 x 4 であり、最初の 3 つの次元は 3 次元ボリューム イメージの高さ、幅、奥行に対応します。4 番目の次元は異なるスキャン モダリティに対応します。データセットはボクセル ラベルを含む 484 個の学習ボリュームおよび 266 個のテスト ボリュームに分割されています。テスト ボリュームにはラベルがないため、この例ではテスト データを使用しません。代わりに、この例では 484 個の学習ボリュームを、学習、検証、およびテストに使用される 3 つの個別のセットに分割します。

学習データと検証データの前処理

より効率的に 3 次元 U-Net ネットワークに学習させるには、補助関数 preprocessBraTSdataset を使用して MRI データを前処理します。この関数は、この例にサポート ファイルとして添付されています。

この補助関数は以下の操作を実行します。

  • 主に脳と腫瘍を含む領域に合わせてデータをトリミングします。データをトリミングすると、各 MRI ボリュームの最も重要な部分とそれに対応するラベルを維持しながら、データのサイズが小さくなります。

  • 平均を減算し、トリミングされた脳の領域の標準偏差で除算することにより、各ボリュームの各モダリティを個別に正規化します。

  • 484 個の学習ボリュームを 400 個の学習セット、29 個の検証セット、および 55 個のテスト セットに分割します。

データの前処理は、完了するのに約 30 分かかることがあります。

sourceDataLoc = [imageDir filesep 'Task01_BrainTumour'];
preprocessDataLoc = fullfile(tempdir,'BraTS','preprocessedDataset');
preprocessBraTSdataset(preprocessDataLoc,sourceDataLoc);

学習および検証用のランダム パッチ抽出データストアの作成

ランダム パッチ抽出データストアを使用して、ネットワークに学習データを供給し、学習の進行状況を検証します。このデータストアは、グラウンド トゥルース イメージと対応するピクセル ラベル データからランダム パッチを抽出します。パッチは、任意の大きさのボリュームでの学習時にメモリ不足を防ぐための一般的な手法です。

imageDatastore を作成して 3 次元イメージ データを格納します。MAT ファイル形式は非標準イメージ形式であるため、イメージ データを読み取るために MAT ファイル リーダーを使用しなければなりません。補助 MAT ファイル リーダー matRead を使用できます。この関数は、この例にサポート ファイルとして添付されています。

volReader = @(x) matRead(x);
volLoc = fullfile(preprocessDataLoc,'imagesTr');
volds = imageDatastore(volLoc, ...
    'FileExtensions','.mat','ReadFcn',volReader);

ラベルを保存する pixelLabelDatastore を作成します。

lblLoc = fullfile(preprocessDataLoc,'labelsTr');
classNames = ["background","tumor"];
pixelLabelID = [0 1];
pxds = pixelLabelDatastore(lblLoc,classNames,pixelLabelID, ...
    'FileExtensions','.mat','ReadFcn',volReader);

1 つのイメージ ボリュームとラベルをプレビューします。関数 labelvolshow を使用して、ラベル付きボリュームを表示します。背景ラベルの可視性 (1) を 0 に設定して、背景を完全に透明にします。

volume = preview(volds);
label = preview(pxds);

viewPnl = uipanel(figure,'Title','Labeled Training Volume');
hPred = labelvolshow(label,volume(:,:,:,1),'Parent',viewPnl, ...
    'LabelColor',[0 0 0;1 0 0]);
hPred.LabelVisibility(1) = 0;

学習イメージとピクセル ラベル データを格納する randomPatchExtractionDatastore を作成します。パッチ サイズとして 132 x 132 x 132 ボクセルを指定します。'PatchesPerImage' を指定して、学習中にボリュームとラベルの各ペアからランダムに配置された 16 個のパッチを抽出します。ミニバッチ サイズとして 8 を指定します。

patchSize = [132 132 132];
patchPerImage = 16;
miniBatchSize = 8;
patchds = randomPatchExtractionDatastore(volds,pxds,patchSize, ...
    'PatchesPerImage',patchPerImage);
patchds.MiniBatchSize = miniBatchSize;

同じ手順に従って、検証イメージとピクセル ラベル データを格納する randomPatchExtractionDatastore を作成します。検証データを使用して、ネットワークが継続的に学習しているか、時間の経過に伴って適合不足や過適合が発生していないかを評価できます。

volLocVal = fullfile(preprocessDataLoc,'imagesVal');
voldsVal = imageDatastore(volLocVal, ...
    'FileExtensions','.mat','ReadFcn',volReader);

lblLocVal = fullfile(preprocessDataLoc,'labelsVal');
pxdsVal = pixelLabelDatastore(lblLocVal,classNames,pixelLabelID, ...
    'FileExtensions','.mat','ReadFcn',volReader);

dsVal = randomPatchExtractionDatastore(voldsVal,pxdsVal,patchSize, ...
    'PatchesPerImage',patchPerImage);
dsVal.MiniBatchSize = miniBatchSize;

関数 transform を、補助関数 augmentAndCrop3dPatch によって指定されたカスタム前処理演算と共に使用して、学習データと検証データを拡張します。この関数は、この例にサポート ファイルとして添付されています。

関数 augmentAndCrop3dPatch は以下の操作を実行します。

  1. 学習データをランダムに回転および反転させて、学習をさらにロバストにする。この関数では検証データの回転または反転は行われません。

  2. 応答パッチをトリミングし、ネットワークの出力サイズを 44 x 44 x 44 ボクセルにする。

dataSource = 'Training';
dsTrain = transform(patchds,@(patchIn)augmentAndCrop3dPatch(patchIn,dataSource));

dataSource = 'Validation';
dsVal = transform(dsVal,@(patchIn)augmentAndCrop3dPatch(patchIn,dataSource));

3 次元 U-Net 層のセットアップ

この例では、3 次元 U-Net ネットワーク [1] を使用します。U-Net では、最初の一連の畳み込み層に最大プーリング層が点在し、入力イメージの解像度を逐次下げていきます。これらの層に、一連の畳み込み層が続き、その中にアップサンプリング演算処理が点在し、入力イメージの解像度を逐次上げていきます。バッチ正規化層は各 ReLU 層の前に作成されます。U-Net の名前は、このネットワークが文字「U」のように対称の形状で描けることに由来しています。

関数 unetLayers を使用して、既定の 3 次元 U-Net ネットワークを作成します。2 クラス セグメンテーションを指定します。また、有効な畳み込みパディングを指定して、テスト ボリュームの予測にオーバーラップタイル手法を使用する際に境界アーティファクトを回避します。

inputPatchSize = [132 132 132 4];
numClasses = 2;
[lgraph,outPatchSize] = unet3dLayers(inputPatchSize,numClasses,'ConvolutionPadding','valid');

小さい腫瘍領域をより適切にセグメント化して大きい背景領域の影響を軽減するため、この例では dicePixelClassificationLayer を使用します。ピクセル分類層を Dice ピクセル分類層に置き換えます。

outputLayer = dicePixelClassificationLayer('Name','Output');
lgraph = replaceLayer(lgraph,'Segmentation-Layer',outputLayer);

この例の学習データと検証データの前処理の節で既にデータが正規化されています。image3dInputLayer (Deep Learning Toolbox) でのデータ正規化は不要なため、入力層をデータ正規化が行われない入力層に置き換えます。

inputLayer = image3dInputLayer(inputPatchSize,'Normalization','none','Name','ImageInputLayer');
lgraph = replaceLayer(lgraph,'ImageInputLayer',inputLayer);

または、Deep Learning Toolbox™ のディープ ネットワーク デザイナー アプリを使用して 3 次元 U-Net ネットワークを変更できます。

更新された 3 次元 U-Net ネットワークのグラフをプロットします。

analyzeNetwork(lgraph)

学習オプションの指定

adam 最適化ソルバーを使用してネットワークに学習させます。関数 trainingOptions (Deep Learning Toolbox) を使用してハイパーパラメーター設定を指定します。学習率の初期値は 5e-4 に設定されており、学習が進むにつれて徐々に減少します。GPU メモリに基づいて MiniBatchSize プロパティを試すことができます。GPU メモリを最大限に活用するには、バッチ サイズを大きくすることより入力パッチを大きくすることを優先します。MiniBatchSize の値が小さい場合、バッチ正規化層の効果が小さくなることに注意してください。MiniBatchSize に基づいて初期学習率を微調整します。

options = trainingOptions('adam', ...
    'MaxEpochs',50, ...
    'InitialLearnRate',5e-4, ...
    'LearnRateSchedule','piecewise', ...
    'LearnRateDropPeriod',5, ...
    'LearnRateDropFactor',0.95, ...
    'ValidationData',dsVal, ...
    'ValidationFrequency',400, ...
    'Plots','training-progress', ...
    'Verbose',false, ...
    'MiniBatchSize',miniBatchSize);

事前学習済みのネットワークとサンプル テスト セットのダウンロード

オプションで、事前学習済みバージョンの 3 次元 U-Net と 5 つのサンプル テスト ボリューム、さらにこれらに対応する BraTS データセットのラベル [3] をダウンロードします。事前学習済みのモデルとサンプル データを使用すると、データセット全体をダウンロードしたりネットワークが学習するのを待機したりすることなく、テスト データに対してセグメンテーションを実行できます。

trained3DUnet_url = 'https://www.mathworks.com/supportfiles/vision/data/brainTumor3DUNetValid.mat';
sampleData_url = 'https://www.mathworks.com/supportfiles/vision/data/sampleBraTSTestSetValid.tar.gz';

imageDir = fullfile(tempdir,'BraTS');
if ~exist(imageDir,'dir')
    mkdir(imageDir);
end

downloadTrained3DUnetSampleData(trained3DUnet_url,sampleData_url,imageDir);
Downloading pretrained 3-D U-Net for BraTS data set.
This will take several minutes to download...
Done.

Downloading sample BraTS test dataset.
This will take several minutes to download and unzip...
Done.

ネットワークの学習

学習オプションとデータ ソースを構成した後、関数 trainNetwork (Deep Learning Toolbox) を使用して 3 次元 U-Net ネットワークに学習させます。ネットワークに学習させるには、次のコードで変数 doTrainingtrue に設定します。学習には、Compute Capability 3.0 以上の CUDA 対応 NVIDIA™ GPU の使用が強く推奨されます。

次のコードで変数 doTrainingfalse のままにしておくと、この例は事前学習済みの 3 次元 U-Net ネットワークを返します。

メモ: 4 つの NVIDIA™ Titan Xp GPU を使用したマルチ GPU システムでの学習には約 30 時間を要します。ご使用の GPU ハードウェアによっては、さらに長い時間がかかる可能性もあります。

doTraining = false;
if doTraining
    modelDateTime = datestr(now,'dd-mmm-yyyy-HH-MM-SS');
    [net,info] = trainNetwork(dsTrain,lgraph,options);
    save(['trained3DUNetValid-' modelDateTime '-Epoch-' num2str(options.MaxEpochs) '.mat'],'net');
else
    inputPatchSize = [132 132 132 4];
    outPatchSize = [44 44 44 2];
    load(fullfile(imageDir,'trained3DUNet','brainTumor3DUNetValid.mat'));
end

これで、U-Net を使用して、脳腫瘍を意味ごとにセグメント化できます。

テスト データのセグメンテーションの実行

イメージ ボリュームのセマンティック セグメンテーションを実行するには、GPU の使用が強く推奨されます (Parallel Computing Toolbox™ が必要)。

テスト用のグラウンド トゥルース ボリュームおよびラベルを含むテスト データのソースを選択します。次のコードで変数 useFullTestSetfalse のままにしておくと、この例ではテスト用のボリュームが 5 つ使用されます。変数 useFullTestSettrue に設定すると、この例ではデータセット全体から選択された 55 個のテスト イメージが使用されます。

useFullTestSet = false;
if useFullTestSet
    volLocTest = fullfile(preprocessDataLoc,'imagesTest');
    lblLocTest = fullfile(preprocessDataLoc,'labelsTest');
else
    volLocTest = fullfile(imageDir,'sampleBraTSTestSetValid','imagesTest');
    lblLocTest = fullfile(imageDir,'sampleBraTSTestSetValid','labelsTest');
    classNames = ["background","tumor"];
    pixelLabelID = [0 1];
end

変数 voldsTest は、グラウンド トゥルース テスト イメージを格納します。変数 pxdsTest は、グラウンド トゥルース ラベルを格納します。

volReader = @(x) matRead(x);
voldsTest = imageDatastore(volLocTest, ...
    'FileExtensions','.mat','ReadFcn',volReader);
pxdsTest = pixelLabelDatastore(lblLocTest,classNames,pixelLabelID, ...
    'FileExtensions','.mat','ReadFcn',volReader);

オーバーラップタイル手法を使用して、各テスト ボリュームのラベルを予測します。入力サイズがネットワークの出力サイズの倍数になるように各テスト ボリュームがパディングされて、有効な畳み込みの影響を補正します。オーバーラップタイル アルゴリズムでは、オーバーラップ パッチを選択し、関数 semanticseg を使用して各パッチのラベルを予測して、パッチを再度組み合わせます。

id = 1;
while hasdata(voldsTest)
    disp(['Processing test volume ' num2str(id)]);
    
    tempGroundTruth = read(pxdsTest);
    groundTruthLabels{id} = tempGroundTruth{1};
    vol{id} = read(voldsTest);
    
    % Use reflection padding for the test image. 
    % Avoid padding of different modalities.
    volSize = size(vol{id},(1:3));
    padSizePre  = (inputPatchSize(1:3)-outPatchSize(1:3))/2;
    padSizePost = (inputPatchSize(1:3)-outPatchSize(1:3))/2 + (outPatchSize(1:3)-mod(volSize,outPatchSize(1:3)));
    volPaddedPre = padarray(vol{id},padSizePre,'symmetric','pre');
    volPadded = padarray(volPaddedPre,padSizePost,'symmetric','post');
    [heightPad,widthPad,depthPad,~] = size(volPadded);
    [height,width,depth,~] = size(vol{id});
    
    tempSeg = categorical(zeros([height,width,depth],'uint8'),[0;1],classNames);
    
    % Overlap-tile strategy for segmentation of volumes.
    for k = 1:outPatchSize(3):depthPad-inputPatchSize(3)+1
        for j = 1:outPatchSize(2):widthPad-inputPatchSize(2)+1
            for i = 1:outPatchSize(1):heightPad-inputPatchSize(1)+1
                patch = volPadded( i:i+inputPatchSize(1)-1,...
                    j:j+inputPatchSize(2)-1,...
                    k:k+inputPatchSize(3)-1,:);
                patchSeg = semanticseg(patch,net);
                tempSeg(i:i+outPatchSize(1)-1, ...
                    j:j+outPatchSize(2)-1, ...
                    k:k+outPatchSize(3)-1) = patchSeg;
            end
        end
    end
    
    % Crop out the extra padded region.
    tempSeg = tempSeg(1:height,1:width,1:depth);

    % Save the predicted volume result.
    predictedLabels{id} = tempSeg;
    id=id+1;
end
Processing test volume 1
Processing test volume 2
Processing test volume 3
Processing test volume 4
Processing test volume 5

グラウンド トゥルースとネットワークの予測の比較

いずれかのテスト イメージを選択して、セマンティック セグメンテーションの精度を評価します。4 次元ボリューム データから最初のモダリティを抽出し、この 3 次元ボリューム データを変数 vol3d に格納します。

volId = 1;
vol3d = vol{volId}(:,:,:,1);

モンタージュに、グラウンド トゥルース ラベルおよび予測されたラベルの中心スライスを奥行方向に沿って表示します。

zID = size(vol3d,3)/2;
zSliceGT = labeloverlay(vol3d(:,:,zID),groundTruthLabels{volId}(:,:,zID));
zSlicePred = labeloverlay(vol3d(:,:,zID),predictedLabels{volId}(:,:,zID));

figure
montage({zSliceGT,zSlicePred},'Size',[1 2],'BorderSize',5) 
title('Labeled Ground Truth (Left) vs. Network Prediction (Right)')

関数 labelvolshow を使用して、グラウンド トゥルース ラベル付きボリュームを表示します。背景ラベルの可視性 (1) を 0 に設定して、背景を完全に透明にします。腫瘍は脳組織の内部にあるため、腫瘍を見られるように脳のボクセルを一部透明にします。脳のボクセルを一部透明にするには、ボリュームのしきい値を [0, 1] の範囲の数値として指定します。このしきい値を下回るすべての正規化ボリューム強度は、完全に透明になります。この例では、脳の内部にある腫瘍の空間的な位置を確認できるように脳のピクセルの一部を表示されたままにするために、ボリュームのしきい値を 1 未満に設定します。

viewPnlTruth = uipanel(figure,'Title','Ground-Truth Labeled Volume');
hTruth = labelvolshow(groundTruthLabels{volId},vol3d,'Parent',viewPnlTruth, ...
    'LabelColor',[0 0 0;1 0 0],'VolumeThreshold',0.68);
hTruth.LabelVisibility(1) = 0;

同じボリュームについて、予測ラベルを表示します。

viewPnlPred = uipanel(figure,'Title','Predicted Labeled Volume');
hPred = labelvolshow(predictedLabels{volId},vol3d,'Parent',viewPnlPred, ...
    'LabelColor',[0 0 0;1 0 0],'VolumeThreshold',0.68);

hPred.LabelVisibility(1) = 0;

このイメージは、一方のボリューム全体のスライスを逐次的に表示した結果を示します。左側はラベル付きのグラウンド トゥルース、右側はネットワーク予測です。

セグメンテーションの精度の定量化

関数 dice を使用してセグメンテーション精度を測定します。この関数は、予測セグメンテーションとグラウンド トゥルース セグメンテーションの Dice 類似度係数を計算します。

diceResult = zeros(length(voldsTest.Files),2);

for j = 1:length(vol)
    diceResult(j,:) = dice(groundTruthLabels{j},predictedLabels{j});
end

テスト ボリュームのセットの平均 Dice スコアを計算します。

meanDiceBackground = mean(diceResult(:,1));
disp(['Average Dice score of background across ',num2str(j), ...
    ' test volumes = ',num2str(meanDiceBackground)])
Average Dice score of background across 5 test volumes = 0.9993
meanDiceTumor = mean(diceResult(:,2));
disp(['Average Dice score of tumor across ',num2str(j), ...
    ' test volumes = ',num2str(meanDiceTumor)])
Average Dice score of tumor across 5 test volumes = 0.9585

次の図は、5 つのサンプル テスト ボリュームのセットの Dice スコアに関する統計量を可視化する boxplot (Statistics and Machine Learning Toolbox) を示します。プロットの赤い線は、クラスの Dice 値の中央値を示します。青いボックスの上下の境界はそれぞれ、25 番目の百分位数と 75 番目の百分位数を示します。黒いひげは、外れ値とは見なされない最も極端なデータ点まで延びます。

Statistics and Machine Learning Toolbox™ がある場合、関数 boxplot を使用してすべてのテスト ボリュームの Dice スコアに関する統計量を可視化できます。boxplot を作成するには、次のコードで変数 createBoxplottrue に設定します。

createBoxplot = false;
if createBoxplot
    figure
    boxplot(diceResult)
    title('Test Set Dice Accuracy')
    xticklabels(classNames)
    ylabel('Dice Coefficient')
end

参考文献

[1] Çiçek, Ö., A. Abdulkadir, S. S. Lienkamp, T. Brox, and O. Ronneberger. "3D U-Net: Learning Dense Volumetric Segmentation from Sparse Annotation." In Proceedings of the International Conference on Medical Image Computing and Computer-Assisted Intervention - MICCAI 2016. Athens, Greece, Oct. 2016, pp. 424-432.

[2] Isensee, F., P. Kickingereder, W. Wick, M. Bendszus, and K. H. Maier-Hein. "Brain Tumor Segmentation and Radiomics Survival Prediction: Contribution to the BRATS 2017 Challenge." In Proceedings of BrainLes: International MICCAI Brainlesion Workshop. Quebec City, Canada, Sept. 2017, pp. 287-297.

[3] "Brain Tumours". Medical Segmentation Decathlon. http://medicaldecathlon.com/

BraTS データセットは、CC-BY-SA 4.0 のライセンスに基づき Medical Segmentation Decathlon によって提供されます。一切の保証および表明を行いません。詳細はライセンスを参照してください。MathWorks® は、この例の事前学習済みのネットワークとサンプル テスト セットのダウンロードの節にリンクのあるデータセットに変更を加えています。変更されたサンプル データセットは主に脳と腫瘍を含む領域に合わせてトリミングされており、各チャネルは平均を減算し、トリミングされた脳の領域の標準偏差で除算することにより、個別に正規化されています。

[4] Sudre, C. H., W. Li, T. Vercauteren, S. Ourselin, and M. J. Cardoso. "Generalised Dice Overlap as a Deep Learning Loss Function for Highly Unbalanced Segmentations." Deep Learning in Medical Image Analysis and Multimodal Learning for Clinical Decision Support: Third International Workshop. Quebec City, Canada, Sept. 2017, pp. 240-248.

[5] Ronneberger, O., P. Fischer, and T. Brox. "U-Net:Convolutional Networks for Biomedical Image Segmentation." In Proceedings of the International Conference on Medical Image Computing and Computer-Assisted Intervention - MICCAI 2015. Munich, Germany, Oct. 2015, pp. 234-241. Available at arXiv:1505.04597.

参考

| | | | | | (Deep Learning Toolbox) | (Deep Learning Toolbox)

関連するトピック