8

K-Nearest Neighbors

K-nearest neighbor (KNN) is a very simple algorithm in which each observation is predicted based on its “similarity” to other observations. Unlike most methods in this book, KNN is a memory-based algorithm and cannot be summarized by a closed-form model. This means the training samples are required at run-time and predictions are made directly from the sample relationships. Consequently, KNNs are also known as lazy learners (Cunningham and Delany, 2007) and can be computationally inefficient. However, KNNs have been successful in a large number of business problems (see, for example, Jiang et al. (2012) and Mccord and Chuah (2011)) and are useful for preprocessing purposes as well (as was discussed in Section 3.3.2).

8.1    Prerequisites

For this chapter we’ll use the following packages:

# Helper packages

library(dplyr)   # for data wrangling

library(ggplot2) # for awesome graphics

library(rsample) # for data splitting

library(recipes) # for feature engineering

# Modeling packages

library(caret)   # for training KNN models

To illustrate various concepts we’ll continue working with the ames_train and ames_test data sets created in Section 2.7; however, we’ll also illustrate the performance of KNNs on the employee attrition and MNIST data sets.

# Create training (70%) set for the rsample::attrition data

attrit <- attrition %>%

mutate_if(is.ordered, factor,ordered = FALSE)

set.seed(123) # for reproducibility

churn_split <- initial_split(attrit, prop = 0.7,

strata = ”Attrition”)

churn_train <- training(churn_split)

# Import MNIST training data

mnist <- dslabs::read_mnist()

names(mnist)

## [1] ”train” ”test”

8.2    Measuring similarity

The KNN algorithm identifies k observations that are “similar” or nearest to the new record being predicted and then uses the average response value (regression) or the most common class (classification) of those k observations as the predicted output.

For illustration, consider our Ames housing data. In real estate, Realtors determine what price they will list (or market) a home for based on “comps” (comparable homes). To identify comps, they look for homes that have very similar attributes to the one being sold. This can include similar features (e.g., square footage, number of rooms, and style of the home), location (e.g., neighborhood and school district), and many other attributes. The Realtor will look at the typical sale price of these comps and will usually list the new home at a very similar price to the prices these comps sold for.

As an example, Figure 8.1 maps 10 homes (blue) that are most similar to the home of interest (red). These homes are all relatively close to the target home and likely have similar characteristics (e.g., home style, size, and school district). Consequently, the Realtor would likely list the target home around the average price that these comps sold for. In essence, this is what the KNN algorithm will do.

8.2.1    Distance measures

How do we determine the similarity between observations (or homes as in Figure 8.1)? We use distance (or dissimilarity) metrics to compute the pairwise differences between observations. The most common distance measures are the Euclidean (8.1) and Manhattan (8.2) distance metrics; both of which measure the distance between observation xa and xb for all j features.

Image
FIGURE 8.1: The 10 nearest neighbors (blue) whose home attributes most closely resemble the house of interest (red).

j=1P(xajxbj)2

(8.1)

j=1P|xajxbj|

(8.2)

Euclidean distance is the most common and measures the straight-line distance between two samples (i.e., how the crow flies). Manhattan measures the point-to-point travel time (i.e., city block) and is commonly used for binary predictors (e.g., one-hot encoded 0/1 indicator variables). A simplified example is presented below and illustrated in Figure 8.2 where the distance measures are computed for the first two homes in ames_train and for only two features (Gr_Liv_Area & Year_Built).

two_houses <- ames_train %>%

select(Gr_Liv_Area, Year_Built) %>%

sample_n(2)

two_houses

## # A tibble: 2 x 2

## Gr_Liv_Area Year_Built

## <int>     <int>

## 1 896 1961

## 2 1511 2002

# Euclidean

dist(two_houses, method = ”euclidean”)

## 1

## 2 616

# Manhattan

dist(two_houses, method = ”manhattan”)

## 1

## 2 656

Image
FIGURE 8.2: Euclidean (A) versus Manhattan (B) distance.

There are other metrics to measure the distance between observations. For example, the Minkowski distance is a generalization of the Euclidean and Manhattan distances and is defined as

(j=1P|xajxbj|q)1q

(8.3)

