Main Content

classify

学習済み深層学習ニューラル ネットワークを使用したデータの分類

説明

1 つの CPU または 1 つの GPU で深層学習用の学習済みニューラル ネットワークを使用して予測を実行できます。GPU を使用するには Parallel Computing Toolbox™ とサポートされている GPU デバイスが必要です。サポートされているデバイスについては、リリース別の GPU サポート (Parallel Computing Toolbox)を参照してください。名前と値の引数 ExecutionEnvironment を使用して、ハードウェア要件を指定します。

複数の出力があるネットワークの場合、代わりに関数 predict を使用して、ReturnCategorical オプションを true に設定します。

Y = classify(net,images) は、学習済みネットワーク net を使用して、指定されたイメージのクラス ラベルを予測します。

Y = classify(net,sequences) は、学習済みネットワーク net を使用して、指定されたシーケンスのクラス ラベルを予測します。

Y = classify(net,features) は、学習済みネットワーク net を使用して、指定された特徴データのクラス ラベルを予測します。

Y = classify(net,X1,...,XN) は、多入力ネットワーク net に対する数値配列または cell 配列 X1、…、XN のデータのクラス ラベルを予測します。入力 Xi は、ネットワーク入力 net.InputNames(i) に対応します。

Y = classify(net,mixed) は、混合するデータ型から成る複数の入力をもつ学習済みネットワーク net を使用してクラス ラベルを予測します。

[Y,scores] = classify(___) は、前述の入力引数のいずれかを使用して、クラス ラベルに対応する分類スコアも返します。

___ = classify(___,Name=Value) は、1 つ以上の名前と値の引数で指定された追加オプションを使用して、クラス ラベルを予測します。

ヒント

長さが異なるシーケンスで予測を行うと、ミニバッチのサイズが、入力データに追加されるパディングの量に影響し、予測値が変わることがあります。さまざまな値を使用して、ネットワークに最適なものを確認してください。ミニバッチのサイズとパディングのオプションを指定するには、MiniBatchSize オプションと SequenceLength オプションをそれぞれ使用します。

すべて折りたたむ

事前学習済みのネットワーク digitsNet を読み込みます。このネットワークには、手書きの数字を分類する分類畳み込みニューラル ネットワークが含まれています。

load digitsNet

ネットワーク層を表示します。ネットワークの出力層は分類層です。

layers = net.Layers
layers = 
  15x1 Layer array with layers:

     1   'imageinput'    Image Input             28x28x1 images with 'zerocenter' normalization
     2   'conv_1'        Convolution             8 3x3x1 convolutions with stride [1  1] and padding 'same'
     3   'batchnorm_1'   Batch Normalization     Batch normalization with 8 channels
     4   'relu_1'        ReLU                    ReLU
     5   'maxpool_1'     Max Pooling             2x2 max pooling with stride [2  2] and padding [0  0  0  0]
     6   'conv_2'        Convolution             16 3x3x8 convolutions with stride [1  1] and padding 'same'
     7   'batchnorm_2'   Batch Normalization     Batch normalization with 16 channels
     8   'relu_2'        ReLU                    ReLU
     9   'maxpool_2'     Max Pooling             2x2 max pooling with stride [2  2] and padding [0  0  0  0]
    10   'conv_3'        Convolution             32 3x3x16 convolutions with stride [1  1] and padding 'same'
    11   'batchnorm_3'   Batch Normalization     Batch normalization with 32 channels
    12   'relu_3'        ReLU                    ReLU
    13   'fc'            Fully Connected         10 fully connected layer
    14   'softmax'       Softmax                 softmax
    15   'classoutput'   Classification Output   crossentropyex with '0' and 9 other classes

テスト イメージを読み込みます。

digitDatasetPath = fullfile(toolboxdir("nnet"),"nndemos","nndatasets","DigitDataset");
imdsTest = imageDatastore(digitDatasetPath,IncludeSubfolders=true);

関数 classify を使用してイメージを分類します。

YTest = classify(net,imdsTest);

いくつかのテスト イメージを、それらの予測と共にランダムに表示します。

numImages = 9;
idx = randperm(numel(imdsTest.Files),numImages);

figure
tiledlayout("flow")
for i = 1:numImages
    nexttile
    imshow(imdsTest.Files{idx(i)});
    title("Predicted Label: " + string(YTest(idx(i))))
end

Figure contains 9 axes objects. Axes object 1 with title Predicted Label: 8 contains an object of type image. Axes object 2 with title Predicted Label: 9 contains an object of type image. Axes object 3 with title Predicted Label: 1 contains an object of type image. Axes object 4 with title Predicted Label: 9 contains an object of type image. Axes object 5 with title Predicted Label: 6 contains an object of type image. Axes object 6 with title Predicted Label: 0 contains an object of type image. Axes object 7 with title Predicted Label: 2 contains an object of type image. Axes object 8 with title Predicted Label: 5 contains an object of type image. Axes object 9 with title Predicted Label: 9 contains an object of type image.

