Main Content

このページの翻訳は最新ではありません。ここをクリックして、英語の最新版を参照してください。

深層学習を使用したビデオとオプティカル フロー データのアクティビティ認識

この例では、ビデオの RGB データとオプティカル フロー データを使用し、アクティビティ認識用に Inflated 3-D (I3D) 2 ストリーム畳み込みニューラル ネットワークに学習させる方法を説明します [1]

視覚ベースのアクティビティ認識では、一連のビデオ フレームを使用し、歩く、泳ぐ、座るといったオブジェクトのアクションを予測します。ビデオのアクティビティ認識は、ヒューマン コンピューター インタラクション、ロボット学習、異常検出、監視、オブジェクト検出といったさまざまな分野で応用されています。たとえば、複数のカメラから入力されるビデオから複数のアクションをオンラインで予測することは、ロボット学習にとって重要です。ビデオのデータ セットに含まれるラベルはノイズが多く、ビデオ内のアクターによって実行される各種アクションのクラスは非常に偏っており、大規模なビデオのデータ セットで効率よく事前学習させることが難しいため、ビデオを使用したアクション認識はイメージ分類と比べてモデル化が困難です。I3D 2 ストリーム畳み込みネットワーク [1] などのいくつかの深層学習手法では、大規模なイメージ分類データ セットで事前学習を行うことによりパフォーマンスを改善できることが示されています。

データの読み込み

この例では、HMDB51 データ セットを使用して I3D ネットワークに学習させます。この例の最後にリストされている補助関数 downloadHMDB51 を使用し、HMDB51 データ セットを hmdb51 という名前のフォルダーにダウンロードします。

downloadFolder = fullfile(tempdir,"hmdb51");
downloadHMDB51(downloadFolder);

ダウンロードが完了したら、RAR ファイル hmdb51_org.rarhmdb51 フォルダーに解凍します。次に、この例の最後にリストされている補助関数 checkForHMDB51Folder を使用し、ダウンロードして解凍したファイルが所定の場所にあることを確認します。

allClasses = checkForHMDB51Folder(downloadFolder);

データ セットには、"飲む"、"走る"、"握手する" など、51 を超えるクラスの 7000 個のクリップから成る約 2 GB のビデオ データが格納されています。各ビデオ フレームの高さは 240 ピクセルで、最小幅は 176 ピクセルです。フレーム数は、18 から約 1000 までの範囲になります。

この例では、学習時間を短縮するために、データ セットに含まれる 51 のすべてのクラスではなく、5 つのアクション クラスを分類するようにアクティビティ認識ネットワークに学習させます。51 のすべてのクラスについて学習させる場合は useAllDatatrue に設定します。

useAllData = false;

if useAllData
    classes = allClasses;
else
    classes = ["kiss","laugh","pick","pour","pushup"];
end
dataFolder = fullfile(downloadFolder, "hmdb51_org");

データセットを、ネットワークに学習させるための学習セットとネットワークを評価するためのテスト セットに分割します。データの 80% を学習セットに使用し、残りをテスト セットに使用します。imageDatastore を使用し、各ラベルに基づき、データを学習データ セットとテスト データ セットに分割します。この処理は、各ラベルのファイルの比率をランダムに選択することによって行われます。

imds = imageDatastore(fullfile(dataFolder,classes),...
    'IncludeSubfolders', true,...
    'LabelSource', 'foldernames',...
    'FileExtensions', '.avi');

[trainImds,testImds] = splitEachLabel(imds,0.8,'randomized');

trainFilenames = trainImds.Files;
testFilenames  = testImds.Files;

ネットワークの入力データを正規化するために、データ セットの最小値および最大値が、この例に添付されている MAT ファイル inputStatistics.mat に用意されています。異なるデータ セットの最小値および最大値を見つけるには、この例の最後にリストされているサポート関数 inputStatistics を使用します。

inputStatsFilename = 'inputStatistics.mat';
if ~exist(inputStatsFilename, 'file')
    disp("Reading all the training data for input statistics...")
    inputStats = inputStatistics(dataFolder);
else
    d = load(inputStatsFilename);
    inputStats = d.inputStats;    
end

ネットワーク学習用のデータストアの作成

この例の終わりに定義されているサポート関数 createFileDatastore を使用し、学習用と検証用の 2 つの FileDatastore オブジェクトを作成します。各データストアは、ビデオ ファイルを読み取って、RGB データ、オプティカル フロー データ、対応するラベル情報を提供します。

データストアによる各読み取りのフレーム数を指定します。一般的な値は、16、32、64、または 128 です。より多くのフレームを使用すればより多くの時間情報を取得できますが、学習と予測により多くのメモリが必要になります。メモリ使用量とパフォーマンスのバランスをとるために、フレーム数を 64 に設定します。システム リソースによっては、この値を下げることが必要になる場合もあります。

numFrames = 64;

読み取るデータストアのフレームの高さと幅を指定します。高さと幅を同じ値に固定することにより、ネットワークに関するデータのバッチ処理が簡単になります。一般的な値は [112, 112]、[224, 224]、[256, 256] です。HMDB51 データ セット内のビデオ フレームの高さと幅の最小値は、それぞれ 240 と 176 です。空間情報を犠牲にしてより多くのフレームを取得するには、[112, 112] と指定します。読み取るデータストアのフレーム サイズを最小値よりも大きく ([256, 256] など) 指定する場合、まず、imresize を使用してフレームのサイズを変更します。

frameSize = [112,112];

fileDatastore の読み取り関数が指定された入力サイズを読み取れるように、inputSizeinputStats 構造体に設定します。

inputSize = [frameSize, numFrames];
inputStats.inputSize = inputSize;
inputStats.Classes = classes;

2 つの FileDatastore オブジェクトを作成します。1 つは学習用、もう 1 つは検証用です。

isDataForValidation = false;
dsTrain = createFileDatastore(trainFilenames,inputStats,isDataForValidation);

isDataForValidation = true;
dsVal = createFileDatastore(testFilenames,inputStats,isDataForValidation);

disp("Training data size: " + string(numel(dsTrain.Files)))
Training data size: 436
disp("Validation data size: " + string(numel(dsVal.Files)))
Validation data size: 109

ネットワーク アーキテクチャの定義

I3D ネットワーク

3 次元 CNN の使用は、ビデオから時空間特徴を抽出するための自然なアプローチです。2 次元フィルターとプーリング カーネルを 3 次元に拡張することにより、事前学習済みの 2 次元イメージ分類ネットワーク (Inception v1 や ResNet-50 など) から I3D ネットワークを作成できます。この手順では、イメージ分類タスクから学習した重みを再利用し、ビデオ認識タスクをブートストラップします。

