printf("%c의 블로그", 'Molkka');

코딩공부/딥러닝

[Python] K-NN 알고리즘 구현하기

mol_kka 2022. 3. 24. 17:17

 

#K-NN 알고리즘이란?

: K-Nearest Neighbor, K-최근접 이웃 알고리즘

: 지도학습 알고리즘에 해당하며 분류 또는 회귀 모두 사용 가능

: 특정 공간 내 입력된 데이터와 가장 가까운 K개의 요소를 찾아 더 많이 일치하는 곳으로 분류하는 알고리즘

 

 
검증 표본(초록색 원)을 첫 번째 파랑 네모의 항목이나 빨강 삼각형의 두 번째 항목으로 분류하려고 할 때, 만약 “k = 3” (실선으로 그려진 원)이면 두 번째 항목으로 할당되어야 한다. 왜냐하면 2개의 삼각형과 1개의 사각형만이 안쪽 원 안에 있기 때문이다. 만약 “k = 5” (점선으로 그려진 원)이면 첫 번째 항목으로 분류되어야 한다. (바깥쪽 원 안에 있는 3개의 사각형 vs. 2개의 삼각형). [출처: 위키피디아]

 

 

 

 

#Python으로 K-NN 알고리즘 구현해보기

import random #데이터 생성을 위해 램덤모듈 사용
import numpy as np

#랜덤으로 데이터 입력(학습)
r = [] #방울토마토 1
b = [] #토마토 0
for i in range(50):
    #크기가 1~10 사이에 있고, 무게가 50~100 사이에 있으면 방울토마토
    r.append([random.randint(1,10),random.randint(50,100),1])
    #크기가 7~20 사이에 있고, 무게가 80~120 사이에 있으면 토마토
    b.append([random.randint(7,20),random.randint(80,120),0])
    
#점x와 점y의 거리 구하는 함수
def distance(x,y):
    return np.sqrt(pow((x[0]-y[0]),2)+pow((x[1]-y[1]),2))

#knn알고리즘
def knn(x,y,k):
    result=[]
    cnt=0
    for i in range(len(y)):
        result.append([distance(x,y[i]),y[i][2]])
    result.sort()
    for i in range(k):
        if(result[i][1]==i):
            cnt+=1
    if(cnt > (k/2)):
        print("이것은 방울토마토")
    else:
        print("이것은 토마토")

size = input("크기>>")
weight = input("무게>>")
num = input("k>>") #K요소
new = [int(size), int(weight)]

knn(new, r+b, int(num))

 

 

 

#K-NN 알고리즘 결과 그래프로 표시해보기

import matplotlib.pyplot as plt
%matplotlib inline
rr = np.array(r)
bb = np.array(b)
for i,j in rr[:,:2]:
    plt.plot(i,j,'or')
for i,j in bb[:,:2]:
    plt.plot(i,j,'ob')
plt.plot(int(size), int(weight), 'og')
plt.show()

>>실행 결과

초록색 점이 새로 입력한 데이터이다. 크기 8, 무게 85로 입력했더니 방울토마토로 분류했다.

결과는 컴퓨터가 어떤 데이터를 학습했고, 사용자가 어떤 값을 입력했냐에 따라 달라진다.

극단 값을 입력하면 정확히 분류해낸다.

경계값은 데이터 품질에 따라 정밀도가 달라질 것으로 예상된다.