ドキュメンテーション

最新のリリースでは、このページがまだ翻訳されていません。 このページの最新版は英語でご覧になれます。

カスタム ミニバッチ データストアの開発

"ミニバッチ データストア" とは、バッチ単位でのデータの読み取りをサポートするデータストアの実装です。ミニバッチ データストアは、Deep Learning Toolbox™ を使用する深層学習アプリケーションの学習データセット、検証データセット、テスト データセット、および予測データセットのソースとして使用できます。

シーケンス データ、時系列データ、またはテキスト データを前処理するには、ここで説明するフレームワークを使用して独自のミニバッチ データストアを構築します。カスタム ミニバッチ データストアを使用する方法を示す例については、シーケンス データのカスタム ミニバッチ データストアを使用したネットワークの学習を参照してください。

概要

カスタム データストアのクラスおよびオブジェクトを使用して、カスタム データストア インターフェイスを構築します。次に、カスタム データストアを使用してデータを MATLAB® に読み込みます。

カスタム ミニバッチ データストアの設計には、matlab.io.Datastore および matlab.io.datastore.MiniBatchable クラスからの継承や、必要なプロパティおよびメソッドの実装が含まれます。オプションで、学習中のシャッフルのサポートを追加できます。

処理のニーズ

クラス

Deep Learning Toolbox での学習データセット、検証データセット、テスト データセット、および予測データセット用のミニバッチ データストア

matlab.io.Datastore および matlab.io.datastore.MiniBatchable

MiniBatchable データストアの実装を参照してください。

学習中のシャッフルをサポートするミニバッチ データストア

matlab.io.Datastorematlab.io.datastore.MiniBatchable および matlab.io.datastore.Shuffleable

シャッフルのサポートの追加を参照してください。

MiniBatchable データストアの実装

MyDatastore という名前のカスタム ミニバッチ データストアを実装するには、スクリプト MyDatastore.m を作成します。このスクリプトは MATLAB パス上になければならず、適切なクラスから継承し、必要なメソッドを定義するコードを含んでいる必要があります。Deep Learning Toolbox での学習データセット、検証データセット、テスト データセット、および予測データセット用のミニバッチ データストアを作成するコードは以下でなければなりません。

これらの手順に加えて、データの処理および解析に必要なその他のプロパティやメソッドを定義できます。

メモ

ネットワークに学習させていて、trainingOptions'Shuffle''once' または 'every-epoch' に指定している場合、matlab.io.datastore.Shuffleable クラスからも継承しなければなりません。詳細は、シャッフルのサポートの追加を参照してください。

この例では、シーケンス データを処理するためのカスタム ミニバッチ データストアを作成する方法を説明します。スクリプトを MySequenceDatastore.m という名前のファイルに保存します。

手順実装

 

  1. クラスの定義を開始します。基底クラス matlab.io.Datastore および matlab.io.datastore.MiniBatchable クラスから継承します。

     

  2. プロパティを定義します。

    • MiniBatchSize および NumObservations のプロパティを再定義します。オプションで、追加のプロパティ属性をどちらかのプロパティに割り当てることができます。詳細は、プロパティの属性 (MATLAB)を参照してください。

    • カスタム ミニバッチ データストアに固有のプロパティを定義することもできます。

     

  3. メソッドを定義します。

    • カスタム ミニバッチ データストア コンストラクターを実装します。

    • hasdata メソッドを実装します。

    • read メソッドを実装します。このメソッドは、1 番目の列が予測子、2 番目の列が応答である table としてデータを返さなければなりません。

      シーケンス データの場合、シーケンスはサイズが D 行 S 列の行列でなければなりません。D は特徴の数、S はシーケンス長です。S の値はミニバッチ間で異なる場合があります。

    • reset メソッドを実装します。

    • progress メソッドを実装します。

    • カスタム ミニバッチ データストアに固有のメソッドを定義することもできます。

     

  4. classdef セクションを終了します。