事前学習済みのネットワーク JapaneseVowelsNet を読み込みます。このネットワークは、[1] および [2] で説明されているように Japanese Vowels データ セットで学習させた事前学習済みの LSTM ネットワークです。これは、ミニバッチのサイズ 27 を使用して、シーケンス長で並べ替えられたシーケンスで学習させています。

load JapaneseVowelsNet

ネットワーク アーキテクチャを表示します。

net.Layers
ans = 
  5x1 Layer array with layers:

     1   'sequenceinput'   Sequence Input          Sequence input with 12 dimensions
     2   'lstm'            LSTM                    LSTM with 100 hidden units
     3   'fc'              Fully Connected         9 fully connected layer
     4   'softmax'         Softmax                 softmax
     5   'classoutput'     Classification Output   crossentropyex with '1' and 8 other classes

テスト データを読み込みます。

[XTest,TTest] = japaneseVowelsTestData;

テスト データを分類します。

YTest = classify(net,XTest);

混同チャートで予測を表示します。

figure
confusionchart(TTest,YTest)

Figure contains an object of type ConfusionMatrixChart.

予測の分類精度を計算します。

accuracy = mean(YTest == TTest)
accuracy = 0.8595

事前学習済みのネットワーク TransmissionCasingNet を読み込みます。このネットワークは、数値センサーの読み取り値、統計量、およびカテゴリカル入力の混合を所与として、トランスミッション システムの歯車の状態を分類します。

load TransmissionCasingNet

ネットワーク アーキテクチャを表示します。

net.Layers
ans = 
  7x1 Layer array with layers:

     1   'input'         Feature Input           22 features with 'zscore' normalization
     2   'fc_1'          Fully Connected         50 fully connected layer
     3   'batchnorm'     Batch Normalization     Batch normalization with 50 channels
     4   'relu'          ReLU                    ReLU
     5   'fc_2'          Fully Connected         2 fully connected layer
     6   'softmax'       Softmax                 softmax
     7   'classoutput'   Classification Output   crossentropyex with classes 'No Tooth Fault' and 'Tooth Fault'

CSV ファイル "transmissionCasingData.csv" からトランスミッション ケーシング データを読み取ります。

filename = "transmissionCasingData.csv";
tbl = readtable(filename,TextType="string");

関数 convertvars を使用して、予測のラベルを categorical に変換します。

labelName = "GearToothCondition";
tbl = convertvars(tbl,labelName,"categorical");

カテゴリカル特徴量を使用して予測を行うには、最初にカテゴリカル特徴量を数値に変換しなければなりません。まず、関数 convertvars を使用して、すべてのカテゴリカル入力変数の名前を格納した string 配列を指定することにより、カテゴリカル予測子を categorical に変換します。このデータ セットには、"SensorCondition""ShaftCondition" という名前の 2 つのカテゴリカル特徴量があります。

categoricalInputNames = ["SensorCondition" "ShaftCondition"];
tbl = convertvars(tbl,categoricalInputNames,"categorical");

カテゴリカル入力変数をループ処理します。各変数について次を行います。

  • 関数 onehotencode を使用して、カテゴリカル値を one-hot 符号化ベクトルに変換する。

  • 関数 addvars を使用して、one-hot ベクトルを table に追加する。対応するカテゴリカル データが含まれる列の後にベクトルを挿入するように指定する。

  • カテゴリカル データが含まれる対応する列を削除する。

for i = 1:numel(categoricalInputNames)
    name = categoricalInputNames(i);
    oh = onehotencode(tbl(:,name));
    tbl = addvars(tbl,oh,After=name);
    tbl(:,name) = [];
end

関数 splitvars を使用して、ベクトルを別々の列に分割します。

tbl = splitvars(tbl);

table の最初の数行を表示します。

