Package org.tribuo.interop
Class ExternalModel<T extends Output<T>,U,V>
java.lang.Object
org.tribuo.Model<T>
org.tribuo.interop.ExternalModel<T,U,V>
- Type Parameters:
T
- The output subclass that this model operates on.U
- The internal representation of features.V
- The internal representation of outputs.
- All Implemented Interfaces:
com.oracle.labs.mlrg.olcut.provenance.Provenancable<ModelProvenance>
,Serializable
- Direct Known Subclasses:
OCIModel
,ONNXExternalModel
,TensorFlowFrozenExternalModel
,TensorFlowSavedModelExternalModel
,XGBoostExternalModel
This is the base class for third party models which are trained externally and
loaded into Tribuo for prediction.
Batch size defaults to DEFAULT_BATCH_SIZE
- See Also:
-
Field Summary
Modifier and TypeFieldDescriptionstatic final int
Default batch size for external model batch predictions.protected final int[]
protected final int[]
Fields inherited from class org.tribuo.Model
ALL_OUTPUTS, BIAS_FEATURE, featureIDMap, generatesProbabilities, name, outputIDInfo, provenance, provenanceOutput
-
Constructor Summary
ModifierConstructorDescriptionprotected
ExternalModel
(String name, ModelProvenance provenance, ImmutableFeatureMap featureIDMap, ImmutableOutputInfo<T> outputIDInfo, boolean generatesProbabilities, Map<String, Integer> featureMapping) Constructs an external model from a model trained outside of Tribuo.protected
ExternalModel
(String name, ModelProvenance provenance, ImmutableFeatureMap featureIDMap, ImmutableOutputInfo<T> outputIDInfo, int[] featureForwardMapping, int[] featureBackwardMapping, boolean generatesProbabilities) Constructs an external model from a model trained outside of Tribuo. -
Method Summary
Modifier and TypeMethodDescriptionprotected abstract U
convertFeatures
(SparseVector input) Converts from a SparseVector using the external model's indices into the ingestion format for the external model.protected abstract U
convertFeaturesList
(List<SparseVector> input) Converts from a list of SparseVector using the external model's indices into the ingestion format for the external model.protected abstract List<Prediction<T>>
convertOutput
(V output, int[] numValidFeatures, List<Example<T>> examples) Converts the output of the external model into a list ofPrediction
s.protected abstract Prediction<T>
convertOutput
(V output, int numValidFeatures, Example<T> example) Converts the output of the external model into aPrediction
.protected static ImmutableFeatureMap
createFeatureMap
(Set<String> featureNames) Creates an immutable feature map from a set of feature names.protected static <T extends Output<T>>
ImmutableOutputInfo<T>createOutputInfo
(OutputFactory<T> factory, Map<T, Integer> outputs) Creates an output info from a set of outputs.protected abstract V
externalPrediction
(U input) Runs the external model's prediction function.int
Gets the current testing batch size.By default third party models don't return excuses.protected List<Prediction<T>>
innerPredict
(Iterable<Example<T>> examples) Called by the base implementations ofModel.predict(Iterable)
andModel.predict(Dataset)
.Uses the model to predict the output for a single example.void
setBatchSize
(int batchSize) Sets a new batch size.Methods inherited from class org.tribuo.Model
castModel, copy, copy, generatesProbabilities, getExcuses, getFeatureIDMap, getName, getOutputIDInfo, getProvenance, getTopFeatures, predict, predict, setName, toString, validate
-
Field Details
-
DEFAULT_BATCH_SIZE
public static final int DEFAULT_BATCH_SIZEDefault batch size for external model batch predictions.- See Also:
-
featureForwardMapping
protected final int[] featureForwardMapping -
featureBackwardMapping
protected final int[] featureBackwardMapping
-
-
Constructor Details
-
ExternalModel
protected ExternalModel(String name, ModelProvenance provenance, ImmutableFeatureMap featureIDMap, ImmutableOutputInfo<T> outputIDInfo, boolean generatesProbabilities, Map<String, Integer> featureMapping) Constructs an external model from a model trained outside of Tribuo.- Parameters:
name
- The model name.provenance
- The model provenance.featureIDMap
- The feature domain.outputIDInfo
- The output domain.generatesProbabilities
- Does this model generate probabilistic predictions.featureMapping
- The mapping from Tribuo's feature names to the model's feature indices.
-
ExternalModel
protected ExternalModel(String name, ModelProvenance provenance, ImmutableFeatureMap featureIDMap, ImmutableOutputInfo<T> outputIDInfo, int[] featureForwardMapping, int[] featureBackwardMapping, boolean generatesProbabilities) Constructs an external model from a model trained outside of Tribuo.- Parameters:
name
- The model name.provenance
- The model provenance.featureIDMap
- The feature domain.outputIDInfo
- The output domain.generatesProbabilities
- Does this model generate probabilistic predictions.featureForwardMapping
- The mapping from Tribuo's indices to the model's indices.featureBackwardMapping
- The mapping from the model's indices to Tribuo's indices.
-
-
Method Details
-
predict
Description copied from class:Model
Uses the model to predict the output for a single example.predict does not mutate the example.
Throws
IllegalArgumentException
if the example has no features or no feature overlap with the model. -
innerPredict
Description copied from class:Model
Called by the base implementations ofModel.predict(Iterable)
andModel.predict(Dataset)
.- Overrides:
innerPredict
in classModel<T extends Output<T>>
- Parameters:
examples
- The examples to predict.- Returns:
- The results of the predictions, in the same order as the examples.
-
convertFeatures
Converts from a SparseVector using the external model's indices into the ingestion format for the external model.- Parameters:
input
- The features using external indices.- Returns:
- The ingestion format for the external model.
-
convertFeaturesList
Converts from a list of SparseVector using the external model's indices into the ingestion format for the external model.- Parameters:
input
- The features using external indices.- Returns:
- The ingestion format for the external model.
-
externalPrediction
Runs the external model's prediction function.- Parameters:
input
- The input in the external model's format.- Returns:
- The output in the external model's format.
-
convertOutput
Converts the output of the external model into aPrediction
.- Parameters:
output
- The output of the external model.numValidFeatures
- The number of valid features in the input.example
- The input example, used to construct the Prediction.- Returns:
- A Tribuo Prediction.
-
convertOutput
protected abstract List<Prediction<T>> convertOutput(V output, int[] numValidFeatures, List<Example<T>> examples) Converts the output of the external model into a list ofPrediction
s.- Parameters:
output
- The output of the external model.numValidFeatures
- An array with the number of valid features in each example.examples
- The input examples, used to construct the Predictions.- Returns:
- A list of Tribuo Predictions.
-
getExcuse
By default third party models don't return excuses. -
getBatchSize
public int getBatchSize()Gets the current testing batch size.- Returns:
- The batch size.
-
setBatchSize
public void setBatchSize(int batchSize) Sets a new batch size.Throws
IllegalArgumentException
if the batch size isn't positive.- Parameters:
batchSize
- The batch size to use.
-
createFeatureMap
Creates an immutable feature map from a set of feature names.Each feature is unobserved.
- Parameters:
featureNames
- The names of the features to create.- Returns:
- A feature map representing the feature names.
-
createOutputInfo
protected static <T extends Output<T>> ImmutableOutputInfo<T> createOutputInfo(OutputFactory<T> factory, Map<T, Integer> outputs) Creates an output info from a set of outputs.- Type Parameters:
T
- The type of the outputs.- Parameters:
factory
- The output factory to use.outputs
- The outputs and ids to observe.- Returns:
- An immutable output info representing the outputs.
-