Class KNNTrainer<T extends Output<T>>
java.lang.Object
org.tribuo.common.nearest.KNNTrainer<T>
- All Implemented Interfaces:
com.oracle.labs.mlrg.olcut.config.Configurable,com.oracle.labs.mlrg.olcut.provenance.Provenancable<TrainerProvenance>,Trainer<T>
A
Trainer for k-nearest neighbour models.-
Nested Class Summary
Nested ClassesModifier and TypeClassDescriptionstatic enumThe available distance functions. -
Field Summary
Fields inherited from interface org.tribuo.Trainer
DEFAULT_SEED -
Constructor Summary
ConstructorsConstructorDescriptionKNNTrainer(int k, KNNTrainer.Distance distance, int numThreads, EnsembleCombiner<T> combiner, KNNModel.Backend backend) Creates a K-NN trainer using the supplied parameters. -
Method Summary
Modifier and TypeMethodDescriptionintThe number of times this trainer instance has had it's train method invoked.voidUsed by the OLCUT configuration system, and should not be called by external code.toString()train(Dataset<T> examples, Map<String, com.oracle.labs.mlrg.olcut.provenance.Provenance> runProvenance) Trains a predictive model using the examples in the given data set.
-
Constructor Details
-
KNNTrainer
public KNNTrainer(int k, KNNTrainer.Distance distance, int numThreads, EnsembleCombiner<T> combiner, KNNModel.Backend backend) Creates a K-NN trainer using the supplied parameters.- Parameters:
k- The number of nearest neighbours to consider.distance- The distance function.numThreads- The number of threads to use.combiner- The combination function to aggregate the k predictions.backend- The computational backend.
-
-
Method Details
-
postConfig
Used by the OLCUT configuration system, and should not be called by external code.- Specified by:
postConfigin interfacecom.oracle.labs.mlrg.olcut.config.Configurable
-
train
-
toString
-
getInvocationCount
Description copied from interface:TrainerThe number of times this trainer instance has had it's train method invoked.This is used to determine how many times the trainer's RNG has been accessed to ensure replicability in the random number stream.
- Specified by:
getInvocationCountin interfaceTrainer<T extends Output<T>>- Returns:
- The number of train invocations.
-
getProvenance
-