Main Content

このページの翻訳は最新ではありません。ここをクリックして、英語の最新版を参照してください。

classify

学習済み深層学習ニューラル ネットワークを使用したデータの分類

説明

1 つの CPU または 1 つの GPU で深層学習用の学習済みニューラル ネットワークを使用して予測を実行できます。GPU を使用するには、Parallel Computing Toolbox™ および Compute Capability 3.0 以上の CUDA® 対応 NVIDIA® GPU が必要です。名前と値のペアの引数 ExecutionEnvironment を使用してハードウェア要件を指定します。

複数の出力があるネットワークの場合、predict を使用して、'ReturnCategorial' オプションを true に設定します。

YPred = classify(net,imds) は、学習済みネットワーク net を使用して、イメージ データストア imds 内のイメージのクラス ラベルを予測します。

YPred = classify(net,ds) は、データストア ds 内のデータのクラス ラベルを予測します。

YPred = classify(net,X) は、数値配列 X で指定されたイメージ データまたは特徴データのクラス ラベルを予測します。

YPred = classify(net,X1,...,XN) は、複数入力ネットワーク net に対する数値配列 X1、…、XN のデータのクラス ラベルを予測します。入力 Xi は、ネットワーク入力 net.InputNames(i) に対応します。

YPred = classify(net,sequences) は、再帰型ネットワーク (LSTM ネットワークや GRU ネットワークなど) net に対する sequences 内の時系列データまたはシーケンス データのクラス ラベルを予測します。

YPred = classify(net,tbl) は、table tbl 内のデータのクラス ラベルを予測します。

YPred = classify(___,Name,Value) は、前述の構文のいずれかを使用して、1 つ以上の名前と値のペアの引数によって指定された追加オプションを使用してクラス ラベルを予測します。

[YPred,scores] = classify(___) は、前述の構文のいずれかを使用して、クラス ラベルに対応する分類スコアも返します。

ヒント

長さが異なるシーケンスで予測を行うと、ミニバッチのサイズが、入力データに追加されるパディングの量に影響し、異なる予測値が得られることがあります。さまざまな値を使用して、ネットワークに最適なものを確認してください。ミニバッチのサイズとパディングのオプションを指定するには、'MiniBatchSize' および 'SequenceLength' オプションをそれぞれ使用します。

すべて折りたたむ

標本データを読み込みます。

[XTrain,YTrain] = digitTrain4DArrayData;

digitTrain4DArrayData は数字の学習セットを 4 次元配列データとして読み込みます。XTrain は 28 x 28 x 1 x 5000 の配列で、28 はイメージの高さ、28 は幅です。1 はチャネルの数で、5000 は手書きの数字の合成イメージの数です。YTrain は各観測値のラベルを含む categorical ベクトルです。

畳み込みニューラル ネットワーク アーキテクチャを構築します。

layers = [ ...
    imageInputLayer([28 28 1])
    convolution2dLayer(5,20)
    reluLayer
    maxPooling2dLayer(2,'Stride',2)
    fullyConnectedLayer(10)
    softmaxLayer
    classificationLayer];

モーメンタム項付き確率的勾配降下法の既定の設定にオプションを設定します。

options = trainingOptions('sgdm');

ネットワークに学習をさせます。

