In k nearest neighbors, the shape of the neighborhood is usually circular. Discriminant Adaptive Nearest Neighbors (dann) is a variation of k nearest neighbors where the shape of the neighborhood is data driven. The neighborhood is elongated along class boundaries and shrunk in the orthogonal direction. See Discriminate Adaptive Nearest Neighbor Classification by Hastie and Tibshirani.
This package brings the dann package into the tidymodels ecosystem.
Models:
In this example, simulated data is made. The overall trend is a circle inside a square.
::opts_chunk$set(echo = TRUE, fig.width = 10, fig.height = 10)
knitrlibrary(parsnip)
library(rsample)
library(scales)
library(dials)
library(tune)
library(yardstick)
library(workflows)
library(tidydann)
library(dplyr, warn.conflicts = FALSE)
library(ggplot2)
library(mlbench)
# Create training data
set.seed(1)
<- mlbench.circle(700, 2) |>
circle_data ::as_tibble()
tibblecolnames(circle_data) <- c("X1", "X2", "Y")
set.seed(42)
<- initial_split(circle_data, prop = .80)
split <- training(split)
train <- testing(split)
test
ggplot(train, aes(x = X1, y = X2, colour = as.factor(Y))) +
geom_point() +
labs(title = "Train Data", colour = "Y")
AUC is nearly perfect for these data.
<- nearest_neighbor_adaptive(neighbors = 5, neighborhood = 50, matrix_diagonal = 1) |>
model set_engine("dann") |>
fit(formula = Y ~ X1 + X2, data = train)
<- model |>
testPredictions predict(new_data = test, type = "prob")
<- test |>
testPredictions select(Y) |>
bind_cols(testPredictions)
|>
testPredictions roc_auc(truth = Y, event_level = "first", .pred_1)
#> # A tibble: 1 × 3
#> .metric .estimator .estimate
#> <chr> <chr> <dbl>
#> 1 roc_auc binary 0.991
In general, dann will struggle as unrelated variables are intermingled with informative variables. To deal with this, sub_dann projects the data onto a unique subspace and then calls dann on the subspace. In the below example there are 2 related variables and 5 that are unrelated.
######################
# Circle data with unrelated variables
######################
# Create training data
set.seed(1)
<- mlbench.circle(700, 2) |>
circle_data ::as_tibble()
tibblecolnames(circle_data) <- c("X1", "X2", "Y")
# Add 5 unrelated variables
<- circle_data |>
circle_data mutate(
U1 = runif(700, -1, 1),
U2 = runif(700, -1, 1),
U3 = runif(700, -1, 1),
U4 = runif(700, -1, 1),
U5 = runif(700, -1, 1)
)
set.seed(42)
<- initial_split(circle_data, prop = .80)
split <- training(split)
train <- testing(split) test
Without careful feature selection, dann’s performance suffers.
<- nearest_neighbor_adaptive(neighbors = 5, neighborhood = 50, matrix_diagonal = 1) |>
model set_engine("dann") |>
fit(formula = Y ~ ., data = train)
<- model |>
testPredictions predict(new_data = test, type = "prob")
<- test |>
testPredictions select(Y) |>
bind_cols(testPredictions)
|>
testPredictions roc_auc(truth = Y, event_level = "first", .pred_1)
#> # A tibble: 1 × 3
#> .metric .estimator .estimate
#> <chr> <chr> <dbl>
#> 1 roc_auc binary 0.759
To deal with uninformative variables, a sub_dann model with tuned parameters is trained.
# define workflow
<-
sub_dann_spec nearest_neighbor_adaptive(
neighbors = tune(),
neighborhood = tune(),
matrix_diagonal = tune(),
weighted = tune(),
sphere = tune(),
num_comp = tune()
|>
) set_engine("sub_dann") |>
set_mode("classification")
<- workflow() |>
sub_dann_wf add_model(sub_dann_spec) |>
add_formula(Y ~ .)
# define grid
set.seed(2)
<- neighborhood() |> get_n_frac(train, frac = .20)
finalized_neighborhood <- num_comp() |> get_p(train[-1])
finalized_num_comp <- grid_random(
grid neighbors(),
finalized_neighborhood,matrix_diagonal(),
weighted(),
sphere(),
finalized_num_comp,size = 30,
filter = neighbors <= neighborhood
)
# tune
set.seed(123)
<- vfold_cv(data = train, v = 5)
cv <- sub_dann_wf |>
sub_dann_tune_res tune_grid(resamples = cv, grid = grid)
<- sub_dann_tune_res |>
best_model select_best(metric = "roc_auc")
With the best hyperparameters found, a final model on all training data is fit. Test AUC improved.
# retrain on all data
<-
final_model |>
sub_dann_wf finalize_workflow(best_model) |>
last_fit(split)
|>
final_model collect_metrics() |>
filter(.metric == "roc_auc") |>
select(.metric, .estimator, .estimate)
#> # A tibble: 1 × 3
#> .metric .estimator .estimate
#> <chr> <chr> <dbl>
#> 1 roc_auc binary 0.985