カスタム ミニバッチ データストアの開発
"ミニバッチ データストア" とは、バッチ単位でのデータの読み取りをサポートするデータストアの実装です。ミニバッチ データストアは、Deep Learning Toolbox™ を使用する深層学習アプリケーションの学習データ セット、検証データ セット、テスト データ セット、および予測データ セットのソースとして使用できます。
シーケンス データ、時系列データ、またはテキスト データを前処理するには、ここで説明するフレームワークを使用して独自のミニバッチ データストアを構築します。カスタム ミニバッチ データストアを使用する方法を示す例については、シーケンス データのカスタム ミニバッチ データストアを使用したネットワークの学習を参照してください。
概要
カスタム データストアのクラスおよびオブジェクトを使用して、カスタム データストア インターフェイスを構築します。次に、カスタム データストアを使用してデータを MATLAB® に読み込みます。
カスタム ミニバッチ データストアの設計には、matlab.io.Datastore
および matlab.io.datastore.MiniBatchable
クラスからの継承や、必要なプロパティおよびメソッドの実装が含まれます。オプションで、学習中のシャッフルのサポートを追加できます。
処理のニーズ | クラス |
---|---|
Deep Learning Toolbox での学習データ セット、検証データ セット、テスト データ セット、および予測データ セット用のミニバッチ データストア |
MiniBatchable データストアの実装を参照してください。 |
学習中のシャッフルをサポートするミニバッチ データストア |
シャッフルのサポートの追加を参照してください。 |
MiniBatchable
データストアの実装
MyDatastore
という名前のカスタム ミニバッチ データストアを実装するには、スクリプト MyDatastore.m
を作成します。このスクリプトは MATLAB パス上になければならず、適切なクラスから継承し、必要なメソッドを定義するコードを含んでいる必要があります。Deep Learning Toolbox での学習データ セット、検証データ セット、テスト データ セット、および予測データ セット用のミニバッチ データストアを作成するコードは以下でなければなりません。
クラス
matlab.io.Datastore
およびmatlab.io.datastore.MiniBatchable
から継承する。プロパティ
MiniBatchSize
およびNumObservations
を定義する。
これらの手順に加えて、データの処理および解析に必要なその他のプロパティやメソッドを定義できます。
メモ
ネットワークに学習させていて、trainingOptions
が 'Shuffle'
を 'once'
または 'every-epoch'
に指定している場合、matlab.io.datastore.Shuffleable
クラスからも継承しなければなりません。詳細については、シャッフルのサポートの追加を参照してください。
データストアの読み取り関数は、table でデータを返さなければなりません。table の要素は、スカラー、行ベクトルであるか、数値配列が格納された 1 行 1 列の cell 配列でなければなりません。
単一の入力層をもつネットワークの場合、最初の列と 2 番目の列はそれぞれ予測子と応答を指定します。
ヒント
複数の入力層または複数の出力をもつネットワークに学習させるには、combine
関数と transform
関数を使用して、(numInputs
+ numOutputs
) 列の cell 配列を出力するデータストアを作成します。ここで、numInputs
はネットワーク入力の数、numOutputs
はネットワーク出力の数です。最初の numInputs
個の列は各入力の予測子を指定し、最後の numOutputs
個の列は応答を指定します。ニューラル ネットワークの InputNames
プロパティと OutputNames
プロパティによって、それぞれ入力と出力の順序が決まります。
予測子の形式は、データのタイプによって異なります。
データ | 予測子の形式 |
---|---|
2 次元イメージ | h x w x c の数値配列。ここで、h、w、および c は、それぞれイメージの高さ、幅、およびチャネル数です。 |
3 次元イメージ | h x w x d x c の数値配列。ここで、h、w、d、および c は、それぞれイメージの高さ、幅、深さ、およびチャネル数です。 |
ベクトル シーケンス | s 行 c 列の行列。ここで、s はシーケンス長、c はシーケンスの特徴の数です。 |
1 次元イメージ シーケンス | h x c x s の配列。ここで、h および c はそれぞれイメージの高さおよびチャネル数に対応します。s はシーケンス長です。 ミニバッチ内の各シーケンスは、同じシーケンス長でなければなりません。 |
2 次元イメージ シーケンス | h x w x c x s の配列。ここで、h、w、および c はそれぞれイメージの高さ、幅、およびチャネル数に対応します。s はシーケンス長です。 ミニバッチ内の各シーケンスは、同じシーケンス長でなければなりません。 |
3 次元イメージ シーケンス | h x w x d x c x s の配列。ここで、h、w、d、および c は、それぞれイメージの高さ、幅、深さ、およびチャネル数に対応します。s はシーケンス長です。 ミニバッチ内の各シーケンスは、同じシーケンス長でなければなりません。 |
特徴 | c 行 1 列の列ベクトル。c は特徴の数です。 |
table 要素には、数値スカラー、数値行ベクトルが含まれているか、数値配列が格納された 1 行 1 列の cell 配列が含まれていなければなりません。
応答の形式は、タスクのタイプによって異なります。
タスク | 応答の形式 |
---|---|
分類 | categorical スカラー |
回帰 |
|
sequence-to-sequence 分類 | カテゴリカル ラベルの 1 行 s 列のシーケンス。ここで、s は対応する予測子シーケンスのシーケンス長です。 |
sequence-to-sequence 回帰 | R 行 s 列の行列。ここで、R は応答の数、s は対応する予測子シーケンスのシーケンス長です。 |
table 要素には、categorical スカラー、数値スカラー、数値行ベクトルが含まれているか、数値配列が格納された 1 行 1 列の cell 配列が含まれていなければなりません。
この例では、シーケンス データを処理するためのカスタム ミニバッチ データストアを作成する方法を説明します。スクリプトを MySequenceDatastore.m
という名前のファイルに保存します。
手順 | 実装 |
---|---|
| classdef MySequenceDatastore < matlab.io.Datastore & ... matlab.io.datastore.MiniBatchable properties Datastore Labels NumClasses SequenceDimension MiniBatchSize end properties(SetAccess = protected) NumObservations end properties(Access = private) % This property is inherited from Datastore CurrentFileIndex end methods function ds = MySequenceDatastore(folder) % Construct a MySequenceDatastore object % Create a file datastore. The readSequence function is % defined following the class definition. fds = fileDatastore(folder, ... 'ReadFcn',@readSequence, ... 'IncludeSubfolders',true); ds.Datastore = fds; % Read labels from folder names numObservations = numel(fds.Files); for i = 1:numObservations file = fds.Files{i}; filepath = fileparts(file); [~,label] = fileparts(filepath); labels{i,1} = label; end ds.Labels = categorical(labels); ds.NumClasses = numel(unique(labels)); % Determine sequence dimension. When you define the LSTM % network architecture, you can use this property to % specify the input size of the sequenceInputLayer. X = preview(fds); ds.SequenceDimension = size(X,1); % Initialize datastore properties. ds.MiniBatchSize = 128; ds.NumObservations = numObservations; ds.CurrentFileIndex = 1; end function tf = hasdata(ds) % Return true if more data is available tf = ds.CurrentFileIndex + ds.MiniBatchSize - 1 ... <= ds.NumObservations; end function [data,info] = read(ds) % Read one mini-batch batch of data miniBatchSize = ds.MiniBatchSize; info = struct; for i = 1:miniBatchSize predictors{i,1} = read(ds.Datastore); responses(i,1) = ds.Labels(ds.CurrentFileIndex); ds.CurrentFileIndex = ds.CurrentFileIndex + 1; end data = preprocessData(ds,predictors,responses); end function data = preprocessData(ds,predictors,responses) % data = preprocessData(ds,predictors,responses) preprocesses % the data in predictors and responses and returns the table % data miniBatchSize = ds.MiniBatchSize; % Pad data to length of longest sequence. sequenceLengths = cellfun(@(X) size(X,2),predictors); maxSequenceLength = max(sequenceLengths); for i = 1:miniBatchSize X = predictors{i}; % Pad sequence with zeros. if size(X,2) < maxSequenceLength X(:,maxSequenceLength) = 0; end predictors{i} = X; end % Return data as a table. data = table(predictors,responses); end function reset(ds) % Reset to the start of the data reset(ds.Datastore); ds.CurrentFileIndex = 1; end end methods (Hidden = true) function frac = progress(ds) % Determine percentage of data read from datastore frac = (ds.CurrentFileIndex - 1) / ds.NumObservations; end end end % end class definition readSequence という関数が使用されます。MAT ファイルからシーケンス データを読み取るにはこの関数を作成しなければなりません。function data = readSequence(filename) % data = readSequence(filename) reads the sequence X from the MAT-file % filename S = load(filename); data = S.X; end |
シャッフルのサポートの追加
シャッフルのサポートを追加するには、まず、MiniBatchable データストアの実装の手順に従います。次に MySequenceDatastore.m
の実装コードを次のように更新します。
追加クラス
matlab.io.datastore.Shuffleable
から継承します。追加メソッド
shuffle
を定義します。
この例のコードでは、シャッフルのサポートを MySequenceDatastore
クラスに追加します。縦並びの省略記号は、MySequenceDatastore
の実装からコードをコピーする必要がある場所を示します。
手順 | 実装 |
---|---|
| classdef MySequenceDatastore < matlab.io.Datastore & ... matlab.io.datastore.MiniBatchable & ... matlab.io.datastore.Shuffleable % previously defined properties . . . methods % previously defined methods . . . function dsNew = shuffle(ds) % dsNew = shuffle(ds) shuffles the files and the % corresponding labels in the datastore. % Create a copy of datastore dsNew = copy(ds); dsNew.Datastore = copy(ds.Datastore); fds = dsNew.Datastore; % Shuffle files and corresponding labels numObservations = dsNew.NumObservations; idx = randperm(numObservations); fds.Files = fds.Files(idx); dsNew.Labels = dsNew.Labels(idx); end end end |
カスタム ミニバッチ データストアの検証
ここに記載したすべての手順に従うと、カスタム ミニバッチ データストアの実装が完了します。このデータストアを使用する前に、カスタム データストアのテストのガイドラインに記載されているガイドラインを使用して、データストアが適切か確認します。
参考
trainnet
| trainingOptions
| dlnetwork