where q > 0 (Han et al., 2011). When q = 2 the Minkowski distance equals the Euclidean distance and when q = 1 it is equal to the Manhattan distance. The Mahalanobis distance is also an attractive measure to use since it accounts for the correlation between two variables (De Maesschalck et al., 2000).

8.2.2    Preprocessing

Due to the squaring in Equation (8.1), the Euclidean distance is more sensitive to outliers. Furthermore, most distance measures are sensitive to the scale of the features. Data with features that have different scales will bias the distance measures as those predictors with the largest values will contribute most to the distance between two samples. For example, consider the three home below: home1 is a four bedroom built in 2008, home2 is a two bedroom built in the same year, and home3 is a three bedroom built a decade earlier.

home1

## # A tibble: 1 x 4

##   home Bedroom_AbvGr Year_Built id

##  <chr>  <int>   <int> <int>

## 1 home1   4   2008 423

home2

## # A tibble: 1 x 4

##   home Bedroom_AbvGr Year_Built id

##  <chr>  <int>   <int> <int>

## 1 home2   2   2008 424

home3

## # A tibble: 1 x 4

##   home Bedroom_AbvGr Year_Built id

##  <chr>  <int>   <int> <int>

## 1 home3   3   1998   6

The Euclidean distance between home1 and home3 is larger due to the larger difference in Year_Built with home2.

features <- c(”Bedroom_AbvGr”, ”Year_Built”)

# distance between home 1 and 2

dist(rbind (home1[,features], home2[,features]))

##  1

## 2 2

# distance between home 1 and 3

dist(rbind (home1[,features], home3[,features]))

##  1

## 2 10

However, Year_Built has a much larger range (1875–2010) than Bedroom_AbvGr (0–8). And if you ask most people, especially families with kids, the difference between 2 and 4 bedrooms is much more significant than a 10 year difference in the age of a home. If we standardize these features, we see that the difference between home1 and home2’s standardized value for Bedroom_AbvGr is larger than the difference between home1 and home3’s Year_Built. And if we compute the Euclidean distance between these standardized home features, we see that now home1 and home3 are more similar than home1 and home2.

home1_std

## # A tibble: 1 x 4

##  home Bedroom_AbvGr Year_Built  id

##  <chr>      <dbl>    <dbl> <int>

## 1 home1      1.38    1.21   423

home2_std

## # A tibble: 1 x 4

##  home Bedroom_AbvGr Year_Built id

##  <chr>  <dbl>   <dbl> <int>

## 1 home2   -1.03  1.21  424

home3_std

## # A tibble: 1 x 4

##  home Bedroom_AbvGr Year_Built id

##  <chr>  <dbl>   <dbl> <int>

## 1 home3   0.176  0.881   6

# distance between home 1 and 2

dist(rbind(home1_std[,features], home2_std[,features]))

##   1

## 2 2.42

# distance between home 1 and 3

dist(rbind(home1_std[,features], home3_home3_std[,features]))

##   1

## 2 1.25

In addition to standardizing numeric features, all categorical features must be one-hot encoded or encoded using another method (e.g., ordinal encoding) so that all categorical features are represented numerically. Furthermore, the KNN method is very sensitive to noisy predictors since they cause similar samples to have larger magnitudes and variability in distance values. Consequently, removing irrelevant, noisy features often leads to significant improvement.

8.3    Choosing k

The performance of KNNs is very sensitive to the choice of k. This was illustrated in Section 2.5.3 where low values of k typically overfit and large values often underfit. At the extremes, when k = 1, we base our prediction on a single observation that has the closest distance measure. In contrast, when k = n, we are simply using the average (regression) or most common class (classification) across all training samples as our predicted value.

There is no general rule about the best k as it depends greatly on the nature of the data. For high signal data with very few noisy (irrelevant) features, smaller values of k tend to work best. As more irrelevant features are involved, larger values of k are required to smooth out the noise. To illustrate, we saw in Section 3.8.3 that we optimized the RMSE for the ames_train data with k = 12. The ames_train data has 2054 observations, so such a small k likely suggests a strong signal exists. In contrast, the churn_train data has 1030 observations and Figure 8.3 illustrates that our loss function is not optimized until k = 271. Moreover, the max ROC value is 0.8078 and the overall proportion of attriting employees to non-attriting is 0.839. This suggest there is likely not a very strong signal in the Attrition data.

