Class AbstractFMModel<T extends Output<T>>

java.lang.Object
org.tribuo.Model<T>
org.tribuo.common.sgd.AbstractSGDModel<T>
org.tribuo.common.sgd.AbstractFMModel<T>
All Implemented Interfaces:
com.oracle.labs.mlrg.olcut.provenance.Provenancable<ModelProvenance>, Serializable, ProtoSerializable<org.tribuo.protos.core.ModelProto>
Direct Known Subclasses:
FMClassificationModel, FMMultiLabelModel, FMRegressionModel

public abstract class AbstractFMModel<T extends Output<T>> extends AbstractSGDModel<T>
A quadratic factorization machine model trained using SGD.

It's an AbstractSGDTrainer operating on FMParameters.

See:

 Rendle, S.
 Factorization machines.
 2010 IEEE International Conference on Data Mining
 
See Also:
  • Constructor Details

    • AbstractFMModel

      protected AbstractFMModel(String name, ModelProvenance provenance, ImmutableFeatureMap featureIDMap, ImmutableOutputInfo<T> outputIDInfo, FMParameters parameters, boolean generatesProbabilities)
      Constructs a factorization machine model trained via SGD.
      Parameters:
      name - The model name.
      provenance - The model provenance.
      featureIDMap - The feature domain.
      outputIDInfo - The output domain.
      parameters - The model parameters.
      generatesProbabilities - Does this model generate probabilities?
  • Method Details

    • getTopFeatures

      public Map<String,List<com.oracle.labs.mlrg.olcut.util.Pair<String,Double>>> getTopFeatures(int n)
      Gets the top n features for each output dimension.

      Note that the feature rankings are based only off the linear portion of the factorization machine.

      Specified by:
      getTopFeatures in class Model<T extends Output<T>>
      Parameters:
      n - The number of features to return. If this value is less than 0, all features are returned for each class.
      Returns:
      A map from string outputs to an ordered list of pairs of feature names and weights associated with that feature in the factorization machine.
    • getLinearWeightsCopy

      public DenseMatrix getLinearWeightsCopy()
      Returns a copy of the linear weights.
      Returns:
      The linear weights.
    • getBiasesCopy

      public DenseVector getBiasesCopy()
      Returns a copy of the output dimension biases.
      Returns:
      The biases.
    • getFactorsCopy

      public Tensor[] getFactorsCopy()
      Returns a copy of the factors. There is one factor matrix per output dimension. The first factor matrix dimension is the factor dimension, the second is the number of features.
      Returns:
      The factors.
    • getExcuse

      public Optional<Excuse<T>> getExcuse(Example<T> example)
      Factorization machines don't provide excuses, use an explainer.
      Specified by:
      getExcuse in class Model<T extends Output<T>>
      Parameters:
      example - The input example.
      Returns:
      Optional.empty.
    • getDimensionName

      protected abstract String getDimensionName(int index)
      Gets the name of the indexed output dimension.
      Parameters:
      index - The output dimension index.
      Returns:
      The name of the requested output dimension.
    • onnxOutput

      protected abstract ONNXNode onnxOutput(ONNXNode input)
      Takes the unnormalized ONNX output of this model and applies an appropriate normalizer from the concrete class.
      Parameters:
      input - Unnormalized ONNX leaf node.
      Returns:
      Normalized ONNX leaf node.
    • onnxModelName

      protected abstract String onnxModelName()
      Returns:
      Name to write into the ONNX Model.
    • writeONNXGraph

      public ONNXNode writeONNXGraph(ONNXRef<?> input)
      Writes this Model into OnnxMl.GraphProto.Builder inside the input's ONNXContext.
      Parameters:
      input - The input to the model graph.
      Returns:
      the output node of the model graph.
    • exportONNXModel

      public ai.onnx.proto.OnnxMl.ModelProto exportONNXModel(String domain, long modelVersion)
      Exports this Model as an ONNX protobuf.
      Parameters:
      domain - A reverse-DNS name to namespace the model (e.g., org.tribuo.classification.sgd.linear).
      modelVersion - A version number for this model.
      Returns:
      The ONNX ModelProto representing this Tribuo Model.