Deep Learning ANN Classification Model
現在この質問をフォロー中です
- フォローしているコンテンツ フィードに更新が表示されます。
- コミュニケーション基本設定に応じて電子メールを受け取ることができます。
エラーが発生しました
ページに変更が加えられたため、アクションを完了できません。ページを再度読み込み、更新された状態を確認してください。
古いコメントを表示
Hi,
I am trying to develop a pattern recognition classification ANN model for 10 different classes using 11 inputs. The model runs but the performance is poor (less than 20%). I want to try to a deep learning technique. However, all the examples are for image-based classification. My problem is numerical-based (no image). Is there a way to do a deep learning pattern-based classification model in MATLAB? Is there an example on how to do that?
採用された回答
Cris LaPierre
2023 年 12 月 5 日
編集済み: Cris LaPierre
2023 年 12 月 5 日
Absolutely. Here is a page showing multiple examples, none of which are images: https://www.mathworks.com/help/deeplearning/gs/pattern-recognition-with-a-shallow-neural-network.html
These are all shallow networks. You can turn a shallow network into a deep network by adding more layers.
I would suggest using the Neural Network Pattern Recognition App to create a network, and then export the code. You can then manually expand that. Here's an example I built up based on the Iris data set example in the app (4 inputs, 3 outputs)
% Solve a Pattern Recognition Problem with a Neural Network
% Script generated by Neural Pattern Recognition app
% Created 05-Dec-2023 10:55:42
%
% This script assumes these variables are defined:
%
% irisInputs - input data.
% irisTargets - target data.
x = irisInputs;
t = irisTargets;
% Choose a Training Function
% For a list of all training functions type: help nntrain
% 'trainlm' is usually fastest.
% 'trainbr' takes longer but may be better for challenging problems.
% 'trainscg' uses less memory. Suitable in low memory situations.
trainFcn = 'trainscg'; % Scaled conjugate gradient backpropagation.
This is the section of code that creates the layers. It is a shallow network because there is only 1 hidden layer
% Create a Pattern Recognition Network
hiddenLayerSize = 10;
net = patternnet(hiddenLayerSize, trainFcn);
You can turn this into a deep learning network by adding more hidden layers. For example, this code would create a 3-layer network.
% Three hidden layer NN
hiddenLayerSize1 = 10;
hiddenLayerSize2 = 20;
hiddenLayerSize3 = 15;
net = patternnet([hiddenLayerSize1 hiddenLayerSize2 hiddenLayerSize3], trainFcn);
% Setup Division of Data for Training, Validation, Testing
net.divideParam.trainRatio = 70/100;
net.divideParam.valRatio = 15/100;
net.divideParam.testRatio = 15/100;
% Train the Network
[net,tr] = train(net,x,t);
% Test the Network
y = net(x);
e = gsubtract(t,y);
performance = perform(net,t,y)
tind = vec2ind(t);
yind = vec2ind(y);
percentErrors = sum(tind ~= yind)/numel(tind);
% View the Network
view(net)
% Plots
% Uncomment these lines to enable various plots.
%figure, plotperform(tr)
%figure, plottrainstate(tr)
%figure, ploterrhist(e)
%figure, plotconfusion(t,y)
%figure, plotroc(t,y)
7 件のコメント
You'll need to load the Iris data set to run this code.
% Load the Iris dataset
[IrisInput,IrisOutput] = iris_dataset;
IrisInput contains 150 observations with four input features corresponding to four measurements:
- Sepal length in cm
- Sepal width in cm
- Petal length in cm
- Petal width in cm
IrisOutput is a 3-by-150 matrix that indicates the classification of each observation by placing a 1 in the row corresponding to the correct species: first fow for setosa, second row for versicolor, or third for virginica.
For more information on this dataset, run the following help command
help iris_dataset
Mostafa
2023 年 12 月 6 日
Thanks for your answer. It was of great help. Unfortunately, while the accuracy did improve a little bit, it is still low (27% from the confusion matrix). Any ohter suggestions on how to improve the accuracy? I had some inputs that were 0 and 1 (yes or no) and I tried to remove them thinking that they may be some prediction challenges. It did help but nothing that significant.
Mostafa
2023 年 12 月 6 日
I actually found that MATLAB does not define my output classes correctly. I have 34,610 observations divided into 10 output classes. However, when I look at the confusion matrix, some of the output classes have the term NaN with zero observation. To define the output classes, I used binary codes. For example, the first output class is:
0 0 0 0 0 0 0 0 0 1
The second is:
0 0 0 0 0 0 0 0 1 0
Any ideas what may be causing this?
Mostafa
2023 年 12 月 6 日
As an update, I reran the file with only 219 observations and I am facing the same problems. Attached is the confusion matrix.
Cris LaPierre
2023 年 12 月 6 日
Consider attaching your data, and share your code. Without that, we can only talk generally.
Mostafa
2023 年 12 月 6 日
Here is the code and small data set. I imported the data and transposed them.
Cris LaPierre
2023 年 12 月 6 日
編集済み: Cris LaPierre
2023 年 12 月 6 日
Not knowing anything about your data, I think you may need to look into feature engineering. Three of your inputs are highly correlated, meaning they aren't adding anything new to the model. However, even after eliminating them, I get similar results. To me, this means there is not enough difference in your inputs to generate an accurate model.
It = readmatrix("InputsMethod1.xlsx");
Ot = readmatrix("Outputs1.xlsx");
xnames = "input" + (1:4);
x = It';
t = Ot';
% turn t into vector of 'class labels'
f = max((0.1:0.1:1)' .* t);
% Normalize data
[x,ps]=mapminmax(x,0,1);
figure
plot([x;f])

% Inputs 1, 2 and 3 are highly correlated
figure
gplotmatrix(It,[],f,[],[],[],[],[],xnames)

% Detemine which features have the highest predictive power
[idx,scores] = fscmrmr(It,f')
idx = 1×4
1 4 2 3
scores = 1×4
0.0397 0.0034 0.0031 0.0205
bar(scores(idx))
xlabel('Predictor rank')
ylabel('Predictor importance score')
xticklabels(xnames(idx));

Here are the results I get with a shallow network for the 2 top features.

その他の回答 (0 件)
カテゴリ
ヘルプ センター および File Exchange で Pattern Recognition についてさらに検索
参考
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!Web サイトの選択
Web サイトを選択すると、翻訳されたコンテンツにアクセスし、地域のイベントやサービスを確認できます。現在の位置情報に基づき、次のサイトの選択を推奨します:
また、以下のリストから Web サイトを選択することもできます。
最適なサイトパフォーマンスの取得方法
中国のサイト (中国語または英語) を選択することで、最適なサイトパフォーマンスが得られます。その他の国の MathWorks のサイトは、お客様の地域からのアクセスが最適化されていません。
南北アメリカ
- América Latina (Español)
- Canada (English)
- United States (English)
ヨーロッパ
- Belgium (English)
- Denmark (English)
- Deutschland (Deutsch)
- España (Español)
- Finland (English)
- France (Français)
- Ireland (English)
- Italia (Italiano)
- Luxembourg (English)
- Netherlands (English)
- Norway (English)
- Österreich (Deutsch)
- Portugal (English)
- Sweden (English)
- Switzerland
- United Kingdom (English)
