Source code for h2o.model.models.ordinal

# -*- encoding: utf-8 -*-
import h2o
from h2o.exceptions import H2OValueError
from h2o.frame import H2OFrame
from h2o.model import ModelBase
from h2o.model.extensions import has_extension
from h2o.utils.compatibility import *  # NOQA
from h2o.utils.typechecks import assert_is_type


[docs]class H2OOrdinalModel(ModelBase):
[docs] def confusion_matrix(self, data): """ Returns a confusion matrix based on H2O's default prediction threshold for a dataset. :param H2OFrame data: the frame with the prediction results for which the confusion matrix should be extracted. """ assert_is_type(data, H2OFrame) j = h2o.api("POST /3/Predictions/models/%s/frames/%s" % (self._id, data.frame_id)) return j["model_metrics"][0]["cm"]["table"]
[docs] def hit_ratio_table(self, train=False, valid=False, xval=False): """ Retrieve the Hit Ratios. If all are ``False`` (default), then return the training metric value. If more than one options is set to ``True``, then return a dictionary of metrics where the keys are "train", "valid", and "xval". :param train: If train is ``True``, then return the hit ratio value for the training data. :param valid: If valid is ``True``, then return the hit ratio value for the validation data. :param xval: If xval is ``True``, then return the hit ratio value for the cross validation data. :return: The hit ratio for this regression model. """ tm = ModelBase._get_metrics(self, train, valid, xval) m = {} for k, v in zip(list(tm.keys()), list(tm.values())): m[k] = None if v is None else v.hit_ratio_table() return list(m.values())[0] if len(m) == 1 else m
[docs] def mean_per_class_error(self, train=False, valid=False, xval=False): """ Retrieve the mean per class error across all classes If all are ``False`` (default), then return the training metric value. If more than one options is set to ``True``, then return a dictionary of metrics where the keys are "train", "valid", and "xval". :param bool train: If ``True``, return the ``mean_per_class_error`` value for the training data. :param bool valid: If ``True``, return the ``mean_per_class_error`` value for the validation data. :param bool xval: If ``True``, return the ``mean_per_class_error`` value for each of the cross-validated splits. :returns: The ``mean_per_class_error`` values for the specified key(s). """ tm = ModelBase._get_metrics(self, train, valid, xval) m = {} for k, v in zip(list(tm.keys()), list(tm.values())): m[k] = None if v is None else v.mean_per_class_error() return list(m.values())[0] if len(m) == 1 else m
[docs] def plot(self, timestep="AUTO", metric="AUTO", save_plot_path=None, **kwargs): """ Plots training set (and validation set if available) scoring history for an H2OOrdinalModel. The ``timestep`` and ``metric`` arguments are restricted to what is available in its scoring history. :param timestep: A unit of measurement for the x-axis. :param metric: A unit of measurement for the y-axis. :param save_plot_path: a path to save the plot via using matplotlib function savefig. :returns: Object that contains the resulting scoring history plot (can be accessed using ``result.figure()``). """ if not has_extension(self, 'ScoringHistory'): raise H2OValueError("Scoring history plot is not available for this type of model (%s)." % self.algo) self.scoring_history_plot(timestep=timestep, metric=metric, save_plot_path=save_plot_path, **kwargs)