R:如何提高梯度提升模型的拟合度

我尝试使用iris包中的gbmgbm数据集拟合梯度增强模型(弱学习者的最大深度= 2棵树)。我将学习次数为M = 1000的迭代次数设置为learning.rate = 0.001。然后,我将结果与回归树的结果进行了比较(使用rpart)。但是,似乎回归树的性能优于梯度提升模型。这是什么原因呢?以及如何改善梯度增强模型的性能?我认为学习率0.001应该足以满足1000次迭代/增强树。

library(rpart)
library(gbm)
data(iris)

train.dat <- iris[1:100,]
test.dat <- iris[101:150,]

learning.rate <- 0.001
M <- 1000
gbm.model <- gbm(Sepal.Length ~ .,data = train.dat,distribution = "gaussian",n.trees = M,interaction.depth = 2,shrinkage = learning.rate,bag.fraction = 1,train.fraction = 1)
yhats.gbm <- predict(gbm.model,newdata = test.dat,n.trees = M)

tree.mod <- rpart(Sepal.Length ~ .,data = train.dat)
yhats.tree <- predict(tree.mod,newdata = test.dat)

> sqrt(mean((test.dat$Sepal.Length - yhats.gbm)^2))
[1] 1.209446
> sqrt(mean((test.dat$Sepal.Length - yhats.tree)^2))
[1] 0.6345438
xyhjojoy 回答:R:如何提高梯度提升模型的拟合度

在虹膜数据集中,有3种不同的物种,前50行是setosa,后50行是杂色,最后50行是弗吉尼亚。因此,我认为最好将行混合在一起,并让“种类”列相关。

library(ggplot2)
ggplot(iris,aes(x=Sepal.Width,y=Sepal.Length,col=Species)) + geom_point()

enter image description here

第二,您应该对几个重复进行此操作,以查看其不确定性。为此,我们可以使用插入符号,并且可以事先定义训练样本并提供固定的网格。我们感兴趣的是交叉验证训练中的错误,与您正在执行的操作类似:

set.seed(999)
idx = split(sample(nrow(iris)),1:nrow(iris) %% 3)
tr = trainControl(method="cv",index=idx)
this_grid = data.frame(interaction.depth=2,shrinkage=0.001,n.minobsinnode=10,n.trees=1000)

gbm_fit = train(Sepal.Width〜。,data = iris,method =“ gbm”, distribution =“ gaussian”,tuneGrid = tg,trControl = tr)

然后我们使用相同的样本来拟合rpart:

#the default for rpart
this_grid = data.frame(cp=0.01)
rpart_fit = train(Sepal.Width ~ .,data=iris,method="rpart",trControl=tr,tuneGrid=this_grid)

最后我们将它们进行比较,它们非常相似:

gbm_fit$resample
       RMSE  Rsquared       MAE Resample
1 0.3459311 0.5000575 0.2585884        0
2 0.3421506 0.4536114 0.2631338        1
3 0.3428588 0.5600722 0.2693837        2

       RMSE  Rsquared       MAE Resample
1 0.3492542 0.3791232 0.2695451        0
2 0.3320841 0.4276960 0.2550386        1
3 0.3284239 0.4343378 0.2570833        2

因此,我怀疑上面的示例中有些奇怪。同样,它始终取决于您的数据,对于某些数据(例如虹膜),rpart可能已经足够好了,因为存在非常强大的预测指标。另外,对于gbm等复杂模型,您很可能需要使用上述方法进行训练以找到最佳参数。

本文链接:https://www.f2er.com/2523900.html

大家都在问