Constructs a learner class object for xgboost::xgboost.

learner_xgboost(
  formula,
  max_depth = 2L,
  eta = 1,
  nrounds = 2L,
  subsample = 1,
  lambda = 1,
  verbose = 0,
  objective = "reg:squarederror",
  info = paste("xgboost", objective),
  learner.args = NULL,
  ...
)

Arguments

formula

(formula) Formula specifying response and design matrix.

max_depth

(integer) Maximum depth of a tree.

eta

(numeric) Learning rate.

nrounds

max number of boosting iterations.

subsample

(numeric) Subsample ratio of the training instance.

lambda

(numeric) L2 regularization term on weights.

verbose

If 0, xgboost will stay silent. If 1, it will print information about performance. If 2, some additional information will be printed out. Note that setting verbose > 0 automatically engages the cb.print.evaluation(period=1) callback function.

objective

(character) Specify the learning task and the corresponding learning objective. See xgboost::xgboost for all available options.

info

(character) Optional information to describe the instantiated learner object.

learner.args

(list) Additional arguments to learner$new().

...

Additional arguments to xgboost::xgboost.

Value

learner object.

Examples

n  <- 1e3
x1 <- rnorm(n, sd = 2)
x2 <- rnorm(n)
lp <- x2*x1 + cos(x1)
yb <- rbinom(n, 1, lava::expit(lp))
y <-  lp + rnorm(n, sd = 0.5**.5)
d0 <- data.frame(y, yb, x1, x2)

# regression
lr <- learner_xgboost(y ~ x1 + x2, nrounds = 5)
lr$estimate(d0)
lr$predict(head(d0))
#> [1] 1.1927135 0.8144226 1.2940172 0.2145363 1.2940172 1.2940172

# binary classification
lr <- learner_xgboost(yb ~ x1 + x2, nrounds = 5,
 objective = "binary:logistic"
)
lr$estimate(d0)
lr$predict(head(d0))
#> [1] 0.8284993 0.6372980 0.8420117 0.8284993 0.8420117 0.8420117

# multi-class classification
d0 <- iris
d0$y <- as.numeric(d0$Species)- 1

lr <- learner_xgboost(y ~ ., objective = "multi:softprob", num_class = 3)
lr$estimate(d0)
lr$predict(head(d0))
#>           [,1]       [,2]       [,3]
#> [1,] 0.9290679 0.03546607 0.03546607
#> [2,] 0.9290679 0.03546607 0.03546607
#> [3,] 0.9290679 0.03546607 0.03546607
#> [4,] 0.9290679 0.03546607 0.03546607
#> [5,] 0.9290679 0.03546607 0.03546607
#> [6,] 0.9290679 0.03546607 0.03546607