Why is the cross-entropy in Neural Network trained w/ GPU different from that in CPU trained one?

1 回表示 (過去 30 日間)
LBW
LBW 2017 年 4 月 9 日
コメント済み: LBW 2017 年 4 月 22 日
I've been trying to use GPU to train an neural network. I've followed the instruction (https://www.mathworks.com/help/nnet/ug/neural-networks-with-parallel-and-gpu-computing.html) to build an NN and everything looks working well.
But when I plot performance, the cross-entropy doesn't change.
The cross-entropy decreases continuously when I perform the same training with CPU.
What causes this difference? And how can I make GPU training works just as how it works in CPU? Here's my code.
x = nndata2gpu(data50');
t = nndata2gpu(label50');
trainFcn = 'trainscg';
hiddenLayerSize = 10;
net = patternnet(hiddenLayerSize, trainFcn);
net.divideParam.trainRatio = 70/100;
net.divideParam.valRatio = 15/100;
net.divideParam.testRatio = 15/100;
net = configure(net,data50',label50');
[net,tr] = train(net,x,t,'useGPU','yes');

回答 (1 件)

Naoya
Naoya 2017 年 4 月 14 日
GPU モードと CPU モードでの実行結果の違いですが、乱数の状態の相違が原因であるものと推測します。 ニューラルネットワークの初期重みやバイアス値は、通常、rand 関数を用いた乱数で与えられます。 その為、乱数のシードを固定化しない場合は、gpu/cpu上限らず、で毎回、実行結果が異なることになります。
下記のように、 rng() でシードを固定化することより、実行結果が統一されるか確認できますでしょうか?
ex)
rng(10,'twister');
....
net = patternnet(hiddenLayerSize, trainFcn);
....
[net,tr] = train(net,x,t,'useGPU','yes');
rng(10,'twister');
....
net = patternnet(hiddenLayerSize, trainFcn);
....
[net,tr] = train(net,x,t,'useGPU','no');
  1 件のコメント
LBW
LBW 2017 年 4 月 22 日
ご回答ありがとうございます。返事が遅くなりましてすみません。
ご指摘の通り、rng(10,'twister');をつけて再度試してみましたが、結果は変わりませんでした。しかし、どうやら最初にnndata2gpuを使ったのが良くなかったようです。 nndata2gpuを使わずに直接x = data50';として、trainの部分で'useGPU','yes'とすれば自動でgpuarrayに変えて計算してくれるようです。このように操作した場合、Cross-EntropyはCPUの時と同じように減少しました。ただ、この時nntraintoolの方にはCalculations: GPU と表示され、nndata2gpuを使った際にはCalculations: GPU(CUDA)と表示されます。
また、rng(10,'twister');をつけて実行してみましたが、CPUとGPUの最終的な結果は同じにはならなかったようです。

サインインしてコメントする。

カテゴリ

Help Center および File ExchangeSequence and Numeric Feature Data Workflows についてさらに検索

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by