rng('default')
net = trainNetwork(XTrain,YTrain,layers,options);
Training on single CPU.
Initializing input data normalization.
|========================================================================================|
|  Epoch  |  Iteration  |  Time Elapsed  |  Mini-batch  |  Mini-batch  |  Base Learning  |
|         |             |   (hh:mm:ss)   |   Accuracy   |     Loss     |      Rate       |
|========================================================================================|
|       1 |           1 |       00:00:00 |       10.16% |       2.3195 |          0.0100 |
|       2 |          50 |       00:00:03 |       50.78% |       1.7102 |          0.0100 |
|       3 |         100 |       00:00:06 |       63.28% |       1.1632 |          0.0100 |
|       4 |         150 |       00:00:10 |       60.16% |       1.0859 |          0.0100 |
|       6 |         200 |       00:00:13 |       68.75% |       0.8996 |          0.0100 |
|       7 |         250 |       00:00:17 |       76.56% |       0.7920 |          0.0100 |
|       8 |         300 |       00:00:20 |       73.44% |       0.8411 |          0.0100 |
|       9 |         350 |       00:00:24 |       81.25% |       0.5508 |          0.0100 |
|      11 |         400 |       00:00:29 |       90.62% |       0.4744 |          0.0100 |
|      12 |         450 |       00:00:33 |       92.19% |       0.3614 |          0.0100 |
|      13 |         500 |       00:00:37 |       94.53% |       0.3160 |          0.0100 |
|      15 |         550 |       00:00:42 |       96.09% |       0.2544 |          0.0100 |
|      16 |         600 |       00:00:46 |       92.19% |       0.2765 |          0.0100 |
|      17 |         650 |       00:00:48 |       95.31% |       0.2460 |          0.0100 |
|      18 |         700 |       00:00:51 |       99.22% |       0.1418 |          0.0100 |
|      20 |         750 |       00:00:55 |       98.44% |       0.1000 |          0.0100 |
|      21 |         800 |       00:00:58 |       98.44% |       0.1449 |          0.0100 |
|      22 |         850 |       00:01:01 |       98.44% |       0.0989 |          0.0100 |
|      24 |         900 |       00:01:05 |       96.88% |       0.1315 |          0.0100 |
|      25 |         950 |       00:01:08 |      100.00% |       0.0859 |          0.0100 |
|      26 |        1000 |       00:01:12 |      100.00% |       0.0701 |          0.0100 |
|      27 |        1050 |       00:01:17 |      100.00% |       0.0759 |          0.0100 |
|      29 |        1100 |       00:01:21 |       99.22% |       0.0663 |          0.0100 |
|      30 |        1150 |       00:01:28 |       98.44% |       0.0776 |          0.0100 |
|      30 |        1170 |       00:01:31 |       99.22% |       0.0732 |          0.0100 |
|========================================================================================|

テスト セットについて学習済みネットワークを実行します。

[XTest,YTest]= digitTest4DArrayData;
YPred = classify(net,XTest);

テスト データの最初の 10 個のイメージを表示して、classify の分類と比較します。

[YTest(1:10,:) YPred(1:10,:)]
ans = 10x2 categorical
     0      0 
     0      0 
     0      0 
     0      0 
     0      0 
     0      0 
     0      0 
     0      0 
     0      0 
     0      0 

classify の結果は、最初の 10 個のイメージの真の数字に一致しています。

すべてのテスト データで精度を計算します。

accuracy = sum(YPred == YTest)/numel(YTest)
accuracy = 0.9820

事前学習済みのネットワークを読み込みます。JapaneseVowelsNet は、[1] および [2] で説明されているように Japanese Vowels データセットで学習させた事前学習済みの LSTM ネットワークです。これは、ミニバッチのサイズ 27 を使用して、シーケンス長で並べ替えられたシーケンスで学習させています。

load JapaneseVowelsNet

ネットワーク アーキテクチャを表示します。

net.Layers
ans = 
  5x1 Layer array with layers:

     1   'sequenceinput'   Sequence Input          Sequence input with 12 dimensions
     2   'lstm'            LSTM                    LSTM with 100 hidden units
     3   'fc'              Fully Connected         9 fully connected layer
     4   'softmax'         Softmax                 softmax
     5   'classoutput'     Classification Output   crossentropyex with '1' and 8 other classes

テスト データを読み込みます。

[XTest,YTest] = japaneseVowelsTestData;

テスト データを分類します。

YPred = classify(net,XTest);

最初の 10 個のシーケンスのラベルと、その予測ラベルを表示します。

[YTest(1:10) YPred(1:10)]
ans = 10x2 categorical
     1      1 
     1      1 
     1      1 
     1      1 
     1      1 
     1      1 
     1      1 
     1      1 
     1      1 
     1      1 

予測の分類精度を計算します。

accuracy = sum(YPred == YTest)/numel(YTest)
accuracy = 0.8595

事前学習済みのネットワーク TransmissionCasingNet を読み込みます。このネットワークは、数値センサーの読み取り値、統計量、およびカテゴリカル入力の混合を所与として、トランスミッション システムの歯車の状態を分類します。

load TransmissionCasingNet.mat

ネットワーク アーキテクチャを表示します。

