Constructs a learner class object for xgboost::xgboost.
Usage
learner_xgboost(
formula,
max_depth = 2L,
eta = 1,
nrounds = 2L,
subsample = 1,
lambda = 1,
verbose = 0,
objective = "reg:squarederror",
info = paste("xgboost", objective),
learner.args = NULL,
...
)Arguments
- formula
(formula) Formula specifying response and design matrix.
- max_depth
(integer) Maximum depth of a tree.
- eta
(numeric) Learning rate.
- nrounds
max number of boosting iterations.
- subsample
(numeric) Subsample ratio of the training instance.
- lambda
(numeric) L2 regularization term on weights.
- verbose
If 0, xgboost will stay silent. If 1, it will print information about performance. If 2, some additional information will be printed out. Note that setting
verbose > 0automatically engages thecb.print.evaluation(period=1)callback function.- objective
(character) Specify the learning task and the corresponding learning objective. See xgboost::xgboost for all available options.
- info
(character) Optional information to describe the instantiated learner object.
- learner.args
(list) Additional arguments to learner$new().
- ...
Additional arguments to xgboost::xgboost.
Value
learner object.
Examples
n <- 1e3
x1 <- rnorm(n, sd = 2)
x2 <- rnorm(n)
lp <- x2*x1 + cos(x1)
yb <- rbinom(n, 1, lava::expit(lp))
y <- lp + rnorm(n, sd = 0.5**.5)
d0 <- data.frame(y, yb, x1, x2)
# regression
lr <- learner_xgboost(y ~ x1 + x2, nrounds = 5)
lr$estimate(d0)
lr$predict(head(d0))
#> [1] -1.9827677 2.0304260 2.0304260 0.1772428 1.5559609 -3.3161051
# binary classification
lr <- learner_xgboost(yb ~ x1 + x2, nrounds = 5,
objective = "binary:logistic"
)
lr$estimate(d0)
lr$predict(head(d0))
#> [1] 0.2974974 0.9214302 0.7824602 0.1971165 0.4704079 0.1971165
# multi-class classification
d0 <- iris
d0$y <- as.numeric(d0$Species)- 1
lr <- learner_xgboost(y ~ ., objective = "multi:softprob", num_class = 3)
lr$estimate(d0)
lr$predict(head(d0))
#> [,1] [,2] [,3]
#> [1,] 0.9290679 0.03546607 0.03546607
#> [2,] 0.9290679 0.03546607 0.03546607
#> [3,] 0.9290679 0.03546607 0.03546607
#> [4,] 0.9290679 0.03546607 0.03546607
#> [5,] 0.9290679 0.03546607 0.03546607
#> [6,] 0.9290679 0.03546607 0.03546607
