kNN算法——幫你找到身邊最相近的人

摘要

本文簡單介紹最近鄰算法的基本思想以及具體python實現,並且分析了其優缺點及適用範圍,適合初學者理解與動手實踐。

新生開學了,部分大學按照興趣分配室友的新聞佔據了頭條,這其中涉及到機器學習算法的應用。此外,新生進入大學後,可能至少參加幾個學生組織或社團。社團是根據學生的興趣將它們分為不同的類別,那麼如何定義這些類別,或者區分各個組織之間的差別呢?我敢肯定,如果你問過運營這些社團的人,他們肯定不會說他們的社團和其它的社團相同,但在某種程度上是相似的。比如,老鄉會和高中同學會都有著同樣的生活方式;足球俱樂部和羽毛球協會對運動有著相同的興趣;科技創新協會和創業俱樂部有相近的的興趣等。也許讓你去衡量這些社團或組織所處理的事情或運行模式,你自己就可以確定哪些社團是自己感興趣的。但有一種算法能夠幫助你更好地做出決策,那就是k-Nearest Neighbors(NN)算法, 本文將使用學生社團來解釋k-NN算法的一些概念,該算法可以說是最簡單的機器學習算法,構建的模型僅包含存儲的訓練數據集。該算法對新數據點進行預測,就是在訓練數據集中找到最接近的數據點——其“最近鄰居”。

kNN算法——幫你找到身邊最相近的人

工作原理

在其最簡單的版本中,k-NN算法僅考慮一個最近鄰居,這個最近鄰居就是我們想要預測點的最近訓練數據點。然後,預測結果就是該訓練點的輸出。下圖說明構造的數據集分類情況。

kNN算法——幫你找到身邊最相近的人

從圖中可以看到,我們添加了三個新的數據點,用星星表示。對於三個點中的每一點,我們都標記了訓練集中離其最近的點,最近鄰算法的預測輸出就是標記的這點(用交叉顏色進行表示)。

同樣,我們也可以考慮任意數量k個鄰居,而不是隻考慮一個最近的鄰居。這是k-NN算法名稱的由來。在考慮多個鄰居時,我們使用投票的方式來分配標籤。這意味著對於每個測試點,我們計算有多少個鄰居屬於0類以及有多少個鄰居屬於1類。然後我們統計這些近鄰中屬於哪一類佔的比重大就將預測點判定為哪一類:換句話說,少數服從多數。以下示例使用了5個最近的鄰居:

kNN算法——幫你找到身邊最相近的人

同樣,將預測結果用交叉的顏色表示。從圖中可以看到,左上角的新數據點的預測與我們僅使用一個最近鄰居時的預測結果不相同。

雖然此圖僅展示了用於二分類的問題,但此方法可應用於具有任意數量類的數據集。對於多分類問題,同樣計算k個鄰居屬於哪些類,並進行數量統計,從中選取數量最多的類作為預測結果。

Scratch實現k-NN算法

以下是k-NN算法的偽代碼,用於對一個數據點進行分類(將其稱為A點):

對於數據集中的每一個點:

  • 首先,計算A點和當前點之間的距離;
  • 然後,按遞增順序對距離進行排序;
  • 其次,把距離最近的k個點作為A的最近鄰;
  • 之後,找到這些鄰居中的絕大多數類;
  • 最後,將絕大多數類返回作為我們對A類的預測;

Python實現代碼如下:

def knnclassify(A, dataset, labels, k):
datasetSize = dataset.shape[0]

# 計算A點和當前點之間的距離
diffMat = tile(A, (datasetSize, 1)) - dataset
sqDiffMat = diffMat ** 2
sqDistances = sqDiffMat.sum(axis=1)
distances = sqDistances ** 0.5

# 按照增序對距離排序
sortedDistIndices = distances.argsort()

# 選出距離最小的k個點
classCount = {}
for i in range(k):
voteIlabel = labels[sortedDistIndices[i]]
classCount[voteIlabel] = classCount.get(voteIlabel, 0) + 1

# 對這些點所處的類別按照頻次排序
sortedClassCount = sorted(classCount.iteritem(), key=operator.itemgetter(1), reverse=True)

return sortedClassCount[0][0]