net.Layers
ans = 
  7x1 Layer array with layers:

     1   'input'         Feature Input           22 features with 'zscore' normalization
     2   'fc_1'          Fully Connected         50 fully connected layer
     3   'batchnorm'     Batch Normalization     Batch normalization with 50 channels
     4   'relu'          ReLU                    ReLU
     5   'fc_2'          Fully Connected         2 fully connected layer
     6   'softmax'       Softmax                 softmax
     7   'classoutput'   Classification Output   crossentropyex with classes 'No Tooth Fault' and 'Tooth Fault'

CSV ファイル "transmissionCasingData.csv" からトランスミッション ケーシング データを読み取ります。

filename = "transmissionCasingData.csv";
tbl = readtable(filename,'TextType','String');

関数 convertvars を使用して、予測のラベルを categorical に変換します。

labelName = "GearToothCondition";
tbl = convertvars(tbl,labelName,'categorical');

カテゴリカル特徴量を使用して予測を行うには、最初にカテゴリカル特徴量を数値に変換しなければなりません。まず、関数 convertvars を使用して、すべてのカテゴリカル入力変数の名前を格納した string 配列を指定することにより、カテゴリカル予測子を categorical に変換します。このデータ セットには、"SensorCondition""ShaftCondition" という名前の 2 つのカテゴリカル特徴量があります。

categoricalInputNames = ["SensorCondition" "ShaftCondition"];
tbl = convertvars(tbl,categoricalInputNames,'categorical');

カテゴリカル入力変数をループ処理します。各変数について次を行います。

  • 関数 onehotencode を使用して、カテゴリカル値を one-hot 符号化ベクトルに変換する。

  • 関数 addvars を使用して、one-hot ベクトルを table に追加する。対応するカテゴリカル データが含まれる列の後にベクトルを挿入するように指定する。

  • カテゴリカル データが含まれる対応する列を削除する。

for i = 1:numel(categoricalInputNames)
    name = categoricalInputNames(i);
    oh = onehotencode(tbl(:,name));
    tbl = addvars(tbl,oh,'After',name);
    tbl(:,name) = [];
end

関数 splitvars を使用して、ベクトルを別々の列に分割します。

tbl = splitvars(tbl);

table の最初の数行を表示します。

head(tbl)
ans=8×23 table
    SigMean     SigMedian    SigRMS    SigVar     SigPeak    SigPeak2Peak    SigSkewness    SigKurtosis    SigCrestFactor    SigMAD     SigRangeCumSum    SigCorrDimension    SigApproxEntropy    SigLyapExponent    PeakFreq    HighFreqPower    EnvPower    PeakSpecKurtosis    No Sensor Drift    Sensor Drift    No Shaft Wear    Shaft Wear    GearToothCondition
    ________    _________    ______    _______    _______    ____________    ___________    ___________    ______________    _______    ______________    ________________    ________________    _______________    ________    _____________    ________    ________________    _______________    ____________    _____________    __________    __________________

    -0.94876     -0.9722     1.3726    0.98387    0.81571       3.6314        -0.041525       2.2666           2.0514         0.8081        28562              1.1429             0.031581            79.931            0          6.75e-06       3.23e-07         162.13                0                1                1              0           No Tooth Fault  
    -0.97537    -0.98958     1.3937    0.99105    0.81571       3.6314        -0.023777       2.2598           2.0203        0.81017        29418              1.1362             0.037835            70.325            0          5.08e-08       9.16e-08         226.12                0                1                1              0           No Tooth Fault  
      1.0502      1.0267     1.4449    0.98491     2.8157       3.6314         -0.04162       2.2658           1.9487        0.80853        31710              1.1479             0.031565            125.19            0          6.74e-06       2.85e-07         162.13                0                1                0              1           No Tooth Fault  
      1.0227      1.0045     1.4288    0.99553     2.8157       3.6314        -0.016356       2.2483           1.9707        0.81324        30984              1.1472             0.032088             112.5            0          4.99e-06        2.4e-07         162.13                0                1                0              1           No Tooth Fault  
      1.0123      1.0024     1.4202    0.99233     2.8157       3.6314        -0.014701       2.2542           1.9826        0.81156        30661              1.1469              0.03287            108.86            0          3.62e-06       2.28e-07         230.39                0                1                0              1           No Tooth Fault  
      1.0275      1.0102     1.4338     1.0001     2.8157       3.6314         -0.02659       2.2439           1.9638        0.81589        31102              1.0985             0.033427            64.576            0          2.55e-06       1.65e-07         230.39                0                1                0              1           No Tooth Fault  
      1.0464      1.0275     1.4477     1.0011     2.8157       3.6314        -0.042849       2.2455           1.9449        0.81595        31665              1.1417             0.034159            98.838            0          1.73e-06       1.55e-07         230.39                0                1                0              1           No Tooth Fault  
      1.0459      1.0257     1.4402    0.98047     2.8157       3.6314        -0.035405       2.2757            1.955        0.80583        31554              1.1345               0.0353            44.223            0          1.11e-06       1.39e-07         230.39                0                1                0              1           No Tooth Fault  