次の図は、2 次元畳み込み層を 3 次元畳み込み層に拡張する方法を示したサンプルです。この拡張では、3 番目の次元 (時間次元) を追加することにより、フィルター サイズ、重み、バイアスを拡張します。

2 ストリーム I3D ネットワーク

ビデオ データは、空間コンポーネントと時間コンポーネントという 2 つの部分をもつと見なすことができます。

  • 空間コンポーネントは、ビデオ内のオブジェクトの形状、テクスチャ、色に関する情報で構成されます。RGB データにはこの情報が含まれています。

  • 時間コンポーネントは、フレーム全体のオブジェクトのモーションに関する情報で構成され、カメラと、シーン内のオブジェクトの間の重要な動作を表します。オプティカル フローの計算は、ビデオから時間情報を抽出するための一般的な手法です。

2 ストリーム CNN には、空間サブネットワークと時間サブネットワークが組み込まれています [2]。密度の高いオプティカル フローとビデオ データ ストリームで学習させた畳み込みニューラル ネットワークは、スタックされた生の RGB フレームよりも、制限された学習データを使ってパフォーマンスを改善できます。次の図は、典型的な 2 ストリーム I3D ネットワークを表しています。

2 ストリーム I3D ネットワークの作成

この例では、ImageNet データベースで事前学習させたネットワークである GoogLeNet を使用し、I3D ネットワークを作成します。

RGB サブネットワークのチャネル数を 3、オプティカル フロー サブネットワークのチャネル数を 2 に指定します。オプティカル フロー データの 2 つのチャネルは、速度の x 成分と y 成分である VxVy をそれぞれ表します。

rgbChannels = 3;
flowChannels = 2;

inputStatistics.mat ファイルから読み込まれた inputStats 構造体の RGB データおよびオプティカル フロー データの最小値と最大値を取得します。これらの値は、I3D ネットワークの image3dInputLayer が入力データを正規化するために必要です。

rgbInputSize = [frameSize, numFrames, rgbChannels];
flowInputSize = [frameSize, numFrames, flowChannels];

rgbMin = inputStats.rgbMin;
rgbMax = inputStats.rgbMax;
oflowMin = inputStats.oflowMin(:,:,1:2);
oflowMax = inputStats.oflowMax(:,:,1:2);

rgbMin = reshape(rgbMin,[1,size(rgbMin)]);
rgbMax = reshape(rgbMax,[1,size(rgbMax)]);
oflowMin = reshape(oflowMin,[1,size(oflowMin)]);
oflowMax = reshape(oflowMax,[1,size(oflowMax)]);

ネットワークの学習を行うために、クラスの数を指定します。

numClasses = numel(classes);

この例に添付されているサポート関数 Inflated3D を使用し、I3D RGB サブネットワークとオプティカル フロー サブネットワークを作成します。サブネットワークは GoogLeNet から作成されます。

cnnNet = googlenet;

netRGB = Inflated3D(numClasses,rgbInputSize,rgbMin,rgbMax,cnnNet);
netFlow = Inflated3D(numClasses,flowInputSize,oflowMin,oflowMax,cnnNet);

各 I3D ネットワークの層グラフから dlnetwork オブジェクトを作成します。

dlnetRGB = dlnetwork(netRGB);
dlnetFlow = dlnetwork(netFlow);

モデル勾配関数の定義

この例の最後にリストされているサポート関数 modelGradients を作成します。関数 modelGradients は、RGB サブネットワーク dlnetRGB、オプティカル フロー サブネットワーク dlnetFlow、入力データ dlRGB および dlFlow のミニバッチ、グラウンド トゥルース ラベル データ dlY のミニバッチを入力として受け取ります。関数は、学習損失値、それぞれのサブネットワークの学習可能なパラメーターについての損失の勾配、サブネットワークのミニバッチの精度を返します。

損失は、各サブネットワークから得られる予測の交差エントロピー損失の平均を求めることによって計算されます。ネットワークの出力予測は、各クラスについて 0 ~ 1 の確率となります。

rgbLoss=crossentropy(rgbPrediction)

flowLoss=crossentropy(flowPrediction)

loss=mean([rgbLoss,flowLoss])

各サブネットワークの精度は、RGB およびオプティカル フローの予測の平均を受け取り、それを入力のグラウンド トゥルース ラベルと比較することによって計算されます。

学習オプションの指定

ミニバッチ サイズを 20、反復回数を 1500 として学習させます。SaveBestAfterIteration パラメーターを使用し、最大の検証精度でモデルを保存するまでの反復回数を指定します。

コサインアニーリング学習率スケジュール [3] パラメーターを指定します。両方のネットワークで次を使用します。

  • 最小学習率として 1e-4。

  • 最大学習率として 1e-3。

  • 学習率スケジュール サイクルが再開するまでのコサインの反復数として 300、500、700。オプション CosineNumIterations では、各コサイン サイクルの幅を定義します。

SGDM 最適化のパラメーターを指定します。RGB ネットワークとオプティカル フロー ネットワークのそれぞれについて、学習開始時に SGDM 最適化パラメーターを初期化します。両方のネットワークで次を使用します。

  • モーメンタムとして 0.9。

  • [] として初期化された初期速度パラメーター。

  • L2 正則化係数として 0.0005。

並列プールを使用し、バックグラウンドでデータのディスパッチを実行するように指定します。DispatchInBackground が true に設定されている場合、並列ワーカーの指定された数で並列プールを開き、この例の一部として提供される DispatchInBackgroundDatastore を作成します。そして、非同期のデータの読み込みと前処理を使用して学習を高速化するために、バックグラウンドでデータのディスパッチを行います。既定では、この例は利用可能な GPU がある場合にそれを使用します。そうでない場合は CPU が使用されます。GPU を使用するには、Parallel Computing Toolbox™、および CUDA® 対応の NVIDIA® GPU が必要です。サポートされる Compute Capability の詳細については、リリース別の GPU サポート (Parallel Computing Toolbox)を参照してください。

params.Classes = classes;
params.MiniBatchSize = 20;
params.NumIterations = 1500;
params.SaveBestAfterIteration = 900;
params.CosineNumIterations = [300, 500, 700];
params.MinLearningRate = 1e-4;
params.MaxLearningRate = 1e-3;
params.Momentum = 0.9;
params.VelocityRGB = [];
params.VelocityFlow = [];
params.L2Regularization = 0.0005;
params.ProgressPlot = false;
params.Verbose = true;
params.ValidationData = dsVal;
params.DispatchInBackground = false;
params.NumWorkers = 4;

ネットワークの学習

