深層学習を使用した脳腫瘍の 3 次元セグメンテーション
この例では、3 次元医用画像から脳腫瘍のセマンティック セグメンテーションを実行する方法を説明します。
セマンティック セグメンテーションでは、イメージの各ピクセルまたは 3 次元ボリュームのボクセルにクラスでラベル付けします。この例では、3 次元 U-Net 深層学習ネットワークを使用して、磁気共鳴法 (MRI) スキャンで脳腫瘍のバイナリ セマンティック セグメンテーションを実行する方法を説明します。U-Net は、セマンティック セグメンテーションの分野で一般的になった高速かつ高効率でシンプルなネットワークです [1]。
医用画像のセグメンテーションの課題の 1 つとして、3 次元ボリュームの格納と処理に必要なメモリの量があります。GPU リソースの制約があるため、全入力ボリュームでネットワークに学習させてセグメンテーションを実行することは現実的ではありません。この例では、学習用およびセグメンテーション用にイメージを小さなパッチ (ブロック) に分割することで問題を解決します。
医用画像のセグメンテーションのもう 1 つの課題として、データにクラスの不均衡があり、従来の交差エントロピー損失の使用時に学習の妨げになることがあります。この例では、重み付きマルチクラス Dice 損失関数 [4] を使用して、問題を解決します。クラスへの重み付けは、Dice スコアに対する大きな領域の影響を無効にするのに役立ち、ネットワークがより小さい領域をセグメント化する方法を学習するのを容易にします。
この例では、事前学習済みの 3 次元 U-Net アーキテクチャを使用して脳腫瘍のセグメンテーションを実行する方法、および一連のテスト イメージを使用してネットワーク性能を評価する方法を説明します。オプションとして、BraTS データ セット [2] を使用して 3 次元 U-Net に学習させることもできます。
事前学習済みの 3 次元 U-Net を使用した脳腫瘍のセグメンテーションの実行
事前学習済みの 3 次元 U-Net のダウンロード
事前学習済みの 3 次元 U-Net を net
という変数にダウンロードします。
dataDir = fullfile(tempdir,"BraTS"); if ~exist(dataDir,'dir') mkdir(dataDir); end trained3DUnetURL = "https://www.mathworks.com/supportfiles/"+ ... "vision/data/brainTumor3DUNetValid.mat"; downloadTrainedNetwork(trained3DUnetURL,dataDir); load(dataDir+filesep+"brainTumor3DUNetValid.mat");
BraTS サンプル データのダウンロード
補助関数 downloadBraTSSampleTestData
を使用して、5 つのサンプル テスト ボリュームとそれらに対応する BraTS データ セットのラベルをダウンロードします [3]。この補助関数は、この例にサポート ファイルとして添付されています。サンプル データを使用すると、データ セット全体をダウンロードすることなく、テスト データに対してセグメンテーションを実行できます。
downloadBraTSSampleTestData(dataDir);
いずれかのボリューム サンプルをピクセル ラベルのグラウンド トゥルースとともに読み込みます。
testDir = dataDir+filesep+"sampleBraTSTestSetValid"; data = load(fullfile(testDir,"imagesTest","BraTS446.mat")); labels = load(fullfile(testDir,"labelsTest","BraTS446.mat")); volTest = data.cropVol; volTestLabels = labels.cropLabel;
セマンティック セグメンテーションの実行
この例では、オーバーラップタイル手法を使用して大きなボリュームを処理します。オーバーラップタイル手法では、オーバーラップしているブロックを選択し、関数semanticseg
(Computer Vision Toolbox)を使用して各ブロックのラベルを予測してから、ブロックを再度組み合わせてセグメント化された完全なテスト ボリュームにします。この手法を使用すると、メモリ リソースが限られている GPU で効率のよい処理を行うことができます。また、この手法を使用すると、ニューラル ネットワークの畳み込みの有効な部分 [5] を使用して境界アーティファクトを減らすことができます。
オーバーラップタイル手法を実装するには、ボリューム データをblockedImage
(Image Processing Toolbox)オブジェクトとして保存し、関数apply
(Image Processing Toolbox)を使用してブロックを処理します。
前のセクションでダウンロードしたサンプル ボリューム用の blockedImage
オブジェクトを作成します。
bim = blockedImage(volTest);
関数 apply
は、blockedImage
内の各ブロックに対してカスタム関数を実行します。各ブロックに対して実行する関数として、semanticsegBlock
を定義します。
semanticsegBlock = @(bstruct)semanticseg(bstruct.Data,net);
ネットワークの出力サイズとして、ブロックのサイズを指定します。オーバーラップするブロックを作成するには、非ゼロの境界サイズを指定します。この例では、ブロックと境界を合わせたサイズがネットワークの入力サイズと等しくなるように境界サイズを指定します。
networkInputSize = net.Layers(1).InputSize; networkOutputSize = net.Layers(end).OutputSize; blockSize = [networkOutputSize(1:3) networkInputSize(end)]; borderSize = (networkInputSize(1:3) - blockSize(1:3))/2;
blockedImage
apply
を使用し、部分ブロックのパディングを true
に設定した状態で、セマンティック セグメンテーションを実行します。ボリューム データに複数のモダリティが含まれているため、既定のパディング手法である "replicate"
が適しています。メモリ リソースが限られている GPU でメモリ不足エラーが発生するのを防ぐため、バッチ サイズを 1 として指定します。ただし、GPU に十分なメモリがある場合は、ブロック サイズを増やすことで処理速度を上げることができます。
batchSize = 1; results = apply(bim, ... semanticsegBlock, ... BlockSize=blockSize, ... BorderSize=borderSize,... PadPartialBlocks=true, ... BatchSize=batchSize); predictedLabels = results.Source;
グラウンド トゥルース ラベルおよび予測されたラベルの中心スライスが奥行方向に沿って示されたモンタージュを表示します。
zID = size(volTest,3)/2;
zSliceGT = labeloverlay(volTest(:,:,zID),volTestLabels(:,:,zID));
zSlicePred = labeloverlay(volTest(:,:,zID),predictedLabels(:,:,zID));
figure
montage({zSliceGT,zSlicePred},Size=[1 2],BorderSize=5)
title("Labeled Ground Truth (Left) vs. Network Prediction (Right)")
次のイメージは、いずれかのボリュームの全体にわたってスライスを逐次的に表示した結果を示しています。左側はラベル付きのグラウンド トゥルース、右側はネットワーク予測です。
3 次元 U-Net の学習
この例のこの部分では、3 次元 U-Net に学習させる方法を示します。学習データ セットのダウンロードやネットワークの学習を行わない場合は、この例のネットワーク性能の評価のセクションに進んでください。
BraTS データ セットのダウンロード
この例では、BraTS データ セット [2] を使用します。BraTS データセットには、脳腫瘍、すなわち最も一般的な原発性悪性脳腫瘍である神経膠腫の MRI スキャンが格納されています。データ ファイルのサイズは~ 7 GB です。
BraTS データをダウンロードするには、Medical Segmentation Decathlon の Web サイトに移動し、[Download Data] リンクをクリックします。"Task01_BrainTumour.tar" ファイル [3] をダウンロードします。変数 imageDir
で指定されたディレクトリに TAR ファイルを解凍します。正常に解凍されると、imageDir
には imagesTr
、imagesTs
、および labelsTr
という 3 つのサブディレクトリを持つ Task01_BrainTumour
という名前のディレクトリが含まれます。
データセットには 750 個の 4 次元ボリュームが格納されており、それぞれが 3 次元イメージのスタックを表します。各 4 次元ボリュームのサイズは 240 x 240 x 155 x 4 であり、最初の 3 つの次元は 3 次元ボリューム イメージの高さ、幅、奥行に対応します。4 番目の次元は異なるスキャン モダリティに対応します。このデータセットは、ボクセル ラベルを含む 484 個の学習ボリュームと 266 個のテスト ボリュームに分割されています。テスト ボリュームにはラベルがないため、この例ではテスト データを使用しません。代わりに、この例では 484 個の学習ボリュームを、学習、検証、およびテストに使用される 3 つの個別のセットに分割します。
学習データと検証データの前処理
より効率的に 3 次元 U-Net ネットワークに学習させるには、補助関数 preprocessBraTSDataset
を使用して MRI データを前処理します。この関数は、この例にサポート ファイルとして添付されています。この補助関数は以下の操作を実行します。
主に脳と腫瘍を含む領域に合わせてデータをトリミングします。データをトリミングすると、各 MRI ボリュームの最も重要な部分とそれに対応するラベルを維持しながら、データのサイズが小さくなります。
平均を減算し、トリミングされた脳の領域の標準偏差で除算することにより、各ボリュームの各モダリティを個別に正規化します。
484 個の学習ボリュームを 400 個の学習セット、29 個の検証セット、および 55 個のテスト セットに分割します。
データの前処理は、完了するのに約 30 分かかることがあります。
sourceDataLoc = dataDir+filesep+"Task01_BrainTumour"; preprocessDataLoc = dataDir+filesep+"preprocessedDataset"; preprocessBraTSDataset(preprocessDataLoc,sourceDataLoc);
学習および検証用のランダム パッチ抽出データストアの作成
imageDatastore
を作成して 3 次元イメージ データを格納します。MAT ファイル形式は非標準イメージ形式であるため、イメージ データを読み取るために MAT ファイル リーダーを使用しなければなりません。補助 MAT ファイル リーダー matRead
を使用できます。この関数は、この例にサポート ファイルとして添付されています。
volLoc = fullfile(preprocessDataLoc,"imagesTr"); volds = imageDatastore(volLoc,FileExtensions=".mat",ReadFcn=@matRead);
ラベルを保存する pixelLabelDatastore
(Computer Vision Toolbox) を作成します。
lblLoc = fullfile(preprocessDataLoc,"labelsTr"); classNames = ["background","tumor"]; pixelLabelID = [0 1]; pxds = pixelLabelDatastore(lblLoc,classNames,pixelLabelID, ... FileExtensions=".mat",ReadFcn=@matRead);
グラウンド トゥルース イメージおよび対応するピクセル ラベル データからランダム パッチを抽出するrandomPatchExtractionDatastore
(Image Processing Toolbox)を作成します。パッチ サイズとして 132 x 132 x 132 ボクセルを指定します。"PatchesPerImage"
を指定して、学習中にボリュームとラベルの各ペアからランダムに配置された 16 個のパッチを抽出します。ミニバッチ サイズとして 8 を指定します。
patchSize = [132 132 132];
patchPerImage = 16;
miniBatchSize = 8;
patchds = randomPatchExtractionDatastore(volds,pxds,patchSize, ...
PatchesPerImage=patchPerImage);
patchds.MiniBatchSize = miniBatchSize;
検証イメージとピクセル ラベル データからパッチを抽出する randomPatchExtractionDatastore
を作成します。検証データを使用して、ネットワークが継続的に学習しているか、時間の経過に伴って適合不足や過適合が発生していないかを評価できます。
volLocVal = fullfile(preprocessDataLoc,"imagesVal"); voldsVal = imageDatastore(volLocVal,FileExtensions=".mat", ... ReadFcn=@matRead); lblLocVal = fullfile(preprocessDataLoc,"labelsVal"); pxdsVal = pixelLabelDatastore(lblLocVal,classNames,pixelLabelID, ... FileExtensions=".mat",ReadFcn=@matRead); dsVal = randomPatchExtractionDatastore(voldsVal,pxdsVal,patchSize, ... PatchesPerImage=patchPerImage); dsVal.MiniBatchSize = miniBatchSize;
3 次元 U-Net 層のセットアップ
この例では、3 次元 U-Net ネットワーク [1] を使用します。U-Net では、最初の一連の畳み込み層に最大プーリング層が点在し、入力イメージの解像度を逐次下げていきます。これらの層に、一連の畳み込み層が続き、その中にアップサンプリング演算処理が点在し、入力イメージの解像度を逐次上げていきます。バッチ正規化層は各 ReLU 層の前に作成されます。U-Net の名前は、このネットワークが文字「U」のように対称の形状で描けることに由来しています。
関数unetLayers
(Computer Vision Toolbox)を使用して、既定の 3 次元 U-Net ネットワークを作成します。2 クラス セグメンテーションを指定します。また、有効な畳み込みパディングを指定して、テスト ボリュームの予測にオーバーラップタイル手法を使用する際に境界アーティファクトを回避します。
numChannels = 4; inputPatchSize = [patchSize numChannels]; numClasses = 2; [lgraph,outPatchSize] = unet3dLayers(inputPatchSize, ... numClasses,ConvolutionPadding="valid");
関数transform
を、補助関数 augmentAndCrop3dPatch
によって指定されたカスタム前処理演算と共に使用して、学習データと検証データを拡張します。この関数は、この例にサポート ファイルとして添付されています。関数 augmentAndCrop3dPatch
は以下の操作を実行します。
学習データをランダムに回転および反転させて、学習をさらにロバストにする。この関数では検証データの回転または反転は行われません。
応答パッチをトリミングし、ネットワークの出力サイズを 44 x 44 x 44 ボクセルにする。
dsTrain = transform(patchds, ... @(patchIn)augmentAndCrop3dPatch(patchIn,outPatchSize,"Training")); dsVal = transform(dsVal, ... @(patchIn)augmentAndCrop3dPatch(patchIn,outPatchSize,"Validation"));
小さい腫瘍領域をより適切にセグメント化して大きい背景領域の影響を軽減するため、この例では dicePixelClassificationLayer
(Computer Vision Toolbox) を使用します。ピクセル分類層を Dice ピクセル分類層に置き換えます。
outputLayer = dicePixelClassificationLayer(Name="Output"); lgraph = replaceLayer(lgraph,"Segmentation-Layer",outputLayer);
この例の学習データと検証データの前処理の節で既にデータは正規化されています。image3dInputLayer
でのデータ正規化は不要なため、入力層をデータ正規化が行われない入力層に置き換えます。
inputLayer = image3dInputLayer(inputPatchSize, ... Normalization="none",Name="ImageInputLayer"); lgraph = replaceLayer(lgraph,"ImageInputLayer",inputLayer);
または、"ディープ ネットワーク デザイナー" アプリを使用して 3 次元 U-Net ネットワークを変更できます。
deepNetworkDesigner(lgraph)
学習オプションの指定
adam
最適化ソルバーを使用してネットワークに学習させます。関数trainingOptions
を使用してハイパーパラメーター設定を指定します。学習率の初期値は 5e-4 に設定されており、学習が進むにつれて徐々に減少します。GPU メモリに基づいて MiniBatchSize
プロパティを試すことができます。GPU メモリを最大限に活用するには、バッチ サイズを大きくすることより入力パッチを大きくすることを優先します。MiniBatchSize
の値が小さい場合、バッチ正規化層の効果が小さくなることに注意してください。MiniBatchSize
に基づいて初期学習率を微調整します。
options = trainingOptions("adam", ... MaxEpochs=50, ... InitialLearnRate=5e-4, ... LearnRateSchedule="piecewise", ... LearnRateDropPeriod=5, ... LearnRateDropFactor=0.95, ... ValidationData=dsVal, ... ValidationFrequency=400, ... Plots="training-progress", ... Verbose=false, ... MiniBatchSize=miniBatchSize);
ネットワークの学習
この例では既定で、ダウンロードした事前学習済みの 3 次元 U-Net ネットワークを使用します。この事前学習済みのネットワークを使用することで、学習の完了を待たずにセマンティック セグメンテーションを実行してセグメンテーションの結果を評価できます。
ネットワークに学習させるには、次のコードで変数 doTraining
を true
に設定します。関数trainNetwork
を使用してネットワークに学習させます。
GPU が利用できる場合、GPU で学習を行います。GPU を使用するには、Parallel Computing Toolbox™、および CUDA® 対応の NVIDIA® GPU が必要です。詳細については、GPU 計算の要件 (Parallel Computing Toolbox)を参照してください。学習には 4 つの NVIDIA™ Titan Xp GPU を使用したマルチ GPU システムで約 30 時間を要します。ご使用の GPU ハードウェアによっては、さらに長い時間がかかる可能性もあります。
doTraining =false; if doTraining [net,info] = trainNetwork(dsTrain,lgraph,options); modelDateTime = string(datetime("now",Format="yyyy-MM-dd-HH-mm-ss")); save("trained3DUNet-"+modelDateTime+".mat","net"); end
ネットワーク性能の評価
テスト用のグラウンド トゥルース ボリュームおよびラベルを含むテスト データのソースを選択します。次のコードで変数 useFullTestSet
を false
のままにしておくと、この例では 5 つのサンプル ボリュームがテストで使用されます。変数 useFullTestSet
を true
に設定すると、この例ではデータセット全体から選択された 55 個のテスト イメージが使用されます。
useFullTestSet =false; if useFullTestSet volLocTest = fullfile(preprocessDataLoc,"imagesTest"); lblLocTest = fullfile(preprocessDataLoc,"labelsTest"); else volLocTest = fullfile(testDir,"imagesTest"); lblLocTest = fullfile(testDir,"labelsTest"); end
変数 voldsTest
は、グラウンド トゥルース テスト イメージを格納します。変数 pxdsTest
は、グラウンド トゥルース ラベルを格納します。
voldsTest = imageDatastore(volLocTest,FileExtensions=".mat", ... ReadFcn=@matRead); pxdsTest = pixelLabelDatastore(lblLocTest,classNames,pixelLabelID, ... FileExtensions=".mat",ReadFcn=@matRead);
各テスト ボリュームについて、関数apply
(Image Processing Toolbox)を使用して各ブロックを処理します。関数 apply
は、この例の終わりで定義されている補助関数 calculateBlockMetrics
によって指定された演算を実行します。関数 calculateBlockMetrics
は、各ブロックのセマンティック セグメンテーションを実行し、予測ラベルとグラウンド トゥルース ラベルの混同行列を計算します。
imageIdx = 1; datasetConfMat = table; while hasdata(voldsTest) % Read volume and label data vol = read(voldsTest); volLabels = read(pxdsTest); % Create blockedImage for volume and label data testVolume = blockedImage(vol); testLabels = blockedImage(volLabels{1}); % Calculate block metrics blockConfMatOneImage = apply(testVolume, ... @(block,labeledBlock) ... calculateBlockMetrics(block,labeledBlock,net), ... ExtraImages=testLabels, ... PadPartialBlocks=true, ... BlockSize=blockSize, ... BorderSize=borderSize, ... UseParallel=false); % Read all the block results of an image and update the image number blockConfMatOneImageDS = blockedImageDatastore(blockConfMatOneImage); blockConfMat = readall(blockConfMatOneImageDS); blockConfMat = struct2table([blockConfMat{:}]); blockConfMat.ImageNumber = imageIdx.*ones(height(blockConfMat),1); datasetConfMat = [datasetConfMat;blockConfMat]; imageIdx = imageIdx + 1; end
関数evaluateSemanticSegmentation
(Computer Vision Toolbox)を使用して、セグメンテーションのデータ セット メトリクスとブロック メトリクスを評価します。
[metrics,blockMetrics] = evaluateSemanticSegmentation( ... datasetConfMat,classNames,Metrics="all");
Evaluating semantic segmentation results ---------------------------------------- * Selected metrics: global accuracy, class accuracy, IoU, weighted IoU. * Processed 5 images. * Finalizing... Done. * Data set metrics: GlobalAccuracy MeanAccuracy MeanIoU WeightedIoU ______________ ____________ _______ ___________ 0.99902 0.97955 0.95978 0.99808
各イメージについて計算されたジャッカード スコアを表示します。
metrics.ImageMetrics.MeanIoU
ans = 5×1
0.9613
0.9570
0.9551
0.9656
0.9594
サポート関数
補助関数 calculateBlockMetrics
は、ブロックのセマンティック セグメンテーションを実行し、予測ラベルとグラウンド トゥルース ラベルの混同行列を計算します。この関数は、ブロックに関する混同行列とメタデータが格納されたフィールドをもつ構造体を返します。この構造体と関数 evaluateSemanticSegmentation
を使用することで、メトリクスを計算してブロック単位で結果を集約できます。
function blockMetrics = calculateBlockMetrics(bstruct,gtBlockLabels,net) % Segment block predBlockLabels = semanticseg(bstruct.Data,net); % Trim away border region from gtBlockLabels blockStart = bstruct.BorderSize + 1; blockEnd = blockStart + bstruct.BlockSize - 1; gtBlockLabels = gtBlockLabels( ... blockStart(1):blockEnd(1), ... blockStart(2):blockEnd(2), ... blockStart(3):blockEnd(3)); % Evaluate segmentation results against ground truth confusionMat = segmentationConfusionMatrix(predBlockLabels,gtBlockLabels); % blockMetrics is a struct with confusion matrices, image number, % and block information. blockMetrics.ConfusionMatrix = confusionMat; blockMetrics.ImageNumber = bstruct.ImageNumber; blockInfo.Start = bstruct.Start; blockInfo.End = bstruct.End; blockMetrics.BlockInfo = blockInfo; end
参考文献
[1] Çiçek, Ö., A. Abdulkadir, S. S. Lienkamp, T. Brox, and O. Ronneberger. "3D U-Net: Learning Dense Volumetric Segmentation from Sparse Annotation." In Proceedings of the International Conference on Medical Image Computing and Computer-Assisted Intervention - MICCAI 2016. Athens, Greece, Oct. 2016, pp. 424-432.
[2] Isensee, F., P. Kickingereder, W. Wick, M. Bendszus, and K. H. Maier-Hein. "Brain Tumor Segmentation and Radiomics Survival Prediction: Contribution to the BRATS 2017 Challenge." In Proceedings of BrainLes: International MICCAI Brainlesion Workshop. Quebec City, Canada, Sept. 2017, pp. 287-297.
[3] "Brain Tumours". Medical Segmentation Decathlon. http://medicaldecathlon.com/
BraTS データセットは、CC-BY-SA 4.0 のライセンスに基づき Medical Segmentation Decathlon によって提供されます。一切の保証および表明を行いません。詳細はライセンスを参照してください。MathWorks® は、この例の BraTS サンプル データのダウンロードのセクションにリンクが示されているデータ セットを変更しています。変更されたサンプル データ セットは主に脳と腫瘍を含む領域に合わせてトリミングされており、各チャネルは平均を減算し、トリミングされた脳の領域の標準偏差で除算することにより、個別に正規化されています。
[4] Sudre, C. H., W. Li, T. Vercauteren, S. Ourselin, and M. J. Cardoso. "Generalised Dice Overlap as a Deep Learning Loss Function for Highly Unbalanced Segmentations." Deep Learning in Medical Image Analysis and Multimodal Learning for Clinical Decision Support: Third International Workshop. Quebec City, Canada, Sept. 2017, pp. 240-248.
[5] Ronneberger, O., P. Fischer, and T. Brox. "U-Net:Convolutional Networks for Biomedical Image Segmentation." In Proceedings of the International Conference on Medical Image Computing and Computer-Assisted Intervention - MICCAI 2015. Munich, Germany, Oct. 2015, pp. 234-241. Available at arXiv:1505.04597.
参考
randomPatchExtractionDatastore
(Image Processing Toolbox) | trainNetwork
| trainingOptions
| transform
| pixelLabelDatastore
(Computer Vision Toolbox) | imageDatastore
| semanticseg
(Computer Vision Toolbox) | dicePixelClassificationLayer
(Computer Vision Toolbox)