Constructs a learner class object for fitting generalized random forest models with grf::regression_forest or grf::probability_forest. As shown in the examples, the constructed learner returns predicted class probabilities of class 2 in case of binary classification. A n times p matrix, with n being the number of observations and p the number of classes, is returned for multi-class classification.

learner_grf(
  formula,
  num.trees = 2000,
  min.node.size = 5,
  alpha = 0.05,
  sample.fraction = 0.5,
  num.threads = 1,
  model = "grf::regression_forest",
  info = model,
  learner.args = NULL,
  ...
)

Arguments

formula

(formula) Formula specifying response and design matrix.

num.trees

Number of trees grown in the forest. Note: Getting accurate confidence intervals generally requires more trees than getting accurate predictions. Default is 2000.

min.node.size

A target for the minimum number of observations in each tree leaf. Note that nodes with size smaller than min.node.size can occur, as in the original randomForest package. Default is 5.

alpha

A tuning parameter that controls the maximum imbalance of a split. Default is 0.05.

sample.fraction

Fraction of the data used to build each tree. Note: If honesty = TRUE, these subsamples will further be cut by a factor of honesty.fraction. Default is 0.5.

num.threads

Number of threads used in training. By default, the number of threads is set to the maximum hardware concurrency.

model

(character) grf model to estimate. Usually regression_forest (grf::regression_forest) or probability_forest (grf::probability_forest).

info

(character) Optional information to describe the instantiated learner object.

learner.args

(list) Additional arguments to learner$new().

...

Additional arguments to model

Value

learner object.

Examples

n <- 5e2
x1 <- rnorm(n, sd = 2)
x2 <- rnorm(n)
lp <- x2*x1 + cos(x1)
yb <- rbinom(n, 1, lava::expit(lp))
y <-  lp + rnorm(n, sd = 0.5**.5)
d <- data.frame(y, yb, x1, x2)

# regression
lr <- learner_grf(y ~ x1 + x2)
lr$estimate(d)
lr$predict(head(d))
#> [1]  0.7899812  0.3753452 -1.2677767 -0.5914875 -0.4241299  0.8747728

# binary classification
lr <- learner_grf(as.factor(yb) ~ x1 + x2, model = "probability_forest")
lr$estimate(d)
lr$predict(head(d)) # predict class probabilities of class 2
#> [1] 0.5603971 0.5166182 0.3020196 0.4391959 0.2786883 0.6221883

# multi-class classification
lr <- learner_grf(Species ~ ., model = "probability_forest")
lr$estimate(iris)
lr$predict(head(iris))
#>         setosa   versicolor    virginica
#> [1,] 0.9992948 0.0003782828 0.0003269231
#> [2,] 0.9896742 0.0091674964 0.0011583333
#> [3,] 0.9983829 0.0013253968 0.0002916667
#> [4,] 0.9957496 0.0038587302 0.0003916667
#> [5,] 0.9993484 0.0004497114 0.0002019231
#> [6,] 0.9488142 0.0493099331 0.0018759158