RGB データとオプティカル フロー データを使用し、サブネットワークに学習させます。変数 doTrainingfalse に設定し、学習の完了を待つことなく、事前学習済みのサブネットワークをダウンロードします。あるいは、サブネットワークに学習させる場合は、変数 doTrainingtrue に設定します。

doTraining = false;

各エポックで次を行います。

  • データのミニバッチをループ処理する前にデータをシャッフルします。

  • minibatchqueue を使用してミニバッチをループ処理します。この例の最後にリストされているサポート関数 createMiniBatchQueue は、指定された学習データストアを使用して minibatchqueue を作成します。

  • 検証データ dsVal を使用してネットワークを検証します。

  • この例の最後にリストされているサポート関数 displayVerboseOutputEveryEpoch を使用し、各エポックの損失と精度の結果を表示します。

各ミニバッチで次を行います。

  • イメージ データまたはオプティカル フロー データとラベルを、基となる型が single の dlarray オブジェクトに変換します。

  • ビデオとオプティカル フロー データの時間次元を空間次元の 1 つとして扱い、3 次元 CNN を使用した処理を有効にします。RGB データまたはオプティカル フロー データに次元ラベル "SSSCB" (spatial、spatial、spatial、channel、batch) を指定し、ラベル データに "CB" を指定します。

minibatchqueue オブジェクトは、この例の最後にリストされているサポート関数 batchRGBAndFlow を使用し、RGB データとオプティカル フロー データをバッチ処理します。

modelFilename = "I3D-RGBFlow-" + numClasses + "Classes-hmdb51.mat";
if doTraining 
    epoch = 1;
    bestValAccuracy = 0;
    accTrain = [];
    accTrainRGB = [];
    accTrainFlow = [];
    lossTrain = [];
        
    iteration = 1;
    shuffled = shuffleTrainDs(dsTrain);
    
    % Number of outputs is three: One for RGB frames, one for optical flow
    % data, and one for ground truth labels.
    numOutputs = 3;
    mbq = createMiniBatchQueue(shuffled, numOutputs, params);
    start = tic;
    trainTime = start;
    
    % Use the initializeTrainingProgressPlot and initializeVerboseOutput
    % supporting functions, listed at the end of the example, to initialize
    % the training progress plot and verbose output to display the training
    % loss, training accuracy, and validation accuracy.
    plotters = initializeTrainingProgressPlot(params);
    initializeVerboseOutput(params);
    
    while iteration <= params.NumIterations

        % Iterate through the data set.
        [dlX1,dlX2,dlY] = next(mbq);

        % Evaluate the model gradients and loss using dlfeval.
        [gradRGB,gradFlow,loss,acc,accRGB,accFlow,stateRGB,stateFlow] = ...
            dlfeval(@modelGradients,dlnetRGB,dlnetFlow,dlX1,dlX2,dlY);
        
        % Accumulate the loss and accuracies.
        lossTrain = [lossTrain, loss];
        accTrain = [accTrain, acc];
        accTrainRGB = [accTrainRGB, accRGB];
        accTrainFlow = [accTrainFlow, accFlow];
        % Update the network state.
        dlnetRGB.State = stateRGB;
        dlnetFlow.State = stateFlow;
        
        % Update the gradients and parameters for the RGB and optical flow
        % subnetworks using the SGDM optimizer.
        [dlnetRGB,gradRGB,params.VelocityRGB,learnRate] = ...
            updateDlNetwork(dlnetRGB,gradRGB,params,params.VelocityRGB,iteration);
        [dlnetFlow,gradFlow,params.VelocityFlow] = ...
            updateDlNetwork(dlnetFlow,gradFlow,params,params.VelocityFlow,iteration);
        
        if ~hasdata(mbq) || iteration == params.NumIterations
            % Current epoch is complete. Do validation and update progress.
            trainTime = toc(trainTime);

            [validationTime,cmat,lossValidation,accValidation,accValidationRGB,accValidationFlow] = ...
                doValidation(params, dlnetRGB, dlnetFlow);

            % Update the training progress.
            displayVerboseOutputEveryEpoch(params,start,learnRate,epoch,iteration,...
                mean(accTrain),mean(accTrainRGB),mean(accTrainFlow),...
                accValidation,accValidationRGB,accValidationFlow,...
                mean(lossTrain),lossValidation,trainTime,validationTime);
            updateProgressPlot(params,plotters,epoch,iteration,start,mean(lossTrain),mean(accTrain),accValidation);
            
            % Save model with the trained dlnetwork and accuracy values.
            % Use the saveData supporting function, listed at the
            % end of this example.
            if iteration >= params.SaveBestAfterIteration
                if accValidation > bestValAccuracy
                    bestValAccuracy = accValidation;
                    saveData(modelFilename, dlnetRGB, dlnetFlow, cmat, accValidation);
                end
            end
        end
        
        if ~hasdata(mbq) && iteration < params.NumIterations
            % Current epoch is complete. Initialize the training loss, accuracy
            % values, and minibatchqueue for the next epoch.
            accTrain = [];
            accTrainRGB = [];
            accTrainFlow = [];
            lossTrain = [];
        
            trainTime = tic;
            epoch = epoch + 1;
            shuffled = shuffleTrainDs(dsTrain);
            numOutputs = 3;
            mbq = createMiniBatchQueue(shuffled, numOutputs, params);
            
        end 
        
        iteration = iteration + 1;
    end
    
    % Display a message when training is complete.
    endVerboseOutput(params);
    
    disp("Model saved to: " + modelFilename);
end

% Download the pretrained model and video file for prediction.
filename = "activityRecognition-I3D-HMDB51.zip";
downloadURL = "https://ssd.mathworks.com/supportfiles/vision/data/" + filename;

filename = fullfile(downloadFolder,filename);
if ~exist(filename,'file')
    disp('Downloading the pretrained network...');
    websave(filename,downloadURL);
end
% Unzip the contents to the download folder.
unzip(filename,downloadFolder);
if ~doTraining
    modelFilename = fullfile(downloadFolder, modelFilename);
end

学習済みネットワークの評価

テスト データ セットを使用し、学習済みのサブネットワークの精度を評価します。

学習中に保存された最適なモデルを読み込みます。

d = load(modelFilename);
dlnetRGB = d.data.dlnetRGB;
dlnetFlow = d.data.dlnetFlow;

minibatchqueue オブジェクトを作成し、テスト データのバッチを読み込みます。

numOutputs = 3;
mbq = createMiniBatchQueue(params.ValidationData, numOutputs, params);

