メインコンテンツ

ディープ ネットワーク デザイナーを使用した PyTorch モデルのインポート

R2023b 以降

この例では、ディープ ネットワーク デザイナー アプリを使用して、PyTorch® モデルを対話形式でインポートする方法を示します。

この例では、次のことを行います。

  1. PyTorch ネットワークを ディープ ネットワーク デザイナー にインポートします。

  2. インポート レポートを使用して問題を検査します。

  3. プレースホルダー関数を完成させます。

  4. ネットワークを解析して、問題が残っていないことを確認します。

ネットワークのインポート

コマンド ラインで deepNetworkDesigner と入力して、ディープ ネットワーク デザイナー を開きます。

PyTorch モデルをインポートするには、ディープ ネットワーク デザイナー スタート ページの [PyTorch から] で、[インポート] をクリックします。

[PyTorch モデルのインポート] ダイアログ ボックスが開きます。モデル ファイルの場所を dNetworkWithUnsupportedOps.pt に設定します。インポート中に、アプリによってカスタム層が現在のフォルダーに保存される場合があります。インポートする前に、現在の作業ディレクトリに対する書き込み権限があることを確認してください。モデルをインポートするには、[インポート] をクリックします。ネットワークのインポートには時間がかかる場合があります。

ネットワークをインポートするには、PyTorch が必要とする順序で入力サイズを指定します。この例では、ネットワーク入力サイズは [1 3 8 16] になります。PyTorch の入力サイズの詳細については、TensorFlow、PyTorch、および ONNX からモデルをインポートする際のヒントを参照してください。

[インポート] をクリックします。[デザイナー] キャンバスにネットワークが表示されます。このネットワークには、入れ子ネットワークを含む層である networklayer オブジェクトが含まれています。ネットワーク層ごとに、その中に含まれる層の数を確認できます。

ネットワーク層の内部を表示するには、層ブロックをダブル クリックします。ネットワーク層の読み取り専用コンテンツが表示されます。

問題の修正

アプリは、インポート中にソフトウェアによって検出された問題をリストしたインポート レポートを生成します。アクションを必要とするプレースホルダー関数があることがわかります。ソフトウェアが PyTorch 層を組み込みの MATLAB® 層に変換できない場合、または関連付けられた MATLAB 関数を使用してカスタム層を生成できない場合、関数はプレースホルダー関数を使用してカスタム層を作成します。プレースホルダー関数は、ネットワークを使用する前に補完しなければなりません。

プレースホルダー関数の編集

この問題を修正するには、[関数の編集] をクリックしてプレースホルダー関数を開きます。ネットワークを使用する前にこの関数を完成させなければなりません。

この関数にはプレースホルダー テキストが含まれており、これをこの層を実装する関数に置き換えなければなりません。Deep Learning Toolbox™ と PyTorch では次元の順序が異なります。プレースホルダー関数の場合、ソフトウェアは入力と出力が正しい形式であることを確認するための補助コードを生成します。

関数の入力と想定する出力は、"value""rank" という 2 つのフィールドをもつ構造体です。この特定の例では、入力値は dlarray です。ランクは値の次元数であり、整数として指定します。

Input.png

プレースホルダー関数を完成させるには、次のことが必要です。

  1. 入力構造体配列からデータを抽出します。

  2. 層のコア機能を実装します。

  3. データを必要な出力形式に変換し、エラーを無効にします。

DNDImportDiagramResized.png

次のセクションでは、これらの各手順について詳細に説明します。完成した関数を確認するには、関数の完成を参照してください。

関数入力の抽出と検査

まず、入力構造体配列から値を抽出します。最初の行にブレークポイントを設定して、構造体配列を表示できるようにします。ブレークポイントを使用して値を調べる方法の詳細については、ブレークポイントを設定するを参照してください。

DNDSetBreakpointResized.png

ディープ ネットワーク デザイナー で、[解析] をクリックします。ソフトウェアはネットワークを解析し、pyAtenMish 層のブレークポイントに達すると停止します。コマンド ウィンドウで inputs{1} を呼び出して、関数への入力を検査します。値フィールドは、5 つの次元をもつ (rank: 5) 形式を整えていないdlarrayオブジェクトであることがわかります。

inputs{1}

