[R] predict function type class vs. prob
David Winsemius
dw|n@em|u@ @end|ng |rom comc@@t@net
Sat Sep 23 21:24:18 CEST 2023
That's embarrassing. Apologies for the garbles HTML posting. I'll see if
this is more readable:
On 9/23/23 05:30, Rui Barradas wrote:
> Às 11:12 de 22/09/2023, Milbert, Sabine (LGL) escreveu:
>> Dear R Help Team,
>>
>> My research group and I use R scripts for our multivariate data
>> screening routines. During routine use, we encountered some
>> inconsistencies within the predict() function of the R Stats Package.
On 9/23/23 05:30, Rui Barradas wrote:
> Às 11:12 de 22/09/2023, Milbert, Sabine (LGL) escreveu:
>> Dear R Help Team,
>>
>> My research group and I use R scripts for our multivariate data
screening routines. During routine use, we encountered some
inconsistencies within the predict() function of the R Stats Package.
In addition to Rui's correction to this misstatement, the caret package
is really a meta package that attempts to implement an umbrella
framework for a vast array of tools from a wide variety of sources. It
is an immense effort but not really a part of the core R project. The
correct place to file issues is found in the DESCRIPTION file:
URL: https://github.com/topepo/caret/
BugReports: https://github.com/topepo/caret/issues
If you use `str` on an object constructed with caret, you discover
that the `predict` function is actually not in the main workspace but
rather embedded in the fit-object itself. I think this is a rather
general statement regarding the caret universe, and so I expect that
your fit -objects can be examined for the code that predict.train will
use with this approach. Your description of your analysis methods was
rather incompletely specified, and I will put an appendix of "svm"
methods that might be specified after my demonstration using code. (Note
that I do not see a caret "weights" hyper-parameter for the "svmLinear"
method which is actually using code from pkg:kernlab.)
library(caret)
svmFit <- train(Species ~ ., data = iris, method = "svmLinear",
trControl = trainControl(method = "cv"))
class(svmFit)
#[1] "train" "train.formula"
str(predict(svmFit))
Factor w/ 3 levels "setosa","versicolor",..: 1 1 1 1 1 1 1 1 1 1 ...
str(svmFit)
#---screen output-------------
List of 24
$ method : chr "svmLinear"
$ modelInfo :List of 13
..$ label : chr "Support Vector Machines with Linear Kernel"
..$ library : chr "kernlab"
..$ type : chr [1:2] "Regression" "Classification"
..$ parameters:'data.frame': 1 obs. of 3 variables:
.. ..$ parameter: chr "C"
.. ..$ class : chr "numeric"
.. ..$ label : chr "Cost"
..$ grid :function (x, y, len = NULL, search = "grid")
..$ loop : NULL
..$ fit :function (x, y, wts, param, lev, last, classProbs, ...)
..$ predict :function (modelFit, newdata, submodels = NULL)
..$ prob :function (modelFit, newdata, submodels = NULL)
..$ predictors:function (x, ...)
..$ tags : chr [1:5] "Kernel Method" "Support Vector Machines"
"Linear Regression" "Linear Classifier" ...
..$ levels :function (x)
..$ sort :function (x)
$ modelType : chr "Classification"
# ---- large amount of screen output omitted------
# note that the class of svmFit$modelInfo$predict is 'function'
# and its code at least to this particular svm method of which there are
about 10!
svmFit$modelInfo$predict
#---- screen output ------
function (modelFit, newdata, submodels = NULL)
{
svmPred <- function(obj, x) {
hasPM <- !is.null(unlist(obj using prob.model))
if (hasPM) {
pred <- kernlab::lev(obj)[apply(kernlab::predict(obj,
x, type = "probabilities"), 1, which.max)]
}
else pred <- kernlab::predict(obj, x)
pred
}
out <- try(svmPred(modelFit, newdata), silent = TRUE)
if (is.character(kernlab::lev(modelFit))) {
if (class(out)[1] == "try-error") {
warning("kernlab class prediction calculations failed;
returning NAs")
out <- rep("", nrow(newdata))
out[seq(along = out)] <- NA
}
}
else {
if (class(out)[1] == "try-error") {
warning("kernlab prediction calculations failed; returning
NAs")
out <- rep(NA, nrow(newdata))
}
}
if (is.matrix(out))
out <- out[, 1]
out
}
<bytecode: 0x561277d4ec50>
--
David
>> Through internal research, we were unable to find the reason for
this and have decided to contact your help team with the following issue:
>>
>> The predict() function is used once to predict the class membership
of a new sample (type = "class") on a trained linear SVM model for
distinguishing two classes (using the caret package). It is then used to
also examine the probability of class membership (type = "prob"). Both
are then presented in an R shiny output. Within the routine, we noticed
two samples (out of 100+) where the class prediction and probability
prediction did not match. The prediction probabilities of one class
(52%) did not match the class membership within the predict function. We
use the same seed and the discrepancy is reproducible in this sample.
The same problem did not occur in other trained models (lda, random
forest, radial SVM...).
Support Vector Machines with Boundrange String Kernel (method =
'svmBoundrangeString')
For classification and regression using package kernlab with tuning
parameters:
length (length, numeric)
Cost (C, numeric)
Support Vector Machines with Class Weights (method = 'svmRadialWeights')
For classification using package kernlab with tuning parameters:
Sigma (sigma, numeric)
Cost (C, numeric)
Weight (Weight, numeric)
Support Vector Machines with Exponential String Kernel (method =
'svmExpoString')
For classification and regression using package kernlab with tuning
parameters:
lambda (lambda, numeric)
Cost (C, numeric)
Support Vector Machines with Linear Kernel (method = 'svmLinear')
For classification and regression using package kernlab with tuning
parameters:
Cost (C, numeric)
Support Vector Machines with Linear Kernel (method = 'svmLinear2')
For classification and regression using package e1071 with tuning
parameters:
Cost (cost, numeric)
Support Vector Machines with Polynomial Kernel (method = 'svmPoly')
For classification and regression using package kernlab with tuning
parameters:
Polynomial Degree (degree, numeric)
Scale (scale, numeric)
Cost (C, numeric)
Support Vector Machines with Radial Basis Function Kernel (method =
'svmRadial')
For classification and regression using package kernlab with tuning
parameters:
Sigma (sigma, numeric)
Cost (C, numeric)
Support Vector Machines with Radial Basis Function Kernel (method =
'svmRadialCost')
For classification and regression using package kernlab with tuning
parameters:
Cost (C, numeric)
Support Vector Machines with Radial Basis Function Kernel (method =
'svmRadialSigma')
For classification and regression using package kernlab with tuning
parameters:
Sigma (sigma, numeric)
Cost (C, numeric)
Note: This SVM model tunes over the cost parameter and the RBF kernel
parameter sigma. In the latter case, using tuneLength will, at most,
evaluate six values of the kernel parameter. This enables a broad search
over the cost parameter and a relatively narrow search over sigma
Support Vector Machines with Spectrum String Kernel (method =
'svmSpectrumString')
For classification and regression using package kernlab with tuning
parameters:
length (length, numeric)
Cost (C, numeric)
>>
>> Is there a weighing of classes within the prediction function or is
the classification limit not at 50%/a majority vote? Or do you have
another explanation for this discrepancy, please let us know.
>>
>> PS: If this is an issue based on the model training function of the
caret package and therefore not your responsibility, please let us know.
>>
>> Thank you in advance for your support!
>>
>> Yours sincerely,
>> Sabine Milbert
>>
>> [[alternative HTML version deleted]]
>>
>> ______________________________________________
>> R-help using r-project.org mailing list -- To UNSUBSCRIBE and more, see
>> https://stat.ethz.ch/mailman/listinfo/r-help
>> PLEASE do read the posting guide
http://www.R-project.org/posting-guide.html
>> and provide commented, minimal, self-contained, reproducible code.
> Hello,
>
> I cannot tell what is going on but I would like to make a correction
to your post.
>
> predict() is a generic function with methods for objects of several
classes in many packages. In base package stats you will find methods
for objects (fits) of class lm, glm and others, see ?predict.
>
> The method you are asking about is predict.train, defined in package
caret, not in package stats.
> to see what predict method is being called, check
>
>
> class(your_fit)
>
>
> Hope this helps,
>
> Rui Barradas
>
> ______________________________________________
> R-help using r-project.org mailing list -- To UNSUBSCRIBE and more, see
> https://stat.ethz.ch/mailman/listinfo/r-help
> PLEASE do read the posting guide
http://www.R-project.org/posting-guide.html
> and provide commented, minimal, self-contained, reproducible code.
>> Through internal research, we were unable to find the reason for this
>> and have decided to contact your help team with the following issue:
>>
>> The predict() function is used once to predict the class membership
>> of a new sample (type = "class") on a trained linear SVM model for
>> distinguishing two classes (using the caret package). It is then used
>> to also examine the probability of class membership (type = "prob").
>> Both are then presented in an R shiny output. Within the routine, we
>> noticed two samples (out of 100+) where the class prediction and
>> probability prediction did not match. The prediction probabilities of
>> one class (52%) did not match the class membership within the predict
>> function. We use the same seed and the discrepancy is reproducible in
>> this sample. The same problem did not occur in other trained models
>> (lda, random forest, radial SVM...).
>>
>> Is there a weighing of classes within the prediction function or is
>> the classification limit not at 50%/a majority vote? Or do you have
>> another explanation for this discrepancy, please let us know.
>>
>> PS: If this is an issue based on the model training function of the
>> caret package and therefore not your responsibility, please let us know.
>>
>> Thank you in advance for your support!
>>
>> Yours sincerely,
>> Sabine Milbert
>>
>> [[alternative HTML version deleted]]
>>
>> ______________________________________________
>> R-help using r-project.org mailing list -- To UNSUBSCRIBE and more, see
>> https://stat.ethz.ch/mailman/listinfo/r-help
>> PLEASE do read the posting guide
>> http://www.R-project.org/posting-guide.html
>> and provide commented, minimal, self-contained, reproducible code.
> Hello,
>
> I cannot tell what is going on but I would like to make a correction
> to your post.
>
> predict() is a generic function with methods for objects of several
> classes in many packages. In base package stats you will find methods
> for objects (fits) of class lm, glm and others, see ?predict.
>
> The method you are asking about is predict.train, defined in package
> caret, not in package stats.
> to see what predict method is being called, check
>
>
> class(your_fit)
>
>
> Hope this helps,
>
> Rui Barradas
>
> ______________________________________________
> R-help using r-project.org mailing list -- To UNSUBSCRIBE and more, see
> https://stat.ethz.ch/mailman/listinfo/r-help
> PLEASE do read the posting guide
> http://www.R-project.org/posting-guide.html
> and provide commented, minimal, self-contained, reproducible code.
More information about the R-help
mailing list