Modifier and Type | Class and Description |
---|---|
static class |
KNNTrainer.Distance
The available distance functions.
|
DEFAULT_SEED
Constructor and Description |
---|
KNNTrainer(int k,
KNNTrainer.Distance distance,
int numThreads,
EnsembleCombiner<T> combiner,
KNNModel.Backend backend)
Creates a K-NN trainer using the supplied parameters.
|
Modifier and Type | Method and Description |
---|---|
int |
getInvocationCount()
The number of times this trainer instance has had it's train method invoked.
|
TrainerProvenance |
getProvenance() |
void |
postConfig()
Used by the OLCUT configuration system, and should not be called by external code.
|
String |
toString() |
Model<T> |
train(Dataset<T> examples,
Map<String,com.oracle.labs.mlrg.olcut.provenance.Provenance> runProvenance)
Trains a predictive model using the examples in the given data set.
|
public KNNTrainer(int k, KNNTrainer.Distance distance, int numThreads, EnsembleCombiner<T> combiner, KNNModel.Backend backend)
k
- The number of nearest neighbours to consider.distance
- The distance function.numThreads
- The number of threads to use.combiner
- The combination function to aggregate the k predictions.backend
- The computational backend.public void postConfig()
postConfig
in interface com.oracle.labs.mlrg.olcut.config.Configurable
public Model<T> train(Dataset<T> examples, Map<String,com.oracle.labs.mlrg.olcut.provenance.Provenance> runProvenance)
Trainer
public int getInvocationCount()
Trainer
This is used to determine how many times the trainer's RNG has been accessed to ensure replicability in the random number stream.
getInvocationCount
in interface Trainer<T extends Output<T>>
public TrainerProvenance getProvenance()
getProvenance
in interface com.oracle.labs.mlrg.olcut.provenance.Provenancable<TrainerProvenance>
Copyright © 2015–2021 Oracle and/or its affiliates. All rights reserved.