Training/Cross validation/Test sets

6 ビュー (過去 30 日間)
Reuben Addison
Reuben Addison 2019 年 3 月 1 日
回答済み: Shubham 2024 年 9 月 4 日
I am new to machine learning and I am a little lost on these concepts, I trained a model with my training data set(60%) and get the optimized parameters(theta values), do i have to train the cross validation sets? if the answer is yes what then do I do with the new theta values for my cross validation data? If not why do I still need a cross validation sets?
Since I can test my model accuracy with just the test sets. I will appreciate it if someone takes me through.
Also how to assess model accuracy with my test sets. I appreciate any help in advance. Even a tutorial video/codes

回答 (1 件)

Shubham
Shubham 2024 年 9 月 4 日
Hi Reuben,
Here's a basic guide on how you might handle training, validation, and testing of a model in MATLAB, using a simple example with a classification model. MATLAB provides several built-in functions for machine learning, including the fitcsvm function for training a support vector machine (SVM) classifier.
Steps in MATLAB
  1. Load and Split Data: First, split your dataset into training, validation, and test sets.
  2. Train the Model: Use the training set to train your model.
  3. Validate the Model: Use the validation set to tune hyperparameters and check for overfitting.
  4. Test the Model: Use the test set to evaluate the final model's performance.
Example Code
Here's a simple example using MATLAB:
% Load your dataset
load fisheriris % Example dataset
X = meas; % Features
y = species; % Labels
% Split the data into training, validation, and test sets
cv = cvpartition(length(y), 'HoldOut', 0.4);
XTrain = X(training(cv), :);
yTrain = y(training(cv), :);
XTemp = X(test(cv), :);
yTemp = y(test(cv), :);
% Further split the temp data into validation and test
cv2 = cvpartition(length(yTemp), 'HoldOut', 0.5);
XVal = XTemp(training(cv2), :);
yVal = yTemp(training(cv2), :);
XTest = XTemp(test(cv2), :);
yTest = yTemp(test(cv2), :);
% Train the model using the training set
model = fitcsvm(XTrain, yTrain);
% Validate the model using the validation set
valPredictions = predict(model, XVal);
valAccuracy = sum(valPredictions == yVal) / length(yVal);
fprintf('Validation Accuracy: %.2f%%\n', valAccuracy * 100);
% Test the model using the test set
testPredictions = predict(model, XTest);
testAccuracy = sum(testPredictions == yTest) / length(yTest);
fprintf('Test Accuracy: %.2f%%\n', testAccuracy * 100);
% Display a detailed classification report
confMat = confusionmat(yTest, testPredictions);
disp('Confusion Matrix:');
disp(confMat);
Explanation
  • Data Splitting: We first split the data into training and temporary sets (60% training, 40% temporary). The temporary set is further split into validation and test sets (20% each).
  • Model Training: We train an SVM model using the training data.
  • Validation: We evaluate the model's performance on the validation set to ensure it generalizes well and to fine-tune parameters if necessary.
  • Testing: Finally, we assess the model's performance on the test set to get an unbiased estimate of its accuracy.

カテゴリ

Help Center および File ExchangeSupport Vector Machine Regression についてさらに検索

製品


リリース

R2018a

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by