Package org.tribuo.common.sgd
Class AbstractFMModel<T extends Output<T>>
java.lang.Object
org.tribuo.Model<T>
org.tribuo.common.sgd.AbstractSGDModel<T>
org.tribuo.common.sgd.AbstractFMModel<T>
- All Implemented Interfaces:
com.oracle.labs.mlrg.olcut.provenance.Provenancable<ModelProvenance>
,Serializable
,ProtoSerializable<org.tribuo.protos.core.ModelProto>
- Direct Known Subclasses:
FMClassificationModel
,FMMultiLabelModel
,FMRegressionModel
A quadratic factorization machine model trained using SGD.
It's an AbstractSGDTrainer
operating on FMParameters
.
See:
Rendle, S. Factorization machines. 2010 IEEE International Conference on Data Mining
- See Also:
-
Nested Class Summary
Nested classes/interfaces inherited from class org.tribuo.common.sgd.AbstractSGDModel
AbstractSGDModel.PredAndActive
-
Field Summary
Fields inherited from class org.tribuo.common.sgd.AbstractSGDModel
addBias, modelParameters
Fields inherited from class org.tribuo.Model
ALL_OUTPUTS, BIAS_FEATURE, featureIDMap, generatesProbabilities, name, outputIDInfo, provenance, provenanceOutput
Fields inherited from interface org.tribuo.protos.ProtoSerializable
DESERIALIZATION_METHOD_NAME, PROVENANCE_SERIALIZER
-
Constructor Summary
ModifierConstructorDescriptionprotected
AbstractFMModel
(String name, ModelProvenance provenance, ImmutableFeatureMap featureIDMap, ImmutableOutputInfo<T> outputIDInfo, FMParameters parameters, boolean generatesProbabilities) Constructs a factorization machine model trained via SGD. -
Method Summary
Modifier and TypeMethodDescriptionai.onnx.proto.OnnxMl.ModelProto
exportONNXModel
(String domain, long modelVersion) Exports thisModel
as an ONNX protobuf.Returns a copy of the output dimension biases.protected abstract String
getDimensionName
(int index) Gets the name of the indexed output dimension.Factorization machines don't provide excuses, use an explainer.Tensor[]
Returns a copy of the factors.Returns a copy of the linear weights.getTopFeatures
(int n) Gets the topn
features for each output dimension.protected abstract String
protected abstract ONNXNode
onnxOutput
(ONNXNode input) Takes the unnormalized ONNX output of this model and applies an appropriate normalizer from the concrete class.writeONNXGraph
(ONNXRef<?> input) Methods inherited from class org.tribuo.common.sgd.AbstractSGDModel
getModelParameters, predictSingle
Methods inherited from class org.tribuo.Model
castModel, copy, copy, createDataCarrier, deserialize, deserializeFromFile, deserializeFromStream, generatesProbabilities, getExcuses, getFeatureIDMap, getName, getOutputIDInfo, getProvenance, innerPredict, predict, predict, predict, serialize, serializeToFile, serializeToStream, setName, toString, validate
-
Constructor Details
-
AbstractFMModel
protected AbstractFMModel(String name, ModelProvenance provenance, ImmutableFeatureMap featureIDMap, ImmutableOutputInfo<T> outputIDInfo, FMParameters parameters, boolean generatesProbabilities) Constructs a factorization machine model trained via SGD.- Parameters:
name
- The model name.provenance
- The model provenance.featureIDMap
- The feature domain.outputIDInfo
- The output domain.parameters
- The model parameters.generatesProbabilities
- Does this model generate probabilities?
-
-
Method Details
-
getTopFeatures
Gets the topn
features for each output dimension.Note that the feature rankings are based only off the linear portion of the factorization machine.
- Specified by:
getTopFeatures
in classModel<T extends Output<T>>
- Parameters:
n
- The number of features to return. If this value is less than 0, all features are returned for each class.- Returns:
- A map from string outputs to an ordered list of pairs of feature names and weights associated with that feature in the factorization machine.
-
getLinearWeightsCopy
Returns a copy of the linear weights.- Returns:
- The linear weights.
-
getBiasesCopy
Returns a copy of the output dimension biases.- Returns:
- The biases.
-
getFactorsCopy
Returns a copy of the factors. There is one factor matrix per output dimension. The first factor matrix dimension is the factor dimension, the second is the number of features.- Returns:
- The factors.
-
getExcuse
Factorization machines don't provide excuses, use an explainer. -
getDimensionName
Gets the name of the indexed output dimension.- Parameters:
index
- The output dimension index.- Returns:
- The name of the requested output dimension.
-
onnxOutput
Takes the unnormalized ONNX output of this model and applies an appropriate normalizer from the concrete class.- Parameters:
input
- Unnormalized ONNX leaf node.- Returns:
- Normalized ONNX leaf node.
-
onnxModelName
- Returns:
- Name to write into the ONNX Model.
-
writeONNXGraph
- Parameters:
input
- The input to the model graph.- Returns:
- the output node of the model graph.
-
exportONNXModel
Exports thisModel
as an ONNX protobuf.- Parameters:
domain
- A reverse-DNS name to namespace the model (e.g., org.tribuo.classification.sgd.linear).modelVersion
- A version number for this model.- Returns:
- The ONNX ModelProto representing this Tribuo Model.
-