下面讓我們深入研究下上述代碼:

  • 函數knnclassify需要4個輸入參數:要分類的輸入向量稱為A,稱為dataSet的訓練樣例的完整矩陣,稱為labels的標籤向量,以及k——在投票中使用的最近鄰居的數量。
  • 使用歐幾里德距離計算A和當前點之間的距離。
  • 按照遞增順序對距離進行排序。
  • 從中選出k個最近距離來對A類進行投票。
  • 之後,獲取classCount字典並將其分解為元組列表,然後按元組中的第2項對元組進行排序。由於排序的順序是相反的,因此我們選擇從最大到最小(設置reverse)。
  • 最後,返回最頻繁出現的類別標籤。

Scikit-Learn實現k-NN算法

Scikit-Learn是一個機器學習工具箱,內部集成了很多機器學習算法。現在讓我們看一下如何使用Scikit-learn實現kNN算法。代碼如下:

from sklearn import datasets
from sklearn.model_selection import train_test_split
from sklearn.neighbors import KNeighborsClassifier
# 導入iris數據集
iris = datasets.load_iris()
X = iris.data
y = iris.target
# 將其按照一定的比例劃分為訓練集和測試集(random_state=0 保證每次運行分割得到一樣的訓練集和測試集)
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)
# 設定鄰居個數
clf = KNeighborsClassifier(n_neighbors=5)
# 擬合訓練數據
clf.fit(X_train, y_train)
# 對測試集進行預測
predictions = clf.predict(X_test)
print("Test set predictions: {}".format(predictions))
# 評估模型性能

accuracy = clf.score(X_test, y_test)
print("Test set accuracy: {:.2f}".format(accuracy))

下面讓我們來看看上述代碼:

  • 首先,生成鳶尾屬植物數據集;
  • 然後,將數據拆分為訓練和測試集,以評估泛化性能;
  • 之後,將鄰居數量(k)指定為5;
  • 接下來,使用訓練集來擬合分類器;
  • 為了對測試數據進行預測,對於測試集中的每個數據點,都要使用該方法計算訓練集中的最近鄰居,並找到其中最頻繁出現的類;
  • 最後,通過使用測試數據和測試標籤調用score函數來評估模型的泛化能力;

模型運行完畢,測試集上得到97%的準確度,這意味著模型在測試數據集中97%的樣本都正確地預測出類別;

kNN算法——幫你找到身邊最相近的人

優點和缺點

一般而言,k-NN分類器有兩個重要參數:鄰居數量以及數據點之間的距離計算方式。

  • 在實踐應用中,一般使用少數3個或5個鄰居時效果通常會很好。當然,應該根據具體情況調整這個參數;
  • 選擇正確的距離測量方法可能有些困難。一般情況下,都是使用歐幾里德距離,歐幾里得距離在許多設置中效果都不錯;

k-NN的優勢之一是該模型非常易於理解,並且通常無需進行大量參數調整的情況下就能獲得比較不錯的性能表現。在考慮使用更高級的技術之前,使用此算法是一種很好的基線方法。k-NN模型的建立通常會比較快,但是當訓練集非常大時(無論是特徵數還是樣本數量),預測時耗費的時間會很多。此外,使用k-NN算法時,對數據進行預處理非常重要。該方法通常在具有許多特徵(數百或更多)的數據集上表現不佳,並且對於大多數特徵在大多數情況下為0的數據集(所謂的稀疏數據集)而言尤其糟糕。

結論

k-NN算法是一種簡單有效的數據分類方法,它是基於實例學習的一種機器學習算法,需要通過數據實例來執行機器學習算法,該算法必須攜帶完整的數據集。而對於大型的數據集,需要耗費比較大的存儲。此外,還需要計算數據庫中每個數據點距離預測點的的距離,這個過程會很麻煩,且耗時多。另一個缺點是k-NN算法不能夠讓你瞭解數據的基礎結構,無法知道每個類別的“平均”或“範例”具體是什麼樣子。

因此,雖然k-NN算法易於理解,但由於預測速度慢且無法處理多特徵問題,因此在實踐中並不常用。

  • Peter Harrington的機器學習(2012)
  • Sarah GuidoAndreas Muller 使用Python進行機器學習簡介(2016)


作者信息

James Le,機器學習工程師

LinkedIn:http://www.linkedin.com/in/khanhnamle94

本文由阿里云云棲社區組織翻譯。

文章原標題《k-Nearest Neighbors: Who are close to you》,譯者:海棠,審校:Uncle_LLD。


分享到:


相關文章: