Why is the cross-entropy in Neural Network trained w/ GPU different from that in CPU trained one?

3 ビュー (過去 30 日間)
LBW
LBW 2017 年 4 月 9 日
コメント済み: LBW 2017 年 4 月 22 日
I've been trying to use GPU to train an neural network. I've followed the instruction (https://www.mathworks.com/help/nnet/ug/neural-networks-with-parallel-and-gpu-computing.html) to build an NN and everything looks working well.
But when I plot performance, the cross-entropy doesn't change.
The cross-entropy decreases continuously when I perform the same training with CPU.
What causes this difference? And how can I make GPU training works just as how it works in CPU? Here's my code.
x = nndata2gpu(data50');
t = nndata2gpu(label50');
trainFcn = 'trainscg';
hiddenLayerSize = 10;
net = patternnet(hiddenLayerSize, trainFcn);
net.divideParam.trainRatio = 70/100;
net.divideParam.valRatio = 15/100;
net.divideParam.testRatio = 15/100;
net = configure(net,data50',label50');
[net,tr] = train(net,x,t,'useGPU','yes');

回答 (1 件)

Naoya
Naoya 2017 年 4 月 14 日
GPU モードと CPU モードでの実行結果の違いですが、乱数の状態の相違が原因であるものと推測します。 ニューラルネットワークの初期重みやバイアス値は、通常、rand 関数を用いた乱数で与えられます。 その為、乱数のシードを固定化しない場合は、gpu/cpu上限らず、で毎回、実行結果が異なることになります。
下記のように、 rng() でシードを固定化することより、実行結果が統一されるか確認できますでしょうか?
ex)
rng(10,'twister');
....
net = patternnet(hiddenLayerSize, trainFcn);
....
[net,tr] = train(net,x,t,'useGPU','yes');
rng(10,'twister');
....
net = patternnet(hiddenLayerSize, trainFcn);
....
[net,tr] = train(net,x,t,'useGPU','no');
  1 件のコメント
LBW
LBW 2017 年 4 月 22 日
ご回答ありがとうございます。返事が遅くなりましてすみません。
ご指摘の通り、rng(10,'twister');をつけて再度試してみましたが、結果は変わりませんでした。しかし、どうやら最初にnndata2gpuを使ったのが良くなかったようです。 nndata2gpuを使わずに直接x = data50';として、trainの部分で'useGPU','yes'とすれば自動でgpuarrayに変えて計算してくれるようです。このように操作した場合、Cross-EntropyはCPUの時と同じように減少しました。ただ、この時nntraintoolの方にはCalculations: GPU と表示され、nndata2gpuを使った際にはCalculations: GPU(CUDA)と表示されます。
また、rng(10,'twister');をつけて実行してみましたが、CPUとGPUの最終的な結果は同じにはならなかったようです。

サインインしてコメントする。

カテゴリ

Help Center および File ExchangeDeep Learning Toolbox についてさらに検索

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by