Class AbstractFMModel<T extends Output<T>>
java.lang.Object
org.tribuo.Model<T>
org.tribuo.common.sgd.AbstractSGDModel<T>
org.tribuo.common.sgd.AbstractFMModel<T>
- All Implemented Interfaces:
com.oracle.labs.mlrg.olcut.provenance.Provenancable<ModelProvenance>,Serializable
- Direct Known Subclasses:
FMClassificationModel,FMMultiLabelModel,FMRegressionModel
A quadratic factorization machine model trained using SGD.
It's an AbstractSGDTrainer operating on FMParameters.
See:
Rendle, S. Factorization machines. 2010 IEEE International Conference on Data Mining
- See Also:
-
Nested Class Summary
Nested classes/interfaces inherited from class org.tribuo.common.sgd.AbstractSGDModel
AbstractSGDModel.PredAndActive -
Field Summary
Fields inherited from class org.tribuo.common.sgd.AbstractSGDModel
addBias, modelParametersFields inherited from class org.tribuo.Model
ALL_OUTPUTS, BIAS_FEATURE, featureIDMap, generatesProbabilities, name, outputIDInfo, provenance, provenanceOutput -
Constructor Summary
ConstructorsModifierConstructorDescriptionprotectedAbstractFMModel(String name, ModelProvenance provenance, ImmutableFeatureMap featureIDMap, ImmutableOutputInfo<T> outputIDInfo, FMParameters parameters, boolean generatesProbabilities) Constructs a factorization machine model trained via SGD. -
Method Summary
Modifier and TypeMethodDescriptionexportONNXModel(String domain, long modelVersion) Exports thisModelas an ONNX protobuf.Returns a copy of the output dimension biases.protected abstract StringgetDimensionName(int index) Gets the name of the indexed output dimension.Factorization machines don't provide excuses, use an explainer.Tensor[]Returns a copy of the factors.Returns a copy of the linear weights.getTopFeatures(int n) Gets the topnfeatures for each output dimension.protected abstract Stringprotected abstract ONNXNodeonnxOutput(ONNXNode input) Takes the unnormalized ONNX output of this model and applies an appropriate normalizer from the concrete class.writeONNXGraph(ONNXRef<?> input) Methods inherited from class org.tribuo.common.sgd.AbstractSGDModel
getModelParameters, predictSingleMethods inherited from class org.tribuo.Model
castModel, copy, copy, generatesProbabilities, getExcuses, getFeatureIDMap, getName, getOutputIDInfo, getProvenance, innerPredict, predict, predict, predict, setName, toString, validate
-
Constructor Details
-
AbstractFMModel
protected AbstractFMModel(String name, ModelProvenance provenance, ImmutableFeatureMap featureIDMap, ImmutableOutputInfo<T> outputIDInfo, FMParameters parameters, boolean generatesProbabilities) Constructs a factorization machine model trained via SGD.- Parameters:
name- The model name.provenance- The model provenance.featureIDMap- The feature domain.outputIDInfo- The output domain.parameters- The model parameters.generatesProbabilities- Does this model generate probabilities?
-
-
Method Details
-
getTopFeatures
Gets the topnfeatures for each output dimension.Note that the feature rankings are based only off the linear portion of the factorization machine.
- Specified by:
getTopFeaturesin classModel<T extends Output<T>>- Parameters:
n- The number of features to return. If this value is less than 0, all features are returned for each class.- Returns:
- A map from string outputs to an ordered list of pairs of feature names and weights associated with that feature in the factorization machine.
-
getLinearWeightsCopy
Returns a copy of the linear weights.- Returns:
- The linear weights.
-
getBiasesCopy
Returns a copy of the output dimension biases.- Returns:
- The biases.
-
getFactorsCopy
Returns a copy of the factors. There is one factor matrix per output dimension. The first factor matrix dimension is the factor dimension, the second is the number of features.- Returns:
- The factors.
-
getExcuse
-
getDimensionName
Gets the name of the indexed output dimension.- Parameters:
index- The output dimension index.- Returns:
- The name of the requested output dimension.
-
onnxOutput
-
onnxModelName
- Returns:
- Name to write into the ONNX Model.
-
writeONNXGraph
- Parameters:
input- The input to the model graph.- Returns:
- the output node of the model graph.
-
exportONNXModel
Exports thisModelas an ONNX protobuf.- Parameters:
domain- A reverse-DNS name to namespace the model (e.g., org.tribuo.classification.sgd.linear).modelVersion- A version number for this model.- Returns:
- The ONNX ModelProto representing this Tribuo Model.
-