[Rd] Error with user defined split function in rpart (PR#7895)
wheelerb at imsweb.com
wheelerb at imsweb.com
Wed May 25 16:15:29 CEST 2005
Full_Name: Bill Wheeler
Version: 2.0.1
OS: Windows 2000
Submission from: (NULL) (67.130.36.229)
The program to reproduce the error is below. I am calling rpart with a
user-defined split function for a binary response variable and one continuous
independent variable. The split function works for some datasets but not
others.
The error is:
Error in "$<-.data.frame"(`*tmp*`, "yval2", value = c(0, 15, 10, 0.6, :
replacement has 5 rows, data has 1
#
# Test out the "user mode" functions, with a binary response
#
rm(list=ls(all=TRUE))
options(warn = 1);
library(rpart);
set.seed(7);
nobs <- 25;
mydata <- data.frame(indx=1:nobs);
mydata[, "y"] <- floor(runif(nobs, min=0, max=2));
mydata[, "x"] <- runif(nobs, min=0, max=2);
mydata$indx <- NULL;
################################################################
# The 'evaluation' function. Called once per node.
# Produce a label (1 or more elements long) for labeling each node,
# and a deviance. The latter is
# - of length 1
# - equal to 0 if the node is "pure" in some sense (unsplittable)
# - does not need to be a deviance: any measure that gets larger
# as the node is less acceptable is fine.
# - the measure underlies cost-complexity pruning, however
temp1 <- function(y, wt, parms) {
print("***** START: TEMP1 *****");
n <- length(y);
# Get the number of y's in each category
sumyEqual0 <- sum(y == 0);
sumyEqual1 <- sum(y == 1);
# Get the proportion of 0's and 1's
p0 <- sumyEqual0/n;
p1 <- sumyEqual1/n;
if (p0 >= p1) {
dev = sumyEqual1;
} else {
dev = sumyEqual0;
}
# Get the vector of labels
labels <- matrix(nrow=1, ncol=5);
# labels[1] is the fitted y category
# labels[2] is sum(y == 0)
# labels[3] is sum(y == 1)
# labels[4] is sum(y == 0)/n
# labels[5] is sum(y == 1)/n
if (p0 >= p1) {
labels[1] = 0;
} else {
labels[1] = 1;
}
labels[2] <- sumyEqual0;
labels[3] <- sumyEqual1;
labels[4] <- sumyEqual0/n;
labels[5] <- sumyEqual1/n;
ret <- list(label=labels, deviance=dev)
print("***** END: TEMP1 *****");
ret
}
# The split function, where most of the work occurs.
# Called once per split variable per node.
# If continuous=T
# The actual x variable is ordered
# y is supplied in the sort order of x, with no missings,
# return two vectors of length (n-1):
# goodness = goodness of the split, larger numbers are better.
# 0 = couldn't find any worthwhile split
# the ith value of goodness evaluates splitting obs 1:i vs (i+1):n
# direction= -1 = send "y< cutpoint" to the left side of the tree
# 1 = send "y< cutpoint" to the right
# this is not a big deal, but making larger "mean y's" move towards
# the right of the tree, as we do here, seems to make it easier to
# read
# If continuos=F, x is a set of integers defining the groups for an
# unordered predictor. In this case:
# direction = a vector of length m= "# groups". It asserts that the
# best split can be found by lining the groups up in this order
# and going from left to right, so that only m-1 splits need to
# be evaluated rather than 2^(m-1)
# goodness = m-1 values, as before.
#
# The reason for returning a vector of goodness is that the C routine
# enforces the "minbucket" constraint. It selects the best return value
# that is not too close to an edge.
temp2 <- function(y, wt, x, parms, continuous) {
print("***** START: TEMP2 *****");
n <- length(y)
# For binary y, get P(Y=0)/n and P(Y=1)/n at each split
temp <- cumsum(y*wt)[-n]
left.wt <- cumsum(wt)[-n]
right.wt <- sum(wt) - left.wt
lp <- temp/left.wt
rsum <- matrix(nrow=1, ncol=n-1, data=0);
for (i in seq(1, n-1))
{
for (j in seq(i+1, n))
{
rsum[i] <- rsum[i] + y[j];
}
}
rp <- rsum/right.wt
lprop <- 1 - lp;
rprop <- rp;
# Get the direction
direc <- matrix(nrow=1, ncol=length(lp), data=1);
for (i in seq(1, length(lp)))
{
if (lprop[i] >= rprop[i])
direc[i] <- -1;
}
goodness <- (lprop + rprop);
ret <- list(goodness= goodness, direction=direc)
print("***** END: TEMP2 *****");
ret
}
# The init function:
# fix up y to deal with offsets
# return a dummy parms list
# numresp is the number of values produced by the eval routine's "label"
# numy is the number of columns for y
# summary is a function used to print one line in summary.rpart
# text is a function used to put text on the plot in text.rpart
# In general, this function would also check for bad data, see rpart.poisson
# for instace.
temp3 <- function(y, offset, parms, wt) {
print("***** START: TEMP3 *****");
if (!is.null(offset)) y <- y-offset
ret <- list(y=y, parms=0, numresp=5, numy=1,
summary= function(yval, dev, wt, ylevel, digits ) {
paste(" mean=", format(signif(yval, digits)),
", MSE=" , format(signif(dev/wt, digits)),
sep='')
},
text= function(yval, dev, wt, ylevel, digits, n, use.n ) {
if(use.n) {paste(formatg(yval,digits),"\nn=", n,sep="")}
else{paste(formatg(yval,digits))}
})
print("***** END: TEMP3 *****");
ret
}
alist <- list(eval=temp1, split=temp2, init=temp3);
fit1 <- rpart(y ~ ., data=mydata, method=alist, control=list(cp=0));
More information about the R-devel
mailing list