Either training_frame, target_col, and is_classification or resume_model and resume_method parameters need to be specified. The other parameters are determined by the Driverless AI platform if not specified, see also dai.suggest_model_params.

dai.train(
  training_frame = NULL,
  target_col = NULL,
  is_classification = NULL,
  is_timeseries = NULL,
  testing_frame = NULL,
  validation_frame = NULL,
  weight_col = NULL,
  fold_col = NULL,
  time_col = NULL,
  scorer = NULL,
  cols_to_drop = NULL,
  accuracy = NULL,
  time = NULL,
  interpretability = NULL,
  time_groups_columns = NULL,
  unavailable_columns_at_prediction_time = NULL,
  time_period_in_seconds = NULL,
  num_prediction_periods = NULL,
  num_gap_periods = NULL,
  enable_gpus = TRUE,
  cols_imputation = NULL,
  config_overrides = NULL,
  seed = NULL,
  experiment_name = NULL,
  resumed_model = NULL,
  resume_method = "same",
  progress = getOption("dai.progress", TRUE)
)

Arguments

training_frame

DAIFrame to use to build the model.

target_col

The name of the target variable.

is_classification

Whether the predicted variable is categorical (TRUE) or numerical (FALSE).

is_timeseries

Whether the target variable is a time-series or not.

testing_frame

DAIFrame to evaluate the model on at the end. It is not used for the model training (optional).

validation_frame

DAIFrame to use for the model validation during the model training (optional).

weight_col

Weights column name (optional).

fold_col

Fold column name (optional).

time_col

Time column name, containing time ordering for timeseries problems (optional).

scorer

Name of one of the available scorers (optional).

cols_to_drop

A character vector of column names to be dropped from the data (optional).

accuracy

Accuracy setting [1-10] (optional).

time

Time setting [1-10] (optional).

interpretability

Interpretability setting [1-10] (optional).

time_groups_columns

List of column names, contributing to time ordering (optional).

unavailable_columns_at_prediction_time

List of column names, which won't be present at prediction time in the testing dataset

time_period_in_seconds

The length of the time period in seconds, used in timeseries problems (optional).

num_prediction_periods

Timeseries forecast horizont in time period units (optional).

num_gap_periods

Number of time periods after which forecast starts (optional).

enable_gpus

Whether to use GPUs (optional).

config_overrides

DriverlessAI config overrides for separate experiment in TOML string format (optional).

seed

The random number generator's seed (optional).

experiment_name

Display name of newly started experiment.

resumed_model

Model used for retraining/re-ensembling/starting from checkpoint (optional). You may want to also set the resume_mode parameter. Any parameter not set here will be taken from the resumed model.

resume_method

How to resume the resumed_model: new model with same parameters; restart from the last checkpoint; and refit the final pipeline.

progress

Whether to display a progress bar (optional).

Value

DAIModel

Details

For the time-series experiment, attention needs to be paid to how the time period, forecasting horizon, and gap are defined. See the documentation for a detailed explanation. If you do not set time_period_in_seconds, then the period is determined automatically and the forecasting horizon will then be influenced by the automatically determined time period. Therefore it is advisable to set the parameter time_period_in_seconds to the cadence of observations in your time series, e.g. for hourly series to 3600 and for daily data to 86400, etc.

See also

Examples

dai.connect(uri = 'http://127.0.0.1:12345', username = 'h2oai', password = 'h2oai')
iris_dai <- as.DAIFrame(iris, progress = FALSE)
# Simple model with minimal parameters
simple_model <- dai.train(training_frame = iris_dai,
                          target_col = 'Species',
                          is_classification = TRUE,
                          is_timeseries = FALSE,
                          time = 1, accuracy = 1, interpretability = 10,
                          progress = FALSE)
print(simple_model)
# \donttest{
# More complex model that may take more time to train
model <- dai.train(training_frame = iris_dai,
                   target_col = 'Species',
                   is_classification = TRUE,
                   is_timeseries = FALSE,
                   time = 5, accuracy = 5, interpretability = 5,
                   progress = FALSE)
print(model)
# Custom config to enable compliant recipe (see config.toml for more details) and 
# disable distribution shift check.
compliant_model <- dai.train(training_frame = iris_dai,
                             target_col = 'Species',
                             is_classification = TRUE,
                             is_timeseries = FALSE,
                             time = 1, accuracy = 1,
                             config_overrides = c("recipe = 'compliant'", "check_distribution_shift = 'off'"),
                             progress = FALSE)
print(compliant_model)

# Refit the final pipeline
refit_model <- dai.train(resumed_model = simple_model, resume_method = 'refit')

# New model with the same parameters
same_model <- dai.train(resumed_model = simple_model, resume_method = 'same')

# New model with the same parameters except the time
new_model <- dai.train(resumed_model = simple_model, resume_method = 'same', time = 2)

# \dontshow{
walmart_train <- dai.create_dataset(dai:::find_file('data/walmart_tts_small_train.zip'), progress = FALSE)
walmart_test <- dai.create_dataset(dai:::find_file('data/walmart_tts_small_test.zip'), progress = FALSE)
# }
# A TS model to forecast 1 week into the future with no gap
ts_model <- dai.train(training_frame = walmart_train,
                      testing_frame = walmart_test,
                      target_col = 'Weekly_Sales',
                      is_classification = FALSE,
                      is_timeseries = TRUE,
                      seed = 25,
                      progress = FALSE,
                      accuracy = 1, time = 1, interpretability = 10,
                      time_period_in_seconds = 3600 * 24 * 7,  # one week
                      num_prediction_periods = 1,
                      num_gap_periods = 0,
                      time_col = 'Date')
# \dontshow{
dai.rm(ts_model, refit_model, same_model, new_model, compliant_model, model,
       simple_model, iris_dai, walmart_train, walmart_test)
# }# }