head(tbl)
ans=8×23 table
    SigMean     SigMedian    SigRMS    SigVar     SigPeak    SigPeak2Peak    SigSkewness    SigKurtosis    SigCrestFactor    SigMAD     SigRangeCumSum    SigCorrDimension    SigApproxEntropy    SigLyapExponent    PeakFreq    HighFreqPower    EnvPower    PeakSpecKurtosis    No Sensor Drift    Sensor Drift    No Shaft Wear    Shaft Wear    GearToothCondition
    ________    _________    ______    _______    _______    ____________    ___________    ___________    ______________    _______    ______________    ________________    ________________    _______________    ________    _____________    ________    ________________    _______________    ____________    _____________    __________    __________________

    -0.94876     -0.9722     1.3726    0.98387    0.81571       3.6314        -0.041525       2.2666           2.0514         0.8081        28562              1.1429             0.031581            79.931            0          6.75e-06       3.23e-07         162.13                0                1                1              0           No Tooth Fault  
    -0.97537    -0.98958     1.3937    0.99105    0.81571       3.6314        -0.023777       2.2598           2.0203        0.81017        29418              1.1362             0.037835            70.325            0          5.08e-08       9.16e-08         226.12                0                1                1              0           No Tooth Fault  
      1.0502      1.0267     1.4449    0.98491     2.8157       3.6314         -0.04162       2.2658           1.9487        0.80853        31710              1.1479             0.031565            125.19            0          6.74e-06       2.85e-07         162.13                0                1                0              1           No Tooth Fault  
      1.0227      1.0045     1.4288    0.99553     2.8157       3.6314        -0.016356       2.2483           1.9707        0.81324        30984              1.1472             0.032088             112.5            0          4.99e-06        2.4e-07         162.13                0                1                0              1           No Tooth Fault  
      1.0123      1.0024     1.4202    0.99233     2.8157       3.6314        -0.014701       2.2542           1.9826        0.81156        30661              1.1469              0.03287            108.86            0          3.62e-06       2.28e-07         230.39                0                1                0              1           No Tooth Fault  
      1.0275      1.0102     1.4338     1.0001     2.8157       3.6314         -0.02659       2.2439           1.9638        0.81589        31102              1.0985             0.033427            64.576            0          2.55e-06       1.65e-07         230.39                0                1                0              1           No Tooth Fault  
      1.0464      1.0275     1.4477     1.0011     2.8157       3.6314        -0.042849       2.2455           1.9449        0.81595        31665              1.1417             0.034159            98.838            0          1.73e-06       1.55e-07         230.39                0                1                0              1           No Tooth Fault  
      1.0459      1.0257     1.4402    0.98047     2.8157       3.6314        -0.035405       2.2757            1.955        0.80583        31554              1.1345               0.0353            44.223            0          1.11e-06       1.39e-07         230.39                0                1                0              1           No Tooth Fault  

table からテスト ラベルを抽出します。

TTest = tbl{:,labelName};

学習済みネットワークを使用してテスト データのラベルを予測し、精度を計算します。学習に使用されるサイズと同じミニバッチ サイズを指定します。

YTest = classify(net,tbl(:,1:end-1));

混同行列で予測を可視化します。

figure
confusionchart(TTest,YTest)

Figure contains an object of type ConfusionMatrixChart.

分類精度を計算します。精度は、ネットワークが正しく予測するラベルの比率です。

accuracy = mean(YTest == TTest)
accuracy = 0.9952

入力引数

すべて折りたたむ

学習済みネットワーク。SeriesNetwork または DAGNetwork オブジェクトとして指定します。事前学習済みのネットワークをインポートする (たとえば、関数 googlenet を使用する)、または trainNetwork を使用して独自のネットワークに学習させることによって、学習済みネットワークを取得できます。

イメージ データ。次のいずれかとして指定します。

データ型説明使用例
データストアImageDatastoreディスクに保存されたイメージのデータストア

イメージのサイズがすべて等しい場合に、ディスクに保存されているイメージを使用して予測を行います。

イメージのサイズが異なる場合は AugmentedImageDatastore オブジェクトを使用します。

AugmentedImageDatastoreサイズ変更、回転、反転、せん断、平行移動を含む、ランダムなアフィン幾何学的変換を適用するデータストア

イメージのサイズが異なる場合に、ディスクに保存されているイメージを使用して予測を行います。

TransformedDatastoreカスタム変換関数を使用して、基になるデータストアから読み取ったデータのバッチを変換するデータストア

  • classify でサポートされていない出力をもつデータストアを変換する。

  • データストアの出力にカスタム変換を適用する。

CombinedDatastore2 つ以上の基になるデータストアからデータを読み取るデータストア

  • 複数の入力をもつネットワークを使用して予測を行う。

  • 異なるデータ ソースから取得した予測子を結合する。

カスタム ミニバッチ データストアデータのミニバッチを返すカスタム データストア

他のデータストアでサポートされていない形式のデータを使用して予測を行います。

詳細は、カスタム ミニバッチ データストアの開発を参照してください。

数値配列数値配列として指定されたイメージメモリに収まり、なおかつサイズ変更などの追加の処理を必要としないデータを使用して予測を行います。
tabletable として指定されたイメージtable に格納されたデータを使用して予測を行います。

複数の入力をもつネットワークでデータストアを使用する場合、データストアは TransformedDatastore オブジェクトまたは CombinedDatastore オブジェクトでなければなりません。

ヒント

ビデオ データのようなイメージのシーケンスの場合、入力引数 sequences を使用します。

データストア

データストアは、イメージと応答のミニバッチを読み取ります。データストアは、メモリに収まらないデータがある場合や、入力データのサイズを変更したい場合に使用します。

以下のデータストアは、イメージ データ用の classify と直接互換性があります。

ImageDatastore オブジェクトを使用すると、事前取得を使用して JPG または PNG イメージ ファイルのバッチ読み取りを行うことができる点に注意してください。イメージの読み取りにカスタム関数を使用する場合、ImageDatastore オブジェクトは事前取得を行いません。

ヒント

イメージのサイズ変更を含む深層学習用のイメージの前処理を効率的に行うには、augmentedImageDatastore を使用します。

関数 imageDatastorereadFcn オプションは通常、速度が大幅に低下するため、前処理またはサイズ変更に使用しないでください。

