Package org.tribuo.classification.sgd.fm
Class FMClassificationTrainer
java.lang.Object
org.tribuo.common.sgd.AbstractSGDTrainer<T,U,V,FMParameters>
org.tribuo.common.sgd.AbstractFMTrainer<Label,Integer,FMClassificationModel>
org.tribuo.classification.sgd.fm.FMClassificationTrainer
- All Implemented Interfaces:
com.oracle.labs.mlrg.olcut.config.Configurable
,com.oracle.labs.mlrg.olcut.provenance.Provenancable<TrainerProvenance>
,Trainer<Label>
,WeightedExamples
A trainer for a classification factorization machine using SGD.
See:
Rendle, S. Factorization machines. 2010 IEEE International Conference on Data Mining
-
Field Summary
Fields inherited from class org.tribuo.common.sgd.AbstractFMTrainer
factorizedDimSize, variance
Fields inherited from class org.tribuo.common.sgd.AbstractSGDTrainer
addBias, epochs, loggingInterval, minibatchSize, optimiser, rng, seed, shuffle
Fields inherited from interface org.tribuo.Trainer
DEFAULT_SEED, INCREMENT_INVOCATION_COUNT
-
Constructor Summary
ConstructorDescriptionFMClassificationTrainer
(LabelObjective objective, StochasticGradientOptimiser optimiser, int epochs, int loggingInterval, int minibatchSize, long seed, int factorizedDimSize, double variance) Constructs an SGD trainer for a factorization machine.FMClassificationTrainer
(LabelObjective objective, StochasticGradientOptimiser optimiser, int epochs, int loggingInterval, long seed, int factorizedDimSize, double variance) Constructs an SGD trainer for a factorization machine.FMClassificationTrainer
(LabelObjective objective, StochasticGradientOptimiser optimiser, int epochs, long seed, int factorizedDimSize, double variance) Constructs an SGD trainer for a factorization machine. -
Method Summary
Modifier and TypeMethodDescriptionprotected FMClassificationModel
createModel
(String name, ModelProvenance provenance, ImmutableFeatureMap featureMap, ImmutableOutputInfo<Label> outputInfo, FMParameters parameters) Creates the appropriate model subclass for this subclass of AbstractSGDTrainer.protected String
Returns the class name of the model that's produced by this trainer.protected SGDObjective<Integer>
Returns the objective used by this trainer.protected Integer
getTarget
(ImmutableOutputInfo<Label> outputInfo, Label output) Extracts the appropriate training time representation from the supplied output.toString()
Methods inherited from class org.tribuo.common.sgd.AbstractFMTrainer
createParameters, getName, postConfig
Methods inherited from class org.tribuo.common.sgd.AbstractSGDTrainer
getInvocationCount, getProvenance, setInvocationCount, setShuffle, shuffleInPlace, train, train, train
-
Constructor Details
-
FMClassificationTrainer
public FMClassificationTrainer(LabelObjective objective, StochasticGradientOptimiser optimiser, int epochs, int loggingInterval, int minibatchSize, long seed, int factorizedDimSize, double variance) Constructs an SGD trainer for a factorization machine.- Parameters:
objective
- The objective function to optimise.optimiser
- The gradient optimiser to use.epochs
- The number of epochs (complete passes through the training data).loggingInterval
- Log the loss after this many iterations. If -1 don't log anything.minibatchSize
- The size of any minibatches.seed
- A seed for the random number generator, used to shuffle the examples before each epoch.factorizedDimSize
- Size of the factorized feature representation.variance
- The variance of the initializer.
-
FMClassificationTrainer
public FMClassificationTrainer(LabelObjective objective, StochasticGradientOptimiser optimiser, int epochs, int loggingInterval, long seed, int factorizedDimSize, double variance) Constructs an SGD trainer for a factorization machine.Sets the minibatch size to 1.
- Parameters:
objective
- The objective function to optimise.optimiser
- The gradient optimiser to use.epochs
- The number of epochs (complete passes through the training data).loggingInterval
- Log the loss after this many iterations. If -1 don't log anything.seed
- A seed for the random number generator, used to shuffle the examples before each epoch.factorizedDimSize
- Size of the factorized feature representation.variance
- The variance of the initializer.
-
FMClassificationTrainer
public FMClassificationTrainer(LabelObjective objective, StochasticGradientOptimiser optimiser, int epochs, long seed, int factorizedDimSize, double variance) Constructs an SGD trainer for a factorization machine.Sets the minibatch size to 1 and the logging interval to 1000.
- Parameters:
objective
- The objective function to optimise.optimiser
- The gradient optimiser to use.epochs
- The number of epochs (complete passes through the training data).seed
- A seed for the random number generator, used to shuffle the examples before each epoch.factorizedDimSize
- Size of the factorized feature representation.variance
- The variance of the initializer.
-
-
Method Details
-
getTarget
Description copied from class:AbstractSGDTrainer
Extracts the appropriate training time representation from the supplied output.- Specified by:
getTarget
in classAbstractSGDTrainer<Label,
Integer, FMClassificationModel, FMParameters> - Parameters:
outputInfo
- The output info to use.output
- The output to extract.- Returns:
- The training time representation of the output.
-
getObjective
Description copied from class:AbstractSGDTrainer
Returns the objective used by this trainer.- Specified by:
getObjective
in classAbstractSGDTrainer<Label,
Integer, FMClassificationModel, FMParameters> - Returns:
- The SGDObjective used by this trainer.
-
createModel
protected FMClassificationModel createModel(String name, ModelProvenance provenance, ImmutableFeatureMap featureMap, ImmutableOutputInfo<Label> outputInfo, FMParameters parameters) Description copied from class:AbstractSGDTrainer
Creates the appropriate model subclass for this subclass of AbstractSGDTrainer.- Specified by:
createModel
in classAbstractSGDTrainer<Label,
Integer, FMClassificationModel, FMParameters> - Parameters:
name
- The model name.provenance
- The model provenance.featureMap
- The feature map.outputInfo
- The output info.parameters
- The model parameters.- Returns:
- A new instance of the appropriate subclass of
Model
.
-
getModelClassName
Description copied from class:AbstractSGDTrainer
Returns the class name of the model that's produced by this trainer.- Specified by:
getModelClassName
in classAbstractSGDTrainer<Label,
Integer, FMClassificationModel, FMParameters> - Returns:
- The model class name;
-
toString
-