テスト データの各バッチについて、RGB ネットワークとオプティカル フロー ネットワークを使用して予測を行い、予測の平均を求め、混同行列を使用して予測精度を計算します。

cmat = sparse(numClasses,numClasses);
while hasdata(mbq)
    [dlRGB, dlFlow, dlY] = next(mbq);
    
    % Pass the video input as RGB and optical flow data through the
    % two-stream subnetworks to get the separate predictions.
    dlYPredRGB = predict(dlnetRGB,dlRGB);
    dlYPredFlow = predict(dlnetFlow,dlFlow);

    % Fuse the predictions by calculating the average of the predictions.
    dlYPred = (dlYPredRGB + dlYPredFlow)/2;
    
    % Calculate the accuracy of the predictions.
    [~,YTest] = max(dlY,[],1);
    [~,YPred] = max(dlYPred,[],1);

    cmat = aggregateConfusionMetric(cmat,YTest,YPred);
end

学習済みのネットワークの平均分類精度を計算します。

accuracyEval = sum(diag(cmat))./sum(cmat,"all")
accuracyEval = 
      0.60909

混同行列を表示します。

figure
chart = confusionchart(cmat,classes);

学習サンプルの数が少ないため、精度を 61% より高くすることは困難です。ネットワークのロバスト性を向上させるには、大規模なデータ セットでさらに学習を行う必要があります。また、Kinetics [1] などの大規模なデータ セットによる事前学習は、結果を改善するのに役立ちます。

新しいビデオを使用した予測

学習済みのネットワークを使用し、新しいビデオのアクションを予測できるようになりました。VideoReadervision.VideoPlayer を使用し、ビデオ pour.avi を読み取って表示します。

videoFilename = fullfile(downloadFolder, "pour.avi");

videoReader = VideoReader(videoFilename);
videoPlayer = vision.VideoPlayer;
videoPlayer.Name = "pour";

while hasFrame(videoReader)
   frame = readFrame(videoReader);
   step(videoPlayer,frame);
end
release(videoPlayer);

この例の最後にリストされているサポート関数 readRGBAndFlow を使用し、RGB データとオプティカル フロー データを読み取ります。

isDataForValidation = true;
readFcn = @(f,u)readRGBAndFlow(f,u,inputStats,isDataForValidation);

この読み取り関数は、ファイルから読み取るデータがさらに存在するかどうかを示す logical isDone 値を返します。この例の終わりに定義されているサポート関数 batchRGBAndFlow を使用し、2 ストリーム サブネットワークを通過するデータをバッチ処理して予測を求めます。

hasdata = true;
userdata = [];
YPred = [];
while hasdata
    [data,userdata,isDone] = readFcn(videoFilename,userdata);
    
    [dlRGB, dlFlow] = batchRGBAndFlow(data(:,1),data(:,2),data(:,3));
    
    % Pass video input as RGB and optical flow data through the two-stream
    % subnetworks to get the separate predictions.
    dlYPredRGB = predict(dlnetRGB,dlRGB);
    dlYPredFlow = predict(dlnetFlow,dlFlow);

    % Fuse the predictions by calculating the average of the predictions.
    dlYPred = (dlYPredRGB + dlYPredFlow)/2;
    [~,YPredCurr] = max(dlYPred,[],1);
    YPred = horzcat(YPred,YPredCurr);
    hasdata = ~isDone;
end
YPred = extractdata(YPred);

histcounts を使用して正しい予測の数をカウントし、正しい予測の最大数を使用して予測されたアクションを取得します。

classes = params.Classes;
counts = histcounts(YPred,1:numel(classes));
[~,clsIdx] = max(counts);
action = classes(clsIdx)
action = 
"pour"

サポート関数

inputStatistics

関数 inputStatistics は、HMDB51 データを格納するフォルダーの名前を入力として受け取り、RGB データとオプティカル フロー データの最小値と最大値を計算します。最小値と最大値は、ネットワークの入力層への正規化入力として使用されます。この関数は、ネットワークの学習中およびテスト中に、後で使用する各ビデオ ファイルのフレーム数も取得します。異なるデータ セットの最小値および最大値を見つけるには、データ セットを格納しているフォルダー名でこの関数を使用します。

function inputStats = inputStatistics(dataFolder)
    ds = createDatastore(dataFolder);
    ds.ReadFcn = @getMinMax;

    tic;
    tt = tall(ds);
    varnames = {'rgbMax','rgbMin','oflowMax','oflowMin'};
    stats = gather(groupsummary(tt,[],{'max','min'}, varnames));
    inputStats.Filename = gather(tt.Filename);
    inputStats.NumFrames = gather(tt.NumFrames);
    inputStats.rgbMax = stats.max_rgbMax;
    inputStats.rgbMin = stats.min_rgbMin;
    inputStats.oflowMax = stats.max_oflowMax;
    inputStats.oflowMin = stats.min_oflowMin;
    save('inputStatistics.mat','inputStats');
    toc;
end

function data = getMinMax(filename)
    reader = VideoReader(filename);
    opticFlow = opticalFlowFarneback;
    data = [];
    while hasFrame(reader)
        frame = readFrame(reader);
        [rgb,oflow] = findMinMax(frame,opticFlow);
        data = assignMinMax(data, rgb, oflow);
    end

    totalFrames = floor(reader.Duration * reader.FrameRate);
    totalFrames = min(totalFrames, reader.NumFrames);
    
    [labelName, filename] = getLabelFilename(filename);
    data.Filename = fullfile(labelName, filename);
    data.NumFrames = totalFrames;

    data = struct2table(data,'AsArray',true);
end

function data = assignMinMax(data, rgb, oflow)
    if isempty(data)
        data.rgbMax = rgb.Max;
        data.rgbMin = rgb.Min;
        data.oflowMax = oflow.Max;
        data.oflowMin = oflow.Min;
        return;
    end
    data.rgbMax = max(data.rgbMax, rgb.Max);
    data.rgbMin = min(data.rgbMin, rgb.Min);

    data.oflowMax = max(data.oflowMax, oflow.Max);
    data.oflowMin = min(data.oflowMin, oflow.Min);
end

function [rgbMinMax,oflowMinMax] = findMinMax(rgb, opticFlow)
    rgbMinMax.Max = max(rgb,[],[1,2]);
    rgbMinMax.Min = min(rgb,[],[1,2]);

    gray = rgb2gray(rgb);
    flow = estimateFlow(opticFlow,gray);
    oflow = cat(3,flow.Vx,flow.Vy,flow.Magnitude);

    oflowMinMax.Max = max(oflow,[],[1,2]);
    oflowMinMax.Min = min(oflow,[],[1,2]);
end

