public abstract class LibSVMTrainer<T extends Output<T>> extends Object implements Trainer<T>
Note the train method is synchronized on LibSVMTrainer.class
due to a global RNG in LibSVM.
This is insufficient to ensure reproducibility if LibSVM is used directly in the same JVM as Tribuo, but
avoids locking on classes Tribuo does not control.
See:
Chang CC, Lin CJ. "LIBSVM: a library for Support Vector Machines" ACM transactions on intelligent systems and technology (TIST), 2011.for the nu-svm algorithm:
Schölkopf B, Smola A, Williamson R, Bartlett P L. "New support vector algorithms" Neural Computation, 2000, 1207-1245.and for the original algorithm:
Cortes C, Vapnik V. "Support-Vector Networks" Machine Learning, 1995.
Modifier and Type | Field and Description |
---|---|
protected libsvm.svm_parameter |
parameters
The SVM parameters suitable for use by LibSVM.
|
protected SVMType<T> |
svmType
The type of SVM algorithm.
|
DEFAULT_SEED
Modifier | Constructor and Description |
---|---|
protected |
LibSVMTrainer()
For olcut.
|
protected |
LibSVMTrainer(SVMParameters<T> parameters,
long seed)
Constructs a LibSVMTrainer from the parameters.
|
Modifier and Type | Method and Description |
---|---|
protected abstract LibSVMModel<T> |
createModel(ModelProvenance provenance,
ImmutableFeatureMap featureIDMap,
ImmutableOutputInfo<T> outputIDInfo,
List<libsvm.svm_model> models)
Construct the appropriate subtype of LibSVMModel for the prediction task.
|
static <T extends Output<T>> |
exampleToNodes(Example<T> example,
ImmutableFeatureMap featureIDMap,
List<libsvm.svm_node> features)
Convert the example into an array of svm_node which represents a sparse feature vector.
|
protected abstract com.oracle.labs.mlrg.olcut.util.Pair<libsvm.svm_node[][],double[][]> |
extractData(Dataset<T> data,
ImmutableOutputInfo<T> outputInfo,
ImmutableFeatureMap featureMap)
Extracts the features and
Output s in LibSVM's format. |
int |
getInvocationCount()
The number of times this trainer instance has had it's train method invoked.
|
TrainerProvenance |
getProvenance() |
void |
postConfig()
Used by the OLCUT configuration system, and should not be called by external code.
|
protected libsvm.svm_parameter |
setupParameters(ImmutableOutputInfo<T> info)
Constructs the svm_parameter.
|
String |
toString() |
LibSVMModel<T> |
train(Dataset<T> examples)
Trains a predictive model using the examples in the given data set.
|
LibSVMModel<T> |
train(Dataset<T> examples,
Map<String,com.oracle.labs.mlrg.olcut.provenance.Provenance> runProvenance)
Trains a predictive model using the examples in the given data set.
|
protected abstract List<libsvm.svm_model> |
trainModels(libsvm.svm_parameter curParams,
int numFeatures,
libsvm.svm_node[][] features,
double[][] outputs,
SplittableRandom localRNG)
Train all the LibSVM instances necessary for this dataset.
|
protected libsvm.svm_parameter parameters
protected LibSVMTrainer()
protected LibSVMTrainer(SVMParameters<T> parameters, long seed)
parameters
- The SVM parameters.public void postConfig()
postConfig
in interface com.oracle.labs.mlrg.olcut.config.Configurable
public LibSVMModel<T> train(Dataset<T> examples)
Trainer
public LibSVMModel<T> train(Dataset<T> examples, Map<String,com.oracle.labs.mlrg.olcut.provenance.Provenance> runProvenance)
Trainer
protected abstract LibSVMModel<T> createModel(ModelProvenance provenance, ImmutableFeatureMap featureIDMap, ImmutableOutputInfo<T> outputIDInfo, List<libsvm.svm_model> models)
provenance
- The model provenance.featureIDMap
- The feature id map.outputIDInfo
- The output id info.models
- The svm models.protected abstract List<libsvm.svm_model> trainModels(libsvm.svm_parameter curParams, int numFeatures, libsvm.svm_node[][] features, double[][] outputs, SplittableRandom localRNG)
curParams
- The LibSVM parameters.numFeatures
- The number of features in this dataset.features
- The features themselves.outputs
- The outputs.localRNG
- The RNG to use for seeding LibSVM's RNG.protected abstract com.oracle.labs.mlrg.olcut.util.Pair<libsvm.svm_node[][],double[][]> extractData(Dataset<T> data, ImmutableOutputInfo<T> outputInfo, ImmutableFeatureMap featureMap)
Output
s in LibSVM's format.data
- The input data.outputInfo
- The output info.featureMap
- The feature info.protected libsvm.svm_parameter setupParameters(ImmutableOutputInfo<T> info)
info
- The output info.public int getInvocationCount()
Trainer
This is used to determine how many times the trainer's RNG has been accessed to ensure replicability in the random number stream.
getInvocationCount
in interface Trainer<T extends Output<T>>
public static <T extends Output<T>> libsvm.svm_node[] exampleToNodes(Example<T> example, ImmutableFeatureMap featureIDMap, List<libsvm.svm_node> features)
If there are collisions in the feature ids then the values are summed.
T
- The type of the ouput.example
- The example to convert.featureIDMap
- The feature id map which holds the indices.features
- A buffer to use.public TrainerProvenance getProvenance()
getProvenance
in interface com.oracle.labs.mlrg.olcut.provenance.Provenancable<TrainerProvenance>
Copyright © 2015–2021 Oracle and/or its affiliates. All rights reserved.