[R] Training nnet in two ways, trying to understand the performance difference - with (i hope!) commented, minimal, self-contained, reproducible code
Tony Breyal
tony.breyal at googlemail.com
Wed Feb 18 15:32:33 CET 2009
hmm, further investigation shows that two different fits are used.
Why did nnet decide to use different fits when the data is basically
the same (2 factors in nn1 and binary in nn2)?
# uses an entropy fit (maximum conditional likelihood)
> nn1
a 57-3-1 network with 178 weights
inputs: make address all num3d our over [etc...]
output(s): type
options were - entropy fitting decay=0.1
# uses the default least squares fit
> nn2
a 57-3-1 network with 178 weights
inputs: make address all num3d our over [etc...]
output(s): as.numeric(type) - 1
options were - decay=0.1
again, many thanks for any help.
Tony
On 18 Feb, 11:40, Tony Breyal <tony.bre... at googlemail.com> wrote:
> Dear all,
>
> Objective: I am trying to learn about neural networks. I want to see
> if i can train an artificial neural network model to discriminate
> between spam and nonspam emails.
>
> Problem: I created my own model (example 1 below) and got an error of
> about 7.7%. I created the same model using the Rattle package (example
> 2 below, based on rattles log script) and got a much better error of
> about 0.073%.
>
> Question 1: I don't understand why the rattle script gives a better
> result? I must therefore be doing something wrong in my own script
> (example 1) and would appreciate some insight :-)
>
> Question 2: As rattle gives a much better result, i would be happy to
> use it's r-code instead of my own. How can I interpret it's
> predictions as either being either 'spam' or 'nonspam'? I have looked
> at the type='class' parameter in ?predict.nnet but it doesn't apply to
> this situation i believe.
>
> Below i give commented, minimal, self-contained and reproducible code.
> (if you ignore the output, it really is very few lines of code and
> therefore minimal i believe?)
>
> ## load library
>
> >library(nnet)
>
> ## Load in spam dataset from package kernlab
>
> >data(list = "spam", package = "kernlab")
> >set.seed(42)
> >my.sample <- sample(nrow(spam), 3221)
> >spam.train <- spam[my.sample, ]
> >spam.test <- spam[-my.sample, ]
>
> ## Example 1 - my own code
> # train artificial neural network (nn1)>( nn1 <- nnet(type~., data=spam.train, size=3, decay=0.1, maxit=1000) )
>
> # predict spam.test dataset on nn1> ( nn1.pr.test <- predict(nn1, spam.test, type='class') )
>
> [1] "spam" "spam" "spam" "spam" "nonspam" "spam"
> "spam"
> [etc...]
> # error matrix>(nn1.test.tab<-table(spam.test$type, nn1.pr.test, dnn=c('Actual', 'Predicted')))
>
> Predicted
> Actual nonspam spam
> nonspam 778 43
> spam 63 496
> # Calucate overall error percentage ~ 7.68%>(nn1.test.perf <- 100 * (nn1.test.tab[2] + nn1.test.tab[3]) / sum(nn1.test.tab))
>
> [1] 7.68116
>
> ## Example 2 - code based on rattles log script
> # train artifical neural network>nn2<-nnet(as.numeric(type)-1~., data=spam.train, size=3, decay=0.1, maxit=1000)
>
> # predict spam.test dataset on nn2.
> # ?predict.nnet does have the parameter type='class', but i can't use
> that here as an option>nn2.pr.test <- predict(nn2, spam.test)
>
> [,1]
> 3 0.984972396013
> 4 0.931149225918
> 10 0.930001139978
> 13 0.923271300707
> 21 0.102282256315
> [etc...]
> # error matrix>( nn2.test.tab <- round(100*table(nn2.pr.test, spam.test$type,
>
> dnn=c("Predicted", "Actual"))/length
> (nn2.pr.test)) )
> Actual
> Predicted nonspam spam
> -0.741896935969825 0 0
> -0.706473834678304 0 0
> -0.595327594045746 0 0
> [etc...]
> # calucate overall error percentage. Am not sure how this line works
> tbh,
> # and i think it should be multiplied by 100. I got this from rattle's
> log script.>(function(x){return((x[1,2]+x[2,1])/sum(x))})
>
> (table(nn2.pr.test, spam.test$type, dnn=c("Predicted",
> "Actual")))
> [1] 0.0007246377
> # i'm guessing the above should be ~0.072%
>
> I know the above probably seems complicated, but any help that can be
> offered would be much appreicated.
>
> Thank you kindly in advance,
> Tony
>
> OS = Windows Vista Ultimate, running R in admin mode> sessionInfo()
>
> R version 2.8.1 (2008-12-22)
> i386-pc-mingw32
>
> locale:
> LC_COLLATE=English_United Kingdom.1252;LC_CTYPE=English_United Kingdom.
> 1252;LC_MONETARY=English_United Kingdom.
> 1252;LC_NUMERIC=C;LC_TIME=English_United Kingdom.1252
>
> attached base packages:
> [1] grid stats graphics grDevices utils datasets
> methods base
>
> other attached packages:
> [1] RGtk2_2.12.8 vcd_1.2-2 colorspace_1.0-0
> MASS_7.2-45 rattle_2.4.8 nnet_7.2-45
>
> loaded via a namespace (and not attached):
> [1] tools_2.8.1
>
> ______________________________________________
> R-h... at r-project.org mailing listhttps://stat.ethz.ch/mailman/listinfo/r-help
> PLEASE do read the posting guidehttp://www.R-project.org/posting-guide.html
> and provide commented, minimal, self-contained, reproducible code.
More information about the R-help
mailing list