Package org.tribuo.common.sgd
Class AbstractSGDModel<T extends Output<T>>
java.lang.Object
org.tribuo.Model<T>
org.tribuo.common.sgd.AbstractSGDModel<T>
- All Implemented Interfaces:
com.oracle.labs.mlrg.olcut.provenance.Provenancable<ModelProvenance>
,Serializable
,ProtoSerializable<org.tribuo.protos.core.ModelProto>
- Direct Known Subclasses:
AbstractFMModel
,AbstractLinearSGDModel
A model trained using SGD.
See:
Bottou L. "Large-Scale Machine Learning with Stochastic Gradient Descent" Proceedings of COMPSTAT, 2010.
- See Also:
-
Nested Class Summary
Modifier and TypeClassDescriptionprotected static final class
A nominal tuple used to capture the prediction and the number of active features used by the model. -
Field Summary
Modifier and TypeFieldDescriptionprotected boolean
protected FeedForwardParameters
The weights for this model.Fields inherited from class org.tribuo.Model
ALL_OUTPUTS, BIAS_FEATURE, featureIDMap, generatesProbabilities, name, outputIDInfo, provenance, provenanceOutput
Fields inherited from interface org.tribuo.protos.ProtoSerializable
DESERIALIZATION_METHOD_NAME, PROVENANCE_SERIALIZER
-
Constructor Summary
ModifierConstructorDescriptionprotected
AbstractSGDModel
(String name, ModelProvenance provenance, ImmutableFeatureMap featureIDMap, ImmutableOutputInfo<T> outputIDInfo, FeedForwardParameters weights, boolean generatesProbabilities, boolean addBias) Constructs a linear model trained via SGD. -
Method Summary
Modifier and TypeMethodDescriptionReturns a copy of the model parameters.protected AbstractSGDModel.PredAndActive
predictSingle
(Example<T> example) Generates the dense vector prediction from the supplied example.Methods inherited from class org.tribuo.Model
castModel, copy, copy, createDataCarrier, deserialize, deserializeFromFile, deserializeFromStream, generatesProbabilities, getExcuse, getExcuses, getFeatureIDMap, getName, getOutputIDInfo, getProvenance, getTopFeatures, innerPredict, predict, predict, predict, serialize, serializeToFile, serializeToStream, setName, toString, validate
-
Field Details
-
modelParameters
The weights for this model. -
addBias
protected boolean addBias
-
-
Constructor Details
-
AbstractSGDModel
protected AbstractSGDModel(String name, ModelProvenance provenance, ImmutableFeatureMap featureIDMap, ImmutableOutputInfo<T> outputIDInfo, FeedForwardParameters weights, boolean generatesProbabilities, boolean addBias) Constructs a linear model trained via SGD.- Parameters:
name
- The model name.provenance
- The model provenance.featureIDMap
- The feature domain.outputIDInfo
- The output domain.weights
- The model weights.generatesProbabilities
- Does this model generate probabilities?addBias
- Should the model add a bias feature to the feature vector?
-
-
Method Details
-
predictSingle
Generates the dense vector prediction from the supplied example.- Parameters:
example
- The example to use for prediction.- Returns:
- The prediction and the number of features involved.
-
getModelParameters
Returns a copy of the model parameters.- Returns:
- A copy of the model parameters.
-