Main Content

このページの内容は最新ではありません。最新版の英語を参照するには、ここをクリックします。

freezeParameters

ONNXParameters の学習可能なネットワーク パラメーターを学習不能に変換する

R2020b 以降

    説明

    params = freezeParameters(params,names) は、ONNXParameters オブジェクト paramsnames で指定されたネットワーク パラメーターを凍結します。関数は、指定されたパラメーターを入力引数 paramsparams.Learnables から出力引数 paramsparams.Nonlearnables に移動します。

    すべて折りたたむ

    squeezenet畳み込みニューラル ネットワークを関数としてインポートし、事前学習済みのネットワークを転移学習で微調整し、新しいイメージ コレクションで分類を実行します。

    この例では、いくつかの補助関数を使用しています。これらの関数のコードを見るには、補助関数を参照してください。

    新しいイメージを解凍してイメージ データストアとして読み込みます。imageDatastore は、フォルダー名に基づいてイメージに自動的にラベルを付け、データを ImageDatastore オブジェクトとして格納します。イメージ データストアを使用すると、メモリに収まらないデータなどの大きなイメージ データを格納し、畳み込みニューラル ネットワークの学習中にイメージをバッチ単位で効率的に読み取ることができます。ミニバッチのサイズを指定します。

    unzip("MerchData.zip");
    miniBatchSize = 8;
    imds = imageDatastore("MerchData", ...
        IncludeSubfolders=true, ...
        LabelSource="foldernames", ...
        ReadSize=miniBatchSize);

    このデータ セットは小さく、75 個の学習イメージが含まれています。いくつかのサンプル イメージを表示します。

    numImages = numel(imds.Labels);
    idx = randperm(numImages,16);
    figure
    for i = 1:16
        subplot(4,4,i)
        I = readimage(imds,idx(i));
        imshow(I)
    end

    学習セットを抽出し、カテゴリカル分類ラベルを one-hot 符号化します。

    XTrain = readall(imds);
    XTrain = single(cat(4,XTrain{:}));
    YTrain_categ = categorical(imds.Labels);
    YTrain = onehotencode(YTrain_categ,2)';

    データ内のクラスの数を判定します。

    classes = categories(YTrain_categ);
    numClasses = numel(classes)
    numClasses = 5
    

    squeezenet は、ImageNet データベースの 100 万個を超えるイメージで学習済みの畳み込みニューラル ネットワークです。結果として、このネットワークは広範囲のイメージに対する豊富な特徴表現を学習しています。このネットワークは、イメージを 1000 個のオブジェクト カテゴリ (キーボード、マウス、鉛筆、多くの動物など) に分類できます。

    事前学習済みの squeezenet ネットワークを関数としてインポートします。

    squeezenetONNX()
    params = importONNXFunction("squeezenet.onnx","squeezenetFcn")
    Function containing the imported ONNX network architecture was saved to the file squeezenetFcn.m.
    To learn how to use this function, type: help squeezenetFcn.
    
    params = 
      ONNXParameters with properties:
    
                 Learnables: [1×1 struct]
              Nonlearnables: [1×1 struct]
                      State: [1×1 struct]
              NumDimensions: [1×1 struct]
        NetworkFunctionName: 'squeezenetFcn'
    
    

    params は、ネットワーク パラメーターを含む ONNXParameters オブジェクトです。squeezenetFcn は、ネットワーク アーキテクチャを含むモデル関数です。importONNXFunction は、squeezenetFcn を現在のフォルダーに保存します。

    新しい学習セットに対する事前学習済みネットワークの分類精度を計算します。

    accuracyBeforeTraining = getNetworkAccuracy(XTrain,YTrain,params);
    fprintf("%.2f accuracy before transfer learning\n",accuracyBeforeTraining);
    0.01 accuracy before transfer learning
    

    非常に低い精度です。

    params.Learnables と入力して、ネットワークの学習可能なパラメーターを表示します。重み (W) やバイアス (B) など、畳み込み層や全結合層のこれらのパラメーターは、学習時にネットワークによって更新されます。学習不能パラメーターは、学習時に一定のままとなります。

    事前学習済みネットワークの最後の 2 つの学習可能パラメーターは、1000 個のクラスに対して構成されています。

    conv10_W: [1×1×512×1000 dlarray]

    conv10_B: [1000×1 dlarray]

    パラメーター conv10_Wconv10_B は、新しい分類問題に対して微調整しなければなりません。パラメーターを初期化して、5 つのクラスを分類するようにパラメーターを移動します。

    params.Learnables.conv10_W = rand(1,1,512,5);
    params.Learnables.conv10_B = rand(5,1);

    ネットワークのすべてのパラメーターを凍結して、それらを学習不能なパラメーターに変換します。多くの初期層の重みを凍結することで、凍結層の勾配を計算する必要をなくし、ネットワークの学習を大幅に高速化できます。

    params = freezeParameters(params,"all");

    ネットワークの最後の 2 つのパラメーターを凍結解除して、それらを学習可能なパラメーターに変換します。

    params = unfreezeParameters(params,"conv10_W");
    params = unfreezeParameters(params,"conv10_B");

    ネットワークの学習準備が整いました。学習オプションを指定します。

    velocity = [];
    numEpochs = 5;
    miniBatchSize = 16;
    initialLearnRate = 0.01;
    momentum = 0.9;
    decay = 0.01;

    学習の進行状況モニター用に合計反復回数を計算します。

    numObservations = size(YTrain,2);
    numIterationsPerEpoch = floor(numObservations./miniBatchSize);
    numIterations = numEpochs*numIterationsPerEpoch;

    TrainingProgressMonitor オブジェクトを初期化します。監視オブジェクトを作成するとタイマーが開始されるため、学習ループの直後でオブジェクトを作成するようにしてください。

    monitor = trainingProgressMonitor(Metrics="Loss",Info="Epoch",XLabel="Iteration");

    ネットワークに学習をさせます。

    epoch = 0;
    iteration = 0;
    executionEnvironment = "cpu"; % Change to "gpu" to train on a GPU.
    
    % Loop over epochs.
    while epoch < numEpochs && ~monitor.Stop
    
        epoch = epoch + 1;
        
        % Shuffle data.
        idx = randperm(numObservations);
        XTrain = XTrain(:,:,:,idx);
        YTrain = YTrain(:,idx);
        
        % Loop over mini-batches.
        i = 0;
        while i < numIterationsPerEpoch && ~monitor.Stop
            i = i + 1;
            iteration = iteration + 1;
            
            % Read mini-batch of data.
            idx = (i-1)*miniBatchSize+1:i*miniBatchSize;
            X = XTrain(:,:,:,idx);        
            Y = YTrain(:,idx);
            
            % If training on a GPU, then convert data to gpuArray.
            if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu"
                X = gpuArray(X);         
            end
            
            % Evaluate the model gradients and loss using dlfeval and the
            % modelGradients function.
            [gradients,loss,state] = dlfeval(@modelGradients,X,Y,params);
            params.State = state;
            
            % Determine the learning rate for the time-based decay learning rate schedule.
            learnRate = initialLearnRate/(1 + decay*iteration);
            
            % Update the network parameters using the SGDM optimizer.
            [params.Learnables,velocity] = sgdmupdate(params.Learnables,gradients,velocity);
            
            % Update the training progress monitor.
            recordMetrics(monitor,iteration,Loss=loss);
            updateInfo(monitor,Epoch=epoch);
            monitor.Progress = 100 * iteration/numIterations;
        end
    end

    微調整後のネットワークの分類精度を計算します。

    accuracyAfterTraining = getNetworkAccuracy(XTrain,YTrain,params);
    fprintf("%.2f accuracy after transfer learning\n",accuracyAfterTraining);
    1.00 accuracy after transfer learning
    

    補助関数

    このセクションでは、この例で使用されている補助関数のコードを示します。

    関数 getNetworkAccuracy は、分類精度を計算することによってネットワーク性能を評価します。

    function accuracy = getNetworkAccuracy(X,Y,onnxParams)
    
    N = size(X,4);
    Ypred = squeezenetFcn(X,onnxParams,Training=false);
    
    [~,YIdx] = max(Y,[],1);
    [~,YpredIdx] = max(Ypred,[],1);
    numIncorrect = sum(abs(YIdx-YpredIdx) > 0);
    accuracy = 1 - numIncorrect/N;
    
    end

    関数 modelGradients は、損失と勾配を計算します。

    function [grad, loss, state] = modelGradients(X,Y,onnxParams)
    
    [y,state] = squeezenetFcn(X,onnxParams,Training=true);
    loss = crossentropy(y,Y,DataFormat="CB");
    grad = dlgradient(loss,onnxParams.Learnables);
    
    end

    関数 squeezenetONNX は、squeezenetネットワークの ONNX モデルを生成します。

    function squeezenetONNX()
        
    exportONNXNetwork(squeezenet,"squeezenet.onnx");
    
    end

    入力引数

    すべて折りたたむ

    ネットワーク パラメーター。ONNXParameters オブジェクトとして指定します。params には、インポートされた ONNX™ モデルのネットワーク パラメーターが格納されます。

    凍結するパラメーターの名前。'all' または string 配列として指定します。names'all' に設定すると、すべての学習可能パラメーターを凍結します。1 行 k 列の string 配列 names でパラメーター名を定義することにより、k 個の学習可能なパラメーターを凍結します。

    例: 'all'

    例: ["gpu_0_sl_pred_b_0", "gpu_0_sl_pred_w_0"]

    データ型: char | string

    出力引数

    すべて折りたたむ

    ネットワーク パラメーター。ONNXParameters オブジェクトとして返されます。params には、freezeParameters によって更新されたネットワーク パラメーターが格納されます。

    バージョン履歴

    R2020b で導入