Class AbstractSGDModel<T extends Output<T>>

java.lang.Object
org.tribuo.Model<T>
org.tribuo.common.sgd.AbstractSGDModel<T>
All Implemented Interfaces:
com.oracle.labs.mlrg.olcut.provenance.Provenancable<ModelProvenance>, Serializable, ProtoSerializable<org.tribuo.protos.core.ModelProto>
Direct Known Subclasses:
AbstractFMModel, AbstractLinearSGDModel

public abstract class AbstractSGDModel<T extends Output<T>> extends Model<T>
A model trained using SGD.

See:

 Bottou L.
 "Large-Scale Machine Learning with Stochastic Gradient Descent"
 Proceedings of COMPSTAT, 2010.
 
See Also:
  • Field Details

    • modelParameters

      protected FeedForwardParameters modelParameters
      The weights for this model.
    • addBias

      protected boolean addBias
  • Constructor Details

    • AbstractSGDModel

      protected AbstractSGDModel(String name, ModelProvenance provenance, ImmutableFeatureMap featureIDMap, ImmutableOutputInfo<T> outputIDInfo, FeedForwardParameters weights, boolean generatesProbabilities, boolean addBias)
      Constructs a linear model trained via SGD.
      Parameters:
      name - The model name.
      provenance - The model provenance.
      featureIDMap - The feature domain.
      outputIDInfo - The output domain.
      weights - The model weights.
      generatesProbabilities - Does this model generate probabilities?
      addBias - Should the model add a bias feature to the feature vector?
  • Method Details

    • predictSingle

      protected AbstractSGDModel.PredAndActive predictSingle(Example<T> example)
      Generates the dense vector prediction from the supplied example.
      Parameters:
      example - The example to use for prediction.
      Returns:
      The prediction and the number of features involved.
    • getModelParameters

      public FeedForwardParameters getModelParameters()
      Returns a copy of the model parameters.
      Returns:
      A copy of the model parameters.