## Script to explore properties of decision trees in 2D ## Author : Sylvain Robert ## FS 2014, 16.05.2014 ## Content: ## 1. training ## 2. pruning ## 3. variance of trees ## 4. bagging ################################################################### ## 1. training .pardefault <- par(no.readonly = TRUE) ## load the data: banana <- read.table(file = 'banana.dat', sep = ',') ## subsample: set.seed(11) banana <- banana[sample.int(nrow(banana), 200),] ## format: colnames(banana) <- c('X1', 'X2', 'Y') banana[,'Y'] <- as.factor(banana[,'Y']) ## plot: plot(X2 ~ X1, data = banana, col = c('pink', 'palegreen3')[Y], pch = 19) ## overfit a tree: require(rpart) tree <- rpart(Y ~ ., data = banana, control = rpart.control(cp = 0.0, minsplit = 10)) ## nice plot: require(rpart.plot) prp(tree, extra = 1, type = 1, box.col = c('pink', 'palegreen3')[tree$frame$yval]) ## function to plot 2D partioning: tree.2d <- function(tree, data = banana, col = c('pink', 'palegreen3'), ...) { x1 <- seq(-2.5,2.5, length.out = 200) x2 <- seq(-2.5,2.5, length.out = 200) newdat <- expand.grid(X1 = x1, X2 = x2) y.pred <- predict(tree, newdata = newdat, type = 'class') ## z <- matrix(as.numeric(y.pred), nrow = length(x1), ncol = length(x2)) image(x1, x2, z, col = col, ...) ## plot the data: points(X2~X1, data = data, col = col[Y], pch = 19) points(X2~X1, data = data) } ## partioning for the overfitted tree: tree.2d(tree, axes = FALSE, asp = 1) ################################################################### ## 2. PRUNING: ## explore trees of decreasing complexity (==increasing Cp) complexities <- sort(unique(tree$frame$complexity)) for (i in seq_along(complexities)) { complexity <- complexities[i] cols <- ifelse(tree$frame$complexity > complexity, 1, "darkgray") par(mfrow = c(2,1), mar = c(.5,0,1,0)) ## plot tree with shaded branches that are pruned prp(tree, col = cols, branch.col = cols, split.col = cols, main = paste('CP =', round(complexity,3)), box.col = c('pink', 'palegreen3')[tree$frame$yval]) p.tree <- prune.rpart(tree, cp = complexity) tree.2d(p.tree, asp = 1, axes = FALSE) devAskNewPage(TRUE) } devAskNewPage(FALSE) ## Select optimal tree: par(.pardefault) ### 1 std-error rule: plotcp(tree) abline(v = 7, col = 'red') ### choose tree size = 8 ==> 7 splits printcp(tree) ## optimal Cp: opt.cp <- tree$cptable[7, 'CP'] ## Prune the tree: opt.tree <- prune.rpart(tree, cp = opt.cp) ## plot it: prp(opt.tree, extra = 1, type = 1, box.col = c('pink', 'palegreen3')[tree$frame$yval]) tree.2d(opt.tree, asp = 1, axes = FALSE) ################################################################### ## 3. Large variance of trees ## load the full dataset banana.full <- read.table(file = 'banana.dat', sep = ',') colnames(banana.full) <- c('X1', 'X2', 'Y') banana.full[,'Y'] <- as.factor(banana.full[,'Y']) ## Repeat this as many times as you want: ## Observe how the trees are really different! High variance!! ## if you increase n, it changes less (of course) n <- 200 # 800 resample <- 'yes' par(.pardefault) while(! resample %in% c('n','no')) { ## sample the data: banana <- banana.full[sample.int(nrow(banana.full), n),] ## fit a tree: tree <- rpart(Y ~ ., data = banana, control = rpart.control(cp = 0.0, minsplit = 10)) ## automatic 1 std-error rule: min.ind <- which.min(tree$cptable[,"xerror"]) min.lim <- tree$cptable[min.ind, "xerror"] + tree$cptable[min.ind, "xstd"] cp.opt <- tree$cptable[(tree$cptable[,"xerror"] < min.lim),"CP"][1] ## pruning opt.tree <- prune.rpart(tree, cp = cp.opt) ## plotting par(mfrow = c(2,1), mar = c(0.5,0,1,0)) prp(opt.tree, extra = 1, type = 1, main = paste('CP =', round(cp.opt,3)), box.col = c('pink', 'palegreen3')[tree$frame$yval]) tree.2d(opt.tree, asp = 1, axes = FALSE, xlim = c(-3,3), ylim = c(-3,3)) print('Do you want to continue: Y/n') resample <- readline() } ################################################################### ## 4. BAGGING of trees: ## a tree has a lot of variance, as we just saw. ## a natural idea would be to take the average decision ## over many trees that we train on different bootstrapped sample. ## We explore here this idea. ## Pushing this idea further would lead to random Forest, but ## this will be for another time! ## function to train one tree on a subsample of data ## determined by ind.training ## return predicted probabilities for each classes on a test set: one.tree <- function(data, ind.training, x.test) { tree <- rpart(Y ~ ., data = data[ind.training, ], control = rpart.control(cp = 0,minsplit = 1)) ## choose optimal cp according to 1-std-error rule: min.ind <- which.min(tree$cptable[,"xerror"]) min.lim <- tree$cptable[min.ind, "xerror"] + tree$cptable[min.ind, "xstd"] cp.opt <- tree$cptable[(tree$cptable[,"xerror"] < min.lim),"CP"][1] tree.sample <- prune.rpart(tree, cp = cp.opt) y <- predict(tree.sample, newdata = x.test, type = 'prob') } ## grid: x1 <- seq(-2.5,2.5, length.out = 100) x2 <- seq(-2.5,2.5, length.out = 100) newdat <- expand.grid(X1 = x1, X2 = x2) ## chose a particular dataset: banana <- banana.full[sample.int(nrow(banana.full), 200),] ## format: colnames(banana) <- c('X1', 'X2', 'Y') banana[,'Y'] <- as.factor(banana[,'Y']) y.colors <- c('pink', 'palegreen3')[banana$Y] ## Train B trees on B different bootstrap samples: ## try different B B <- 100 # 1, 10, 100, 500, 1000 n <- 200 many.tree <- replicate(B, one.tree(banana, sample.int(nrow(banana),n, replace = TRUE), x.test = newdat)) ## average over the B repeatition: mean.tree <- apply(many.tree, c(1,2), mean) ## plot majority class: ## Observe how the partioning becomes more and more 'curvy' and ## can follow better the 'banana' shape of the data! y.pred <- ifelse(mean.tree[,1] > 0.5, 1, 2) z <- matrix(as.numeric(y.pred), nrow = length(x1), ncol = length(x2)) par(.pardefault) image(x1, x2, z, col = c('pink', 'palegreen3'), main = paste('Bag of trees of size = ', B)) ## plot probability (intensity of color = probability of this class) rbPal <- colorRampPalette(c('palegreen3','pink')) y.prob <- rbPal(10)[as.numeric(cut(mean.tree[,1],breaks = 10))] z.prob <- matrix(as.numeric(mean.tree[,1]), nrow = length(x1), ncol = length(x2)) image(x1, x2, z.prob, col = rbPal(20), main = paste('Bag of trees of size = ', B)) ## overlay with plot of the original data: points(X2~X1, data = banana, col = y.colors, pch = 19) points(X2~X1, data = banana) ## more errors in zone of lower probability ## low probability == more uncertainty