ans = 

  struct with fields:

    value: [5-D dlarray]
     rank: 5

Deep Learning Toolbox と PyTorch では次元の順序が異なります。プレースホルダー層関数では、入力は PyTorch® の順序になっています。関数が想定する出力は、PyTorch とは逆の順序になります。詳細については、入力次元の順序を参照してください。

Mish 関数の実装

次に、mish 活性化関数を実装します。mish 関数は、入力の形状、サイズ、ランクを維持します。

mish(x)=xtanh(softplus(x))

詳細については、PyTorch のドキュメンテーションを参照してください。

function varargout = pyAtenMish(varargin)
% Function for the PyTorch operator named aten::mish.

% ...

Xval = inputs{1}.value;
Xrank = inputs{1}.rank;

% Softplus function
Yval = log(1+exp(Xval));

% Mish function
Yval = Xval .* tanh(Yval);

% ...
end

関数の完成

最後に、mish 関数の出力を層が想定する出力タイプに変換し、エラーを無効にします。関数は、"value" フィールドと "rank" フィールドをもつ構造体配列を出力しなければなりません。"value" フィールドは形式を整えた dlarray オブジェクトでなければなりません。データ形式の詳細については、fmtを参照してください。出力のランクを入力のランクと同じに設定します。関数内のエラーをコメント アウトするか削除して無効にします。また、関数の冒頭にある実装方法を説明するコメントも削除できます。

完成した関数は次のようになります。

function varargout = pyAtenMish(varargin)
% Function for the PyTorch operator named aten::mish.
%
% Inputs:
%   Each input argument is a struct with fields:
%       value: Input data
%       rank:  Number of dimensions of the input data, including any
%              trailing dimensions of size 1, specified as a scalar.
%   Most functions will have a single input. Some functions can have
%   multiple inputs if the original PyTorch operator expects multiple
%   inputs. If a function requires multiple inputs, then varargin{1} is the
%   first input, varargin{2} is the second input, and so on.
%
% Outputs:
%   Each output argument is a struct with fields:
%       value: Output data
%       rank:  Number of dimensions of the output data, including any
%              trailing dimensions of size 1, specified as a scalar.
%   The function can have multiple outputs if the original PyTorch
%   operator expects multiple outputs. If a function returns multiple
%   outputs, then varargout{1} is the first output, varargout{2} is the
%   second output, and so on.

import dNetworkWithUnsupportedOps.ops.*

%% Do Not Edit - Code Generated by PyTorch Importer
% This code permutes the dimensions of the inputs into PyTorch ordering.
% When you implement the rest of this function, assume that the dimensions
% of the arrays are in the same order that they appear in the original
% PyTorch model.
inputs = cell(1,nargin);
[inputs{:}] = permuteToPyTorchDimensionOrder(varargin{:});

%% Do Not Edit - Code Generated by PyTorch Importer
% This code creates a cell array for the outputs.
outputs = cell(1,nargout);


%% To Do - Implement Function
% Write code to implement the function here. The results must be assigned
% to a cell array named 'outputs' where each element is a structure array
% containing the values of the output and rank. For example, if the first
% output has value Y with Yrank number of dimensions, then
% outputs{1} = struct('value',Y,'rank',Yrank);

% Extract the value and rank of X from the input struct. The dimensions
% are indexed in PyTorch order.
Xval = inputs{1}.value;
Xrank = inputs{1}.rank;

% Softplus function
Yval = log(1+exp(Xval));

% Mish function
Yval = Xval .* tanh(Yval);

% Determine rank and dimension format of the output.
Yrank = Xrank;
Yfmt = repmat('U',1,Yrank);

% Convert the output to a dlarray.
Yval = dlarray(Yval, Yfmt);

% Return a struct containing 'value' and 'rank' fields.
outputs{1} = struct('value',Yval,'rank',Yrank);


%% Do Not Edit - Code Generated by PyTorch Importer
% This code permutes the dimensions of the outputs back into reverse-PyTorch
% ordering.
varargout = cell(1,nargout);
[varargout{:}] = permutePyTorchToReversePyTorch(outputs{:});
end

ネットワークの確認

ネットワークが完全であることを確認するには、[解析] をクリックします。ネットワーク アナライザーはエラーをゼロと報告します。

参考

トピック