Skip to contents

Constructs a learner class object for xgboost::xgboost.

Usage

learner_xgboost(
  formula,
  max_depth = 2L,
  learning_rate = 1,
  nrounds = 2L,
  subsample = 1,
  reg_lambda = 1,
  objective = "reg:squarederror",
  info = paste("xgboost", objective),
  learner.args = NULL,
  ...
)

Arguments

formula

(formula) Formula specifying response and design matrix.

max_depth

(integer) Maximum depth of a tree.

learning_rate

(numeric) Learning rate.

nrounds

Number of boosting iterations / rounds.

Note that the number of default boosting rounds here is not automatically tuned, and different problems will have vastly different optimal numbers of boosting rounds.

subsample

(numeric) Subsample ratio of the training instance.

reg_lambda

(numeric) L2 regularization term on weights.

objective

(character) Specify the learning task and the corresponding learning objective. See xgboost::xgboost for all available options.

info

(character) Optional information to describe the instantiated learner object.

learner.args

(list) Additional arguments to learner$new().

...

Additional arguments to xgboost::xgboost.

Value

learner object.

Examples

n  <- 1e3
x1 <- rnorm(n, sd = 2)
x2 <- rnorm(n)
lp <- x2*x1 + cos(x1)
yb <- rbinom(n, 1, lava::expit(lp))
y <-  lp + rnorm(n, sd = 0.5**.5)
d0 <- data.frame(y, yb, x1, x2)

# regression
lr <- learner_xgboost(y ~ x1 + x2, nrounds = 5)
lr$estimate(d0)
lr$predict(head(d0))
#> [1] -2.1982367  1.9023441  1.9023441 -1.1556151 -0.8474901 -1.5107955

# binary classification
lr <- learner_xgboost(yb ~ x1 + x2, nrounds = 5,
 objective = "binary:logistic"
)
lr$estimate(d0)
lr$predict(head(d0))
#> [1] 0.1874139 0.8569067 0.6753441 0.2279280 0.6552737 0.2279280

# multi-class classification
d0 <- iris
d0$y <- as.numeric(d0$Species)- 1

lr <- learner_xgboost(y ~ ., objective = "multi:softprob", num_class = 3)
lr$estimate(d0)
lr$predict(head(d0))
#>           [,1]       [,2]       [,3]
#> [1,] 0.9290679 0.03546607 0.03546607
#> [2,] 0.9290679 0.03546607 0.03546607
#> [3,] 0.9290679 0.03546607 0.03546607
#> [4,] 0.9290679 0.03546607 0.03546607
#> [5,] 0.9290679 0.03546607 0.03546607
#> [6,] 0.9290679 0.03546607 0.03546607