Class AbstractFMTrainer<T extends Output<T>,U,V extends AbstractFMModel<T>>

java.lang.Object
org.tribuo.common.sgd.AbstractSGDTrainer<T,U,V,FMParameters>
org.tribuo.common.sgd.AbstractFMTrainer<T,U,V>
All Implemented Interfaces:
com.oracle.labs.mlrg.olcut.config.Configurable, com.oracle.labs.mlrg.olcut.provenance.Provenancable<TrainerProvenance>, Trainer<T>, WeightedExamples
Direct Known Subclasses:
FMClassificationTrainer, FMMultiLabelTrainer, FMRegressionTrainer

public abstract class AbstractFMTrainer<T extends Output<T>,U,V extends AbstractFMModel<T>> extends AbstractSGDTrainer<T,U,V,FMParameters>
A trainer for a quadratic factorization machine model which uses SGD.

It's an AbstractSGDTrainer operating on FMParameters.

See:

 Rendle, S.
 Factorization machines.
 2010 IEEE International Conference on Data Mining
 
  • Field Details

    • factorizedDimSize

      @Config(mandatory=true, description="The size of the factorized feature representation.") protected int factorizedDimSize
    • variance

      @Config(mandatory=true, description="The variance of the initializer.") protected double variance
  • Constructor Details

    • AbstractFMTrainer

      protected AbstractFMTrainer(StochasticGradientOptimiser optimiser, int epochs, int loggingInterval, int minibatchSize, long seed, int factorizedDimSize, double variance)
      Constructs an SGD trainer for a factorization machine.
      Parameters:
      optimiser - The gradient optimiser to use.
      epochs - The number of epochs (complete passes through the training data).
      loggingInterval - Log the loss after this many iterations. If -1 don't log anything.
      minibatchSize - The size of any minibatches.
      seed - A seed for the random number generator, used to shuffle the examples before each epoch.
      factorizedDimSize - Size of the factorized feature representation.
      variance - The variance of the initializer.
    • AbstractFMTrainer

      protected AbstractFMTrainer()
      For olcut.
  • Method Details