Package org.tribuo.interop
Class ExternalModel<T extends Output<T>,U,V>
java.lang.Object
org.tribuo.Model<T>
org.tribuo.interop.ExternalModel<T,U,V>
- Type Parameters:
T
- The output subclass that this model operates on.U
- The internal representation of features.V
- The internal representation of outputs.
- All Implemented Interfaces:
com.oracle.labs.mlrg.olcut.provenance.Provenancable<ModelProvenance>
,Serializable
,ProtoSerializable<org.tribuo.protos.core.ModelProto>
- Direct Known Subclasses:
OCIModel
,ONNXExternalModel
,TensorFlowFrozenExternalModel
,TensorFlowSavedModelExternalModel
,XGBoostExternalModel
This is the base class for third party models which are trained externally and
loaded into Tribuo for prediction.
Batch size defaults to DEFAULT_BATCH_SIZE
- See Also:
-
Field Summary
Modifier and TypeFieldDescriptionstatic final int
Default batch size for external model batch predictions.protected final int[]
The backward mapping from the external indices to Tribuo's indices.protected final int[]
The forward mapping from Tribuo's indices to the external indices.Fields inherited from class org.tribuo.Model
ALL_OUTPUTS, BIAS_FEATURE, featureIDMap, generatesProbabilities, name, outputIDInfo, provenance, provenanceOutput
Fields inherited from interface org.tribuo.protos.ProtoSerializable
DESERIALIZATION_METHOD_NAME, PROVENANCE_SERIALIZER
-
Constructor Summary
ModifierConstructorDescriptionprotected
ExternalModel
(String name, ModelProvenance provenance, ImmutableFeatureMap featureIDMap, ImmutableOutputInfo<T> outputIDInfo, boolean generatesProbabilities, Map<String, Integer> featureMapping) Constructs an external model from a model trained outside of Tribuo.protected
ExternalModel
(String name, ModelProvenance provenance, ImmutableFeatureMap featureIDMap, ImmutableOutputInfo<T> outputIDInfo, int[] featureForwardMapping, int[] featureBackwardMapping, boolean generatesProbabilities) Constructs an external model from a model trained outside of Tribuo. -
Method Summary
Modifier and TypeMethodDescriptionprotected abstract U
convertFeatures
(SparseVector input) Converts from a SparseVector using the external model's indices into the ingestion format for the external model.protected abstract U
convertFeaturesList
(List<SparseVector> input) Converts from a list of SparseVector using the external model's indices into the ingestion format for the external model.protected abstract List<Prediction<T>>
convertOutput
(V output, int[] numValidFeatures, List<Example<T>> examples) Converts the output of the external model into a list ofPrediction
s.protected abstract Prediction<T>
convertOutput
(V output, int numValidFeatures, Example<T> example) Converts the output of the external model into aPrediction
.protected static ImmutableFeatureMap
createFeatureMap
(Set<String> featureNames) Creates an immutable feature map from a set of feature names.protected static <T extends Output<T>>
ImmutableOutputInfo<T>createOutputInfo
(OutputFactory<T> factory, Map<T, Integer> outputs) Creates an output info from a set of outputs.protected abstract V
externalPrediction
(U input) Runs the external model's prediction function.int
Gets the current testing batch size.By default third party models don't return excuses.protected List<Prediction<T>>
innerPredict
(Iterable<Example<T>> examples) Called by the base implementations ofModel.predict(Iterable)
andModel.predict(Dataset)
.Uses the model to predict the output for a single example.void
setBatchSize
(int batchSize) Sets a new batch size.protected static boolean
validateFeatureMapping
(int[] featureForwardMapping, int[] featureBackwardMapping, ImmutableFeatureMap featureDomain) Checks if the feature mappings are valid for the supplied feature map.Methods inherited from class org.tribuo.Model
castModel, copy, copy, createDataCarrier, deserialize, deserializeFromFile, deserializeFromStream, generatesProbabilities, getExcuses, getFeatureIDMap, getName, getOutputIDInfo, getProvenance, getTopFeatures, predict, predict, serialize, serializeToFile, serializeToStream, setName, toString, validate
-
Field Details
-
DEFAULT_BATCH_SIZE
public static final int DEFAULT_BATCH_SIZEDefault batch size for external model batch predictions.- See Also:
-
featureForwardMapping
protected final int[] featureForwardMappingThe forward mapping from Tribuo's indices to the external indices. -
featureBackwardMapping
protected final int[] featureBackwardMappingThe backward mapping from the external indices to Tribuo's indices.
-
-
Constructor Details
-
ExternalModel
protected ExternalModel(String name, ModelProvenance provenance, ImmutableFeatureMap featureIDMap, ImmutableOutputInfo<T> outputIDInfo, boolean generatesProbabilities, Map<String, Integer> featureMapping) Constructs an external model from a model trained outside of Tribuo.- Parameters:
name
- The model name.provenance
- The model provenance.featureIDMap
- The feature domain.outputIDInfo
- The output domain.generatesProbabilities
- Does this model generate probabilistic predictions.featureMapping
- The mapping from Tribuo's feature names to the model's feature indices.
-
ExternalModel
protected ExternalModel(String name, ModelProvenance provenance, ImmutableFeatureMap featureIDMap, ImmutableOutputInfo<T> outputIDInfo, int[] featureForwardMapping, int[] featureBackwardMapping, boolean generatesProbabilities) Constructs an external model from a model trained outside of Tribuo.- Parameters:
name
- The model name.provenance
- The model provenance.featureIDMap
- The feature domain.outputIDInfo
- The output domain.generatesProbabilities
- Does this model generate probabilistic predictions.featureForwardMapping
- The mapping from Tribuo's indices to the model's indices.featureBackwardMapping
- The mapping from the model's indices to Tribuo's indices.
-
-
Method Details
-
predict
Description copied from class:Model
Uses the model to predict the output for a single example.predict does not mutate the example.
Throws
IllegalArgumentException
if the example has no features or no feature overlap with the model. -
innerPredict
Description copied from class:Model
Called by the base implementations ofModel.predict(Iterable)
andModel.predict(Dataset)
.- Overrides:
innerPredict
in classModel<T extends Output<T>>
- Parameters:
examples
- The examples to predict.- Returns:
- The results of the predictions, in the same order as the examples.
-
convertFeatures
Converts from a SparseVector using the external model's indices into the ingestion format for the external model.- Parameters:
input
- The features using external indices.- Returns:
- The ingestion format for the external model.
-
convertFeaturesList
Converts from a list of SparseVector using the external model's indices into the ingestion format for the external model.- Parameters:
input
- The features using external indices.- Returns:
- The ingestion format for the external model.
-
externalPrediction
Runs the external model's prediction function.- Parameters:
input
- The input in the external model's format.- Returns:
- The output in the external model's format.
-
convertOutput
Converts the output of the external model into aPrediction
.- Parameters:
output
- The output of the external model.numValidFeatures
- The number of valid features in the input.example
- The input example, used to construct the Prediction.- Returns:
- A Tribuo Prediction.
-
convertOutput
protected abstract List<Prediction<T>> convertOutput(V output, int[] numValidFeatures, List<Example<T>> examples) Converts the output of the external model into a list ofPrediction
s.- Parameters:
output
- The output of the external model.numValidFeatures
- An array with the number of valid features in each example.examples
- The input examples, used to construct the Predictions.- Returns:
- A list of Tribuo Predictions.
-
getExcuse
By default third party models don't return excuses. -
getBatchSize
public int getBatchSize()Gets the current testing batch size.- Returns:
- The batch size.
-
setBatchSize
public void setBatchSize(int batchSize) Sets a new batch size.Throws
IllegalArgumentException
if the batch size isn't positive.- Parameters:
batchSize
- The batch size to use.
-
createFeatureMap
Creates an immutable feature map from a set of feature names.Each feature is unobserved.
- Parameters:
featureNames
- The names of the features to create.- Returns:
- A feature map representing the feature names.
-
createOutputInfo
protected static <T extends Output<T>> ImmutableOutputInfo<T> createOutputInfo(OutputFactory<T> factory, Map<T, Integer> outputs) Creates an output info from a set of outputs.- Type Parameters:
T
- The type of the outputs.- Parameters:
factory
- The output factory to use.outputs
- The outputs and ids to observe.- Returns:
- An immutable output info representing the outputs.
-
validateFeatureMapping
protected static boolean validateFeatureMapping(int[] featureForwardMapping, int[] featureBackwardMapping, ImmutableFeatureMap featureDomain) Checks if the feature mappings are valid for the supplied feature map.- Parameters:
featureForwardMapping
- The forward feature mapping.featureBackwardMapping
- The backward feature mapping.featureDomain
- The feature domain.- Returns:
- True if the feature mapping is valid (the forward and backward mappings are a bijection and the same size as the feature domain).
-