Main Content

このページの翻訳は最新ではありません。ここをクリックして、英語の最新版を参照してください。

カスタム ミニバッチ データストアの開発

"ミニバッチ データストア" とは、バッチ単位でのデータの読み取りをサポートするデータストアの実装です。ミニバッチ データストアは、Deep Learning Toolbox™ を使用する深層学習アプリケーションの学習データセット、検証データセット、テスト データセット、および予測データセットのソースとして使用できます。

シーケンス データ、時系列データ、またはテキスト データを前処理するには、ここで説明するフレームワークを使用して独自のミニバッチ データストアを構築します。カスタム ミニバッチ データストアを使用する方法を示す例については、シーケンス データのカスタム ミニバッチ データストアを使用したネットワークの学習を参照してください。

概要

カスタム データストアのクラスおよびオブジェクトを使用して、カスタム データストア インターフェイスを構築します。次に、カスタム データストアを使用してデータを MATLAB® に読み込みます。

カスタム ミニバッチ データストアの設計には、matlab.io.Datastore および matlab.io.datastore.MiniBatchable クラスからの継承や、必要なプロパティおよびメソッドの実装が含まれます。オプションで、学習中のシャッフルのサポートを追加できます。

処理のニーズ

クラス

Deep Learning Toolbox での学習データセット、検証データセット、テスト データセット、および予測データセット用のミニバッチ データストア

matlab.io.Datastore および matlab.io.datastore.MiniBatchable

MiniBatchable データストアの実装を参照してください。

学習中のシャッフルをサポートするミニバッチ データストア

matlab.io.Datastorematlab.io.datastore.MiniBatchable および matlab.io.datastore.Shuffleable

シャッフルのサポートの追加を参照してください。

MiniBatchable データストアの実装

MyDatastore という名前のカスタム ミニバッチ データストアを実装するには、スクリプト MyDatastore.m を作成します。このスクリプトは MATLAB パス上になければならず、適切なクラスから継承し、必要なメソッドを定義するコードを含んでいる必要があります。Deep Learning Toolbox での学習データセット、検証データセット、テスト データセット、および予測データセット用のミニバッチ データストアを作成するコードは以下でなければなりません。

これらの手順に加えて、データの処理および解析に必要なその他のプロパティやメソッドを定義できます。

メモ

ネットワークに学習させていて、trainingOptions'Shuffle''once' または 'every-epoch' に指定している場合、matlab.io.datastore.Shuffleable クラスからも継承しなければなりません。詳細については、シャッフルのサポートの追加を参照してください。

データストアの読み取り関数は、table でデータを返さなければなりません。table の要素は、スカラー、行ベクトルであるか、数値配列が格納された 1 行 1 列の cell 配列でなければなりません。

単一の入力層をもつネットワークの場合、最初の列と 2 番目の列はそれぞれ予測子と応答を指定します。

ヒント

複数の入力層があるネットワークにデータストアを使用するには、関数 combine および transform を使用して、列数が (numInputs + 1) の cell 配列を出力するデータストアを作成します。ここで、numInputs はネットワーク入力の数です。この場合、最初の numInputs 列は各入力の予測子を指定し、最後の列は応答を指定します。入力の順序は、層グラフ layersInputNames プロパティによって指定します。

予測子の形式は、データのタイプによって異なります。

データ予測子の形式
2 次元イメージ

h x w x c の数値配列。ここで、h、w、および c は、それぞれイメージの高さ、幅、およびチャネル数です。

3 次元イメージ

h x w x d x c の数値配列。ここで、h、w、d、および c は、それぞれイメージの高さ、幅、深さ、およびチャネル数です。

ベクトル シーケンス

c 行 s 列の行列。ここで、c はシーケンスの特徴の数、s はシーケンス長です。

2 次元イメージ シーケンス

h x w x c x s の配列。ここで、h、w、および c はそれぞれイメージの高さ、幅、およびチャネル数に対応します。s はシーケンス長です。