学習済みネットワークを使用してテスト データのラベルを予測し、精度を計算します。学習に使用されるサイズと同じミニバッチ サイズを指定します。

YPred = classify(net,tbl(:,1:end-1));

分類精度を計算します。精度は、ネットワークが正しく予測するラベルの比率です。

YTest = tbl{:,labelName};
accuracy = sum(YPred == YTest)/numel(YTest)
accuracy = 0.9952

入力引数

すべて折りたたむ

学習済みネットワーク。SeriesNetwork または DAGNetwork オブジェクトとして指定します。事前学習済みのネットワークをインポートする (たとえば、関数 googlenet を使用する)、または trainNetwork を使用して独自のネットワークに学習させることによって、学習済みネットワークを取得できます。

イメージ データストア。ImageDatastore オブジェクトとして指定します。

ImageDatastore を使用すると、事前取得を使用して JPG または PNG イメージ ファイルのバッチ読み取りを行うことができます。イメージの読み取りにカスタム関数を使用する場合、ImageDatastore は事前取得を行いません。

ヒント

イメージのサイズ変更を含む深層学習用のイメージの前処理を効率的に行うには、augmentedImageDatastore を使用します。

imageDatastorereadFcn オプションは通常、速度が大幅に低下するため、前処理またはサイズ変更に使用しないでください。

メモリ外のデータおよび前処理用のデータストア。データストアは、table または cell 配列でデータを返さなければなりません。データストア出力の形式は、ネットワーク アーキテクチャによって異なります。

ネットワーク アーキテクチャデータストア出力出力の例
単一入力

table または cell 配列。最初の列は予測子を指定します。

table の要素は、スカラー、行ベクトルであるか、数値配列が格納された 1 行 1 列の cell 配列でなければなりません。

カスタム データストアは table を出力しなければなりません。

data = read(ds)
data =

  4×1 table

        Predictors    
    __________________

    {224×224×3 double}
    {224×224×3 double}
    {224×224×3 double}
    {224×224×3 double}
data = read(ds)
data =

  4×1 cell array

    {224×224×3 double}
    {224×224×3 double}
    {224×224×3 double}
    {224×224×3 double}
複数入力

少なくとも numInputs 個の列をもつ cell 配列。numInputs はネットワーク入力の数です。

最初の numInputs 個の列は、各入力の予測子を指定します。

入力の順序は、ネットワークの InputNames プロパティによって指定されます。

data = read(ds)
data =

  4×2 cell array

    {224×224×3 double}    {128×128×3 double}
    {224×224×3 double}    {128×128×3 double}
    {224×224×3 double}    {128×128×3 double}
    {224×224×3 double}    {128×128×3 double}

予測子の形式は、データのタイプによって異なります。

データ予測子の形式
2 次元イメージ

h x w x c の数値配列。ここで、h、w、および c は、それぞれイメージの高さ、幅、およびチャネル数です。

3 次元イメージ

h x w x d x c の数値配列。ここで、h、w、d、および c は、それぞれイメージの高さ、幅、深さ、およびチャネル数です。

ベクトル シーケンス

c 行 s 列の行列。ここで、c はシーケンスの特徴の数、s はシーケンス長です。

2 次元イメージ シーケンス

h x w x c x s の配列。ここで、h、w、および c はそれぞれイメージの高さ、幅、およびチャネル数に対応します。s はシーケンス長です。