関数 transform および combine を使用して、予測を行うための他の組み込みデータストアを使用できます。これらの関数は、データストアから読み取られたデータを、classify に必要な形式に変換できます。

データストア出力に必要な形式は、ネットワーク アーキテクチャによって異なります。

ネットワーク アーキテクチャデータストア出力出力の例
単一入力

table または cell 配列。最初の列は予測子を指定します。

table の要素は、スカラー、行ベクトルであるか、数値配列が格納された 1 行 1 列の cell 配列でなければなりません。

カスタム データストアは table を出力しなければなりません。

data = read(ds)
data =

  4×1 table

        Predictors    
    __________________

    {224×224×3 double}
    {224×224×3 double}
    {224×224×3 double}
    {224×224×3 double}
data = read(ds)
data =

  4×1 cell array

    {224×224×3 double}
    {224×224×3 double}
    {224×224×3 double}
    {224×224×3 double}
複数入力

少なくとも numInputs 個の列をもつ cell 配列。numInputs はネットワーク入力の数です。

最初の numInputs 個の列は、各入力の予測子を指定します。

入力の順序は、ネットワークの InputNames プロパティによって指定されます。

data = read(ds)
data =

  4×2 cell array

    {224×224×3 double}    {128×128×3 double}
    {224×224×3 double}    {128×128×3 double}
    {224×224×3 double}    {128×128×3 double}
    {224×224×3 double}    {128×128×3 double}

予測子の形式は、データのタイプによって異なります。

データ形式
2 次元イメージ

h×w×c の数値配列。ここで、h、w、および c は、それぞれイメージの高さ、幅、およびチャネル数です。

3 次元イメージh×w×d×c の数値配列。ここで、h、w、d、および c は、それぞれイメージの高さ、幅、深さ、およびチャネル数です。

詳細については、深層学習用のデータストアを参照してください。

数値配列

メモリに収まり、なおかつ拡張などの追加の処理を必要としないデータの場合、イメージのデータ セットを数値配列として指定できます。

数値配列のサイズと形状は、イメージ データのタイプによって異なります。

データ形式
2 次元イメージ

h×w×c×N の数値配列。ここで、h、w、および c は、それぞれイメージの高さ、幅、およびチャネル数です。N はイメージの数です。

3 次元イメージh×w×d×c×N の数値配列。ここで、h、w、d、および c は、それぞれイメージの高さ、幅、深さ、およびチャネル数です。N はイメージの数です。

table

データストアまたは数値配列の代わりに、イメージを table で指定することもできます。

イメージを table で指定した場合、table の各行は観測値に対応します。

イメージ入力の場合、予測子は table の最初の列に格納し、次のいずれかとして指定しなければなりません。

  • イメージの絶対ファイル パスまたは相対ファイル パス。文字ベクトルとして指定します。

  • 2 次元イメージを表す h×w×c の数値配列が格納された 1 行 1 列の cell 配列。ここで、h、w、および c は、それぞれイメージの高さ、幅、およびチャネル数に対応します。

シーケンス データまたは時系列データ。次のいずれかとして指定します。

データ型説明使用例
データストアTransformedDatastoreカスタム変換関数を使用して、基になるデータストアから読み取ったデータのバッチを変換するデータストア

  • classify でサポートされていない出力をもつデータストアを変換する。

  • データストアの出力にカスタム変換を適用する。

CombinedDatastore2 つ以上の基になるデータストアからデータを読み取るデータストア

  • 複数の入力をもつネットワークを使用して予測を行う。

  • 異なるデータ ソースから取得した予測子を結合する。

カスタム ミニバッチ データストアデータのミニバッチを返すカスタム データストア

他のデータストアでサポートされていない形式のデータを使用して予測を行います。

詳細は、カスタム ミニバッチ データストアの開発を参照してください。

数値配列または cell 配列数値配列として指定した、単一のシーケンス。または数値配列の cell 配列として指定した、シーケンスのデータ セットメモリに収まり、なおかつカスタム変換などの追加の処理を必要としないデータを使用して、予測を行います。

データストア

データストアは、シーケンスと応答のミニバッチを読み取ります。データストアは、データがメモリに収まらない場合や、データに変換を適用したい場合に使用します。

以下のデータストアは、シーケンス データ用の classify と直接互換性があります。

関数 transform および combine を使用して、予測を行うための他の組み込みデータストアを使用できます。これらの関数は、データストアから読み取られたデータを、classify に必要な table または cell 配列形式に変換できます。たとえば、ArrayDatastore オブジェクトおよび TabularTextDatastore オブジェクトをそれぞれ使用して、インメモリ配列および CSV ファイルから読み取ったデータの変換と結合を行うことができます。

データストアは、table または cell 配列でデータを返さなければなりません。カスタム ミニバッチ データストアは、table を出力しなければなりません。

データストア出力出力の例
table
data = read(ds)
data =

  4×2 table

        Predictors    
    __________________

    {12×50 double}
    {12×50 double}
    {12×50 double}
    {12×50 double}
cell 配列
data = read(ds)
data =

  4×2 cell array

    {12×50 double}
    {12×50 double}
    {12×50 double}
    {12×50 double}