function ds = createDatastore(folder)    
    ds = fileDatastore(folder,...
        'IncludeSubfolders', true,...
        'FileExtensions', '.avi',...
        'UniformRead', true,...
        'ReadFcn', @getMinMax);
    disp("NumFiles: " + numel(ds.Files));
end

createFileDatastore

関数 createFileDatastore は、指定されたファイル名を使用して FileDatastore オブジェクトを作成します。FileDatastore オブジェクトは 'partialfile' モードでデータを読み取るため、すべての読み取りにおいてビデオから部分的に読み取ったフレームを返すことができます。この特徴は、大きなビデオ ファイルの読み取りで、フレームのすべてをメモリに取り込めない場合に役立ちます。

function datastore = createFileDatastore(filenames,inputStats,isDataForValidation)
    readFcn = @(f,u)readRGBAndFlow(f,u,inputStats,isDataForValidation);
    datastore = fileDatastore(filenames,...
        'ReadFcn',readFcn,...
        'ReadMode','partialfile');
end

readRGBAndFlow

関数 readRGBAndFlow は、RGB フレーム、対応するオプティカル フロー データ、指定されたビデオ ファイルのラベル値を読み取ります。学習中、読み取り関数は、ランダムに選択された開始フレームを使用し、ネットワークの入力サイズに従って特定の数のフレームを読み取ります。オプティカル フロー データは、ビデオ ファイルの最初から計算されますが、開始フレームに到達するまではスキップされます。テスト中、すべてのフレームは順番に読み取られ、対応するオプティカル フロー データが計算されます。RGB フレームとオプティカル フロー データは、学習用には必要なネットワーク入力サイズになるようランダムにトリミングされ、テストおよび検証用には中心がトリミングされます。

function [data,userdata,done] = readRGBAndFlow(filename,userdata,inputStats,isDataForValidation)
    if isempty(userdata)
        userdata.reader      = VideoReader(filename);
        userdata.batchesRead = 0;
        userdata.opticalFlow = opticalFlowFarneback;
        
        [totalFrames,userdata.label] = getTotalFramesAndLabel(inputStats,filename);
        if isempty(totalFrames)
            totalFrames = floor(userdata.reader.Duration * userdata.reader.FrameRate);
            totalFrames = min(totalFrames, userdata.reader.NumFrames);
        end
        userdata.totalFrames = totalFrames;
    end
    reader      = userdata.reader;
    totalFrames = userdata.totalFrames;
    label       = userdata.label;
    batchesRead = userdata.batchesRead;
    opticalFlow = userdata.opticalFlow;

    inputSize = inputStats.inputSize;
    H = inputSize(1);
    W = inputSize(2);
    rgbC = 3;
    flowC = 2;
    numFrames = inputSize(3);

    if numFrames > totalFrames
        numBatches = 1;
    else
        numBatches = floor(totalFrames/numFrames);
    end

    imH = userdata.reader.Height;
    imW = userdata.reader.Width;
    imsz = [imH,imW];

    if ~isDataForValidation

        augmentFcn = augmentTransform([imsz,3]);
        cropWindow = randomCropWindow2d(imsz, inputSize(1:2));
        %  1. Randomly select required number of frames,
        %     starting randomly at a specific frame.
        if numFrames >= totalFrames
            idx = 1:totalFrames;
            % Add more frames to fill in the network input size.
            additional = ceil(numFrames/totalFrames);
            idx = repmat(idx,1,additional);
            idx = idx(1:numFrames);
        else
            startIdx = randperm(totalFrames - numFrames);
            startIdx = startIdx(1);
            endIdx = startIdx + numFrames - 1;
            idx = startIdx:endIdx;
        end

        video = zeros(H,W,rgbC,numFrames);
        oflow = zeros(H,W,flowC,numFrames);
        i = 1;
        % Discard the first set of frames to initialize the optical flow.
        for ii = 1:idx(1)-1
            frame = read(reader,ii);
            getRGBAndFlow(frame,opticalFlow,augmentFcn,cropWindow);
        end
        % Read the next set of required number of frames for training.
        for ii = idx
            frame = read(reader,ii);
            [rgb,vxvy] = getRGBAndFlow(frame,opticalFlow,augmentFcn,cropWindow);
            video(:,:,:,i) = rgb;
            oflow(:,:,:,i) = vxvy;
            i = i + 1;
        end
    else
        augmentFcn = @(data)(data);
        cropWindow = centerCropWindow2d(imsz, inputSize(1:2));
        toRead = min([numFrames,totalFrames]);
        video = zeros(H,W,rgbC,toRead);
        oflow = zeros(H,W,flowC,toRead);
        i = 1;
        while hasFrame(reader) && i <= numFrames
            frame = readFrame(reader);
            [rgb,vxvy] = getRGBAndFlow(frame,opticalFlow,augmentFcn,cropWindow);
            video(:,:,:,i) = rgb;
            oflow(:,:,:,i) = vxvy;
            i = i + 1;
        end
        if numFrames > totalFrames
            additional = ceil(numFrames/totalFrames);
            video = repmat(video,1,1,1,additional);
            oflow = repmat(oflow,1,1,1,additional);
            video = video(:,:,:,1:numFrames);
            oflow = oflow(:,:,:,1:numFrames);
        end
    end

    % The network expects the video and optical flow input in 
    % the following dlarray format: 
    % "SSSCB" ==> Height x Width x Frames x Channels x Batch
    %
    % Permute the data 
    %  from
    %      Height x Width x Channels x Frames
    %  to 
    %      Height x Width x Frames x Channels
    video = permute(video, [1,2,4,3]);
    oflow = permute(oflow, [1,2,4,3]);

    data = {video, oflow, label};

    batchesRead = batchesRead + 1;

    userdata.batchesRead = batchesRead;

    % Set the done flag to true, if the reader has read all the frames or
    % if it is training.
    done = batchesRead == numBatches || ~isDataForValidation;
end

function [rgb,vxvy] = getRGBAndFlow(rgb,opticalFlow,augmentFcn,cropWindow)
    rgb = augmentFcn(rgb);
    gray = rgb2gray(rgb);
    flow = estimateFlow(opticalFlow,gray);
    vxvy = cat(3,flow.Vx,flow.Vy,flow.Vy);

    rgb = imcrop(rgb, cropWindow);
    vxvy = imcrop(vxvy, cropWindow);
    vxvy = vxvy(:,:,1:2);
end

function [label,fname] = getLabelFilename(filename)
    [folder,name,ext] = fileparts(string(filename));
    [~,label] = fileparts(folder);
    fname = name + ext;
    label = string(label);
    fname = string(fname);
end

