Constructs a learner class object for fitting generalized
random forest models with grf::regression_forest or
grf::probability_forest. As shown in the examples, the constructed learner
returns predicted class probabilities of class 2 in case of binary
classification. A n times p matrix, with n being the number of
observations and p the number of classes, is returned for multi-class
classification.
Usage
learner_grf(
formula,
num.trees = 2000,
min.node.size = 5,
alpha = 0.05,
sample.fraction = 0.5,
num.threads = 1,
model = "grf::regression_forest",
info = model,
learner.args = NULL,
...
)Arguments
- formula
(formula) Formula specifying response and design matrix.
- num.trees
Number of trees grown in the forest. Note: Getting accurate confidence intervals generally requires more trees than getting accurate predictions. Default is 2000.
- min.node.size
A target for the minimum number of observations in each tree leaf. Note that nodes with size smaller than min.node.size can occur, as in the original randomForest package. Default is 5.
- alpha
A tuning parameter that controls the maximum imbalance of a split. Default is 0.05.
- sample.fraction
Fraction of the data used to build each tree. Note: If honesty = TRUE, these subsamples will further be cut by a factor of honesty.fraction. Default is 0.5.
- num.threads
Number of threads used in training. By default, the number of threads is set to the maximum hardware concurrency.
- model
(character) grf model to estimate. Usually regression_forest (grf::regression_forest) or probability_forest (grf::probability_forest).
- info
(character) Optional information to describe the instantiated learner object.
- learner.args
(list) Additional arguments to learner$new().
- ...
Additional arguments to
model
Value
learner object.
Examples
n <- 5e2
x1 <- rnorm(n, sd = 2)
x2 <- rnorm(n)
lp <- x2*x1 + cos(x1)
yb <- rbinom(n, 1, lava::expit(lp))
y <- lp + rnorm(n, sd = 0.5**.5)
d <- data.frame(y, yb, x1, x2)
# regression
lr <- learner_grf(y ~ x1 + x2)
lr$estimate(d)
lr$predict(head(d))
#> [1] 0.7819701 -2.5837244 1.1191245 -2.6589655 0.3622218 0.7357379
# binary classification
lr <- learner_grf(as.factor(yb) ~ x1 + x2, model = "probability_forest")
lr$estimate(d)
lr$predict(head(d)) # predict class probabilities of class 2
#> [1] 0.8236664 0.1345047 0.8830038 0.1046585 0.6214976 0.7902968
# multi-class classification
lr <- learner_grf(Species ~ ., model = "probability_forest")
lr$estimate(iris)
lr$predict(head(iris))
#> setosa versicolor virginica
#> [1,] 0.9991906 0.0005125000 0.0002969188
#> [2,] 0.9856089 0.0128613095 0.0015297619
#> [3,] 0.9988476 0.0009142857 0.0002380952
#> [4,] 0.9970881 0.0023238095 0.0005880952
#> [5,] 0.9989144 0.0004886905 0.0005969188
#> [6,] 0.9451567 0.0515652431 0.0032780696