予測子の形式は、データのタイプによって異なります。

データ予測子の形式
ベクトル シーケンス

c 行 s 列の行列。ここで、c はシーケンスの特徴の数、s はシーケンス長です。

1 次元イメージ シーケンス

h x c x s の配列。ここで、h および c はそれぞれイメージの高さおよびチャネル数に対応します。s はシーケンス長です。

ミニバッチ内の各シーケンスは、同じシーケンス長でなければなりません。

2 次元イメージ シーケンス

h x w x c x s の配列。ここで、h、w、および c はそれぞれイメージの高さ、幅、およびチャネル数に対応します。s はシーケンス長です。

ミニバッチ内の各シーケンスは、同じシーケンス長でなければなりません。

3 次元イメージ シーケンス

h x w x d x c x s の配列。ここで、h、w、d、および c は、それぞれイメージの高さ、幅、深さ、およびチャネル数に対応します。s はシーケンス長です。

ミニバッチ内の各シーケンスは、同じシーケンス長でなければなりません。

予測子が table で返される場合、数値スカラーまたは数値行ベクトルが要素に含まれているか、数値配列が格納された 1 行 1 列の cell 配列が要素に含まれていなければなりません。

詳細については、深層学習用のデータストアを参照してください。

数値配列または cell 配列

メモリに収まり、なおかつカスタム変換などの追加の処理を必要としないデータの場合、単一のシーケンスを数値配列として指定するか、シーケンスのデータ セットを数値配列の cell 配列として指定することができます。

cell 配列入力の場合、cell 配列は、数値配列から成る N 行 1 列の cell 配列でなければなりません。ここで、N は観測値の数です。シーケンスを表す数値配列のサイズと形状は、シーケンス データのタイプによって異なります。

入力説明
ベクトル シーケンスc 行 s 列の行列。ここで、c はシーケンスの特徴の数、s はシーケンス長です。
1 次元イメージ シーケンスh×c×s の配列。ここで、h および c はそれぞれイメージの高さおよびチャネル数に対応します。s はシーケンス長です。
2 次元イメージ シーケンスh×w×c×s の配列。ここで、h、w、および c は、それぞれイメージの高さ、幅、およびチャネル数に対応します。s はシーケンス長です。
3 次元イメージ シーケンスh×w×d×c×s。ここで、h、w、d、および c は、それぞれ 3 次元イメージの高さ、幅、深さ、およびチャネル数に対応します。s はシーケンス長です。

特徴データ。次のいずれかとして指定します。

データ型説明使用例
データストアTransformedDatastoreカスタム変換関数を使用して、基になるデータストアから読み取ったデータのバッチを変換するデータストア

  • classify でサポートされていない出力をもつデータストアを変換する。

  • データストアの出力にカスタム変換を適用する。

CombinedDatastore2 つ以上の基になるデータストアからデータを読み取るデータストア

  • 複数の入力をもつネットワークを使用して予測を行う。

  • 異なるデータ ソースから取得した予測子を結合する。

カスタム ミニバッチ データストアデータのミニバッチを返すカスタム データストア

他のデータストアでサポートされていない形式のデータを使用して予測を行います。

詳細は、カスタム ミニバッチ データストアの開発を参照してください。

tabletable として指定された特徴データtable に格納されたデータを使用して予測を行います。
数値配列数値配列として指定された特徴データメモリに収まり、なおかつカスタム変換などの追加の処理を必要としないデータを使用して、予測を行います。

データストア

データストアは、特徴データと応答のミニバッチを読み取ります。データストアは、データがメモリに収まらない場合や、データに変換を適用したい場合に使用します。

以下のデータストアは、特徴データ用の classify と直接互換性があります。

関数 transform および combine を使用して、予測を行うための他の組み込みデータストアを使用できます。これらの関数は、データストアから読み取られたデータを、classify に必要な table または cell 配列形式に変換できます。詳細については、深層学習用のデータストアを参照してください。

複数の入力があるネットワークの場合、データストアは TransformedDatastore オブジェクトまたは CombinedDatastore オブジェクトでなければなりません。

データストアは、table または cell 配列でデータを返さなければなりません。カスタム ミニバッチ データストアは、table を出力しなければなりません。データストア出力の形式は、ネットワーク アーキテクチャによって異なります。

ネットワーク アーキテクチャデータストア出力出力の例
単入力層

少なくとも 1 つの列をもつ table または cell 配列。最初の列は予測子を指定します。

table の要素は、スカラー、行ベクトルであるか、数値配列が格納された 1 行 1 列の cell 配列でなければなりません。

カスタム ミニバッチ データストアは、table を出力しなければなりません。

1 つの入力があるネットワークの table:

data = read(ds)
data =

  4×2 table

        Predictors    
    __________________

    {24×1 double}
    {24×1 double}
    {24×1 double}
    {24×1 double}

1 つの入力があるネットワークの cell 配列:

data = read(ds)
data =

  4×1 cell array

    {24×1 double}
    {24×1 double}
    {24×1 double}
    {24×1 double}

