[R] Rpart decision tree

Achim Zeileis Achim.Zeileis at uibk.ac.at
Wed Feb 15 20:41:34 CET 2012


On Wed, 15 Feb 2012, dlofaro wrote:

> Dear R-Users,
> I'm a R beginner and I have a similar problem:
> I fitted a SurvivalTree with rpart and I'm tryng to plot with the partykit
> package by the funcion as.party but it sends me this error:
>
>> plot(as.party(rpartSurv))
> Error in plot.constparty(as.party(rpartSurv)) :
>  node_surv not yet implemented/
>
> In effect after the library command the node_surv is not reported:
>
>> library(partykit)
>
> Attaching package: ?partykit?
> The following object(s) are masked from ?package:party?:
>    ctree, ctree_control, edge_simple, node_barplot, node_boxplot,
> node_inner, node_terminal
>
> How can I implement the node_surv plot?

One needs to adapt the old party:::node_surv code to the new 
infrastructure in "partykit". We haven't done so yet, but I put together a 
quick prototype for you. If you source the code below (including setting 
the "grapcon_generator" class), then the following should work:
plot(as.party(rpartSurv), terminal_panel = node_surv)

Hope that helps,
Z

node_surv <- function(obj,
 	 	      ylines = 2,
 		      id = TRUE, ...)
{
     ## extract response
     y <- obj$fitted[["(response)"]]
     stopifnot(inherits(y, "Surv"))
     stopifnot(require("survival"))

     ## helper functions
     mysurvfit <- function(y, weights, ...) structure(
         survival:::survfitKM(x = gl(1, NROW(y)), y = y, casewt = weights, ...),
 	class = "survfit")

     dostep <- function(x, y) {
         ### create a step function based on x, y coordinates
         ### modified from `survival:print.survfit'
         if (is.na(x[1] + y[1])) {
             x <- x[-1]
             y <- y[-1]
         }
         n <- length(x)
         if (n > 2) {
             # replace verbose horizonal sequences like
             # (1, .2), (1.4, .2), (1.8, .2), (2.3, .2), (2.9, .2), (3, .1)
             # with (1, .2), (3, .1).  They are slow, and can smear the
             # looks of the line type.
             dupy <- c(TRUE, diff(y[-n]) !=0, TRUE)
             n2 <- sum(dupy)

             #create a step function
             xrep <- rep(x[dupy], c(1, rep(2, n2-1)))
             yrep <- rep(y[dupy], c(rep(2, n2-1), 1))
             RET <- list(x = xrep, y = yrep)
         } else {
             if (n == 1) {
                 RET <- list(x = x, y = y)
             } else {
                 RET <- list(x = x[c(1,2,2)], y = y[c(1,1,2)])
             }
         }
         return(RET)
     }

     ### panel function for Kaplan-Meier curves in nodes
     rval <- function(node) {

         ## extract data
 	nid <- id_node(node)
 	dat <- data_party(obj, nid)
 	yn <- dat[["(response)"]]
 	wn <- dat[["(weights)"]]
 	if(is.null(wn)) wn <- rep(1, length(yn))

         ## get Kaplan-Meier curver in node
         km <- mysurvfit(yn, weights = wn, ...)
         a <- dostep(km$time, km$surv)

         ## set up plot
         yscale <- c(0, 1)
         xscale <- c(0, max(y[,1]))

         top_vp <- viewport(layout = grid.layout(nrow = 2, ncol = 3,
                            widths = unit(c(ylines, 1, 1),
                                          c("lines", "null", "lines")),
                            heights = unit(c(1, 1), c("lines", "null"))),
                            width = unit(1, "npc"),
                            height = unit(1, "npc") - unit(2, "lines"),
 			   name = paste("node_surv", nid, sep = ""))

         pushViewport(top_vp)
         grid.rect(gp = gpar(fill = "white", col = 0))

         ## main title
         top <- viewport(layout.pos.col=2, layout.pos.row=1)
         pushViewport(top)
 	mainlab <- paste(ifelse(id, paste("Node", nid, "(n = "), "n = "),
 	                 sum(wn), ifelse(id, ")", ""), sep = "")
         grid.text(mainlab)
         popViewport()

         plot <- viewport(layout.pos.col=2, layout.pos.row=2,
                          xscale=xscale, yscale=yscale,
 			 name = paste("node_surv", nid, "plot",
                          sep = ""))

         pushViewport(plot)
         grid.lines(a$x/max(a$x), a$y)
         grid.xaxis()
         grid.yaxis()
         grid.rect(gp = gpar(fill = "transparent"))
         upViewport(2)
     }

     return(rval)
}
class(node_surv) <- "grapcon_generator"



More information about the R-help mailing list