R6 class for prediction models
R6 class for prediction models
Provides standardized estimation and prediction methods
info
Optional information/name of the model
formals
List with formal arguments of estimation and prediction functions
formula
Formula specifying response and design matrix
args
additional arguments specified during initialization
fit
Active binding returning estimated model object
new()
Create a new prediction model object
ml_model$new(
formula = NULL,
estimate,
predict = predict,
predict.args = NULL,
info = NULL,
specials,
response.arg = "y",
x.arg = "x",
...
)
formula
formula specifying outcome and design matrix
estimate
function for fitting the model (must be a function response, 'y', and design matrix, 'x'. Alternatively, a function with a single 'formula' argument)
predict
prediction function (must be a function of model object, 'object', and new design matrix, 'newdata')
predict.args
optional arguments to prediction function
info
optional description of the model
specials
optional additional terms (weights, offset, id, subset, ...) passed to 'estimate'
response.arg
name of response argument
x.arg
name of design matrix argument
...
optional arguments to fitting function
estimate()
Estimation method
predict()
Prediction method
update()
Update formula
print()
Print method
design()
Extract design matrix (features) from data
data(iris)
rf <- function(formula, ...)
ml_model$new(formula, info="grf::probability_forest",
estimate=function(x,y, ...) grf::probability_forest(X=x, Y=y, ...),
predict=function(object, newdata) predict(object, newdata)$predictions, ...)
args <- expand.list(num.trees=c(100,200), mtry=1:3,
formula=c(Species ~ ., Species ~ Sepal.Length + Sepal.Width))
models <- lapply(args, function(par) do.call(rf, par))
x <- models[[1]]$clone()
x$estimate(iris)
predict(x, newdata=head(iris))
#> setosa versicolor virginica
#> [1,] 0.9871017 0.005227273 0.007670996
#> [2,] 0.9658442 0.026746753 0.007409091
#> [3,] 0.9675108 0.009080087 0.023409091
#> [4,] 0.9760942 0.015996753 0.007909091
#> [5,] 0.9871017 0.005227273 0.007670996
#> [6,] 0.9357020 0.045131313 0.019166667
# Reduce Ex. timing
a <- targeted::cv(models, data=iris)
cbind(coef(a), attr(args, "table"))
#> brier -logscore num.trees mtry
#> model1 0.10562162 0.2273141 100 1
#> model2 0.10104547 0.2201344 200 1
#> model3 0.09139703 0.1888983 100 2
#> model4 0.09161848 0.1917914 200 2
#> model5 0.08864399 0.1766376 100 3
#> model6 0.08702378 0.1736402 200 3
#> model7 0.34668502 0.5648823 100 1
#> model8 0.34768644 0.5681522 200 1
#> model9 0.34083376 0.5487346 100 2
#> model10 0.34571569 0.5574617 200 2
#> model11 0.34304641 0.5518997 100 3
#> model12 0.34078208 0.5518433 200 3
#> formula
#> model1 Species ~ .
#> model2 Species ~ .
#> model3 Species ~ .
#> model4 Species ~ .
#> model5 Species ~ .
#> model6 Species ~ .
#> model7 Species ~ Sepal.Length + Sepal.Width
#> model8 Species ~ Sepal.Length + Sepal.Width
#> model9 Species ~ Sepal.Length + Sepal.Width
#> model10 Species ~ Sepal.Length + Sepal.Width
#> model11 Species ~ Sepal.Length + Sepal.Width
#> model12 Species ~ Sepal.Length + Sepal.Width
ff <- ml_model$new(estimate=function(y,x) lm.fit(x=x, y=y),
predict=function(object, newdata) newdata%*%object$coefficients)
## tmp <- ff$estimate(y, x=x)
## ff$predict(x)