フィルターのクリア

ニューラルネットワー​クの学習をdoubl​e型で行うことはでき​ますか?

1 回表示 (過去 30 日間)
Fumiya Watanabe
Fumiya Watanabe 2018 年 6 月 26 日
コメント済み: Fumiya Watanabe 2018 年 7 月 5 日
ニューラルネットワークの学習をdouble型で行うことはできますか?
現在、ある実数値ベクトルを入力とする回帰問題をNeural Network Toolboxを用いて実現しようとしています。 このベクトル入力を画像入力として扱うことで実現を考えています。しかしながら、trainNetworkを実行するとsingle型として扱われてしまう問題が生じており、解決法がわからず困っております。
例えば、次の自作の回帰層を考えます。
classdef testLayer < nnet.layer.RegressionLayer
methods
function layer = testLayer()
end
function loss = forwardLoss(layer, Y, T)
loss = gpuArray(0);
end
function dLdX = backwardLoss(layer, Y, T)
dLdX = gpuArray(zeros(size(Y)));
end
end
end
この自作回帰層を用いて、次のように学習を実行します。
%%学習データ
x_in = rand(10, 1, 1, 6);
y_tr = rand(6, 5);
%%層構造とオプションの定義
layers = [
imageInputLayer([10 1 1], 'Normalization', 'none', 'Name', 'Input')
fullyConnectedLayer(2, 'Name', 'Layer1')
reluLayer('Name', 'ReLU1')
fullyConnectedLayer(5, 'Name', 'Output')
testLayer
];
layers(end).Name = 'Regression';
options = trainingOptions(...
'sgdm',...
'InitialLearnRate', 0.001, ...
'MiniBatchSize', 3, ...
'MaxEpochs', 1);
%%学習開始
net = trainNetwork(x_in, y_tr, layers, options);
すると、次のエラーが発生します。
エラー: trainNetwork (line 154)
Incorrect type of dLdX for 'backwardLoss' in the output layer. Expected gpuArray of underlying type 'single', but instead has
underlying type 'double'.
上記の自作回帰層で、gpuArrayの内部をsingleにキャストすることで実行することが可能となるのですが、実際に使っている自作回帰層ではdouble型でないと計算できない関数を利用しているため、
function loss = forwardLoss(layer, Y, T)
loss = gpuArray(single(myfun(double(Y), double(T))));
end
のようなキャストをしていく必要が生じてしまいます。これを避けるために学習をdouble型で実行したいのですが、解決法はありますでしょうか。

採用された回答

Naoya
Naoya 2018 年 6 月 29 日
Neural Network Toolbox で提供される 畳み込みニューラルネットワークですが、trainNetwork 側で与えるデータ型は single, double 両方を受け付けます。
しかしながら、基本的にGPU上では単精度演算として扱われますので、GPU へ渡すゲートウェイとなるデータ型は single型となってしまいます。
  3 件のコメント
Naoya
Naoya 2018 年 7 月 3 日
ご連絡ありがとうございます。 cpuモードの場合でも backwardLoss 関数のゲートウェイは single型にする必要があります。
Fumiya Watanabe
Fumiya Watanabe 2018 年 7 月 5 日
ご回答ありがとうございます。
入力としてはdouble型を受け付けるが、計算内部はGPU・CPUどちらの場合でもsingle型で実行される形になっており、自作の層を扱う場合はsingle型でほかの層とのやり取りが必要であると理解いたしました。 ありがとうございました。

サインインしてコメントする。

その他の回答 (0 件)

カテゴリ

Help Center および File Exchange深層学習データの前処理 についてさらに検索

製品


リリース

R2018a

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!