Class LibSVMClassificationTrainer

java.lang.Object
org.tribuo.common.libsvm.LibSVMTrainer<Label>
org.tribuo.classification.libsvm.LibSVMClassificationTrainer
All Implemented Interfaces:
com.oracle.labs.mlrg.olcut.config.Configurable, com.oracle.labs.mlrg.olcut.provenance.Provenancable<TrainerProvenance>, WeightedLabels, Trainer<Label>

public class LibSVMClassificationTrainer extends LibSVMTrainer<Label> implements WeightedLabels
A trainer for classification models that uses LibSVM.

Note the train method is synchronized on LibSVMTrainer.class due to a global RNG in LibSVM. This is insufficient to ensure reproducibility if LibSVM is used directly in the same JVM as Tribuo, but avoids locking on classes Tribuo does not control.

See:

 Chang CC, Lin CJ.
 "LIBSVM: a library for Support Vector Machines"
 ACM transactions on intelligent systems and technology (TIST), 2011.
 
for the nu-svc algorithm:
 Schölkopf B, Smola A, Williamson R, Bartlett P L.
 "New support vector algorithms"
 Neural Computation, 2000, 1207-1245.
 
and for the original algorithm:
 Cortes C, Vapnik V.
 "Support-Vector Networks"
 Machine Learning, 1995.
 
  • Constructor Details

    • LibSVMClassificationTrainer

      protected LibSVMClassificationTrainer()
      For OLCUT.
    • LibSVMClassificationTrainer

      public LibSVMClassificationTrainer(SVMParameters<Label> parameters)
      Constructs a classification LibSVM trainer using the specified parameters and Trainer.DEFAULT_SEED.
      Parameters:
      parameters - The SVM parameters.
    • LibSVMClassificationTrainer

      public LibSVMClassificationTrainer(SVMParameters<Label> parameters, long seed)
      Constructs a classification LibSVM trainer using the specified parameters and seed.
      Parameters:
      parameters - The SVM parameters.
      seed - The RNG seed for LibSVM's internal RNG.
  • Method Details

    • postConfig

      public void postConfig()
      Used by the OLCUT configuration system, and should not be called by external code.
      Specified by:
      postConfig in interface com.oracle.labs.mlrg.olcut.config.Configurable
      Overrides:
      postConfig in class LibSVMTrainer<Label>
    • createModel

      protected LibSVMModel<Label> createModel(ModelProvenance provenance, ImmutableFeatureMap featureIDMap, ImmutableOutputInfo<Label> outputIDInfo, List<libsvm.svm_model> models)
      Description copied from class: LibSVMTrainer
      Construct the appropriate subtype of LibSVMModel for the prediction task.
      Specified by:
      createModel in class LibSVMTrainer<Label>
      Parameters:
      provenance - The model provenance.
      featureIDMap - The feature id map.
      outputIDInfo - The output id info.
      models - The svm models.
      Returns:
      An implementation of LibSVMModel.
    • trainModels

      protected List<libsvm.svm_model> trainModels(libsvm.svm_parameter curParams, int numFeatures, libsvm.svm_node[][] features, double[][] outputs, SplittableRandom localRNG)
      Description copied from class: LibSVMTrainer
      Train all the LibSVM instances necessary for this dataset.
      Specified by:
      trainModels in class LibSVMTrainer<Label>
      Parameters:
      curParams - The LibSVM parameters.
      numFeatures - The number of features in this dataset.
      features - The features themselves.
      outputs - The outputs.
      localRNG - The RNG to use for seeding LibSVM's RNG.
      Returns:
      A list of LibSVM models.
    • extractData

      protected com.oracle.labs.mlrg.olcut.util.Pair<libsvm.svm_node[][],double[][]> extractData(Dataset<Label> data, ImmutableOutputInfo<Label> outputInfo, ImmutableFeatureMap featureMap)
      Description copied from class: LibSVMTrainer
      Extracts the features and Outputs in LibSVM's format.
      Specified by:
      extractData in class LibSVMTrainer<Label>
      Parameters:
      data - The input data.
      outputInfo - The output info.
      featureMap - The feature info.
      Returns:
      The features and outputs.
    • setupParameters

      protected libsvm.svm_parameter setupParameters(ImmutableOutputInfo<Label> outputIDInfo)
      Description copied from class: LibSVMTrainer
      Constructs the svm_parameter. Most of the time this is a no-op, but classification overrides it to incorporate label weights if they exist.
      Overrides:
      setupParameters in class LibSVMTrainer<Label>
      Parameters:
      outputIDInfo - The output info.
      Returns:
      The svm_parameters to use for training.
    • setLabelWeights

      public void setLabelWeights(Map<Label,Float> weights)
      Description copied from interface: WeightedLabels
      Sets the label weights used by this trainer.

      Supply Collections.emptyMap() to turn off label weights.

      Specified by:
      setLabelWeights in interface WeightedLabels
      Parameters:
      weights - A map from Label instances to weight values.