ミニバッチ内の各シーケンスは、同じシーケンス長でなければなりません。

3 次元イメージ シーケンス

h x w x d x c x s の配列。ここで、h、w、d、および c は、それぞれイメージの高さ、幅、深さ、およびチャネル数に対応します。s はシーケンス長です。

ミニバッチ内の各シーケンスは、同じシーケンス長でなければなりません。

特徴

c 行 1 列の列ベクトル。c は特徴の数です。

詳細については、深層学習用のデータストアを参照してください。

イメージ データまたは特徴データ。数値配列として指定します。配列のサイズは入力のタイプによって以下のように異なります。

入力説明
2 次元イメージh x w x c x N の数値配列。ここで、h、w、および c は、それぞれイメージの高さ、幅、およびチャネル数です。N はイメージの数です。
3 次元イメージh x w x d x c x N の数値配列。ここで、h、w、d、および c は、それぞれイメージの高さ、幅、深さ、およびチャネル数です。N はイメージの数です。
特徴N 行 numFeatures 列の数値配列。ここで、N は観測値の数、numFeatures は入力データの特徴の数です。

配列に NaN が含まれる場合、ネットワーク全体に伝播されます。

複数の入力があるネットワークの場合、複数の配列 X1、…、XN を指定できます。ここで、N はネットワーク入力の数であり、入力 Xi はネットワーク入力 net.InputNames(i) に対応します。

シーケンス データまたは時系列データ。数値配列の N 行 1 列の cell 配列、1 つのシーケンスを表す数値配列、またはデータストアとして指定します。ここで、N は観測値の数です。

cell 配列入力または数値配列入力の場合、シーケンスが含まれる数値配列の次元は、データのタイプによって異なります。

入力説明
ベクトル シーケンスc 行 s 列の行列。ここで、c はシーケンスの特徴の数、s はシーケンス長です。
2 次元イメージ シーケンスh x w x c x s の配列。ここで、h、w、および c は、それぞれイメージの高さ、幅、およびチャネル数に対応します。s はシーケンス長です。
3 次元イメージ シーケンスh x w x d x c x s。ここで、h、w、d、および c は、それぞれ 3 次元イメージの高さ、幅、深さ、およびチャネル数に対応します。s はシーケンス長です。

データストア入力の場合、データストアはシーケンスの cell 配列、または最初の列にシーケンスが含まれる table としてデータを返さなければなりません。シーケンス データの次元は、上記の table に対応していなければなりません。

イメージ データまたは特徴データの table。table の各行は観測値に対応します。

table の列での予測子の配置は、入力データのタイプによって異なります。

入力予測子
イメージ データ
  • イメージの絶対ファイル パスまたは相対ファイル パス。単一列の文字ベクトルとして指定します。

  • イメージ。3 次元数値配列として指定します。

単一列で予測子を指定します。

特徴データ

数値スカラー。

table の numFeatures 個の列で予測子を指定します。numFeatures は入力データの特徴の数です。

この引数は、単一の入力のみがあるネットワークをサポートします。

データ型: table

名前と値のペアの引数

例: 'MiniBatchSize','256' はミニバッチのサイズを 256 に指定します。

オプションの引数 Name,Value のコンマ区切りのペアを指定します。Name は引数名で、Value は対応する値です。Name は一重引用符 (' ') で囲まなければなりません。

予測に使用するミニバッチのサイズ。正の整数として指定します。ミニバッチのサイズが大きくなるとより多くのメモリが必要になりますが、予測時間が短縮される可能性があります。

長さが異なるシーケンスで予測を行うと、ミニバッチのサイズが、入力データに追加されるパディングの量に影響し、異なる予測値が得られることがあります。さまざまな値を使用して、ネットワークに最適なものを確認してください。ミニバッチのサイズとパディングのオプションを指定するには、'MiniBatchSize' および 'SequenceLength' オプションをそれぞれ使用します。

例: 'MiniBatchSize',256