ミニバッチ内の各シーケンスは、同じシーケンス長でなければなりません。

3 次元イメージ シーケンス

h x w x d x c x s の配列。ここで、h、w、d、および c は、それぞれイメージの高さ、幅、深さ、およびチャネル数に対応します。s はシーケンス長です。

ミニバッチ内の各シーケンスは、同じシーケンス長でなければなりません。

特徴

c 行 1 列の列ベクトル。c は特徴の数です。

table 要素には、数値スカラー、数値行ベクトルが含まれているか、数値配列が格納された 1 行 1 列の cell 配列が含まれていなければなりません。

関数 trainNetwork は、複数のシーケンス入力層をもつネットワークをサポートしていません。

応答の形式は、タスクのタイプによって異なります。

タスク応答の形式
分類categorical スカラー
回帰

  • スカラー

  • 数値ベクトル

  • イメージを表す 3 次元数値配列

sequence-to-sequence 分類

カテゴリカル ラベルの 1 行 s 列のシーケンス。ここで、s は対応する予測子シーケンスのシーケンス長です。

sequence-to-sequence 回帰

R 行 s 列の行列。ここで、R は応答の数、s は対応する予測子シーケンスのシーケンス長です。

table 要素には、categorical スカラー、数値スカラー、数値行ベクトルが含まれているか、数値配列が格納された 1 行 1 列の cell 配列が含まれていなければなりません。

この例では、シーケンス データを処理するためのカスタム ミニバッチ データストアを作成する方法を説明します。スクリプトを MySequenceDatastore.m という名前のファイルに保存します。

手順実装

  1. クラスの定義を開始します。基底クラス matlab.io.Datastore および matlab.io.datastore.MiniBatchable クラスから継承します。

  2. プロパティを定義します。

    • MiniBatchSize および NumObservations のプロパティを再定義します。オプションで、追加のプロパティ属性をどちらかのプロパティに割り当てることができます。詳細については、プロパティの属性を参照してください。

    • カスタム ミニバッチ データストアに固有のプロパティを定義することもできます。

  3. メソッドを定義します。

    • カスタム ミニバッチ データストア コンストラクターを実装します。

    • hasdata メソッドを実装します。

    • read メソッドを実装します。このメソッドは、1 番目の列が予測子、2 番目の列が応答であるテーブルとしてデータを返さなければなりません。

      シーケンス データの場合、シーケンスはサイズが c 行 s 列の行列でなければなりません。c は特徴の数、s はシーケンス長です。s の値はミニバッチ間で異なる場合があります。

    • reset メソッドを実装します。

    • progress メソッドを実装します。

    • カスタム ミニバッチ データストアに固有のメソッドを定義することもできます。

  4. classdef セクションを終了します。

