The k-Nearest Neighbors (kNN) algorithm is arguably the simplest machine learning algorithm. Building the model only consists of storing the training dataset. To make a prediction for a new data point, the algorithm finds the closest data points in the training dataset, it “nearest neighbors”.
k-Neighbors Classification
In its simplest version, the algorithm only considers exactly one nearest neighbor, which is the closest training data point to the point we want to make a prediction for. The prediction is then simply the known output for this training point. Figure forge_one_neighbor illustrates this for the case of classification on the forge dataset (mglearn package can be downloaded here):
Here, we added three new data points, shown as crosses. For each of them, we marked the closest point in the training set. The prediction of the one-nearestneighbor algorithm is the label of that point (shown by the color of the cross). Instead of considering only the closest neighbor, we can also consider an arbitrary number k of neighbors. This is where the name of the k neighbors algorithm comes from. When considering more than one neighbor, we use voting to assign a label. This means, for each test point, we count how many neighbors are red, and how many neighbors are blue. We then assign the class that is more frequent: in other words, the majority class among the k neighbors. Below is an illustration using the three closest neighbors. Again, the prediction is shown as the color of the cross. You can see that the prediction changed for the point in the top left from using only one neighbor.
While this illustration is for a binary classification problem, you can imagine this working with any number of classes. For more classes, we count how many neighbors belong to each class, and again predict the most common class. Now let’s look at how we can apply the k nearest neighbors algorithm using scikit-learn.
First, we split our data into a training and a test set, so we can evaluate generalization performance, as discussed in Chapter 1 (Introduction).
Next we import and instantiate the class. This is when we can set parameters, like the number of neighbors to use. Here, we set it to three.
Now, we fit the classifier using the training set. For KNeighborsClassifier this means storing the dataset, so we can compute neighbors during prediction.
To make predictions on the test data, we call the predict method. This computes the nearest neighbors in the training set and finds the most common class among these:
To evaluate how well our model generalizes, we can call the score method with the test data together with the test labels:
We see that our model is about 86% accurate, meaning the model predicted the class correctly for 85% of the samples in the test dataset.
Analyzing KNeighborsClassifier
For two-dimensional datasets, we can also illustrate the prediction for all possible test point in the xy-plane. We color the plane red in regions where points would be assigned the red class, and blue otherwise. This lets us view the decision boundary, which is the divide between where the algorithm assigns class red versus where it assigns class blue.
Here is a visualization of the decision boundary for one, three and five neighbors:
- import matplotlib.pyplot as plt
- fig, axes = plt.subplots(1, 3, figsize=(10, 3))
- for n_neighbors, ax in zip([1, 3, 9], axes):
- clf = KNeighborsClassifier(n_neighbors=n_neighbors).fit(X, y)
- mglearn.plots.plot_2d_separator(clf, X, fill=True, eps=0.5, ax=ax, alpha=.4)
- ax.scatter(X[:, 0], X[:, 1], c=y, s=60, cmap=mglearn.cm2)
- ax.set_title("%d neighbor(s)" % n_neighbors)
- plt.show()
As you can see in the left figure, using a single neighbor results in a decision boundary that follows the training data closely. Considering more and more neighbors leads to a smoother decision boundary. A smoother boundary corresponds to a simple model. In other words, using few neighbors corresponds to high model complexity (as shown on the right side of below figure), and using many neighbors corresponds to low model complexity (as shown on the left side of below figure).
Figure model_complexity
Let’s investigate whether we can confirm the connection between model complexity and generalization that we discussed above. We will do this on the real world breast cancer dataset. We begin by splitting the dataset into a training and a test set. Then we will evaluate training and test set performance with different numbers of neighbors.
- from sklearn.datasets import load_breast_cancer
- from sklearn.model_selection import train_test_split
- from sklearn.neighbors import KNeighborsClassifier
- import matplotlib.pyplot as plt
- import mglearn
- cancer = load_breast_cancer()
- X_train, X_test, y_train, y_test = train_test_split(
- cancer.data, cancer.target, stratify=cancer.target, random_state=66)
- training_accuracy = []
- test_accuracy = []
- # try n_neighbors from 1 to 10.
- neighbors_settings = range(1, 11)
- for n_neighbors in neighbors_settings:
- # build the model
- clf = KNeighborsClassifier(n_neighbors=n_neighbors)
- clf.fit(X_train, y_train)
- # record training set accuracy
- training_accuracy.append(clf.score(X_train, y_train))
- # record generalization accuracy
- test_accuracy.append(clf.score(X_test, y_test))
- plt.plot(neighbors_settings, training_accuracy, label="training accuracy")
- plt.plot(neighbors_settings, test_accuracy, label="test accuracy")
- plt.legend()
- plt.show()
The plot shows the training and test set accuracy on the y axis against the setting of n_neighbors on the x axis. While the real world plots are rarely very smooth, we can still recognize some of the characteristics of overfitting and underfitting. As considering fewer neighbors corresponds to a more complex model, the plot is horizontally flipped relative to the illustration in Figure model_complexity.
Considering a single nearest neighbor, the prediction on the training set is perfect. Considering more neighbors, the model becomes more simple, and the training accuracy drops. The test set accuracy for using a single neighbor is lower then when using more neighbors, indicating that using a single nearest neighbor leads to a model that is too complex. On the other hand, when considering 10 neighbors, the model is too simple, and performance is even worse. The best performance is somewhere in the middle, around using six neighbors.
k-Neighbors Regression
There is also a regression variant of the k-nearest neighbors algorithm. Again, let’s start by using a single nearest neighbor, this time using the wave dataset. We added three test data points as green crosses on the x axis. The prediction using a single neighbor is just the target value of the nearest neighbor, shown as the blue cross:
- mglearn.plots.plot_knn_regression(n_neighbors=1)
Again, we can also use more than one nearest neighbor for regression. When using multiple nearest neighbors for regression, the prediction is the average (or mean) of the relevant neighbors:
- mglearn.plots.plot_knn_regression(n_neighbors=3)
The k nearest neighbors algorithm for regression is implemented in the KNeighborsRegressor class in scikit-learn. Using it looks much like the KNeighborsClassifier above:
- import mglearn
- from sklearn.neighbors import KNeighborsRegressor
- from sklearn.model_selection import train_test_split
- X, y = mglearn.datasets.make_wave(n_samples=40)
- # split the wave dataset into a training and a test set
- X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)
- # Instantiate the model, set the number of neighbors to consider to 3:
- reg = KNeighborsRegressor(n_neighbors=3)
- # Fit the model using the training data and training targets:
- reg.fit(X_train, y_train)
- print "Prediction of X_test:\n%s" % (reg.predict(X_test))
- print "Prediction score: %.02f" % (reg.score(X_test, y_test))
Analyzing k nearest neighbors regression
For our one-dimensional dataset, we can see what the predictions look like for all possible feature values. To do this, we create a test-dataset consisting of many points on the line.
- import os
- dmode = os.environ.get('DISPLAY', '')
- if dmode:
- import matplotlib.pyplot as plt
- import numpy as np
- fig, axes = plt.subplots(1, 3, figsize=(15, 4))
- # create 1000 data points, evenly spaced between -3 and 3
- line = np.linspace(-3, 3, 1000).reshape(-1, 1)
- plt.suptitle("nearest_neighbor_regression")
- for n_neighbors, ax in zip([1, 3, 9], axes):
- # make predictions using 1, 3 or 9 neighbors
- reg = KNeighborsRegressor(n_neighbors=n_neighbors).fit(X, y)
- ax.plot(X, y, 'o')
- ax.plot(X, -3 * np.ones(len(X)), 'o')
- ax.plot(line, reg.predict(line))
- ax.set_title("%d neighbor(s)" % n_neighbors)
- plt.show()
In the plots above, the blue points are again the responses for the training data, while the red line is the prediction made by the model for all points on the line. Using only a single neighbor, each point in the training set has an obvious influence on the predictions, and the predicted values go through all of the data points. This leads to a very unsteady prediction. Considering more neighbors leads to smoother predictions, but these do not fit the training data as well.
Strengths, weaknesses and parameters
In principal, there are two important parameters to the KNeighbors classifier: the number of neighbors and how you measure distance between data points. In practice, using a small number of neighbors like 3 or 5 often works well, but you should certainly adjust this parameter. Choosing the right distance measure is somewhat beyond the scope of this book. By default, Euclidean distance is used, which works well in many settings.
One of the strengths of nearest neighbors is that the model is very easy to understand, and often gives reasonable performance without a lot of adjustments. Using nearest neighbors is a good baseline method to try before considering more advanced techniques. Building the nearest neighbors model is usually very fast, but when your training set is very large (either in number of features or in number of samples) prediction can be slow.
When using nearest neighbors, it’s important to preprocess your data (see Chapter 3 Unsupervised Learning). Nearest neighbors often does not perform well on dataset with very many features, in particular sparse datasets, a common type of data in which there are many features, but only few of the features are non-zero for any given data point.
So while the nearest neighbors algorithm is easy to understand, it is not often used in practice, due to prediction being slow, and its inability to handle many features. The method we discuss next has neither of these drawbacks.
Supplement
* [ ML In Action ] Classifying with k-Nearest Neighbors
沒有留言:
張貼留言