Constructs a learner class object for fitting generalized
random forest models with grf::regression_forest or
grf::probability_forest. As shown in the examples, the constructed learner
returns predicted class probabilities of class 2 in case of binary
classification. A n times p
matrix, with n
being the number of
observations and p
the number of classes, is returned for multi-class
classification.
learner_grf(
formula,
num.trees = 2000,
min.node.size = 5,
alpha = 0.05,
sample.fraction = 0.5,
num.threads = 1,
model = "grf::regression_forest",
info = model,
learner.args = NULL,
...
)
(formula) Formula specifying response and design matrix.
Number of trees grown in the forest. Note: Getting accurate confidence intervals generally requires more trees than getting accurate predictions. Default is 2000.
A target for the minimum number of observations in each tree leaf. Note that nodes with size smaller than min.node.size can occur, as in the original randomForest package. Default is 5.
A tuning parameter that controls the maximum imbalance of a split. Default is 0.05.
Fraction of the data used to build each tree. Note: If honesty = TRUE, these subsamples will further be cut by a factor of honesty.fraction. Default is 0.5.
Number of threads used in training. By default, the number of threads is set to the maximum hardware concurrency.
(character) grf model to estimate. Usually regression_forest (grf::regression_forest) or probability_forest (grf::probability_forest).
(character) Optional information to describe the instantiated learner object.
(list) Additional arguments to learner$new().
Additional arguments to model
learner object.
n <- 5e2
x1 <- rnorm(n, sd = 2)
x2 <- rnorm(n)
lp <- x2*x1 + cos(x1)
yb <- rbinom(n, 1, lava::expit(lp))
y <- lp + rnorm(n, sd = 0.5**.5)
d <- data.frame(y, yb, x1, x2)
# regression
lr <- learner_grf(y ~ x1 + x2)
lr$estimate(d)
lr$predict(head(d))
#> [1] 0.7899812 0.3753452 -1.2677767 -0.5914875 -0.4241299 0.8747728
# binary classification
lr <- learner_grf(as.factor(yb) ~ x1 + x2, model = "probability_forest")
lr$estimate(d)
lr$predict(head(d)) # predict class probabilities of class 2
#> [1] 0.5603971 0.5166182 0.3020196 0.4391959 0.2786883 0.6221883
# multi-class classification
lr <- learner_grf(Species ~ ., model = "probability_forest")
lr$estimate(iris)
lr$predict(head(iris))
#> setosa versicolor virginica
#> [1,] 0.9992948 0.0003782828 0.0003269231
#> [2,] 0.9896742 0.0091674964 0.0011583333
#> [3,] 0.9983829 0.0013253968 0.0002916667
#> [4,] 0.9957496 0.0038587302 0.0003916667
#> [5,] 0.9993484 0.0004497114 0.0002019231
#> [6,] 0.9488142 0.0493099331 0.0018759158