classdef MySequenceDatastore < matlab.io.Datastore & ...
                       matlab.io.datastore.MiniBatchable
    
    properties
        Datastore
        Labels
        NumClasses
        SequenceDimension
        MiniBatchSize
    end
    
    properties(SetAccess = protected)
        NumObservations
    end

    properties(Access = private)
        % This property is inherited from Datastore
        CurrentFileIndex
    end


    methods
        
        function ds = MySequenceDatastore(folder)
            % Construct a MySequenceDatastore object

            % Create a file datastore. The readSequence function is
            % defined following the class definition.
            fds = fileDatastore(folder, ...
                'ReadFcn',@readSequence, ...
                'IncludeSubfolders',true);
            ds.Datastore = fds;

            % Read labels from folder names
            numObservations = numel(fds.Files);
            for i = 1:numObservations
                file = fds.Files{i};
                filepath = fileparts(file);
                [~,label] = fileparts(filepath);
                labels{i,1} = label;
            end
            ds.Labels = categorical(labels);
            ds.NumClasses = numel(unique(labels));
            
            % Determine sequence dimension. When you define the LSTM
            % network architecture, you can use this property to
            % specify the input size of the sequenceInputLayer.
            X = preview(fds);
            ds.SequenceDimension = size(X,1);
            
            % Initialize datastore properties.
            ds.MiniBatchSize = 128;
            ds.NumObservations = numObservations;
            ds.CurrentFileIndex = 1;
        end

        function tf = hasdata(ds)
            % Return true if more data is available
            tf = ds.CurrentFileIndex + ds.MiniBatchSize - 1 ...
                <= ds.NumObservations;
        end

        function [data,info] = read(ds)            
            % Read one mini-batch batch of data
            miniBatchSize = ds.MiniBatchSize;
            info = struct;
            
            for i = 1:miniBatchSize
                predictors{i,1} = read(ds.Datastore);
                responses(i,1) = ds.Labels(ds.CurrentFileIndex);
                ds.CurrentFileIndex = ds.CurrentFileIndex + 1;
            end
            
            data = preprocessData(ds,predictors,responses);
        end

        function data = preprocessData(ds,predictors,responses)
            % data = preprocessData(ds,predictors,responses) preprocesses
            % the data in predictors and responses and returns the table
            % data
            
            miniBatchSize = ds.MiniBatchSize;
            
            % Pad data to length of longest sequence.
            sequenceLengths = cellfun(@(X) size(X,2),predictors);
            maxSequenceLength = max(sequenceLengths);
            for i = 1:miniBatchSize
                X = predictors{i};
                
                % Pad sequence with zeros.
                if size(X,2) < maxSequenceLength
                    X(:,maxSequenceLength) = 0;
                end
                
                predictors{i} = X;
            end
            
            % Return data as a table.
            data = table(predictors,responses);
        end

        function reset(ds)
            % Reset to the start of the data
            reset(ds.Datastore);
            ds.CurrentFileIndex = 1;
        end
        
    end 

    methods (Hidden = true)

        function frac = progress(ds)
            % Determine percentage of data read from datastore
            frac = (ds.CurrentFileIndex - 1) / ds.NumObservations;
        end

    end

end % end class definition
カスタム データストアの読み取り方法の実装には、readSequence という関数が使用されます。MAT ファイルからシーケンス データを読み取るにはこの関数を作成しなければなりません。
function data = readSequence(filename)
% data = readSequence(filename) reads the sequence X from the MAT-file
% filename

S = load(filename);
data = S.X;
end

シャッフルのサポートの追加

シャッフルのサポートを追加するには、まず、MiniBatchable データストアの実装の手順に従います。次に MySequenceDatastore.m の実装コードを次のように更新します。

この例のコードでは、シャッフルのサポートを MySequenceDatastore クラスに追加します。縦並びの省略記号は、MySequenceDatastore の実装からコードをコピーする必要がある場所を示します。

手順実装

  1. クラス定義を更新して、matlab.io.datastore.Shuffleable クラスからも継承します。

  2. shuffle の定義を既存の methods セクションに追加します。

classdef MySequenceDatastore < matlab.io.Datastore & ...
                       matlab.io.datastore.MiniBatchable & ...
                       matlab.io.datastore.Shuffleable
   
   % previously defined properties 
   .
   .
   . 


   methods

        % previously defined methods
        .
        .
        . 
   
        function dsNew = shuffle(ds)
            % dsNew = shuffle(ds) shuffles the files and the
            % corresponding labels in the datastore.
            
            % Create a copy of datastore
            dsNew = copy(ds);
            dsNew.Datastore = copy(ds.Datastore);
            fds = dsNew.Datastore;
            
            % Shuffle files and corresponding labels
            numObservations = dsNew.NumObservations;
            idx = randperm(numObservations);
            fds.Files = fds.Files(idx);
            dsNew.Labels = dsNew.Labels(idx);
        end

     end

end
  

カスタム ミニバッチ データストアの検証

ここに記載したすべての手順に従うと、カスタム ミニバッチ データストアの実装が完了します。このデータストアを使用する前に、カスタム データ ストアのテストのガイドラインに記載されているガイドラインを使用して、データストアが適切か確認します。

参考

関連する例

詳細