Image

When using KNN for classification, it is best to assess odd numbers for k to avoid ties in the event there is equal proportion of response levels (i.e. when k = 2 one of the neighbors could have class “0” while the other neighbor has class “1”).

# Create blueprint

blueprint <- recipe(Attrition ~ ., data = churn_train) %>%

step_nzv(all_nominal()) %>%

step_integer(contains(”Satisfaction”)) %>%

step_integer(WorkLifeBalance) %>%

step_integer(JobInvolvement) %>%

step_dummy(all_nominal(), -all_outcomes(),one_hot TRUE) = %>%

step_center(all_numeric(), -all_outcomes()) %>%

step_scale(all_numeric(), -all_outcomes())

# Create a resampling method

cv <- trainControl(

 method = ”repeatedcv”,

 number = 10,

 repeats = 5,

 classProbs = TRUE,

 summaryFunction = twoClassSummary

)

# Create a hyperparameter grid search

hyper_grid <- expand.grid(

 k = floor(seq(1, nrow (churn_train) /3,length.out = 20))

)

# Fit knn model and perform grid search

 knn_grid <- train(

 blueprint,

 data = churn_train,

 method = ”knn”,

 trControl = cv,

 tuneGrid = hyper_grid,

 metric = ”ROC”

)

ggplot(knn_grid)

8.4    MNIST example

The MNIST data set is significantly larger than the Ames housing and attrition data sets. Because we want this example to run locally and in a reasonable amount of time (< 1 hour), we will train our initial models on a random sample of 10,000 rows from the training set.

Image
FIGURE 8.3: Cross validated search grid results for Attrition training data where 20 values between 1 and 343 are assessed for k. When k = 1, the predicted value is based on a single observation that is closest to the target sample and when k = 343, the predicted value is based on the response with the largest proportion for 1/3 of the training sample.

set.seed(123)

index <- sample(nrow(mnist$train$images), size = 10000)

mnist_x <- mnist$train$images[index, ]

mnist_y <- factor(mnist$train$labels[index])

Recall that the MNIST data contains 784 features representing the darkness (0–255) of pixels in images of handwritten numbers (0–9). As stated in Section 8.2.2, KNN models can be severely impacted by irrelevant features. One culprit of this is zero, or near-zero variance features (see Section 3.4). Figure 8.4 illustrates that there are nearly 125 features that have zero variance and many more that have very little variation.

mnist_x %>%

as.data.frame() %>%

map_df(sd) %>%

gather(feature, sd) %>%

ggplot(aes(sd)) +

geom_histogram(binwidth = 1)

Figure 8.5 shows which features are driving this concern. Images (A)–(C) illustrate typical handwritten numbers from the test set. Image (D) illustrates which features in our images have variability. The white in the center shows that the features that represent the center pixels have regular variability whereas the black exterior highlights that the features representing the edge pixels in our images have zero or near-zero variability. These features have low variability in pixel values because they are rarely drawn on.

Image
FIGURE 8.4: Distribution of variability across the MNIST features. We see a significant number of zero variance features that should be removed.
Image
FIGURE 8.5: Example images (A)-(C) from our data set and (D) highlights near-zero variance features around the edges of our images.

By identifying and removing these zero (or near-zero) variance features, we end up keeping 249 of the original 784 predictors. This can cause dramatic improvements to both the accuracy and speed of our algorithm. Furthermore, by removing these upfront we can remove some of the overhead experienced by caret::train(). Furthermore, we need to add column names to the feature matrices as these are required by caret.

# Rename features

colnames(mnist_x) <- paste0(”V”, 1:ncol(mnist_x))

# Remove near zero variance features manually

nzv <- nearZeroVar(mnist_x)

index <- setdiff(1:ncol(mnist_x), nzv)

mnist_x <- mnist_x[, index]

Next we perform our search grid. Since we are working with a larger data set, using resampling (e.g., k-fold cross validation) becomes costly. Moreover, as we have more data, our estimated error rate produced by a simple train vs. validation set becomes less biased and variable. Consequently, the following CV procedure (cv) uses 70% of our data to train and the remaining 30% for validation. We can adjust the number of times we do this which becomes similar to the bootstrap procedure discussed in Section 2.4.