多入力層

少なくとも numInputs 個の列をもつ cell 配列。numInputs はネットワーク入力の数です。

最初の numInputs 個の列は、各入力の予測子を指定します。

入力の順序は、ネットワークの InputNames プロパティによって指定されます。

2 つの入力があるネットワークの cell 配列:

data = read(ds)
data =

  4×3 cell array

    {24×1 double}    {28×1 double}
    {24×1 double}    {28×1 double}
    {24×1 double}    {28×1 double}
    {24×1 double}    {28×1 double}

予測子は、c 行 1 列の列ベクトルでなければなりません。ここで、c は特徴の数です。

詳細については、深層学習用のデータストアを参照してください。

table

メモリに収まり、なおかつカスタム変換などの追加の処理を必要としない特徴データの場合、特徴データと応答を table として指定できます。

table の各行は観測値に対応します。table の列での予測子の配置は、タスクのタイプによって異なります。

タスク予測子
特徴分類

1 つ以上の列でスカラーとして指定された特徴。

数値配列

メモリに収まり、なおかつカスタム変換などの追加の処理を必要としない特徴データの場合、特徴データを数値配列として指定できます。

数値配列は、N 行 numFeatures 列の数値配列でなければなりません。ここで、N は観測値の数、numFeatures は入力データの特徴の数です。

複数の入力をもつネットワークの数値配列または cell 配列。

イメージ、シーケンス、および特徴の予測子入力の場合、予測子の形式は、imagessequences、または features のそれぞれの引数の説明に記載されている形式と一致しなければなりません。

複数の入力をもつネットワークに学習させる方法を説明する例については、イメージ データおよび特徴データにおけるネットワークの学習を参照してください。

混在データ。次のいずれかとして指定します。

データ型説明使用例
TransformedDatastoreカスタム変換関数を使用して、基になるデータストアから読み取ったデータのバッチを変換するデータストア

  • 複数の入力をもつネットワークを使用して予測を行う。

  • classify でサポートされていないデータストアの出力を、必要な形式に変換する。

  • データストアの出力にカスタム変換を適用する。

CombinedDatastore2 つ以上の基になるデータストアからデータを読み取るデータストア

  • 複数の入力をもつネットワークを使用して予測を行う。

  • 異なるデータ ソースから取得した予測子を結合する。

カスタム ミニバッチ データストアデータのミニバッチを返すカスタム データストア

他のデータストアでサポートされていない形式のデータを使用して予測を行います。

詳細は、カスタム ミニバッチ データストアの開発を参照してください。

関数 transform および combine を使用して、予測を行うための他の組み込みデータストアを使用できます。これらの関数は、データストアから読み取られたデータを、classify に必要な table または cell 配列形式に変換できます。詳細については、深層学習用のデータストアを参照してください。

データストアは、table または cell 配列でデータを返さなければなりません。カスタム ミニバッチ データストアは、table を出力しなければなりません。データストア出力の形式は、ネットワーク アーキテクチャによって異なります。

データストア出力出力の例

numInputs 列の cell 配列。numInputs はネットワーク入力の数です。

入力の順序は、ネットワークの InputNames プロパティによって指定されます。

data = read(ds)
data =

  4×3 cell array

    {24×1 double}    {28×1 double}
    {24×1 double}    {28×1 double}
    {24×1 double}    {28×1 double}
    {24×1 double}    {28×1 double}

イメージ、シーケンス、および特徴の予測子入力の場合、予測子の形式は、imagessequences、または features のそれぞれの引数の説明に記載されている形式と一致しなければなりません。

複数の入力をもつネットワークに学習させる方法を説明する例については、イメージ データおよび特徴データにおけるネットワークの学習を参照してください。

ヒント

数値配列をデータストアに変換するには、arrayDatastore を使用します。

名前と値の引数

オプションの引数のペアを Name1=Value1,...,NameN=ValueN として指定します。ここで、Name は引数名で、Value は対応する値です。名前と値の引数は他の引数の後に指定しなければなりませんが、ペアの順序は重要ではありません。

R2021a より前では、コンマを使用して名前と値の各ペアを区切り、Name を引用符で囲みます。

例: MiniBatchSize=256 はミニバッチのサイズを 256 に指定します。

予測に使用するミニバッチのサイズ。正の整数として指定します。ミニバッチのサイズが大きくなるとより多くのメモリが必要になりますが、予測時間が短縮される可能性があります。

長さが異なるシーケンスで予測を行うと、ミニバッチのサイズが、入力データに追加されるパディングの量に影響し、予測値が変わることがあります。さまざまな値を使用して、ネットワークに最適なものを確認してください。ミニバッチのサイズとパディングのオプションを指定するには、MiniBatchSize オプションと SequenceLength オプションをそれぞれ使用します。

データ型: single | double | int8 | int16 | int32 | int64 | uint8 | uint16 | uint32 | uint64

