public class KMeansTrainer extends Object implements Trainer<ClusterID>
It's slightly contorted to fit the Tribuo Trainer and Model API, as the cluster assignments can only be retrieved from the model after training, and require re-evaluating each example.
The Trainer has a parameterised distance function, and a selectable number of threads used in the training step. The thread pool is local to an invocation of train, so there can be multiple concurrent trainings.
See:
J. Friedman, T. Hastie, & R. Tibshirani. "The Elements of Statistical Learning" Springer 2001. PDF
Modifier and Type | Class and Description |
---|---|
static class |
KMeansTrainer.Distance
Possible distance functions.
|
DEFAULT_SEED
Constructor and Description |
---|
KMeansTrainer(int centroids,
int iterations,
KMeansTrainer.Distance distanceType,
int numThreads,
long seed)
Constructs a K-Means trainer using the supplied parameters.
|
Modifier and Type | Method and Description |
---|---|
int |
getInvocationCount()
The number of times this trainer instance has had it's train method invoked.
|
TrainerProvenance |
getProvenance() |
protected static DenseVector[] |
initialiseCentroids(int centroids,
Dataset<ClusterID> examples,
ImmutableFeatureMap featureMap,
SplittableRandom rng)
Initialisation method called at the start of each train call.
|
protected void |
mStep(ForkJoinPool fjp,
DenseVector[] centroidVectors,
Map<Integer,List<Integer>> clusterAssignments,
SparseVector[] data,
double[] weights) |
void |
postConfig() |
String |
toString() |
KMeansModel |
train(Dataset<ClusterID> dataset)
Trains a predictive model using the examples in the given data set.
|
KMeansModel |
train(Dataset<ClusterID> examples,
Map<String,com.oracle.labs.mlrg.olcut.provenance.Provenance> runProvenance)
Trains a predictive model using the examples in the given data set.
|
public KMeansTrainer(int centroids, int iterations, KMeansTrainer.Distance distanceType, int numThreads, long seed)
centroids
- The number of centroids to use.iterations
- The maximum number of iterations.distanceType
- The distance function.numThreads
- The number of threads.seed
- The random seed.public void postConfig()
postConfig
in interface com.oracle.labs.mlrg.olcut.config.Configurable
public KMeansModel train(Dataset<ClusterID> examples, Map<String,com.oracle.labs.mlrg.olcut.provenance.Provenance> runProvenance)
Trainer
public KMeansModel train(Dataset<ClusterID> dataset)
Trainer
public int getInvocationCount()
Trainer
This is used to determine how many times the trainer's RNG has been accessed to ensure replicability in the random number stream.
getInvocationCount
in interface Trainer<ClusterID>
protected static DenseVector[] initialiseCentroids(int centroids, Dataset<ClusterID> examples, ImmutableFeatureMap featureMap, SplittableRandom rng)
centroids
- The number of centroids to create.examples
- The dataset to use.featureMap
- The feature map to use for centroid sampling.rng
- The RNG to use.DenseVector
array of centroids.protected void mStep(ForkJoinPool fjp, DenseVector[] centroidVectors, Map<Integer,List<Integer>> clusterAssignments, SparseVector[] data, double[] weights)
public TrainerProvenance getProvenance()
getProvenance
in interface com.oracle.labs.mlrg.olcut.provenance.Provenancable<TrainerProvenance>
Copyright © 2015–2021 Oracle and/or its affiliates. All rights reserved.