public class EasyPredictModelWrapper
extends java.lang.Object
implements java.io.Serializable
EasyPredictModelWrapper.ErrorConsumer
in the process of EasyPredictModelWrapper.Config
creation.
Advanced scoring features are disabled by default for performance reasons. Configuration flags
allow the user to output also
- leaf node assignment,
- GLRM reconstructed matrix,
- staged probabilities,
- prediction contributions (SHAP values).
Deprecation note: Total number of unknown categorical variables is newly accessible by registering CountingErrorConsumer
.
See the top-of-tree master version of this file here on github.Modifier and Type | Class and Description |
---|---|
static class |
EasyPredictModelWrapper.Config
Configuration builder for instantiating a Wrapper.
|
static class |
EasyPredictModelWrapper.ErrorConsumer
Observer interface with methods corresponding to errors during the prediction.
|
Constructor and Description |
---|
EasyPredictModelWrapper(EasyPredictModelWrapper.Config config)
Create a wrapper for a generated model.
|
EasyPredictModelWrapper(GenModel model)
Create a wrapper for a generated model.
|
Modifier and Type | Method and Description |
---|---|
protected double[] |
fillRawData(RowData data,
double[] rawData) |
java.lang.String[] |
getContributionNames()
Returns names of contributions for prediction results with constributions enabled.
|
boolean |
getEnableContributions() |
boolean |
getEnableGLRMReconstruct() |
boolean |
getEnableLeafAssignment() |
boolean |
getEnableStagedProbabilities() |
java.lang.String |
getHeader()
Some autoencoder thing, I'm not sure what this does.
|
GenModel |
getModel() |
ModelCategory |
getModelCategory()
Get the category (type) of model.
|
java.lang.String[] |
getResponseDomainValues()
Get the array of levels for the response column.
|
java.lang.String[] |
leafNodeAssignment(RowData data) |
SharedTreeMojoModel.LeafNodeAssignments |
leafNodeAssignmentExtended(RowData data) |
protected double[] |
preamble(ModelCategory c,
RowData data) |
protected double[] |
preamble(ModelCategory c,
RowData data,
double offset) |
AbstractPrediction |
predict(RowData data)
|
protected double[] |
predict(RowData data,
double offset,
double[] preds) |
AbstractPrediction |
predict(RowData data,
ModelCategory mc)
Make a prediction on a new data point.
|
AnomalyDetectionPrediction |
predictAnomalyDetection(RowData data)
Make a prediction on a new data point using a Anomaly Detection model.
|
AutoEncoderModelPrediction |
predictAutoEncoder(RowData data)
Make a prediction on a new data point using an AutoEncoder model.
|
BinomialModelPrediction |
predictBinomial(RowData data)
Make a prediction on a new data point using a Binomial model.
|
BinomialModelPrediction |
predictBinomial(RowData data,
double offset)
Make a prediction on a new data point using a Binomial model.
|
ClusteringModelPrediction |
predictClustering(RowData data)
Make a prediction on a new data point using a Clustering model.
|
float[] |
predictContributions(RowData data) |
FeatureContribution[] |
predictContributions(RowData data,
int topN,
int bottomN,
boolean compareAbs)
Calculate and sort shapley values.
|
CoxPHModelPrediction |
predictCoxPH(RowData data) |
CoxPHModelPrediction |
predictCoxPH(RowData data,
double offset) |
DimReductionModelPrediction |
predictDimReduction(RowData data)
Make a prediction on a new data point using a Dimension Reduction model (PCA, GLRM)
|
KLimeModelPrediction |
predictKLime(RowData data) |
MultinomialModelPrediction |
predictMultinomial(RowData data)
Make a prediction on a new data point using a Multinomial model.
|
MultinomialModelPrediction |
predictMultinomial(RowData data,
double offset)
Make a prediction on a new data point using a Multinomial model.
|
OrdinalModelPrediction |
predictOrdinal(RowData data)
Make a prediction on a new data point using a Ordinal model.
|
OrdinalModelPrediction |
predictOrdinal(RowData data,
double offset)
Make a prediction on a new data point using a Ordinal model.
|
double[] |
predictRaw(RowData data,
double offset)
Make a prediction on a new data point.
|
RegressionModelPrediction |
predictRegression(RowData data)
Make a prediction on a new data point using a Regression model.
|
RegressionModelPrediction |
predictRegression(RowData data,
double offset)
Make a prediction on a new data point using a Regression model.
|
TargetEncoderPrediction |
predictTargetEncoding(RowData data)
Perform target encoding based on TargetEncoderMojoModel
|
UpliftBinomialModelPrediction |
predictUpliftBinomial(RowData data)
Make a prediction on a new data point using Uplift Binomial model.
|
Word2VecPrediction |
predictWord2Vec(RowData data)
Lookup word embeddings for a given word (or set of words).
|
float[] |
predictWord2Vec(java.lang.String[] sentence)
Calculate an aggregated word-embedding for a given input sentence (sequence of words).
|
SortedClassProbability[] |
sortByDescendingClassProbability(BinomialModelPrediction p)
A helper function to return an array of binomial class probabilities for a prediction in sorted order.
|
TargetEncoderPrediction |
transformWithTargetEncoding(RowData data)
Deprecated.
Use
predictTargetEncoding(RowData) instead. |
KeyValue[] |
varimp()
See
varimp(int)
return descending sorted by relative importance array of all variables in the model |
KeyValue[] |
varimp(int n)
|
public final GenModel m
public EasyPredictModelWrapper(EasyPredictModelWrapper.Config config)
config
- The wrapper configurationpublic EasyPredictModelWrapper(GenModel model)
model
- The generated modelpublic boolean getEnableLeafAssignment()
public boolean getEnableGLRMReconstruct()
public boolean getEnableStagedProbabilities()
public boolean getEnableContributions()
public AbstractPrediction predict(RowData data, ModelCategory mc) throws PredictException
data
- A new data point. Unknown or missing column name is treated as a NaN or ignored. Column names are case sensitive.PredictException
public double[] predictRaw(RowData data, double offset) throws PredictException
data
- A new data point. Column names are case-sensitive.offset
- Value of offset (use 0 if the model was trained without offset).PredictException
- if prediction cannot be made (eg.: input is invalid)public AbstractPrediction predict(RowData data) throws PredictException
PredictException
public java.lang.String[] getContributionNames()
public AutoEncoderModelPrediction predictAutoEncoder(RowData data) throws PredictException
data
- A new data point. Unknown or missing column name is treated as a NaN or ignored. Column names are case sensitive.PredictException
public DimReductionModelPrediction predictDimReduction(RowData data) throws PredictException
data
- A new data point. Unknown column name is treated as a NaN. Column names are case sensitive.PredictException
public float[] predictWord2Vec(java.lang.String[] sentence) throws PredictException
sentence
- array of word forming a sentencePredictException
- if model is not a WordEmbedding modelpublic Word2VecPrediction predictWord2Vec(RowData data) throws PredictException
data
- RawData structure, every key with a String value will be translated to an embedding,
note: keys only purpose is to link the output embedding to the input word.PredictException
- if model is not a WordEmbedding modelpublic AnomalyDetectionPrediction predictAnomalyDetection(RowData data) throws PredictException
data
- A new data point. Unknown column name is treated as a NaN. Column names are case sensitive.PredictException
public BinomialModelPrediction predictBinomial(RowData data) throws PredictException
data
- A new data point. Unknown or missing column name is treated as a NaN or ignored. Column names are case sensitive.PredictException
public BinomialModelPrediction predictBinomial(RowData data, double offset) throws PredictException
data
- A new data point. Unknown or missing column name is treated as a NaN or ignored. Column names are case sensitive.offset
- An offset for the prediction.PredictException
public UpliftBinomialModelPrediction predictUpliftBinomial(RowData data) throws PredictException
data
- A new data point. Unknown or missing column name is treated as a NaN or ignored. Column names are case sensitive.PredictException
@Deprecated public TargetEncoderPrediction transformWithTargetEncoding(RowData data) throws PredictException
predictTargetEncoding(RowData)
instead.PredictException
public TargetEncoderPrediction predictTargetEncoding(RowData data) throws PredictException
data
- RowData structure with data for which we want to produce transformations.
Unknown column name is treated as a NaN. Column names are case sensitive.PredictException
public java.lang.String[] leafNodeAssignment(RowData data) throws PredictException
PredictException
public SharedTreeMojoModel.LeafNodeAssignments leafNodeAssignmentExtended(RowData data) throws PredictException
PredictException
public MultinomialModelPrediction predictMultinomial(RowData data) throws PredictException
data
- A new data point. Unknown or missing column name is treated as a NaN or ignored. Column names are case sensitive.PredictException
public MultinomialModelPrediction predictMultinomial(RowData data, double offset) throws PredictException
data
- A new data point. Unknown or missing column name is treated as a NaN or ignored. Column names are case sensitive.offset
- Prediction offsetPredictException
public OrdinalModelPrediction predictOrdinal(RowData data) throws PredictException
data
- A new data point. Unknown or missing column name is treated as a NaN or ignored. Column names are case sensitive.PredictException
public OrdinalModelPrediction predictOrdinal(RowData data, double offset) throws PredictException
data
- A new data point. Unknown or missing column name is treated as a NaN or ignored. Column names are case sensitive.offset
- Prediction offsetPredictException
public SortedClassProbability[] sortByDescendingClassProbability(BinomialModelPrediction p)
p
- The prediction.public ClusteringModelPrediction predictClustering(RowData data) throws PredictException
data
- A new data point. Unknown or missing column name is treated as a NaN or ignored. Column names are case sensitive.PredictException
public RegressionModelPrediction predictRegression(RowData data) throws PredictException
data
- A new data point. Unknown or missing column name is treated as a NaN or ignored. Column names are case sensitive.PredictException
public RegressionModelPrediction predictRegression(RowData data, double offset) throws PredictException
data
- A new data point. Unknown or missing column name is treated as a NaN or ignored. Column names are case sensitive.offset
- Prediction offsetPredictException
public KLimeModelPrediction predictKLime(RowData data) throws PredictException
data
- A new data point. Unknown or missing column name is treated as a NaN or ignored. Column names are case sensitive.PredictException
public CoxPHModelPrediction predictCoxPH(RowData data, double offset) throws PredictException
PredictException
public CoxPHModelPrediction predictCoxPH(RowData data) throws PredictException
PredictException
public float[] predictContributions(RowData data) throws PredictException
PredictException
public FeatureContribution[] predictContributions(RowData data, int topN, int bottomN, boolean compareAbs) throws PredictException
data
- A new data point. Unknown or missing column name is treated as a NaN or ignored. Column names are case sensitive.topN
- Return only #topN highest contributions + bias.
If topN<0 then sort all SHAP values in descending order
If topN<0 && bottomN<0 then sort all SHAP values in descending orderbottomN
- Return only #bottomN lowest contributions + bias
If topN and bottomN are defined together then return array of #topN + #bottomN + bias
If bottomN<0 then sort all SHAP values in ascending order
If topN<0 && bottomN<0 then sort all SHAP values in descending ordercompareAbs
- True to compare absolute values of contributionsPredictException
- When #data cannot be properly translate to raw data.public KeyValue[] varimp()
varimp(int)
return descending sorted by relative importance array of all variables in the modelpublic KeyValue[] varimp(int n)
public GenModel getModel()
public ModelCategory getModelCategory()
public java.lang.String[] getResponseDomainValues()
public java.lang.String getHeader()
protected double[] preamble(ModelCategory c, RowData data) throws PredictException
PredictException
protected double[] preamble(ModelCategory c, RowData data, double offset) throws PredictException
PredictException
protected double[] fillRawData(RowData data, double[] rawData) throws PredictException
PredictException
protected double[] predict(RowData data, double offset, double[] preds) throws PredictException
PredictException