パフォーマンスの最適化。'Acceleration' と次のいずれかで構成されるコンマ区切りのペアとして指定します。

  • 'auto' — 入力ネットワークとハードウェア リソースに適した最適化の回数を自動的に適用します。

  • 'mex' — MEX 関数をコンパイルして実行します。このオプションは GPU の使用時にのみ利用できます。GPU を使用するには、Parallel Computing Toolbox および Compute Capability 3.0 以上の CUDA 対応 NVIDIA GPU が必要です。Parallel Computing Toolbox または適切な GPU が利用できない場合、エラーが返されます。

  • 'none' — すべての高速化を無効にします。

既定のオプションは 'auto' です。'auto' が指定されている場合、MATLAB® は互換性のある最適化を複数適用します。'auto' オプションを使用する場合、MATLAB は MEX 関数を生成しません。

'Acceleration' オプション 'auto' および 'mex' を使用すると、パフォーマンス上のメリットが得られますが、初期実行時間が長くなります。互換性のあるパラメーターを使用した後続の呼び出しは、より高速になります。新しい入力データを使用して関数を複数回呼び出す場合は、パフォーマンスの最適化を使用してください。

'mex' オプションは、関数の呼び出しに使用されたネットワークとパラメーターに基づいて MEX 関数を生成し、実行します。複数の MEX 関数を一度に 1 つのネットワークに関連付けることができます。ネットワークの変数をクリアすると、そのネットワークに関連付けられている MEX 関数もクリアされます。

'mex' オプションは、GPU の使用時にのみ利用できます。C/C++ コンパイラと GPU Coder™ Interface for Deep Learning Libraries サポート パッケージがインストールされていなければなりません。MATLAB でアドオン エクスプローラーを使用してサポート パッケージをインストールします。設定手順については、MEX の設定 (GPU Coder)を参照してください。GPU Coder は不要です。

'mex' オプションではサポートされていない層があります。サポートされている層の一覧については、サポートされている層 (GPU Coder)を参照してください。sequenceInputLayer を含む再帰型ニューラル ネットワーク (RNN) はサポートされていません。

'mex' オプションは、複数の入力層または複数の出力層をもつネットワークをサポートしていません。

'mex' オプションの使用時に、MATLAB Compiler™ を使用してネットワークを配布することはできません。

例: 'Acceleration','mex'

ハードウェア リソース。'ExecutionEnvironment' と次のいずれかで構成されるコンマ区切りのペアとして指定します。

  • 'auto' — 利用可能な場合は GPU を使用し、そうでない場合は CPU を使用します。

  • 'gpu' — GPU を使用します。GPU を使用するには、Parallel Computing Toolbox および Compute Capability 3.0 以上の CUDA 対応 NVIDIA GPU が必要です。Parallel Computing Toolbox または適切な GPU が利用できない場合、エラーが返されます。

  • 'cpu' — CPU を使用します。

例: 'ExecutionEnvironment','cpu'

入力シーケンスのパディング、切り捨て、または分割を行うオプション。次のいずれかに指定します。

  • 'longest' — 各ミニバッチで、最長のシーケンスと同じ長さになるようにシーケンスのパディングを行います。このオプションを使用するとデータは破棄されませんが、パディングによってネットワークにノイズが生じることがあります。

  • 'shortest' — 各ミニバッチで、最短のシーケンスと同じ長さになるようにシーケンスの切り捨てを行います。このオプションを使用するとパディングは追加されませんが、データが破棄されます。

  • 正の整数 — 各ミニバッチで、ミニバッチで最長のシーケンスより大きい、指定長の最も近い倍数になるようにシーケンスのパディングを行った後、それらのシーケンスを指定長のより小さなシーケンスに分割します。分割が発生すると、追加のミニバッチが作成されます。シーケンス全体がメモリに収まらない場合は、このオプションを使用します。または、'MiniBatchSize' オプションをより小さい値に設定して、ミニバッチごとのシーケンス数を減らしてみます。

入力シーケンスのパディング、切り捨て、および分割の効果の詳細は、シーケンスのパディング、切り捨て、および分割を参照してください。

例: 'SequenceLength','shortest'

パディングまたは切り捨ての方向。次のいずれかに指定します。

  • 'right' — シーケンスの右側に対してパディングまたは切り捨てを行います。シーケンスは同じタイム ステップで始まり、シーケンスの末尾に対して切り捨てまたはパディングの追加が行われます。

  • 'left' — シーケンスの左側に対してパディングまたは切り捨てを行います。シーケンスが同じタイム ステップで終わるように、シーケンスの先頭に対して切り捨てまたはパディングの追加が行われます。