Image

Our hyperparameter grid search assesses 13 k values between 1–25 and takes approximately 3 minutes.

# Use train/validate resampling method

cv <- trainControl(

 method = ”LGOCV”,

 p = 0.7,

 number = 1,

 savePredictions = TRUE

)

# Create a hyperparameter grid search

hyper_grid <- expand.grid(k = seq(3, 25, by = 2))

# Execute grid search

knn_mnist <- train(

 mnist_x,

 mnist_y,

 method = ”knn”,

 tuneGrid = hyper_grid,

 preProc = c(”center”, ”scale”),

 trControl = cv

)

ggplot(knn_mnist)

Figure 8.6 illustrates the grid search results and our best model used 3 nearest neighbors and provided an accuracy of 93.8%. Looking at the results for each class, we can see that 8s were the hardest to detect followed by 2s, 3s, and 4s (based on sensitivity). The most common incorrectly predicted digit is 1 (specificity).

Image
FIGURE 8.6: KNN search grid results for the MNIST data

# Create confusion matrix

cm <- confusionMatrix(knn_mnist$pred$pred, knn_mnist$pred$obs)

cm$byClass[, c(1:2, 11)] # sensitivity, specificity, & accuracy

##       Sensitivity Specificity Balanced Accuracy

## Class: 0     0.964     0.996     0.980

## Class: 1     0.992     0.984     0.988

## Class: 2     0.916     0.996     0.956

## Class: 3     0.916     0.992     0.954

## Class: 4     0.870     0.996     0.933

## Class: 5     0.915     0.991     0.953

## Class: 6     0.980     0.989     0.984

## Class: 7     0.933     0.990     0.961

## Class: 8     0.822     0.998     0.910

## Class: 9     0.933     0.985     0.959

Feature importance for KNNs is computed by finding the features with the smallest distance measure (see Equation (8.1)). Since the response variable in the MNIST data is multiclass, the variable importance scores below sort the features by maximum importance across the classes.

# Top 20 most important features

vi <- varImp(knn_mnist)

vi

## ROC curve variable importance

##

##  variables are sorted by maximum importance

##  only 20 most important variables shown (out

##

##       X0  X1   X2  X3     X4    X5   X6    X7

## V435  100.0  100.0  100.0  100.0  100.0  100.0  100.0  100.0

## V407  99.4  99.4  99.4  99.4  99.4  99.4  99.4  99.4

## V463  97.9  97.9  97.9  97.9  97.9  97.9  97.9  97.9

## V379  97.4  97.4  97.4  97.4  97.4  97.4  97.4  97.4

## V434  95.9  95.9  95.9  95.9  95.9  95.9  96.7  95.9

## V380  96.1  96.1  96.1  96.1  96.1  96.1  96.1  96.1

## V462  95.6  95.6  95.6  95.6  95.6  95.6  95.6  95.6

## V408  95.4  95.4  95.4  95.4  95.4  95.4  95.4  95.4

## V352  93.5  93.5  93.5  93.5  93.5  93.5  93.5  93.5

## V490  93.1  93.1  93.1  93.1  93.1  93.1  93.1  93.1

## V406  92.9  92.9  92.9  92.9  92.9  92.9  92.9  92.9

## V437  70.8  60.4  92.8  52.0  71.1  83.4  75.5  91.1

## V351  92.4  92.4  92.4  92.4  92.4  92.4  92.4  92.4

## V409  70.5  76.1  88.1  54.5  79.9  77.7  84.9  91.9

## V436  90.0  90.0  90.9  90.0  90.0  90.0  91.4  90.0

## V464  76.7  76.5  90.2  76.5  76.5  76.6  77.7  82.0

## V491  89.5  89.5  89.5  89.5  89.5  89.5  89.5  89.5

## V598  68.0  68.0  88.4  68.0  68.0  84.9  68.0  88.2

## V465  63.1  36.6  87.7  38.2  50.7  80.6  59.9  84.3

## V433  63.7  55.7  76.7  55.7  57.4  55.7  87.6  68.4

##    X8  X9

## V435 100.0 80.6

## V407 99.4 75.2

## V463 97.9 83.3

## V379 97.4 86.6

## V434 95.9 76.2

