dai.train.Rd
Either training_frame
, target_col
, and is_classification
or
resume_model
and resume_method
parameters need to be specified. The
other parameters are determined by the Driverless AI platform if not specified,
see also dai.suggest_model_params
.
dai.train( training_frame = NULL, target_col = NULL, is_classification = NULL, is_timeseries = NULL, testing_frame = NULL, validation_frame = NULL, weight_col = NULL, fold_col = NULL, time_col = NULL, scorer = NULL, cols_to_drop = NULL, accuracy = NULL, time = NULL, interpretability = NULL, time_groups_columns = NULL, unavailable_columns_at_prediction_time = NULL, time_period_in_seconds = NULL, num_prediction_periods = NULL, num_gap_periods = NULL, enable_gpus = TRUE, cols_imputation = NULL, config_overrides = NULL, seed = NULL, experiment_name = NULL, resumed_model = NULL, resume_method = "same", progress = getOption("dai.progress", TRUE) )
training_frame | DAIFrame to use to build the model. |
---|---|
target_col | The name of the target variable. |
is_classification | Whether the predicted variable is categorical (TRUE) or numerical (FALSE). |
is_timeseries | Whether the target variable is a time-series or not. |
testing_frame | DAIFrame to evaluate the model on at the end. It is not used for the model training (optional). |
validation_frame | DAIFrame to use for the model validation during the model training (optional). |
weight_col | Weights column name (optional). |
fold_col | Fold column name (optional). |
time_col | Time column name, containing time ordering for timeseries problems (optional). |
scorer | Name of one of the available scorers (optional). |
cols_to_drop | A character vector of column names to be dropped from the data (optional). |
accuracy | Accuracy setting [1-10] (optional). |
time | Time setting [1-10] (optional). |
interpretability | Interpretability setting [1-10] (optional). |
time_groups_columns | List of column names, contributing to time ordering (optional). |
unavailable_columns_at_prediction_time | List of column names, which won't be present at prediction time in the testing dataset |
time_period_in_seconds | The length of the time period in seconds, used in timeseries problems (optional). |
num_prediction_periods | Timeseries forecast horizont in time period units (optional). |
num_gap_periods | Number of time periods after which forecast starts (optional). |
enable_gpus | Whether to use GPUs (optional). |
config_overrides | DriverlessAI config overrides for separate experiment in TOML string format (optional). |
seed | The random number generator's seed (optional). |
experiment_name | Display name of newly started experiment. |
resumed_model | Model used for retraining/re-ensembling/starting from checkpoint (optional). You may want to also set the |
resume_method | How to resume the |
progress | Whether to display a progress bar (optional). |
DAIModel
For the time-series experiment, attention needs to be paid to how the time period,
forecasting horizon, and gap are defined. See the documentation
for a detailed explanation. If you do not set time_period_in_seconds
, then the
period is determined automatically and the forecasting horizon will then be influenced
by the automatically determined time period. Therefore it is advisable to set the parameter
time_period_in_seconds
to the cadence of observations in your time series, e.g.
for hourly series to 3600
and for daily data to 86400
, etc.
dai.connect(uri = 'http://127.0.0.1:12345', username = 'h2oai', password = 'h2oai') iris_dai <- as.DAIFrame(iris, progress = FALSE) # Simple model with minimal parameters simple_model <- dai.train(training_frame = iris_dai, target_col = 'Species', is_classification = TRUE, is_timeseries = FALSE, time = 1, accuracy = 1, interpretability = 10, progress = FALSE) print(simple_model) # \donttest{ # More complex model that may take more time to train model <- dai.train(training_frame = iris_dai, target_col = 'Species', is_classification = TRUE, is_timeseries = FALSE, time = 5, accuracy = 5, interpretability = 5, progress = FALSE) print(model) # Custom config to enable compliant recipe (see config.toml for more details) and # disable distribution shift check. compliant_model <- dai.train(training_frame = iris_dai, target_col = 'Species', is_classification = TRUE, is_timeseries = FALSE, time = 1, accuracy = 1, config_overrides = c("recipe = 'compliant'", "check_distribution_shift = 'off'"), progress = FALSE) print(compliant_model) # Refit the final pipeline refit_model <- dai.train(resumed_model = simple_model, resume_method = 'refit') # New model with the same parameters same_model <- dai.train(resumed_model = simple_model, resume_method = 'same') # New model with the same parameters except the time new_model <- dai.train(resumed_model = simple_model, resume_method = 'same', time = 2) # \dontshow{ walmart_train <- dai.create_dataset(dai:::find_file('data/walmart_tts_small_train.zip'), progress = FALSE) walmart_test <- dai.create_dataset(dai:::find_file('data/walmart_tts_small_test.zip'), progress = FALSE) # } # A TS model to forecast 1 week into the future with no gap ts_model <- dai.train(training_frame = walmart_train, testing_frame = walmart_test, target_col = 'Weekly_Sales', is_classification = FALSE, is_timeseries = TRUE, seed = 25, progress = FALSE, accuracy = 1, time = 1, interpretability = 10, time_period_in_seconds = 3600 * 24 * 7, # one week num_prediction_periods = 1, num_gap_periods = 0, time_col = 'Date') # \dontshow{ dai.rm(ts_model, refit_model, same_model, new_model, compliant_model, model, simple_model, iris_dai, walmart_train, walmart_test) # }# }