在训练方法中,tuneGrid和trControl之间是什么关系?

R中训练已知ML模型的首选方法是使用caret包及其通用的train方法。我的问题是tuneGridtrControl参数之间是什么关系?因为它们无疑是相关的,我无法通过阅读文档来弄清楚它们之间的关系...例如:

library(caret)  
# train and choose best model using cross validation
df <- ... # contains input data
control <- trainControl(method = "cv",number = 10,p = .9,allowParallel = TRUE)
fit <- train(y ~ .,method = "knn",data = df,tuneGrid = data.frame(k = seq(9,71,2)),trControl = control)

如果我在上面运行代码,那是怎么回事?如何将每个trainControl定义中包含90%数据的10个CV折叠与k的32个级别结合在一起?

更具体地说:

  • 参数k有32个级别。
  • 我也有10折简历。

k最近邻居模型是否训练了32 * 10次?还是其他?

a349158555 回答:在训练方法中,tuneGrid和trControl之间是什么关系?

是的,您是正确的。您将训练数据分为10组,例如1..10。从集合1开始,您使用全部2..10(训练数据的90%)训练模型,并在集合1上进行测试。对于集合2,集合3再次重复此过程。总共10次,您有32个要测试的k值,因此32 * 10 = 320。

您还可以使用trainControl中的returnResamp函数提取此简历结果。我将其简化为下面的k的3倍和4个值:

df <- mtcars
set.seed(100)
control <- trainControl(method = "cv",number = 3,p = .9,returnResamp="all")
fit <- train(mpg  ~ .,method = "knn",data = mtcars,tuneGrid = data.frame(k = 2:5),trControl = control)

resample_results = fit$resample
resample_results
       RMSE  Rsquared      MAE k Resample
1  3.502321 0.7772086 2.483333 2    Fold1
2  3.807011 0.7636239 2.861111 3    Fold1
3  3.592665 0.8035741 2.697917 4    Fold1
4  3.682105 0.8486331 2.741667 5    Fold1
5  2.473611 0.8665093 1.995000 2    Fold2
6  2.673429 0.8128622 2.210000 3    Fold2
7  2.983224 0.7120910 2.645000 4    Fold2
8  2.998199 0.7207914 2.608000 5    Fold2
9  2.094039 0.9620830 1.610000 2    Fold3
10 2.551035 0.8717981 2.113333 3    Fold3
11 2.893192 0.8324555 2.482500 4    Fold3
12 2.806870 0.8700533 2.368333 5    Fold3

# we manually calculate the mean RMSE for each parameter
tapply(resample_results$RMSE,resample_results$k,mean)
       2        3        4        5 
2.689990 3.010492 3.156360 3.162392

# and we can see it corresponds to the final fit result
fit$results
k     RMSE  Rsquared      MAE    RMSESD RsquaredSD     MAESD
1 2 2.689990 0.8686003 2.029444 0.7286489 0.09245494 0.4376844
2 3 3.010492 0.8160947 2.394815 0.6925154 0.05415954 0.4067066
3 4 3.156360 0.7827069 2.608472 0.3805227 0.06283697 0.1122577
4 5 3.162392 0.8131593 2.572667 0.4601396 0.08070670 0.1891581
本文链接:https://www.f2er.com/3054667.html

大家都在问