function [totalFrames,label] = getTotalFramesAndLabel(info, filename)
    filenames = info.Filename;
    frames = info.NumFrames;
    [labelName, fname] = getLabelFilename(filename);
    idx = strcmp(filenames, fullfile(labelName,fname));
    totalFrames = frames(idx);    
    label = categorical(string(labelName), string(info.Classes));
end

augmentTransform

関数 augmentTransform は、ランダムな左右反転係数とスケーリング係数を使った拡張方法を作成します。

function augmentFcn = augmentTransform(sz)
% Randomly flip and scale the image.
tform = randomAffine2d('XReflection',true,'Scale',[1 1.1]);
rout = affineOutputView(sz,tform,'BoundsStyle','CenterOutput');

augmentFcn = @(data)augmentData(data,tform,rout);

    function data = augmentData(data,tform,rout)
        data = imwarp(data,tform,'OutputView',rout);
    end
end

modelGradients

関数 modelGradients は、RGB データ dlRGB のミニバッチ、対応するオプティカル フロー データ dlFlow、および対応するターゲット dlY を入力として受け取り、対応する損失、学習可能なパラメーターについての損失の勾配、および学習精度を返します。勾配を計算するため、学習ループの中で関数 dlfeval を使用して、関数 modelGradients を評価します。

function [gradientsRGB,gradientsFlow,loss,acc,accRGB,accFlow,stateRGB,stateFlow] = modelGradients(dlnetRGB,dlnetFlow,dlRGB,dlFlow,Y)

% Pass video input as RGB and optical flow data through the two-stream
% network.
[dlYPredRGB,stateRGB] = forward(dlnetRGB,dlRGB);
[dlYPredFlow,stateFlow] = forward(dlnetFlow,dlFlow);

% Calculate fused loss, gradients, and accuracy for the two-stream
% predictions.
rgbLoss = crossentropy(dlYPredRGB,Y);
flowLoss = crossentropy(dlYPredFlow,Y);
% Fuse the losses.
loss = mean([rgbLoss,flowLoss]);

gradientsRGB = dlgradient(loss,dlnetRGB.Learnables);
gradientsFlow = dlgradient(loss,dlnetFlow.Learnables);

% Fuse the predictions by calculating the average of the predictions.
dlYPred = (dlYPredRGB + dlYPredFlow)/2;

% Calculate the accuracy of the predictions.
[~,YTest] = max(Y,[],1);
[~,YPred] = max(dlYPred,[],1);

acc = gather(extractdata(sum(YTest == YPred)./numel(YTest)));

% Calculate the accuracy of the RGB and flow predictions.
[~,YTest] = max(Y,[],1);
[~,YPredRGB] = max(dlYPredRGB,[],1);
[~,YPredFlow] = max(dlYPredFlow,[],1);

accRGB = gather(extractdata(sum(YTest == YPredRGB)./numel(YTest)));
accFlow = gather(extractdata(sum(YTest == YPredFlow)./numel(YTest)));
end

doValidation

関数 doValidation は、検証データを使用してネットワークを検証します。

function [validationTime, cmat, lossValidation, accValidation, accValidationRGB, accValidationFlow] = doValidation(params, dlnetRGB, dlnetFlow)

    validationTime = tic;

    numOutputs = 3;
    mbq = createMiniBatchQueue(params.ValidationData, numOutputs, params);

    lossValidation = [];
    numClasses = numel(params.Classes);
    cmat = sparse(numClasses,numClasses);
    cmatRGB = sparse(numClasses,numClasses);
    cmatFlow = sparse(numClasses,numClasses);
    while hasdata(mbq)

        [dlX1,dlX2,dlY] = next(mbq);

        [loss,YTest,YPred,YPredRGB,YPredFlow] = predictValidation(dlnetRGB,dlnetFlow,dlX1,dlX2,dlY);

        lossValidation = [lossValidation,loss];
        cmat = aggregateConfusionMetric(cmat,YTest,YPred);
        cmatRGB = aggregateConfusionMetric(cmatRGB,YTest,YPredRGB);
        cmatFlow = aggregateConfusionMetric(cmatFlow,YTest,YPredFlow);
    end
    lossValidation = mean(lossValidation);
    accValidation = sum(diag(cmat))./sum(cmat,"all");
    accValidationRGB = sum(diag(cmatRGB))./sum(cmatRGB,"all");
    accValidationFlow = sum(diag(cmatFlow))./sum(cmatFlow,"all");

    validationTime = toc(validationTime);
end

predictValidation

関数 predictValidation は、RGB データとオプティカル フロー データに対して、指定された dlnetwork オブジェクトを使用して損失と予測値を計算します。

function [loss,YTest,YPred,YPredRGB,YPredFlow] = predictValidation(dlnetRGB,dlnetFlow,dlRGB,dlFlow,Y)

% Pass the video input through the two-stream
% network.
dlYPredRGB = predict(dlnetRGB,dlRGB);
dlYPredFlow = predict(dlnetFlow,dlFlow);

% Calculate the cross-entropy separately for the two-stream
% outputs.
rgbLoss = crossentropy(dlYPredRGB,Y);
flowLoss = crossentropy(dlYPredFlow,Y);

% Fuse the losses.
loss = mean([rgbLoss,flowLoss]);

% Fuse the predictions by calculating the average of the predictions.
dlYPred = (dlYPredRGB + dlYPredFlow)/2;

% Calculate the accuracy of the predictions.
[~,YTest] = max(Y,[],1);
[~,YPred] = max(dlYPred,[],1);

[~,YPredRGB] = max(dlYPredRGB,[],1);
[~,YPredFlow] = max(dlYPredFlow,[],1);

end

updateDlnetwork

関数 updateDlnetwork は、SGDM 最適化関数 sgdmupdate を使用して、勾配や他のパラメーターを含む指定された dlnetwork オブジェクトを更新します。

function [dlnet,gradients,velocity,learnRate] = updateDlNetwork(dlnet,gradients,params,velocity,iteration)
    % Determine the learning rate using the cosine-annealing learning rate schedule.
    learnRate = cosineAnnealingLearnRate(iteration, params);

    % Apply L2 regularization to the weights.
    idx = dlnet.Learnables.Parameter == "Weights";
    gradients(idx,:) = dlupdate(@(g,w) g + params.L2Regularization*w, gradients(idx,:), dlnet.Learnables(idx,:));

    % Update the network parameters using the SGDM optimizer.
    [dlnet, velocity] = sgdmupdate(dlnet, gradients, velocity, learnRate, params.Momentum);
end

cosineAnnealingLearnRate

関数 cosineAnnealingLearnRate は、現在の反復回数、最小学習率、最大学習率、およびアニーリングの反復回数に基づいて学習率を計算します [3]。