LSTM 層は 1 タイム ステップずつシーケンス データを処理するため、層の OutputMode プロパティが 'last' の場合、最後のタイム ステップでパディングを行うと層の出力に悪影響を与える可能性があります。シーケンス データの左側に対してパディングまたは切り捨てを行うには、'SequencePaddingDirection' オプションを 'left' に設定します。

sequence-to-sequence ネットワークの場合 (各 LSTM 層について OutputMode プロパティが 'sequence' である場合)、最初のタイム ステップでパティングを行うと、それ以前のタイム ステップの予測に悪影響を与える可能性があります。シーケンスの右側に対してパディングまたは切り捨てを行うには、'SequencePaddingDirection' オプションを 'right' に設定します。

入力シーケンスのパディング、切り捨て、および分割の効果の詳細は、シーケンスのパディング、切り捨て、および分割を参照してください。

入力シーケンスをパディングする値。スカラーとして指定します。このオプションは、SequenceLength'longest' または正の整数の場合にのみ有効です。ネットワーク全体にエラーが伝播される可能性があるため、NaN でシーケンスをパディングしないでください。

例: 'SequencePaddingValue',-1

出力引数

すべて折りたたむ

予測クラス ラベル。categorical ベクトル、または categorical ベクトルの cell 配列として返されます。YPred の形式は、タスクのタイプによって異なります。

次の表は、分類タスクの形式について説明しています。

タスク形式
イメージ分類または特徴分類ラベルの N 行 1 列の categorical ベクトル。N は観測値の数です。
sequence-to-label 分類
sequence-to-sequence 分類

ラベルのカテゴリカル シーケンスの N 行 1 列の cell 配列。N は観測値の数です。SequenceLength オプションを各ミニバッチに個別に適用した後、各シーケンスには、対応する入力シーケンスと同じ数のタイム ステップが含まれています。

観測値が 1 つの sequence-to-sequence 分類タスクでは、sequences を行列にすることができます。この場合、YPred はラベルのカテゴリカル シーケンスです。

予測スコアまたは応答。行列、または行列の cell 配列として返されます。scores の形式は、タスクのタイプによって異なります。

次の表は、scores の形式について説明しています。

タスク形式
イメージ分類N 行 K 列の行列。N は観測値の数、K はクラスの数です。
sequence-to-label 分類
特徴分類
sequence-to-sequence 分類

行列の N 行 1 列の cell 配列。N は観測値の数です。シーケンスは K 行の行列で、K はクラスの数です。SequenceLength オプションを各ミニバッチに個別に適用した後、各シーケンスには、対応する入力シーケンスと同じ数のタイム ステップが含まれています。

観測値が 1 つの sequence-to-sequence 分類タスクでは、sequences を行列にすることができます。この場合、scores は、予測クラス スコアの行列です。

分類スコアを調べる例については、深層学習を使用した Web カメラ イメージの分類を参照してください。

アルゴリズム

Deep Learning Toolbox™ に含まれる深層学習における学習、予測、検証用のすべての関数は、単精度浮動小数点演算を使用して計算を実行します。深層学習用の関数には trainNetworkpredictclassifyactivations などがあります。CPU と GPU の両方を使用してネットワークに学習させる場合、単精度演算が使用されます。

代替方法

複数の出力があるネットワークの場合、predict を使用して、'ReturnCategorial' オプションを true に設定します。

predict を使用して学習済みネットワークから予測スコアを計算できます。

activations を使用してネットワーク層から活性化を計算することもできます。

sequence-to-label および sequence-to-sequence 分類ネットワークでは、classifyAndUpdateState および predictAndUpdateState を使用してネットワークの状態の予測および更新を実行できます。

参照

[1] M. Kudo, J. Toyama, and M. Shimbo. "Multidimensional Curve Classification Using Passing-Through Regions." Pattern Recognition Letters. Vol. 20, No. 11–13, pages 1103–1111.

[2] UCI Machine Learning Repository: Japanese Vowels Dataset. https://archive.ics.uci.edu/ml/datasets/Japanese+Vowels

拡張機能

R2016a で導入