classdef MySequenceDatastore < matlab.io.Datastore & ...
                       matlab.io.datastore.MiniBatchable
    
    properties
        Datastore
        Labels
        NumClasses
        SequenceDimension
        MiniBatchSize
    end
    
    properties(SetAccess = protected)
        NumObservations
    end

    properties(Access = private)
        % This property is inherited from Datastore
        CurrentFileIndex
    end


    methods
        
        function ds = MySequenceDatastore(folder)
            % Construct a MySequenceDatastore object

            % Create a file datastore. The readSequence function is
            % defined following the class definition.
            fds = fileDatastore(folder, ...
                'ReadFcn',@readSequence, ...
                'IncludeSubfolders',true);
            ds.Datastore = fds;

            % Read labels from folder names
            numObservations = numel(fds.Files);
            for i = 1:numObservations
                file = fds.Files{i};
                filepath = fileparts(file);
                [~,label] = fileparts(filepath);
                labels{i,1} = label;
            end
            ds.Labels = categorical(labels);
            ds.NumClasses = numel(unique(labels));
            
            % Determine sequence dimension. When you define the LSTM
            % network architecture, you can use this property to
            % specify the input size of the sequenceInputLayer.
            X = preview(fds);
            ds.SequenceDimension = size(X,1);
            
            % Initialize datastore properties.
            ds.MiniBatchSize = 128;
            ds.NumObservations = numObservations;
            ds.CurrentFileIndex = 1;
        end

        function tf = hasdata(ds)
            % Return true if more data is available
            tf = ds.CurrentFileIndex + ds.MiniBatchSize - 1 ...
                <= ds.NumObservations;
        end

        function [data,info] = read(ds)            
            % Read one mini-batch batch of data
            miniBatchSize = ds.MiniBatchSize;
            info = struct;
            
            for i = 1:miniBatchSize
                predictors{i,1} = read(ds.Datastore);
                responses(i,1) = ds.Labels(ds.CurrentFileIndex);
                ds.CurrentFileIndex = ds.CurrentFileIndex + 1;
            end
            
            data = preprocessData(ds,predictors,responses);
        end

        function data = preprocessData(ds,predictors,responses)
            % data = preprocessData(ds,predictors,responses) preprocesses
            % the data in predictors and responses and returns the table
            % data
            
            miniBatchSize = ds.MiniBatchSize;
            
            % Pad data to length of longest sequence.
            sequenceLengths = cellfun(@(X) size(X,2),predictors);
            maxSequenceLength = max(sequenceLengths);
            for i = 1:miniBatchSize
                X = predictors{i};
                
                % Pad sequence with zeros.
                if size(X,2) < maxSequenceLength
                    X(:,maxSequenceLength) = 0;
                end
                
                predictors{i} = X;
            end
            
            % Return data as a table.
            data = table(predictors,responses);
        end

        function reset(ds)
            % Reset to the start of the data
            reset(ds.Datastore);
            ds.CurrentFileIndex = 1;
        end
        
    end 

    methods (Hidden = true)

        function frac = progress(ds)
            % Determine percentage of data read from datastore
            frac = (ds.CurrentFileIndex - 1) / ds.NumObservations;
        end

    end

end % end class definition
カスタム データストアの読み取り方法の実装には、readSequence という関数が使用されます。MAT ファイルからシーケンス データを読み取るにはこの関数を作成しなければなりません。
function data = readSequence(filename)
% data = readSequence(filename) reads the sequence X from the MAT-file
% filename

S = load(filename);
data = S.X;
end

シャッフルのサポートの追加

シャッフルのサポートを追加するには、まず、MiniBatchable データストアの実装の手順に従います。次に MySequenceDatastore.m の実装コードを次のように更新します。

この例のコードでは、シャッフルのサポートを MySequenceDatastore クラスに追加します。縦並びの省略記号は、MySequenceDatastore の実装からコードをコピーする必要がある場所を示します。

手順実装

 

  1. クラス定義を更新して、matlab.io.datastore.Shuffleable クラスからも継承します。

     

  2. shuffle の定義を既存の methods セクションに追加します。

classdef MySequenceDatastore < matlab.io.Datastore & ...
                       matlab.io.datastore.MiniBatchable & ...
                       matlab.io.datastore.Shuffleable
   
   % previously defined properties 
   .
   .
   . 


   methods

        % previously defined methods
        .
        .
        . 
   
        function dsNew = shuffle(ds)
            % dsNew = shuffle(ds) shuffles the files and the
            % corresponding labels in the datastore.
            
            % Create a copy of datastore
            dsNew = copy(ds);
            dsNew.Datastore = copy(ds.Datastore);
            fds = dsNew.Datastore;
            
            % Shuffle files and corresponding labels
            numObservations = dsNew.NumObservations;
            idx = randperm(numObservations);
            fds.Files = fds.Files(idx);
            dsNew.Labels = dsNew.Labels(idx);
        end

     end

end
  

カスタム ミニバッチ データストアの検証

ここに記載したすべての手順に従うと、カスタム ミニバッチ データストアの実装が完了します。このデータストアを使用する前に、カスタム データ ストアのテストのガイドライン (MATLAB)に記載されているガイドラインを使用して、データストアが適切か確認します。

参考

関連する例

詳細