パフォーマンスの最適化。次のいずれかとして指定します。

  • "auto" — 入力ネットワークとハードウェア リソースに適した最適化の回数を自動的に適用します。

  • "mex" — MEX 関数をコンパイルして実行します。このオプションは GPU の使用時にのみ利用できます。GPU を使用するには Parallel Computing Toolbox とサポートされている GPU デバイスが必要です。サポートされているデバイスについては、リリース別の GPU サポート (Parallel Computing Toolbox)を参照してください。Parallel Computing Toolbox または適切な GPU が利用できない場合、エラーが返されます。

  • "none" — すべての高速化を無効にします。

Acceleration"auto" の場合、MATLAB® は互換性のある最適化を複数適用し、MEX 関数を生成しません。

"auto" オプションおよび "mex" オプションは、パフォーマンス上のメリットがありますが、初期実行時間が長くなります。互換性のあるパラメーターを使用した後続の呼び出しは、より高速になります。新しい入力データを使用して関数を複数回呼び出す場合は、パフォーマンスの最適化を使用してください。

"mex" オプションは、関数の呼び出しに使用されたネットワークとパラメーターに基づいて MEX 関数を生成し、実行します。複数の MEX 関数を一度に 1 つのネットワークに関連付けることができます。ネットワークの変数をクリアすると、そのネットワークに関連付けられている MEX 関数もクリアされます。

"mex" オプションは、単一の GPU の使用時に利用できます。

"mex" オプションを使用するには、C/C++ コンパイラがインストールされ、GPU Coder™ Interface for Deep Learning Libraries サポート パッケージがなければなりません。MATLAB でアドオン エクスプローラーを使用してサポート パッケージをインストールします。設定手順については、MEX の設定 (GPU Coder)を参照してください。GPU Coder は不要です。

"mex" オプションではサポートされていない層があります。サポートされている層の一覧については、サポートされている層 (GPU Coder)を参照してください。

"mex" オプションを使用する場合、MATLAB Compiler™ はネットワークの展開をサポートしません。

ハードウェア リソース。次のいずれかとして指定します。

  • "auto" — 利用可能な場合は GPU を使用し、そうでない場合は CPU を使用します。

  • "gpu" — GPU を使用します。GPU を使用するには Parallel Computing Toolbox とサポートされている GPU デバイスが必要です。サポートされているデバイスについては、リリース別の GPU サポート (Parallel Computing Toolbox)を参照してください。Parallel Computing Toolbox または適切な GPU が利用できない場合、エラーが返されます。

  • "cpu" — CPU を使用します。

  • "multi-gpu" — 既定のクラスター プロファイルに基づいてローカルの並列プールを使用して、1 つのマシンで複数の GPU を使用します。現在の並列プールがない場合、使用可能な GPU の数と等しいプール サイズの並列プールが起動されます。

  • "parallel" — 既定のクラスター プロファイルに基づいてローカルまたはリモートの並列プールを使用します。現在の並列プールがない場合、既定のクラスター プロファイルを使用して 1 つのプールが起動されます。プールから GPU にアクセスできる場合、固有の GPU を持つワーカーのみが計算を実行します。プールに GPU がない場合、代わりに使用可能なすべての CPU ワーカーで計算が実行されます。

さまざまな実行環境をどのような場合に使用するかの詳細は、Scale Up Deep Learning in Parallel, on GPUs, and in the Cloudを参照してください。

"gpu""multi-gpu"、および "parallel" のオプションを使用するには、Parallel Computing Toolbox が必要です。深層学習に GPU を使用するには、サポートされている GPU デバイスもなければなりません。サポートされているデバイスについては、リリース別の GPU サポート (Parallel Computing Toolbox)を参照してください。これらのいずれかのオプションの選択時に Parallel Computing Toolbox または適切な GPU を利用できない場合、エラーが返されます。

"multi-gpu" オプションおよび "parallel" オプションは、状態パラメーターをもつカスタム層、または予測時にステートフルな組み込み層 (LSTMLayer オブジェクト、BiLSTMLayer オブジェクト、GRULayer オブジェクトなどの再帰層など) を含むネットワークをサポートしていません。

入力シーケンスのパディング、切り捨て、または分割を行うオプション。次のいずれかに指定します。

  • "longest" — 各ミニバッチで、最長のシーケンスと同じ長さになるようにシーケンスのパディングを行います。このオプションを使用するとデータは破棄されませんが、パディングによってネットワークにノイズが生じることがあります。

  • "shortest" — 各ミニバッチで、最短のシーケンスと同じ長さになるようにシーケンスの切り捨てを行います。このオプションを使用するとパディングは追加されませんが、データが破棄されます。

  • 正の整数 — 各ミニバッチで、ミニバッチで最長のシーケンスより大きい、指定長の最も近い倍数になるようにシーケンスのパディングを行った後、それらのシーケンスを指定長のより小さなシーケンスに分割します。分割が発生すると、追加のミニバッチが作成されます。シーケンス全体がメモリに収まらない場合は、このオプションを使用します。または、MiniBatchSize オプションをより小さい値に設定して、ミニバッチごとのシーケンス数を減らしてみます。

入力シーケンスのパディング、切り捨て、および分割の効果の詳細は、シーケンスのパディング、切り捨て、および分割を参照してください。