## V380 96.1 88.0

## V462 95.6 83.4

## V408 95.4 75.0

## V352 93.5 87.1

## V490 93.1 81.9

## V406 92.9 74.6

## V437 52.0 70.8

## V351 92.4 82.1

## V409 52.7 76.1

## V436 90.0 78.8

## V464 76.5 76.7

## V491 89.5 77.4

## V598 68.0 38.8

## V465 57.1 63.1

## V433 55.7 63.7

We can plot these results to get an understanding of what pixel features are driving our results. The image shows that the most influential features lie around the edges of numbers (outer white circle) and along the very center. This makes intuitive sense as many key differences between numbers lie in these areas. For example, the main difference between a 3 and an 8 is whether the left side of the number is enclosed.

# Get median value for feature importance

imp <- vi$importance %>%

rownames_to_column(var = ”feature”) %>%

gather(response, imp, - feature) %>%

group_by(feature) %>%

summarize(imp = median(imp))

# Create tibble for all edge pixels

edges <- tibble(

 feature = paste0(”V”, nzv),

 imp = 0

)

# Combine and plot

imp <- rbind(imp, edges) %>%

mutate(ID = as.numeric(str_extract(feature,”\\d+”))) %>%

arrange(ID)

image(matrix(imp $imp,28, 28),col = gray(seq(0, 1, 0.05)),

xaxt=”n”, yaxt=”n”)

We can look at a few of our correct (left) and incorrect (right) predictions in Figure 8.8. When looking at the incorrect predictions, we can rationalize some of the errors (e.g., the actual 4 where we predicted a 1 has a strong vertical stroke compared to the rest of the number’s features, the actual 2 where we predicted a 0 is blurry and not well defined.)

Image
FIGURE 8.7: Image heat map showing which features, on average, are most influential across all response classes for our KNN model.

# Get a few accurate predictions

set.seed(9)

good <- knn_mnist$pred %>%

filter(pred == obs) %>%

sample_n(4)

# Get a few inaccurate predictions

set.seed(9)

bad <- knn_mnist$pred %>%

filter(pred != obs) %>%

sample_n(4)

combine <- bind_rows(good, bad)

# Get original feature set with all pixel features

set.seed(123)

index <- sample(nrow(mnist$train$images), 10000)

X <- mnist$train$images[index,]

# Plot results

par(mfrow = c(4, 2), mar=c(1, 1, 1, 1))

layout(matrix(seq_len(nrow(combine)), 4, 2, byrow = FALSE))

for(i in seq_len(nrow(combine))) {

image(matrix(X[combine$rowIndex[i],], 28, 28)[, 28:1],

   col = gray(seq(0, 1, 0.05)),

   main = paste(”Actual:”, combine$obs[i], ” ”,

    ”Predicted:”, combine$pred[i]),

   xaxt=”n”, yaxt=”n”)

}

8.5    Final thoughts

KNNs are a very simplistic, and intuitive, algorithm that can provide average to decent predictive power, especially when the response is dependent on the local structure of the features. However, a major drawback of KNNs is their computation time, which increases by n × p for each observation. Furthermore, since KNNs are a lazy learner, they require the model be run at prediction time which limits their use for real-time modeling. Some work has been done to minimize this effect; for example the FNN package (Beygelzimer et al., 2019) provides a collection of fast k-nearest neighbor search algorithms and applications such as cover-tree (Beygelzimer et al., 2006) and kd-tree (Robinson, 1981).

Although KNNs rarely provide the best predictive performance, they have many benefits, for example, in feature engineering and in data cleaning and preprocessing. We discussed KNN for imputation in Section 3.3.2. Bruce and Bruce (2017) discuss another approach that uses KNNs to add a local knowledge feature. This includes running a KNN to estimate the predicted output or class and using this predicted value as a new feature for downstream modeling. However, this approach also invites more opportunities for target leakage.

Other alternatives to traditional KNNs such as using invariant metrics, tangent distance metrics, and adaptive nearest neighbor methods are also discussed in Friedman et al. (2001) and are worth exploring.

Image
FIGURE 8.8: Actual images from the MNIST data set along with our KNN model’s predictions. Left column illustrates a few accurate predictions and the right column illustrates a few inaccurate predictions.