function lr = cosineAnnealingLearnRate(iteration, params)
    if iteration == params.NumIterations
        lr = params.MinLearningRate;
        return;
    end
    cosineNumIter = [0, params.CosineNumIterations];
    csum = cumsum(cosineNumIter);
    block = find(csum >= iteration, 1,'first');
    cosineIter = iteration - csum(block - 1);
    annealingIteration = mod(cosineIter, cosineNumIter(block));
    cosineIteration = cosineNumIter(block);
    minR = params.MinLearningRate;
    maxR = params.MaxLearningRate;
    cosMult = 1 + cos(pi * annealingIteration / cosineIteration);
    lr = minR + ((maxR - minR) *  cosMult / 2);
end

aggregateConfusionMetric

関数 aggregateConfusionMetric は、予測される結果 YPred と期待される結果 YTest をインクリメンタルに混同行列に追加します。

function cmat = aggregateConfusionMetric(cmat,YTest,YPred)
YTest = gather(extractdata(YTest));
YPred = gather(extractdata(YPred));
[m,n] = size(cmat);
cmat = cmat + full(sparse(YTest,YPred,1,m,n));
end

createMiniBatchQueue

関数 createMiniBatchQueue は、指定されたデータストアからのデータ量 miniBatchSize を提供する minibatchqueue オブジェクトを作成します。また、並列プールが開いている場合は、DispatchInBackgroundDatastore も作成します。

function mbq = createMiniBatchQueue(datastore, numOutputs, params)
if params.DispatchInBackground && isempty(gcp('nocreate'))
    % Start a parallel pool, if DispatchInBackground is true, to dispatch
    % data in the background using the parallel pool.
    c = parcluster('local');
    c.NumWorkers = params.NumWorkers;
    parpool('local',params.NumWorkers);
end
p = gcp('nocreate');
if ~isempty(p)
    datastore = DispatchInBackgroundDatastore(datastore, p.NumWorkers);
end
inputFormat(1:numOutputs-1) = "SSSCB";
outputFormat = "CB";
mbq = minibatchqueue(datastore, numOutputs, ...
    "MiniBatchSize", params.MiniBatchSize, ...
    "MiniBatchFcn", @batchRGBAndFlow, ...
    "MiniBatchFormat", [inputFormat,outputFormat]);
end

batchRGBAndFlow

関数 batchRGBAndFlow は、イメージ データ、フロー データ、ラベル データをバッチ処理し、それぞれデータ形式 "SSSCB""SSSCB""CB" で、対応する dlarray 値にします。

function [dlX1,dlX2,dlY] = batchRGBAndFlow(images, flows, labels)
% Batch dimension: 5
X1 = cat(5,images{:});
X2 = cat(5,flows{:});

% Batch dimension: 2
labels = cat(2,labels{:});

% Feature dimension: 1
Y = onehotencode(labels,1);

% Cast data to single for processing.
X1 = single(X1);
X2 = single(X2);
Y = single(Y);

% Move data to the GPU if possible.
if canUseGPU
    X1 = gpuArray(X1);
    X2 = gpuArray(X2);
    Y = gpuArray(Y);
end

% Return X and Y as dlarray objects.
dlX1 = dlarray(X1,"SSSCB");
dlX2 = dlarray(X2,"SSSCB");
dlY = dlarray(Y,"CB");
end

shuffleTrainDs

関数 shuffleTrainDs は、学習データストア dsTrain に存在するファイルをシャッフルします。

function shuffled = shuffleTrainDs(dsTrain)
shuffled = copy(dsTrain);
n = numel(shuffled.Files);
shuffledIndices = randperm(n);
shuffled.Files = shuffled.Files(shuffledIndices);
reset(shuffled);
end

saveData

関数 saveData は、指定された dlnetwork オブジェクトと精度の値を MAT ファイルに保存します。

function saveData(modelFilename, dlnetRGB, dlnetFlow, cmat, accValidation)
dlnetRGB = gatherFromGPUToSave(dlnetRGB);
dlnetFlow = gatherFromGPUToSave(dlnetFlow);
data.ValidationAccuracy = accValidation;
data.cmat = cmat;
data.dlnetRGB = dlnetRGB;
data.dlnetFlow = dlnetFlow;
save(modelFilename, 'data');
end

gatherFromGPUToSave

関数 gatherFromGPUToSave は、モデルをディスクに保存するために GPU からデータを収集します。

function dlnet = gatherFromGPUToSave(dlnet)
if ~canUseGPU
    return;
end
dlnet.Learnables = gatherValues(dlnet.Learnables);
dlnet.State = gatherValues(dlnet.State);
    function tbl = gatherValues(tbl)
        for ii = 1:height(tbl)
            tbl.Value{ii} = gather(tbl.Value{ii});
        end
    end
end

checkForHMDB51Folder

関数 checkForHMDB51Folder は、ダウンロード フォルダーにあるダウンロード済みのデータをチェックします。

function classes = checkForHMDB51Folder(dataLoc)
hmdbFolder = fullfile(dataLoc, "hmdb51_org");
if ~exist(hmdbFolder, "dir")
    error("Download 'hmdb51_org.rar' file using the supporting function 'downloadHMDB51' before running the example and extract the RAR file.");    
end

classes = ["brush_hair","cartwheel","catch","chew","clap","climb","climb_stairs",...
    "dive","draw_sword","dribble","drink","eat","fall_floor","fencing",...
    "flic_flac","golf","handstand","hit","hug","jump","kick","kick_ball",...
    "kiss","laugh","pick","pour","pullup","punch","push","pushup","ride_bike",...
    "ride_horse","run","shake_hands","shoot_ball","shoot_bow","shoot_gun",...
    "sit","situp","smile","smoke","somersault","stand","swing_baseball","sword",...
    "sword_exercise","talk","throw","turn","walk","wave"];
expectFolders = fullfile(hmdbFolder, classes);
if ~all(arrayfun(@(x)exist(x,'dir'),expectFolders))
    error("Download hmdb51_org.rar using the supporting function 'downloadHMDB51' before running the example and extract the RAR file.");
end
end

downloadHMDB51

関数 downloadHMDB51 は、データ セットをダウンロードしてディレクトリに保存します。

function downloadHMDB51(dataLoc)

if nargin == 0
    dataLoc = pwd;
end
dataLoc = string(dataLoc);

if ~exist(dataLoc,"dir")
    mkdir(dataLoc);
end

