Class ExternalModel<T extends Output<T>,U,V>

java.lang.Object
org.tribuo.Model<T>
org.tribuo.interop.ExternalModel<T,U,V>
Type Parameters:
T - The output subclass that this model operates on.
U - The internal representation of features.
V - The internal representation of outputs.
All Implemented Interfaces:
com.oracle.labs.mlrg.olcut.provenance.Provenancable<ModelProvenance>, Serializable, ProtoSerializable<org.tribuo.protos.core.ModelProto>
Direct Known Subclasses:
OCIModel, ONNXExternalModel, TensorFlowFrozenExternalModel, TensorFlowSavedModelExternalModel, XGBoostExternalModel

public abstract class ExternalModel<T extends Output<T>,U,V> extends Model<T>
This is the base class for third party models which are trained externally and loaded into Tribuo for prediction.

Batch size defaults to DEFAULT_BATCH_SIZE

See Also:
  • Field Details

    • DEFAULT_BATCH_SIZE

      public static final int DEFAULT_BATCH_SIZE
      Default batch size for external model batch predictions.
      See Also:
    • featureForwardMapping

      protected final int[] featureForwardMapping
      The forward mapping from Tribuo's indices to the external indices.
    • featureBackwardMapping

      protected final int[] featureBackwardMapping
      The backward mapping from the external indices to Tribuo's indices.
  • Constructor Details

    • ExternalModel

      protected ExternalModel(String name, ModelProvenance provenance, ImmutableFeatureMap featureIDMap, ImmutableOutputInfo<T> outputIDInfo, boolean generatesProbabilities, Map<String,Integer> featureMapping)
      Constructs an external model from a model trained outside of Tribuo.
      Parameters:
      name - The model name.
      provenance - The model provenance.
      featureIDMap - The feature domain.
      outputIDInfo - The output domain.
      generatesProbabilities - Does this model generate probabilistic predictions.
      featureMapping - The mapping from Tribuo's feature names to the model's feature indices.
    • ExternalModel

      protected ExternalModel(String name, ModelProvenance provenance, ImmutableFeatureMap featureIDMap, ImmutableOutputInfo<T> outputIDInfo, int[] featureForwardMapping, int[] featureBackwardMapping, boolean generatesProbabilities)
      Constructs an external model from a model trained outside of Tribuo.
      Parameters:
      name - The model name.
      provenance - The model provenance.
      featureIDMap - The feature domain.
      outputIDInfo - The output domain.
      generatesProbabilities - Does this model generate probabilistic predictions.
      featureForwardMapping - The mapping from Tribuo's indices to the model's indices.
      featureBackwardMapping - The mapping from the model's indices to Tribuo's indices.
  • Method Details

    • predict

      public Prediction<T> predict(Example<T> example)
      Description copied from class: Model
      Uses the model to predict the output for a single example.

      predict does not mutate the example.

      Throws IllegalArgumentException if the example has no features or no feature overlap with the model.

      Specified by:
      predict in class Model<T extends Output<T>>
      Parameters:
      example - the example to predict.
      Returns:
      the result of the prediction.
    • innerPredict

      protected List<Prediction<T>> innerPredict(Iterable<Example<T>> examples)
      Description copied from class: Model
      Called by the base implementations of Model.predict(Iterable) and Model.predict(Dataset).
      Overrides:
      innerPredict in class Model<T extends Output<T>>
      Parameters:
      examples - The examples to predict.
      Returns:
      The results of the predictions, in the same order as the examples.
    • convertFeatures

      protected abstract U convertFeatures(SparseVector input)
      Converts from a SparseVector using the external model's indices into the ingestion format for the external model.
      Parameters:
      input - The features using external indices.
      Returns:
      The ingestion format for the external model.
    • convertFeaturesList

      protected abstract U convertFeaturesList(List<SparseVector> input)
      Converts from a list of SparseVector using the external model's indices into the ingestion format for the external model.
      Parameters:
      input - The features using external indices.
      Returns:
      The ingestion format for the external model.
    • externalPrediction

      protected abstract V externalPrediction(U input)
      Runs the external model's prediction function.
      Parameters:
      input - The input in the external model's format.
      Returns:
      The output in the external model's format.
    • convertOutput

      protected abstract Prediction<T> convertOutput(V output, int numValidFeatures, Example<T> example)
      Converts the output of the external model into a Prediction.
      Parameters:
      output - The output of the external model.
      numValidFeatures - The number of valid features in the input.
      example - The input example, used to construct the Prediction.
      Returns:
      A Tribuo Prediction.
    • convertOutput

      protected abstract List<Prediction<T>> convertOutput(V output, int[] numValidFeatures, List<Example<T>> examples)
      Converts the output of the external model into a list of Predictions.
      Parameters:
      output - The output of the external model.
      numValidFeatures - An array with the number of valid features in each example.
      examples - The input examples, used to construct the Predictions.
      Returns:
      A list of Tribuo Predictions.
    • getExcuse

      public Optional<Excuse<T>> getExcuse(Example<T> example)
      By default third party models don't return excuses.
      Specified by:
      getExcuse in class Model<T extends Output<T>>
      Parameters:
      example - The input example.
      Returns:
      Optional.empty.
    • getBatchSize

      public int getBatchSize()
      Gets the current testing batch size.
      Returns:
      The batch size.
    • setBatchSize

      public void setBatchSize(int batchSize)
      Sets a new batch size.

      Throws IllegalArgumentException if the batch size isn't positive.

      Parameters:
      batchSize - The batch size to use.
    • createFeatureMap

      protected static ImmutableFeatureMap createFeatureMap(Set<String> featureNames)
      Creates an immutable feature map from a set of feature names.

      Each feature is unobserved.

      Parameters:
      featureNames - The names of the features to create.
      Returns:
      A feature map representing the feature names.
    • createOutputInfo

      protected static <T extends Output<T>> ImmutableOutputInfo<T> createOutputInfo(OutputFactory<T> factory, Map<T,Integer> outputs)
      Creates an output info from a set of outputs.
      Type Parameters:
      T - The type of the outputs.
      Parameters:
      factory - The output factory to use.
      outputs - The outputs and ids to observe.
      Returns:
      An immutable output info representing the outputs.
    • validateFeatureMapping

      protected static boolean validateFeatureMapping(int[] featureForwardMapping, int[] featureBackwardMapping, ImmutableFeatureMap featureDomain)
      Checks if the feature mappings are valid for the supplied feature map.
      Parameters:
      featureForwardMapping - The forward feature mapping.
      featureBackwardMapping - The backward feature mapping.
      featureDomain - The feature domain.
      Returns:
      True if the feature mapping is valid (the forward and backward mappings are a bijection and the same size as the feature domain).