Class KNNTrainer<T extends Output<T>>
java.lang.Object
org.tribuo.common.nearest.KNNTrainer<T>
- All Implemented Interfaces:
com.oracle.labs.mlrg.olcut.config.Configurable
,com.oracle.labs.mlrg.olcut.provenance.Provenancable<TrainerProvenance>
,Trainer<T>
A
Trainer
for k-nearest neighbour models.-
Nested Class Summary
Nested ClassesModifier and TypeClassDescriptionstatic enum
The available distance functions. -
Field Summary
Fields inherited from interface org.tribuo.Trainer
DEFAULT_SEED
-
Constructor Summary
ConstructorsConstructorDescriptionKNNTrainer
(int k, KNNTrainer.Distance distance, int numThreads, EnsembleCombiner<T> combiner, KNNModel.Backend backend) Creates a K-NN trainer using the supplied parameters. -
Method Summary
Modifier and TypeMethodDescriptionint
The number of times this trainer instance has had it's train method invoked.void
Used by the OLCUT configuration system, and should not be called by external code.toString()
train
(Dataset<T> examples, Map<String, com.oracle.labs.mlrg.olcut.provenance.Provenance> runProvenance) Trains a predictive model using the examples in the given data set.
-
Constructor Details
-
KNNTrainer
public KNNTrainer(int k, KNNTrainer.Distance distance, int numThreads, EnsembleCombiner<T> combiner, KNNModel.Backend backend) Creates a K-NN trainer using the supplied parameters.- Parameters:
k
- The number of nearest neighbours to consider.distance
- The distance function.numThreads
- The number of threads to use.combiner
- The combination function to aggregate the k predictions.backend
- The computational backend.
-
-
Method Details
-
postConfig
Used by the OLCUT configuration system, and should not be called by external code.- Specified by:
postConfig
in interfacecom.oracle.labs.mlrg.olcut.config.Configurable
-
train
-
toString
-
getInvocationCount
Description copied from interface:Trainer
The number of times this trainer instance has had it's train method invoked.This is used to determine how many times the trainer's RNG has been accessed to ensure replicability in the random number stream.
- Specified by:
getInvocationCount
in interfaceTrainer<T extends Output<T>>
- Returns:
- The number of train invocations.
-
getProvenance
-