dataUrl     = "http://serre-lab.clps.brown.edu/wp-content/uploads/2013/10/hmdb51_org.rar";
options     = weboptions('Timeout', Inf);
rarFileName = fullfile(dataLoc, 'hmdb51_org.rar');
fileExists  = exist(rarFileName, 'file');

% Download the RAR file and save it to the download folder.
if ~fileExists
    disp("Downloading hmdb51_org.rar (2 GB) to the folder:")
    disp(dataLoc)
    disp("This download can take a few minutes...") 
    websave(rarFileName, dataUrl, options); 
    disp("Download complete.")
    disp("Extract the hmdb51_org.rar file contents to the folder: ") 
    disp(dataLoc)
end
end

initializeTrainingProgressPlot

関数 initializeTrainingProgressPlot は、学習損失、学習精度、および検証精度を表示する 2 つのプロットを構成します。

function plotters = initializeTrainingProgressPlot(params)
if params.ProgressPlot
    % Plot the loss, training accuracy, and validation accuracy.
    figure
    
    % Loss plot
    subplot(2,1,1)
    plotters.LossPlotter = animatedline;
    xlabel("Iteration")
    ylabel("Loss")
    
    % Accuracy plot
    subplot(2,1,2)
    plotters.TrainAccPlotter = animatedline('Color','b');
    plotters.ValAccPlotter = animatedline('Color','g');
    legend('Training Accuracy','Validation Accuracy','Location','northwest');
    xlabel("Iteration")
    ylabel("Accuracy")
else
    plotters = [];
end
end

initializeVerboseOutput

関数 initializeVerboseOutput は、学習値のテーブルの列見出しを表示します。列見出しは、エポック、ミニバッチ精度、その他の学習値を示します。

function initializeVerboseOutput(params)
if params.Verbose
    disp(" ")
    if canUseGPU
        disp("Training on GPU.")
    else
        disp("Training on CPU.")
    end
    p = gcp('nocreate');
    if ~isempty(p)
        disp("Training on parallel cluster '" + p.Cluster.Profile + "'. ")
    end
    disp("NumIterations:" + string(params.NumIterations));
    disp("MiniBatchSize:" + string(params.MiniBatchSize));
    disp("Classes:" + join(string(params.Classes), ","));    
    disp("|=======================================================================================================================================================================|")
    disp("| Epoch | Iteration | Time Elapsed |     Mini-Batch Accuracy    |    Validation Accuracy     | Mini-Batch | Validation |  Base Learning  | Train Time | Validation Time |")
    disp("|       |           |  (hh:mm:ss)  |       (Avg:RGB:Flow)       |       (Avg:RGB:Flow)       |    Loss    |    Loss    |      Rate       | (hh:mm:ss) |   (hh:mm:ss)    |")
    disp("|=======================================================================================================================================================================|")
end
end

displayVerboseOutputEveryEpoch

関数 displayVerboseOutputEveryEpoch は、エポック、ミニバッチ精度、検証精度、ミニバッチ損失など、学習値の詳細出力を表示します。

function displayVerboseOutputEveryEpoch(params,start,learnRate,epoch,iteration,...
        accTrain,accTrainRGB,accTrainFlow,accValidation,accValidationRGB,accValidationFlow,lossTrain,lossValidation,trainTime,validationTime)
    if params.Verbose
        D = duration(0,0,toc(start),'Format','hh:mm:ss');
        trainTime = duration(0,0,trainTime,'Format','hh:mm:ss');
        validationTime = duration(0,0,validationTime,'Format','hh:mm:ss');

        lossValidation = gather(extractdata(lossValidation));
        lossValidation = compose('%.4f',lossValidation);

        accValidation = composePadAccuracy(accValidation);
        accValidationRGB = composePadAccuracy(accValidationRGB);
        accValidationFlow = composePadAccuracy(accValidationFlow);

        accVal = join([accValidation,accValidationRGB,accValidationFlow], " : ");

        lossTrain = gather(extractdata(lossTrain));
        lossTrain = compose('%.4f',lossTrain);

        accTrain = composePadAccuracy(accTrain);
        accTrainRGB = composePadAccuracy(accTrainRGB);
        accTrainFlow = composePadAccuracy(accTrainFlow);

        accTrain = join([accTrain,accTrainRGB,accTrainFlow], " : ");
        learnRate = compose('%.13f',learnRate);

        disp("| " + ...
            pad(string(epoch),5,'both') + " | " + ...
            pad(string(iteration),9,'both') + " | " + ...
            pad(string(D),12,'both') + " | " + ...
            pad(string(accTrain),26,'both') + " | " + ...
            pad(string(accVal),26,'both') + " | " + ...
            pad(string(lossTrain),10,'both') + " | " + ...
            pad(string(lossValidation),10,'both') + " | " + ...
            pad(string(learnRate),13,'both') + " | " + ...
            pad(string(trainTime),10,'both') + " | " + ...
            pad(string(validationTime),15,'both') + " |")
    end

end

function acc = composePadAccuracy(acc)
    acc = compose('%.2f',acc*100) + "%";
    acc = pad(string(acc),6,'left');
end

endVerboseOutput

関数 endVerboseOutput は、学習中の詳細出力の末尾を表示します。

function endVerboseOutput(params)
if params.Verbose
    disp("|=======================================================================================================================================================================|")        
end
end

updateProgressPlot

関数 updateProgressPlot は、学習中の損失と精度の情報で進行状況プロットを更新します。

function updateProgressPlot(params,plotters,epoch,iteration,start,lossTrain,accuracyTrain,accuracyValidation)
if params.ProgressPlot
    
    % Update the training progress.
    D = duration(0,0,toc(start),"Format","hh:mm:ss");
    title(plotters.LossPlotter.Parent,"Epoch: " + epoch + ", Elapsed: " + string(D));
    addpoints(plotters.LossPlotter,iteration,double(gather(extractdata(lossTrain))));
    addpoints(plotters.TrainAccPlotter,iteration,accuracyTrain);
    addpoints(plotters.ValAccPlotter,iteration,accuracyValidation);
    drawnow
end
end

参考文献

[1] Carreira, Joao, and Andrew Zisserman. "Quo Vadis, Action Recognition? A New Model and the Kinetics Dataset." Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition (CVPR): 6299??6308.Honolulu, HI: IEEE, 2017.

[2] Simonyan, Karen, and Andrew Zisserman. "Two-Stream Convolutional Networks for Action Recognition in Videos." Advances in Neural Information Processing Systems 27, Long Beach, CA: NIPS, 2017.

[3] Loshchilov, Ilya, and Frank Hutter. "SGDR: Stochastic Gradient Descent with Warm Restarts." International Conferencee on Learning Representations 2017. Toulon, France: ICLR, 2017.