Group K-fold partitioning a dataset

7 ビュー (過去 30 日間)
Ivan Abraham
Ivan Abraham 2018 年 7 月 31 日
回答済み: Jaimin 2025 年 1 月 9 日
The scikit-learn package in Python has a Group K-Fold function that allows you to split the data-set into test/train folds while ensuring the same "group" is not present in different folds. This is useful for example in studies where the same subject/person generates multiple data-points and we want to make sure the samples/data-points belonging to the same subject don't appear in both the training and testing folds.
I was wondering if MATLAB has a way to do this or enable this option in cvpartition function or in some other way. The default options only seem to preserve relative class-sizes.

回答 (1 件)

Jaimin
Jaimin 2025 年 1 月 9 日
While MATLAB does not offer a built-in function exactly like scikit-learn's GroupKFold, you can achieve similar results by manually creating your own group-based cross-validation partitions.
Here is how you can do it:
  1. Determine the unique groups in your dataset.
  2. Randomly shuffle these groups and then split them into k folds.
  3. Assign each data point to a fold based on its group.
% Sample data
data = rand(100, 5); % 100 samples, 5 features
labels = randi([0, 1], 100, 1); % Binary labels
groups = randi([1, 20], 100, 1); % 20 unique groups
% Number of folds
k = 5;
% Get unique groups
uniqueGroups = unique(groups);
% Shuffle groups
shuffledGroups = uniqueGroups(randperm(length(uniqueGroups)));
% Split groups into k folds
folds = cell(k, 1);
foldSize = ceil(length(shuffledGroups) / k);
for i = 1:k
startIdx = (i-1) * foldSize + 1;
endIdx = min(i * foldSize, length(shuffledGroups));
folds{i} = shuffledGroups(startIdx:endIdx);
end
% Create cross-validation partitions
cvIndices = zeros(size(groups));
for i = 1:k
testGroups = folds{i};
testIdx = ismember(groups, testGroups);
cvIndices(testIdx) = i;
end
for i = 1:k
testIdx = (cvIndices == i);
trainIdx = ~testIdx;
trainData = data(trainIdx, :);
trainLabels = labels(trainIdx);
testData = data(testIdx, :);
testLabels = labels(testIdx);
fprintf('Fold %d: Train on %d samples, Test on %d samples\n', i, sum(trainIdx), sum(testIdx));
end
For more information kindly refer following MathWorks documentation.

カテゴリ

Help Center および File ExchangeDiscriminant Analysis についてさらに検索

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by