データ型: single | double | int8 | int16 | int32 | int64 | uint8 | uint16 | uint32 | uint64 | char | string

パディングまたは切り捨ての方向。次のいずれかに指定します。

  • "right" — シーケンスの右側に対してパディングまたは切り捨てを行います。シーケンスは同じタイム ステップで始まり、シーケンスの末尾に対して切り捨てまたはパディングの追加が行われます。

  • "left" — シーケンスの左側に対してパディングまたは切り捨てを行います。シーケンスが同じタイム ステップで終わるように、シーケンスの先頭に対して切り捨てまたはパディングの追加が行われます。

再帰層は 1 タイム ステップずつシーケンス データを処理するため、再帰層の OutputMode プロパティが 'last' の場合、最後のタイム ステップでパディングを行うと層の出力に悪影響を与える可能性があります。シーケンス データの左側に対してパディングまたは切り捨てを行うには、SequencePaddingDirection オプションを "left" に設定します。

sequence-to-sequence ネットワークの場合 (各再帰層について OutputMode プロパティが 'sequence' である場合)、最初のタイム ステップでパティングを行うと、それ以前のタイム ステップの予測に悪影響を与える可能性があります。シーケンスの右側に対してパディングまたは切り捨てを行うには、SequencePaddingDirection オプションを "right" に設定します。

入力シーケンスのパディング、切り捨て、および分割の効果の詳細は、シーケンスのパディング、切り捨て、および分割を参照してください。

入力シーケンスをパディングする値。スカラーとして指定します。

このオプションは、SequenceLength"longest" または正の整数の場合にのみ有効です。ネットワーク全体にエラーが伝播される可能性があるため、NaN でシーケンスをパディングしないでください。

データ型: single | double | int8 | int16 | int32 | int64 | uint8 | uint16 | uint32 | uint64

出力引数

すべて折りたたむ

予測クラス ラベル。categorical ベクトル、または categorical ベクトルの cell 配列として返されます。Y の形式は、タスクのタイプによって異なります。

次の表は、分類タスクの形式について説明しています。

タスク形式
イメージ分類または特徴分類ラベルの N 行 1 列の categorical ベクトル。N は観測値の数です。
sequence-to-label 分類
sequence-to-sequence 分類

ラベルのカテゴリカル シーケンスの N 行 1 列の cell 配列。N は観測値の数です。SequenceLength オプションが各ミニバッチに個別に適用された後は、各シーケンスに、対応する入力シーケンスと同じ数のタイム ステップが含まれます。

観測値が 1 つの sequence-to-sequence 分類タスクでは、sequences を行列にすることができます。この場合、Y はラベルのカテゴリカル シーケンスです。

予測スコアまたは応答。行列、または行列の cell 配列として返されます。scores の形式は、タスクのタイプによって異なります。

次の表は、scores の形式について説明しています。

タスク形式
イメージ分類N 行 K 列の行列。N は観測値の数、K はクラスの数です。
sequence-to-label 分類
特徴分類
sequence-to-sequence 分類

行列の N 行 1 列の cell 配列。N は観測値の数です。シーケンスは K 行の行列で、K はクラスの数です。SequenceLength オプションが各ミニバッチに個別に適用された後は、各シーケンスに、対応する入力シーケンスと同じ数のタイム ステップが含まれます。

観測値が 1 つの sequence-to-sequence 分類タスクでは、sequences を行列にすることができます。この場合、scores は、予測クラス スコアの行列です。

分類スコアを調べる例については、深層学習を使用した Web カメラ イメージの分類を参照してください。

アルゴリズム

関数 trainNetwork を使用してネットワークに学習させる場合や、DAGNetwork オブジェクトおよび SeriesNetwork オブジェクトと共に予測関数または検証関数を使用する場合、ソフトウェアは単精度浮動小数点演算を使用して、これらの計算を実行します。学習、予測、および検証のための関数には、trainNetworkpredictclassify、および activations が含まれます。CPU と GPU の両方を使用してネットワークに学習させる場合、単精度演算が使用されます。

代替方法

複数の出力層をもつネットワークを使用してデータを分類するには、関数 predict を使用し、ReturnCategorical オプションを 1 (true) に設定します。

予測された分類スコアを計算するために、関数 predict を使用することもできます。

ネットワーク層から活性化を計算するには、関数 activations を使用します。

LSTM ネットワークなどの再帰型ネットワークでは、classifyAndUpdateState および predictAndUpdateState を使用してネットワークの状態の予測および更新を実行できます。

参照

[1] Kudo, Mineichi, Jun Toyama, and Masaru Shimbo. “Multidimensional Curve Classification Using Passing-through Regions.” Pattern Recognition Letters 20, no. 11–13 (November 1999): 1103–11. https://doi.org/10.1016/S0167-8655(99)00077-X.

[2] UCI Machine Learning Repository: Japanese Vowels Dataset. https://archive.ics.uci.edu/ml/datasets/Japanese+Vowels.

拡張機能

バージョン履歴

R2016a で導入