Choosing the 3 closest points

2 ビュー (過去 30 日間)
Joakim Hansen
Joakim Hansen 2017 年 9 月 21 日
編集済み: Jan 2017 年 9 月 22 日
Hi!
I am doing some simple machine learning(kNN classifier), and are having trouble with picking the three nearest points of each point I want to classify. I want the find the three 'tp's that are closest to each of the 'x's and add them to the 3x1 matrix 'cp'. I feel like I am on to something, but it's not the closest that are being chosen. In the last piece of the code, you're able to compare the points that is chosen to all the points, it's pretty easy to see if the closest is chosen.
clearvars
my1=transpose([1 1]);
my2=transpose([2 2]);
sigma=0.2;
Sigma2=sigma*eye(2);
%Trining points
tp=[mvnrnd(my1,Sigma2,50); mvnrnd(my2,Sigma2,50)];
yy=[ones(50,1);(-1)*ones(50,1)];
%Points to be classified
x=[mvnrnd(my1,Sigma2,100); mvnrnd(my2,Sigma2,100)];
%Closest points
cp=zeros(3,2);
%Fills the three spots
if cp(1,1)==0 && cp(1,2)==0
cp(1,1)=tp(1,1);
cp(1,2)=tp(1,2);
end
if cp(2,1)==0 && cp(2,2)==0
cp(2,1)=tp(2,1);
cp(2,2)=tp(2,2);
end
if cp(3,1)==0 && cp(3,2)==0
cp(3,1)=tp(3,1);
cp(3,2)=tp(3,2);
end
%Prelocating memory
CP1=zeros(1);
CP2=zeros(1);
CP3=zeros(1);
for ii=4:100
for k=1:3
LengthCPtoX = sqrt(((x(1,1)-cp(k,1))^2)-((x(1,2)-cp(k,2))^2));
LengthTPtoX = sqrt(((x(1,1)-tp(ii,1))^2)-((x(1,2)-tp(ii,2))^2));
if LengthCPtoX > LengthTPtoX
CP1=sqrt(((x(1,1)-cp(1,1))^2)-((x(1,2)-cp(1,2))^2));
CP2=sqrt(((x(1,1)-cp(2,1))^2)-((x(1,2)-cp(2,2))^2));
CP3=sqrt(((x(1,1)-cp(3,1))^2)-((x(1,2)-cp(3,2))^2));
maxCP=max([CP1 CP2 CP3]);
if maxCP==CP1
cp(1,1)=tp(ii,1);
cp(1,2)=tp(ii,2);
end
if maxCP==CP2
cp(2,1)=tp(ii,1);
cp(2,2)=tp(ii,2);
end
if maxCP==CP3
cp(3,1)=tp(ii,1);
cp(3,2)=tp(ii,2);
end
if cp(1,1)==tp(ii,1) || cp(2,1)==tp(ii,1) || cp(3,1)==tp(ii,1)
break
end
end
end
end
figure
subplot(2,1,1);
scatter(cp(:,1),cp(:,2),'b')
hold on
scatter(x(1,1),x(1,2),'r')
grid on
xlim([0 3.5]);
ylim([0 3.5]);
subplot(2,1,2);
scatter(tp(:,1),tp(:,2),'b')
hold on
scatter(x(1,1),x(1,2),'r')
grid on
xlim([0 3.5]);
ylim([0 3.5]);

回答 (2 件)

Image Analyst
Image Analyst 2017 年 9 月 22 日
Why are you trying to do a knn classification manually/yourself when there is a function built in to the Statistics and Machine Learning Toolbox to do it, knnsearch()?
  1 件のコメント
Joakim Hansen
Joakim Hansen 2017 年 9 月 22 日
The whole point is to understand the classifier, so using a premade function is not very educational :)

サインインしてコメントする。


Jan
Jan 2017 年 9 月 22 日
編集済み: Jan 2017 年 9 月 22 日
The code inside the
if LengthCPtoX > LengthTPtoX
...
end
block does not depend on the index k. Therefore I think the loop over k is not used at all.
You can implement this
find the three 'tp's that are closest to each of the 'x's and add them to
the 3x1 matrix 'cp'
by:
function cp = Find3Nearest(tp, x)
nx = size(x, 1);
cp = zeros(3, 3, nx);
for ix = 1:nx
% Get squared distance (save time for SQRT):
dist2 = sum(bsxfun(@minus, tp, x(ix, :)) .^2);
% Get the minimum value and replace it by Inf iteratively. This is
% faster than sorting and choosing the 3 smallest values.
for i3 = 1:3
[minValue, minIndex] = min(dist2);
cp(i3, :, ix) = tp(minIndex, :);
dist2(minIndex) = Inf;
end
end
Untested Please debug this
  2 件のコメント
