メインコンテンツ

畳み込み LSTM ネットワークのコード生成

この例では、畳み込み層と双方向長短期記憶 (BiLSTM) 層を含む深層学習ネットワークの MEX 関数を生成する方法を示します。生成される関数は、いずれのサードパーティ ライブラリも使用しません。生成された MEX 関数は、指定されたビデオ ファイルのデータをビデオ フレームのシーケンスとして読み取り、ビデオのアクティビティを分類するラベルを出力します。このネットワークの学習の詳細については、深層学習を使用したビデオの分類 (Deep Learning Toolbox)の例を参照してください。サポートされるコンパイラの詳細については、深層学習に MATLAB Coder を使用するための前提条件を参照してください。

この例は、Mac®、Linux®、および Windows® の各プラットフォームでサポートされます。MATLAB® Online™ ではサポートされていません。

入力ビデオの準備

readvideo 補助関数を使用して、ビデオ ファイル pushup.mp4 を読み取ります。ビデオを確認するには、ビデオ ファイルの個々のフレームに対してループ処理を行い、関数 imshow を使用します。

filename = "pushup.mp4";
video = readVideo(filename);
numFrames = size(video,4);
figure
for i = 1:numFrames
    frame = video(:,:,:,i);
    imshow(frame/255);
    drawnow
end

centerCrop 補助関数を使用して、入力ビデオ フレームを学習済みネットワークの入力サイズに合わせて中央にトリミングします。

inputSize = [224 224 3];
video = centerCrop(video,inputSize);

"video_classify" エントリポイント関数

エントリポイント関数 video_classify.m はイメージ シーケンスを受け取り、それを予測のために学習済みネットワークに渡します。この関数は深層学習を使用したビデオの分類 (Deep Learning Toolbox)の例の畳み込み LSTM ネットワークを使用します。この関数は net.mat ファイルから永続変数にネットワーク オブジェクトを読み込み、classify (Deep Learning Toolbox)関数を使用して予測を実行します。それ以降の呼び出しでは、関数はこの永続オブジェクトを再利用します。

type('video_classify.m')
function out = video_classify(in) %#codegen
%   Copyright 2021-2024 The MathWorks, Inc.

% A persistent object dlnet is used to load the dlnetwork object. At the
% first call to this function, the persistent object is constructed and
% setup. When the function is called subsequent times, the same object is
% reused to call predict on inputs, thus avoiding reconstructing and
% reloading the network object. A categorial arrary labels is also loaded

persistent dlnet;
persistent labels;

if isempty(dlnet)
    dlnet = coder.loadDeepLearningNetwork('dlnet.mat');
    labels = coder.load('labels.mat');
end

% The dlnetwork object require dlarrays as inputs, convert input to a
% dlarray
dlIn = dlarray(in, 'SSCT');

% pass input to network and perform prediction
dlOut = predict(dlnet, dlIn); 
scores = extractdata(dlOut);

classNames = labels.classNames;

% Convert prediction scores to labels
out = scores2label(scores,classNames,1);

事前学習済みのネットワークのダウンロード

downloadVideoClassificationNetwork 補助関数を実行してビデオ分類ネットワークをダウンロードし、そのネットワークを MAT ファイル net.mat に保存します。

downloadVideoClassificationNetwork();

MEX 関数の生成

MEX 関数を生成するには、cfg という名前のcoder.MexCodeConfigオブジェクトを作成します。cfgTargetLang プロパティを C++ に設定します。いずれのサードパーティ ライブラリも使用しないコードを生成するには、targetlibnone に設定してcoder.DeepLearningConfig関数を使用します。これを cfg オブジェクトの DeepLearningConfig プロパティに割り当てます。

cfg = coder.config('mex');
cfg.TargetLang = 'C++';
cfg.DeepLearningConfig = coder.DeepLearningConfig('none');

関数coder.typeofを使用してエントリポイント関数への入力引数の型とサイズを指定します。この例では、入力はサイズ 224×224×3 と可変のシーケンス長をもつ single 型です。

Input = coder.typeof(single(0),[224 224 3 Inf],[false false false true]);

codegenコマンドを実行して、MEX 関数を生成します。

codegen -config cfg video_classify -args {Input} -report
Code generation successful: View report

生成された MEX 関数の実行

中央にトリミングされたビデオ入力を使用して、生成された MEX 関数を実行します。

output = video_classify_mex(single(video))
output = categorical
     pushup 

入力ビデオに予測を重ねて表示します。

video = readVideo(filename);
numFrames = size(video,4);
figure
for i = 1:numFrames
    frame = video(:,:,:,i);
    frame = insertText(frame, [1 1], char(output), 'TextColor', [255 255 255],'FontSize',30, 'BoxColor', [0 0 0]);
    imshow(frame/255);
    drawnow
end

補助関数

readVideo 補助関数は、MATLAB または Jetson™ デバイスのいずれかでビデオ ファイルを読み取り、4 次元配列として返します。

function video = readVideo(filename, frameSize)

if coder.target('MATLAB')
    vr = VideoReader(filename);
else
    hwobj = jetson();
    vr = VideoReader(hwobj, filename, 'Width', frameSize(1), 'Height', frameSize(2));
end
H = vr.Height;
W = vr.Width;
C = 3;

% Preallocate video array
numFrames = floor(vr.Duration * vr.FrameRate);
video = zeros(H,W,C,numFrames);

% Read frames
i = 0;
while hasFrame(vr)
    i = i + 1;
    video(:,:,:,i) = readFrame(vr);
end

% Remove unallocated frames
if size(video,4) > i
    video(:,:,:,i+1:end) = [];
end

end

centerCrop 補助関数は、ビデオを向きに基づいて四角形にトリミングし、指定された入力サイズに合わせてサイズを変更します。

function videoResized = centerCrop(video,inputSize)
%   Copyright 2020-2021 The MathWorks, Inc.

sz = size(video);
videoTmp = video;

if sz(1) < sz(2)
    % Video is landscape
    idx = floor((sz(2) - sz(1))/2);
    videoTmp(:,1:(idx-1),:,:) = [];
    videoTmp(:,(sz(1)+1):end,:,:) = [];
    
elseif sz(2) < sz(1)
    % Video is portrait
    idx = floor((sz(1) - sz(2))/2);
    videoTmp(1:(idx-1),:,:,:) = [];
    videoTmp((sz(2)+1):end,:,:,:) = [];
end

videoResized = imresize(videoTmp,inputSize(1:2));
videoResized = reshape(videoResized, inputSize(1), inputSize(2), inputSize(3), []);
end

参考

| |

トピック