Class FMClassificationModel

All Implemented Interfaces:
com.oracle.labs.mlrg.olcut.provenance.Provenancable<ModelProvenance>, Serializable, ONNXExportable, ProtoSerializable<org.tribuo.protos.core.ModelProto>

public class FMClassificationModel extends AbstractFMModel<Label> implements ONNXExportable
The inference time version of a factorization machine trained using SGD.

See:

 Rendle, S.
 Factorization machines.
 2010 IEEE International Conference on Data Mining
 
See Also:
  • Field Details

    • CURRENT_VERSION

      public static final int CURRENT_VERSION
      Protobuf serialization version.
      See Also:
  • Method Details

    • deserializeFromProto

      public static FMClassificationModel deserializeFromProto(int version, String className, com.google.protobuf.Any message) throws com.google.protobuf.InvalidProtocolBufferException
      Deserialization factory.
      Parameters:
      version - The serialized object version.
      className - The class name.
      message - The serialized data.
      Returns:
      The deserialized object.
      Throws:
      com.google.protobuf.InvalidProtocolBufferException - If the protobuf could not be parsed from the message.
    • predict

      public Prediction<Label> predict(Example<Label> example)
      Description copied from class: Model
      Uses the model to predict the output for a single example.

      predict does not mutate the example.

      Throws IllegalArgumentException if the example has no features or no feature overlap with the model.

      Specified by:
      predict in class Model<Label>
      Parameters:
      example - the example to predict.
      Returns:
      the result of the prediction.
    • serialize

      public org.tribuo.protos.core.ModelProto serialize()
      Description copied from interface: ProtoSerializable
      Serializes this object to a protobuf.
      Specified by:
      serialize in interface ProtoSerializable<org.tribuo.protos.core.ModelProto>
      Overrides:
      serialize in class Model<Label>
      Returns:
      The protobuf.
    • copy

      protected FMClassificationModel copy(String newName, ModelProvenance newProvenance)
      Description copied from class: Model
      Copies a model, replacing its provenance and name with the supplied values.

      Used to provide the provenance removal functionality.

      Specified by:
      copy in class Model<Label>
      Parameters:
      newName - The new name.
      newProvenance - The new provenance.
      Returns:
      A copy of the model.
    • getDimensionName

      protected String getDimensionName(int index)
      Description copied from class: AbstractFMModel
      Gets the name of the indexed output dimension.
      Specified by:
      getDimensionName in class AbstractFMModel<Label>
      Parameters:
      index - The output dimension index.
      Returns:
      The name of the requested output dimension.
    • onnxModelName

      protected String onnxModelName()
      Specified by:
      onnxModelName in class AbstractFMModel<Label>
      Returns:
      Name to write into the ONNX Model.
    • onnxOutput

      protected ONNXNode onnxOutput(ONNXNode input)
      Description copied from class: AbstractFMModel
      Takes the unnormalized ONNX output of this model and applies an appropriate normalizer from the concrete class.
      Specified by:
      onnxOutput in class AbstractFMModel<Label>
      Parameters:
      input - Unnormalized ONNX leaf node.
      Returns:
      Normalized ONNX leaf node.