public class KMeansTrainer extends Object implements Trainer<ClusterID>
It's slightly contorted to fit the Tribuo Trainer and Model API, as the cluster assignments can only be retrieved from the model after training, and require re-evaluating each example.
The Trainer has a parameterised distance function, and a selectable number of threads used in the training step. The thread pool is local to an invocation of train, so there can be multiple concurrent trainings.
Note parallel training uses a ForkJoinPool
which requires that the Tribuo codebase
is given the "modifyThread" and "modifyThreadGroup" privileges when running under a
SecurityManager
.
See:
J. Friedman, T. Hastie, & R. Tibshirani. "The Elements of Statistical Learning" Springer 2001. PDF
For more on optional kmeans++ initialisation, see:
D. Arthur, S. Vassilvitskii. "K-Means++: The Advantages of Careful Seeding" PDF
Modifier and Type | Class and Description |
---|---|
static class |
KMeansTrainer.Distance
Possible distance functions.
|
static class |
KMeansTrainer.Initialisation
Possible initialization functions.
|
DEFAULT_SEED
Constructor and Description |
---|
KMeansTrainer(int centroids,
int iterations,
KMeansTrainer.Distance distanceType,
int numThreads,
long seed)
Constructs a K-Means trainer using the supplied parameters and the default random initialisation.
|
KMeansTrainer(int centroids,
int iterations,
KMeansTrainer.Distance distanceType,
KMeansTrainer.Initialisation initialisationType,
int numThreads,
long seed)
Constructs a K-Means trainer using the supplied parameters.
|
Modifier and Type | Method and Description |
---|---|
int |
getInvocationCount()
The number of times this trainer instance has had it's train method invoked.
|
TrainerProvenance |
getProvenance() |
protected void |
mStep(ForkJoinPool fjp,
DenseVector[] centroidVectors,
Map<Integer,List<Integer>> clusterAssignments,
SparseVector[] data,
double[] weights)
Runs the mStep, writing to the
centroidVectors array. |
void |
postConfig()
Used by the OLCUT configuration system, and should not be called by external code.
|
String |
toString() |
KMeansModel |
train(Dataset<ClusterID> dataset)
Trains a predictive model using the examples in the given data set.
|
KMeansModel |
train(Dataset<ClusterID> examples,
Map<String,com.oracle.labs.mlrg.olcut.provenance.Provenance> runProvenance)
Trains a predictive model using the examples in the given data set.
|
public KMeansTrainer(int centroids, int iterations, KMeansTrainer.Distance distanceType, int numThreads, long seed)
centroids
- The number of centroids to use.iterations
- The maximum number of iterations.distanceType
- The distance function.numThreads
- The number of threads.seed
- The random seed.public KMeansTrainer(int centroids, int iterations, KMeansTrainer.Distance distanceType, KMeansTrainer.Initialisation initialisationType, int numThreads, long seed)
centroids
- The number of centroids to use.iterations
- The maximum number of iterations.distanceType
- The distance function.initialisationType
- The centroid initialization method.numThreads
- The number of threads.seed
- The random seed.public void postConfig()
postConfig
in interface com.oracle.labs.mlrg.olcut.config.Configurable
public KMeansModel train(Dataset<ClusterID> examples, Map<String,com.oracle.labs.mlrg.olcut.provenance.Provenance> runProvenance)
Trainer
public KMeansModel train(Dataset<ClusterID> dataset)
Trainer
public int getInvocationCount()
Trainer
This is used to determine how many times the trainer's RNG has been accessed to ensure replicability in the random number stream.
getInvocationCount
in interface Trainer<ClusterID>
protected void mStep(ForkJoinPool fjp, DenseVector[] centroidVectors, Map<Integer,List<Integer>> clusterAssignments, SparseVector[] data, double[] weights)
centroidVectors
array.fjp
- The ForkJoinPool to run the computation in if it should be executed in parallel.
If the fjp is null then the computation is executed sequentially on the main thread.centroidVectors
- The centroid vectors to write out.clusterAssignments
- The current cluster assignments.data
- The data points.weights
- The example weights.public TrainerProvenance getProvenance()
getProvenance
in interface com.oracle.labs.mlrg.olcut.provenance.Provenancable<TrainerProvenance>
Copyright © 2015–2021 Oracle and/or its affiliates. All rights reserved.