Package org.tribuo.classification.libsvm
Class LibSVMClassificationTrainer
java.lang.Object
org.tribuo.common.libsvm.LibSVMTrainer<Label>
org.tribuo.classification.libsvm.LibSVMClassificationTrainer
- All Implemented Interfaces:
com.oracle.labs.mlrg.olcut.config.Configurable
,com.oracle.labs.mlrg.olcut.provenance.Provenancable<TrainerProvenance>
,WeightedLabels
,Trainer<Label>
A trainer for classification models that uses LibSVM.
Note the train method is synchronized on LibSVMTrainer.class
due to a global RNG in LibSVM.
This is insufficient to ensure reproducibility if LibSVM is used directly in the same JVM as Tribuo, but
avoids locking on classes Tribuo does not control.
See:
Chang CC, Lin CJ. "LIBSVM: a library for Support Vector Machines" ACM transactions on intelligent systems and technology (TIST), 2011.for the nu-svc algorithm:
Schölkopf B, Smola A, Williamson R, Bartlett P L. "New support vector algorithms" Neural Computation, 2000, 1207-1245.and for the original algorithm:
Cortes C, Vapnik V. "Support-Vector Networks" Machine Learning, 1995.
-
Field Summary
Fields inherited from class org.tribuo.common.libsvm.LibSVMTrainer
parameters, svmType
Fields inherited from interface org.tribuo.Trainer
DEFAULT_SEED, INCREMENT_INVOCATION_COUNT
-
Constructor Summary
ModifierConstructorDescriptionprotected
For OLCUT.LibSVMClassificationTrainer
(SVMParameters<Label> parameters) Constructs a classification LibSVM trainer using the specified parameters andTrainer.DEFAULT_SEED
.LibSVMClassificationTrainer
(SVMParameters<Label> parameters, long seed) Constructs a classification LibSVM trainer using the specified parameters and seed. -
Method Summary
Modifier and TypeMethodDescriptionprotected LibSVMModel<Label>
createModel
(ModelProvenance provenance, ImmutableFeatureMap featureIDMap, ImmutableOutputInfo<Label> outputIDInfo, List<libsvm.svm_model> models) Construct the appropriate subtype of LibSVMModel for the prediction task.protected com.oracle.labs.mlrg.olcut.util.Pair<libsvm.svm_node[][],
double[][]> extractData
(Dataset<Label> data, ImmutableOutputInfo<Label> outputInfo, ImmutableFeatureMap featureMap) Extracts the features andOutput
s in LibSVM's format.void
Used by the OLCUT configuration system, and should not be called by external code.void
setLabelWeights
(Map<Label, Float> weights) Sets the label weights used by this trainer.protected libsvm.svm_parameter
setupParameters
(ImmutableOutputInfo<Label> outputIDInfo) Constructs the svm_parameter.protected List<libsvm.svm_model>
trainModels
(libsvm.svm_parameter curParams, int numFeatures, libsvm.svm_node[][] features, double[][] outputs, SplittableRandom localRNG) Train all the LibSVM instances necessary for this dataset.Methods inherited from class org.tribuo.common.libsvm.LibSVMTrainer
exampleToNodes, getInvocationCount, getProvenance, setInvocationCount, toString, train, train, train
-
Constructor Details
-
LibSVMClassificationTrainer
protected LibSVMClassificationTrainer()For OLCUT. -
LibSVMClassificationTrainer
Constructs a classification LibSVM trainer using the specified parameters andTrainer.DEFAULT_SEED
.- Parameters:
parameters
- The SVM parameters.
-
LibSVMClassificationTrainer
Constructs a classification LibSVM trainer using the specified parameters and seed.- Parameters:
parameters
- The SVM parameters.seed
- The RNG seed for LibSVM's internal RNG.
-
-
Method Details
-
postConfig
public void postConfig()Used by the OLCUT configuration system, and should not be called by external code.- Specified by:
postConfig
in interfacecom.oracle.labs.mlrg.olcut.config.Configurable
- Overrides:
postConfig
in classLibSVMTrainer<Label>
-
createModel
protected LibSVMModel<Label> createModel(ModelProvenance provenance, ImmutableFeatureMap featureIDMap, ImmutableOutputInfo<Label> outputIDInfo, List<libsvm.svm_model> models) Description copied from class:LibSVMTrainer
Construct the appropriate subtype of LibSVMModel for the prediction task.- Specified by:
createModel
in classLibSVMTrainer<Label>
- Parameters:
provenance
- The model provenance.featureIDMap
- The feature id map.outputIDInfo
- The output id info.models
- The svm models.- Returns:
- An implementation of LibSVMModel.
-
trainModels
protected List<libsvm.svm_model> trainModels(libsvm.svm_parameter curParams, int numFeatures, libsvm.svm_node[][] features, double[][] outputs, SplittableRandom localRNG) Description copied from class:LibSVMTrainer
Train all the LibSVM instances necessary for this dataset.- Specified by:
trainModels
in classLibSVMTrainer<Label>
- Parameters:
curParams
- The LibSVM parameters.numFeatures
- The number of features in this dataset.features
- The features themselves.outputs
- The outputs.localRNG
- The RNG to use for seeding LibSVM's RNG.- Returns:
- A list of LibSVM models.
-
extractData
protected com.oracle.labs.mlrg.olcut.util.Pair<libsvm.svm_node[][],double[][]> extractData(Dataset<Label> data, ImmutableOutputInfo<Label> outputInfo, ImmutableFeatureMap featureMap) Description copied from class:LibSVMTrainer
Extracts the features andOutput
s in LibSVM's format.- Specified by:
extractData
in classLibSVMTrainer<Label>
- Parameters:
data
- The input data.outputInfo
- The output info.featureMap
- The feature info.- Returns:
- The features and outputs.
-
setupParameters
Description copied from class:LibSVMTrainer
Constructs the svm_parameter. Most of the time this is a no-op, but classification overrides it to incorporate label weights if they exist.- Overrides:
setupParameters
in classLibSVMTrainer<Label>
- Parameters:
outputIDInfo
- The output info.- Returns:
- The svm_parameters to use for training.
-
setLabelWeights
Description copied from interface:WeightedLabels
Sets the label weights used by this trainer.Supply
Collections.emptyMap()
to turn off label weights.- Specified by:
setLabelWeights
in interfaceWeightedLabels
- Parameters:
weights
- A map from Label instances to weight values.
-