現在この質問をフォロー中です
- フォローしているコンテンツ フィードに更新が表示されます。
- コミュニケーション基本設定に応じて電子メールを受け取ることができます。
matlabのディープラーニングでは、なぜテストデータを使わずにバリデーションデータを使うのか
4 ビュー (過去 30 日間)
古いコメントを表示
ssk
2019 年 3 月 7 日
プログラミング初心者です。
下記リンクにつきまして、
[imdsTrain,imdsValidation] = splitEachLabel(imds,0.7);
という一文がありますが、なぜ、テストデータを使わずにバリデーションデータを使うのでしょうか。
imdsValidationではなく、imdsTestだと納得できるのですが不思議です。
もしバリデーションデータを使うのであれば、テストデータは使わなくてもいいかご教示頂けますと幸いです。
採用された回答
Kenta
2019 年 3 月 12 日
単に、ここではバリデーションデータをテストデータと読み替えて問題ないと思います。また、以下のように、
[imdsTrain,imdsValidation, imdsTest] = splitEachLabel(imds,0.7,0.2);
などとして、画像を訓練、バリデーション、テストデータに分けると良いかもしれません。
リンクの学習曲線のところでは、バリデーションデータを使います。
そして、最後のところで
[YPred,probs] = classify(net,imdsTest);
accuracy = mean(YPred == imdsTest.Labels)
とすると、テストデータで正答率を計算できます。ここで、optionsのところに
'ValidationPatience', 3
を追加すれば学習の早期終了ができます。「'ValidationPatience' の値は、ネットワークの学習が停止するまでに、検証セットでの損失が前の最小損失以上になることが許容される回数です。」
とあります。学習がある程度のところで限界が来たらそこで学習がストップするので学習時間を短縮できたり、過学習が抑えられる可能性があります。
11 件のコメント
ssk
2019 年 3 月 12 日
itakuraさま、いつもご回答いただきまして誠にありがとうございます。本例につきまして、バリデーションデータとテストデータを置き換えることができる旨、ご教示いただきありがとうございます。imdstestも使った例でも試してみます。また、'ValidationPatience', も利用してみたいと思います。
あわせて、下記リンクにつきましてitakura様宛に追加でご質問させていただきましたのでご覧頂けますと幸いです。
https://jp.mathworks.com/matlabcentral/answers/447586-cross-validation
ssk
2019 年 3 月 13 日
追加でのご質問失礼いたします。
[imdsTrain,imdsTest,imdsValidation] = splitEachLabel(imds,0.7,0.2);と仮にテストデータとバリデーションデータを分けた場合、クロスバリデーションは以前ご教示いただいた方法と同じものを使っても差し支えないでしょうか。それとも、imdstest向けのコードの変更が必要でしょうか。
Kenta
2019 年 3 月 13 日
3つに分けたかったら少し改変する必要があります。
具体的には、9つ分のイメージデータストアを合わせたものをもう一度、imdsTrain, imdsValidationにわける必要があります。
ただ、今回のような結果だといままでどおりの交差検証のコードでひとまずやってみたらよいと思います。枚数を増やしてうまくいかなかったらまた対策を考える方向がよいかと思います。
ssk
2019 年 3 月 13 日
ご教示いただきありがとうございます。まずはもともとの交差検証のコードで進めたいと思います。念のため、9つ分のイメージデータストアを合わせたものをもう一度、imdsTrain, imdsValidationにわけるコードにつき、以前ご教示いただいたコードを基に自身で作成してみます。
ssk
2019 年 3 月 13 日
トレーニング、テスト、バリデーションの3つに分けたコードを試しに作成してみたのですが、以下のコードでご趣旨を反映できておりますでしょうか。
%% cross validation
[imds01,imds02,imds03,imds04,imds05,imds06,imds07,imds08,imds09,imds010]...
= splitEachLabel(imds,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1,'randomize');
imdsTrain1 = imageDatastore(cat(1,imds01.Files,imds02.Files,imds03.Files,imds04.Files,imds05.Files,imds06.Files,imds07.Files,imds08.Files,imds09.Files));
imdsTrain1.Labels = cat(1,imds01.Labels,imds02.Labels,imds03.Labels,imds04.Labels,imds05.Labels,imds06.Labels,imds07.Labels,imds08.Labels,imds09.Labels);
imdsTrain2 = imageDatastore(cat(1,imds01.Files,imds02.Files,imds03.Files,imds04.Files,imds05.Files,imds06.Files,imds07.Files,imds08.Files,imds010.Files));
imdsTrain2.Labels = cat(1,imds01.Labels,imds02.Labels,imds03.Labels,imds04.Labels,imds05.Labels,imds06.Labels,imds07.Labels,imds08.Labels,imds010.Labels);
imdsTrain3 = imageDatastore(cat(1,imds01.Files,imds02.Files,imds03.Files,imds04.Files,imds05.Files,imds06.Files,imds07.Files,imds09.Files,imds010.Files));
imdsTrain3.Labels = cat(1,imds01.Labels,imds02.Labels,imds03.Labels,imds04.Labels,imds05.Labels,imds06.Labels,imds07.Labels,imds09.Labels,imds010.Labels);
imdsTrain4 = imageDatastore(cat(1,imds01.Files,imds02.Files,imds03.Files,imds04.Files,imds05.Files,imds06.Files,imds08.Files,imds09.Files,imds010.Files));
imdsTrain4.Labels = cat(1,imds01.Labels,imds02.Labels,imds03.Labels,imds04.Labels,imds05.Labels,imds06.Labels,imds08.Labels,imds09.Labels,imds010.Labels);
imdsTrain5 = imageDatastore(cat(1,imds01.Files,imds02.Files,imds03.Files,imds04.Files,imds05.Files,imds07.Files,imds08.Files,imds09.Files,imds010.Files));
imdsTrain5.Labels = cat(1,imds01.Labels,imds02.Labels,imds03.Labels,imds04.Labels,imds05.Labels,imds07.Labels,imds08.Labels,imds09.Labels,imds010.Labels);
imdsTrain6 = imageDatastore(cat(1,imds01.Files,imds02.Files,imds03.Files,imds04.Files,imds06.Files,imds07.Files,imds08.Files,imds09.Files,imds010.Files));
imdsTrain6.Labels = cat(1,imds01.Labels,imds02.Labels,imds03.Labels,imds04.Labels,imds06.Labels,imds07.Labels,imds08.Labels,imds09.Labels,imds010.Labels);
imdsTrain7 = imageDatastore(cat(1,imds01.Files,imds02.Files,imds03.Files,imds06.Files,imds05.Files,imds07.Files,imds08.Files,imds09.Files,imds010.Files));
imdsTrain7.Labels = cat(1,imds01.Labels,imds02.Labels,imds03.Labels,imds05.Labels,imds06.Labels,imds07.Labels,imds08.Labels,imds09.Labels,imds010.Labels);
imdsTrain8 = imageDatastore(cat(1,imds01.Files,imds02.Files,imds04.Files,imds05.Files,imds06.Files,imds07.Files,imds08.Files,imds09.Files,imds010.Files));
imdsTrain8.Labels = cat(1,imds01.Labels,imds02.Labels,imds04.Labels,imds05.Labels,imds06.Labels,imds07.Labels,imds08.Labels,imds09.Labels,imds010.Labels);
imdsTrain9 = imageDatastore(cat(1,imds01.Files,imds03.Files,imds04.Files,imds05.Files,imds06.Files,imds07.Files,imds08.Files,imds09.Files,imds010.Files));
imdsTrain9.Labels = cat(1,imds01.Labels,imds03.Labels,imds04.Labels,imds05.Labels,imds06.Labels,imds07.Labels,imds08.Labels,imds09.Labels,imds010.Labels);
imdsTrain10 = imageDatastore(cat(1,imds02.Files,imds03.Files,imds04.Files,imds05.Files,imds06.Files,imds07.Files,imds08.Files,imds09.Files,imds010.Files));
imdsTrain10.Labels = cat(1,imds02.Labels,imds03.Labels,imds04.Labels,imds05.Labels,imds06.Labels,imds07.Labels,imds08.Labels,imds09.Labels,imds010.Labels);
%% training
accuracy=zeros(1,10);
for i=1:10
stname1=sprintf('imdsTrain%d',i);
eval(['trainimds' ,'=', stname1,';'])
%trainimds.ReadFcn = @(filename)resize(filename);
i2=10-i+1;
stname2=sprintf('imds0%d',i2);
eval(['imdsValidation' ,'=', stname2,';'])
imdsValidation.ReadFcn = @(filename)resize(filename);
[imds11,imds12,imds13,imds14,imds15]...
= splitEachLabel(imds,0.2,0.2,0.2,0.2,'randomize');
imdsTest11 = imageDatastore(cat(1,imds11.Files,imds12.Files,imds13.Files,imds14.Files));
imdsTest11.Labels = cat(1,imds11.Labels,imds12.Labels,imds13.Labels,imds14.Labels);
imdsTest12 = imageDatastore(cat(1,imds11.Files,imds12.Files,imds13.Files,imds15.Files));
imdsTest12.Labels = cat(1,imds11.Labels,imds12.Labels,imds13.Labels,imds15.Labels);
imdsTest13 = imageDatastore(cat(1,imds11.Files,imds12.Files,imds14.Files,imds15.Files));
imdsTest13.Labels = cat(1,imds11.Labels,imds12.Labels,imds14.Labels,imds15.Labels);
imdsTest14 = imageDatastore(cat(1,imds11.Files,imds13.Files,imds14.Files,imds15.Files));
imdsTest14.Labels = cat(1,imds11.Labels,imds13.Labels,imds14.Labels,imds15.Labels);
imdsTest15 = imageDatastore(cat(1,imds12.Files,imds13.Files,imds14.Files,imds15.Files));
imdsTest15.Labels = cat(1,imds11.Labels,imds13.Labels,imds14.Labels,imds15.Labels);
%% training for test data
accuracy=zeros(11,15);
for i3=11:15
stname3=sprintf('imdsTest%d',i3);
eval(['imdsTest' ,'=', stname3,';'])
%imdsTest.ReadFcn = @(filename)resize(filename);
i4=15-i+1;
stname4=sprintf('imds0%d',i4);
eval(['imdsValidation' ,'=', stname4,';'])
imdsValidation.ReadFcn = @(filename)resize(filename);
%%train network(中略)
[YPred,probs] = classify(net,augimdsValidation);
accuracy = mean(YPred == imdsValidation.Labels)
[YPred,probs] = classify(net,imdsTest);
accuracy = mean(YPred == imdsTest.Labels)
Kenta
2019 年 3 月 14 日
i番目のループのなかで、トレーニングデータ(仮)をトレーニングデータとバリデーションデータに分けたらいいと思います。そして、バリデーションデータをテストデータ(ただ名前を変えるだけ)としてテストしたらいいです。
ある程度までロスが下がり切ったりしたら計算時間が冗長になるし、訓練データに過適合するのを防げます。ただ、たくさんの枚数をこなしたときに必ずしももこの操作が必要かどうかは不明です。1クラス100枚くらいで交差検証なしでやってみてはどうでしょうか。CPUで計算してもそこまで計算時間はかからないと思います。
%% cross validation
[imds01,imds02,imds03,imds04,imds05,imds06,imds07,imds08,imds09,imds010]...
= splitEachLabel(imds,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1,'randomize');
imdsTrain1 = imageDatastore(cat(1,imds01.Files,imds02.Files,imds03.Files,imds04.Files,imds05.Files,imds06.Files,imds07.Files,imds08.Files,imds09.Files));
imdsTrain1.Labels = cat(1,imds01.Labels,imds02.Labels,imds03.Labels,imds04.Labels,imds05.Labels,imds06.Labels,imds07.Labels,imds08.Labels,imds09.Labels);
imdsTrain2 = imageDatastore(cat(1,imds01.Files,imds02.Files,imds03.Files,imds04.Files,imds05.Files,imds06.Files,imds07.Files,imds08.Files,imds010.Files));
imdsTrain2.Labels = cat(1,imds01.Labels,imds02.Labels,imds03.Labels,imds04.Labels,imds05.Labels,imds06.Labels,imds07.Labels,imds08.Labels,imds010.Labels);
imdsTrain3 = imageDatastore(cat(1,imds01.Files,imds02.Files,imds03.Files,imds04.Files,imds05.Files,imds06.Files,imds07.Files,imds09.Files,imds010.Files));
imdsTrain3.Labels = cat(1,imds01.Labels,imds02.Labels,imds03.Labels,imds04.Labels,imds05.Labels,imds06.Labels,imds07.Labels,imds09.Labels,imds010.Labels);
imdsTrain4 = imageDatastore(cat(1,imds01.Files,imds02.Files,imds03.Files,imds04.Files,imds05.Files,imds06.Files,imds08.Files,imds09.Files,imds010.Files));
imdsTrain4.Labels = cat(1,imds01.Labels,imds02.Labels,imds03.Labels,imds04.Labels,imds05.Labels,imds06.Labels,imds08.Labels,imds09.Labels,imds010.Labels);
imdsTrain5 = imageDatastore(cat(1,imds01.Files,imds02.Files,imds03.Files,imds04.Files,imds05.Files,imds07.Files,imds08.Files,imds09.Files,imds010.Files));
imdsTrain5.Labels = cat(1,imds01.Labels,imds02.Labels,imds03.Labels,imds04.Labels,imds05.Labels,imds07.Labels,imds08.Labels,imds09.Labels,imds010.Labels);
imdsTrain6 = imageDatastore(cat(1,imds01.Files,imds02.Files,imds03.Files,imds04.Files,imds06.Files,imds07.Files,imds08.Files,imds09.Files,imds010.Files));
imdsTrain6.Labels = cat(1,imds01.Labels,imds02.Labels,imds03.Labels,imds04.Labels,imds06.Labels,imds07.Labels,imds08.Labels,imds09.Labels,imds010.Labels);
imdsTrain7 = imageDatastore(cat(1,imds01.Files,imds02.Files,imds03.Files,imds06.Files,imds05.Files,imds07.Files,imds08.Files,imds09.Files,imds010.Files));
imdsTrain7.Labels = cat(1,imds01.Labels,imds02.Labels,imds03.Labels,imds05.Labels,imds06.Labels,imds07.Labels,imds08.Labels,imds09.Labels,imds010.Labels);
imdsTrain8 = imageDatastore(cat(1,imds01.Files,imds02.Files,imds04.Files,imds05.Files,imds06.Files,imds07.Files,imds08.Files,imds09.Files,imds010.Files));
imdsTrain8.Labels = cat(1,imds01.Labels,imds02.Labels,imds04.Labels,imds05.Labels,imds06.Labels,imds07.Labels,imds08.Labels,imds09.Labels,imds010.Labels);
imdsTrain9 = imageDatastore(cat(1,imds01.Files,imds03.Files,imds04.Files,imds05.Files,imds06.Files,imds07.Files,imds08.Files,imds09.Files,imds010.Files));
imdsTrain9.Labels = cat(1,imds01.Labels,imds03.Labels,imds04.Labels,imds05.Labels,imds06.Labels,imds07.Labels,imds08.Labels,imds09.Labels,imds010.Labels);
imdsTrain10 = imageDatastore(cat(1,imds02.Files,imds03.Files,imds04.Files,imds05.Files,imds06.Files,imds07.Files,imds08.Files,imds09.Files,imds010.Files));
imdsTrain10.Labels = cat(1,imds02.Labels,imds03.Labels,imds04.Labels,imds05.Labels,imds06.Labels,imds07.Labels,imds08.Labels,imds09.Labels,imds010.Labels);
%% training
accuracy=zeros(1,10);
for i=1:10
stname1=sprintf('imdsTrain%d',i);
eval(['trainimds' ,'=', stname1,';'])
%trainimds.ReadFcn = @(filename)resize(filename);
[imdstrain,imdsvalidation]=splitEachLabel(trainimds,0.8);
i2=10-i+1;
stname2=sprintf('imds0%d',i2);
eval(['imdsTest' ,'=', stname2,';'])
imdsTest.ReadFcn = @(filename)resize(filename);
%% training for test data
%imdstrainで訓練
%imdsvalidationをoptionsのなかのvalidationに指定
%imdstestでテスト
ssk
2019 年 3 月 14 日
編集済み: ssk
2019 年 3 月 14 日
ありがとうございます!コードを試したところ無事に動きました。本コードにおけるクロスバリデーションのニュアンスの確認をしたいのですが、はじめに全ての画像をtrainingとして均等に10分割し、さらに10分割した画像をそれぞれtraining:validation = 8:2で分ける。このとき、testはvalidationと同視できるので、training:test = 8:2である。(つまり、本データの8割をtraining、2割をtest(validation)として使う。その後、組み合わせをかえてそれぞれの画像のaccuracyを調べて平均を取る。上記の認識でよろしいでしょうか?
以前あった例ですと、
[imdsTrain,imdsValidation, imdsTest] = splitEachLabel(imds,0.7,0.2,0.1); で合計が100%ですが、今回の場合は、[imdsTrain,imdsValidation, imdsTest] = splitEachLabel(imds,0.8,0.2,0.2);で合計120%のような気もするのですが、例えば[imdsTrain,imdsValidation, imdsTest] = splitEachLabel(imds,0.6,0.2,0.2);のような形で修正する必要はないのでしょうか?
また、なぜテストデータとバリデーションデータを同視できるか理由をご存知でしたらご教示いただけますと幸いです。
Kenta
2019 年 3 月 15 日
今回の場合は、1000枚あったとすると、900, 100に分けて、その900をさらに800と100に分けたイメージです。
もちろんおっしゃるように、1000枚を一気に8:1:1に分けても等価です。そのサイクルを10回して、それを平均すればいいです。
同視できるというのは、バリデーションデータと名付けられたデータでテストをしているので、その場合に限り、バリデーションデータをテストデータと読み替えて問題ないのでは?ということです。
本来、バリデーションデータとテストデータは異なったニュアンスを持っているものと思います。
ssk
2019 年 3 月 15 日
ニュアンスをご教示いただきありがとうございました!
覚えが悪く大変申し訳なのでもう一度確認しますが、(1,000枚の画像がある場合、)まず900(training), 100(test)に分けて、その後で900を更に800(training)、100(varidation)に分けるということですね。以上から、800(training)、100(test)、100(validation)になるということでしょうか。
上記コードですと、for loop内では以下のようになっておりますので
[imdstrain,imdsvalidation]=splitEachLabel(trainimds,0.8);
imstrain1の場合、
training:900(枚)*0.8= 720(枚)
validation:900(枚)*0.2= 180(枚)
test: 100(枚)・・・imds10
これを順に10回続けていって・・・
imstrain10の場合、
training:900(枚)*0.8= 720(枚)
validation:900(枚)*0.2= 180(枚)
test: 100(枚)・・・imds01
以上の平均を求めるという認識でよろしいでしょうか?
その他の回答 (0 件)
参考
カテゴリ
Help Center および File Exchange で Deep Learning Toolbox についてさらに検索
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!エラーが発生しました
ページに変更が加えられたため、アクションを完了できません。ページを再度読み込み、更新された状態を確認してください。
Web サイトの選択
Web サイトを選択すると、翻訳されたコンテンツにアクセスし、地域のイベントやサービスを確認できます。現在の位置情報に基づき、次のサイトの選択を推奨します:
また、以下のリストから Web サイトを選択することもできます。
最適なサイトパフォーマンスの取得方法
中国のサイト (中国語または英語) を選択することで、最適なサイトパフォーマンスが得られます。その他の国の MathWorks のサイトは、お客様の地域からのアクセスが最適化されていません。
南北アメリカ
- América Latina (Español)
- Canada (English)
- United States (English)
ヨーロッパ
- Belgium (English)
- Denmark (English)
- Deutschland (Deutsch)
- España (Español)
- Finland (English)
- France (Français)
- Ireland (English)
- Italia (Italiano)
- Luxembourg (English)
- Netherlands (English)
- Norway (English)
- Österreich (Deutsch)
- Portugal (English)
- Sweden (English)
- Switzerland
- United Kingdom(English)
アジア太平洋地域
- Australia (English)
- India (English)
- New Zealand (English)
- 中国
- 日本Japanese (日本語)
- 한국Korean (한국어)