Joakim Hansen
Joakim Hansen 2017 年 9 月 22 日
I worked it out, here's the result :)
clearvars
my1=transpose([1 1]);
my2=transpose([3 3]);
sigma=0.2;
Sigma2=sigma*eye(2);
%Amount of vectors that are to be classified
NumbVecs=200;
%Amount of training vectors
NumbTrVecs=100;
%First vec to be measured towards the prefilled cp's
StartVec=4;
%kNN rule
k=3;
%Trining points
tp=[mvnrnd(my1,Sigma2,50); mvnrnd(my2,Sigma2,50)];
yy=[ones(50,1);(-1)*ones(50,1)];
%Points to be classified
x=[mvnrnd(my1,Sigma2,100); mvnrnd(my2,Sigma2,100)];
%Closest points
cp=zeros(3,2);
%Training points classified
tpc=[tp yy];
%Preallocating memory to x classified as class 1 or 2
x1=zeros(105,2);
x2=zeros(105,2);
%Points with prelocated memory to classifing
xc=[x zeros(NumbVecs,1)];
%Fills the three spots
if cp(1,1)==0 && cp(1,2)==0
cp(1,1)=tp(1,1);
cp(1,2)=tp(1,2);
end
if cp(2,1)==0 && cp(2,2)==0
cp(2,1)=tp(2,1);
cp(2,2)=tp(2,2);
end
if cp(3,1)==0 && cp(3,2)==0
cp(3,1)=tp(3,1);
cp(3,2)=tp(3,2);
end
%Prelocating memory to length of the three so far closest points
CP1=zeros(1);
CP2=zeros(1);
CP3=zeros(1);
for w=1:NumbVecs
for ii=StartVec:NumbTrVecs
for kcp=1:k
LengthCPtoX = sqrt(((x(w,1)-cp(kcp,1))^2)+((x(w,2)-cp(kcp,2))^2));
LengthTPtoX = sqrt(((x(w,1)-tp(ii,1))^2)+((x(w,2)-tp(ii,2))^2));
if LengthCPtoX > LengthTPtoX
CP1=sqrt(((x(w,1)-cp(1,1))^2)+((x(w,2)-cp(1,2))^2));
CP2=sqrt(((x(w,1)-cp(2,1))^2)+((x(w,2)-cp(2,2))^2));
CP3=sqrt(((x(w,1)-cp(3,1))^2)+((x(w,2)-cp(3,2))^2));
maxCP=max([CP1 CP2 CP3]);
if maxCP==CP1
cp(1,1)=tp(ii,1);
cp(1,2)=tp(ii,2);
end
if maxCP==CP2
cp(2,1)=tp(ii,1);
cp(2,2)=tp(ii,2);
end
if maxCP==CP3
cp(3,1)=tp(ii,1);
cp(3,2)=tp(ii,2);
end
if cp(1,1)==tp(ii,1) || cp(2,1)==tp(ii,1) || cp(3,1)==tp(ii,1)
break
end
end
%Closest points classified
cpc=[cp zeros(3,1)];
%Checking what classes the three closest points belong to and put
%the vector in the class which most of the closest points are included in.
for jj=1:100
for kcp1=1:k
if cp(kcp1,1)==tp(jj,1) && cp(kcp1,2)==tp(jj,2)
cpc(kcp1,3)=tpc(jj,3);
end
end
end
end
end
%Sums the labels of the cp's and gives the x a label
if (cpc(1,3)+cpc(2,3)+cpc(3,3)) >= 1
xc(w,3)=1;
else
xc(w,3)=-1;
end
if xc(w,3)==1
x1(w,1)=xc(w,1);
x1(w,2)=xc(w,2);
else
x2(w,1)=xc(w,1);
x2(w,2)=xc(w,2);
end
end
figure
scatter(x1(:,1),x1(:,2),'g')
hold on
scatter(x2(:,1),x2(:,2),'k')
hold on
scatter(tp(1:50,1),tp(1:50,2),'s','r')
hold on
scatter(tp(51:100,1),tp(51:100,2),'s','b')
xlim([-1 4])
ylim([-1 4])
Jan
Jan 2017 年 9 月 22 日
編集済み: Jan 2017 年 9 月 22 日
You can still accelerate it by avoiding the SQRT before searching for the max value. Apply the expensive square root on the result only. Note:
x = rand(1, 1000);
[v,k] = max(sqrt(x));
% Equivalent but cheaper:
[v2,k] = max(x);
v = sqrt(v2);
The code can be improved by vectorizing, e.g.:
for kcp1=1:k
if cp(kcp1,1)==tp(jj,1) && cp(kcp1,2)==tp(jj,2)
cpc(kcp1,3)=tpc(jj,3);
end
end
by
match = (cp(:, 1) == tp(jj,1) & cp(:, 2) == tp(jj,2));
cpc(match, 3) = tpc(jj,3);

サインインしてコメントする。

カテゴリ

Help Center および File ExchangeClassification Ensembles についてさらに検索

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by