001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016package org.tribuo.clustering.kmeans;
017
018import com.oracle.labs.mlrg.olcut.config.Config;
019import com.oracle.labs.mlrg.olcut.provenance.Provenance;
020import com.oracle.labs.mlrg.olcut.util.MutableLong;
021import com.oracle.labs.mlrg.olcut.util.StreamUtil;
022import org.tribuo.Dataset;
023import org.tribuo.Example;
024import org.tribuo.ImmutableFeatureMap;
025import org.tribuo.ImmutableOutputInfo;
026import org.tribuo.Trainer;
027import org.tribuo.clustering.ClusterID;
028import org.tribuo.clustering.ImmutableClusteringInfo;
029import org.tribuo.math.la.DenseVector;
030import org.tribuo.math.la.SparseVector;
031import org.tribuo.provenance.ModelProvenance;
032import org.tribuo.provenance.TrainerProvenance;
033import org.tribuo.provenance.impl.TrainerProvenanceImpl;
034
035import java.time.OffsetDateTime;
036import java.util.ArrayList;
037import java.util.Arrays;
038import java.util.Collections;
039import java.util.HashMap;
040import java.util.List;
041import java.util.Map;
042import java.util.Map.Entry;
043import java.util.SplittableRandom;
044import java.util.concurrent.ExecutionException;
045import java.util.concurrent.ForkJoinPool;
046import java.util.concurrent.atomic.AtomicInteger;
047import java.util.logging.Level;
048import java.util.logging.Logger;
049import java.util.stream.IntStream;
050import java.util.stream.Stream;
051
052/**
053 * A K-Means trainer, which generates a K-means clustering of the supplied
054 * data. The model finds the centres, and then predict needs to be
055 * called to infer the centre assignments for the input data.
056 * <p>
057 * It's slightly contorted to fit the Tribuo Trainer and Model API, as the cluster assignments
058 * can only be retrieved from the model after training, and require re-evaluating each example.
059 * <p>
060 * The Trainer has a parameterised distance function, and a selectable number
061 * of threads used in the training step. The thread pool is local to an invocation of train,
062 * so there can be multiple concurrent trainings.
063 * <p>
064 * See:
065 * <pre>
066 * J. Friedman, T. Hastie, &amp; R. Tibshirani.
067 * "The Elements of Statistical Learning"
068 * Springer 2001. <a href="http://web.stanford.edu/~hastie/ElemStatLearn/">PDF</a>
069 * </pre>
070 */
071public class KMeansTrainer implements Trainer<ClusterID> {
072    private static final Logger logger = Logger.getLogger(KMeansTrainer.class.getName());
073
074    /**
075     * Possible distance functions.
076     */
077    public enum Distance {
078        /**
079         * Euclidean (or l2) distance.
080         */
081        EUCLIDEAN,
082        /**
083         * Cosine similarity as a distance measure.
084         */
085        COSINE,
086        /**
087         * L1 (or Manhattan) distance.
088         */
089        L1
090    }
091
092    @Config(mandatory = true, description = "Number of centroids (i.e., the \"k\" in k-means).")
093    private int centroids;
094
095    @Config(mandatory = true,description="The number of iterations to run.")
096    private int iterations;
097
098    @Config(mandatory = true,description="The distance function to use.")
099    private Distance distanceType;
100
101    @Config(description="The number of threads to use for training.")
102    private int numThreads = 1;
103
104    @Config(mandatory = true,description="The seed to use for the RNG.")
105    private long seed;
106
107    private SplittableRandom rng;
108
109    private int trainInvocationCounter;
110
111    /**
112     * for olcut.
113     */
114    private KMeansTrainer() {}
115
116    /**
117     * Constructs a K-Means trainer using the supplied parameters.
118     * @param centroids The number of centroids to use.
119     * @param iterations The maximum number of iterations.
120     * @param distanceType The distance function.
121     * @param numThreads The number of threads.
122     * @param seed The random seed.
123     */
124    public KMeansTrainer(int centroids, int iterations, Distance distanceType, int numThreads, long seed) {
125        this.centroids = centroids;
126        this.iterations = iterations;
127        this.distanceType = distanceType;
128        this.numThreads = numThreads;
129        this.seed = seed;
130        postConfig();
131    }
132
133    @Override
134    public synchronized void postConfig() {
135        this.rng = new SplittableRandom(seed);
136    }
137
138    @Override
139    public KMeansModel train(Dataset<ClusterID> examples, Map<String, Provenance> runProvenance) {
140        // Creates a new local RNG and adds one to the invocation count.
141        TrainerProvenance trainerProvenance;
142        SplittableRandom localRNG;
143        synchronized(this) {
144            localRNG = rng.split();
145            trainerProvenance = getProvenance();
146            trainInvocationCounter++;
147        }
148        ImmutableFeatureMap featureMap = examples.getFeatureIDMap();
149        DenseVector[] centroidVectors = initialiseCentroids(centroids,examples,featureMap,localRNG);
150
151        ForkJoinPool fjp = new ForkJoinPool(numThreads);
152
153        int[] oldCentre = new int[examples.size()];
154        SparseVector[] data = new SparseVector[examples.size()];
155        double[] weights = new double[examples.size()];
156        int n = 0;
157        for (Example<ClusterID> example : examples) {
158            weights[n] = example.getWeight();
159            data[n] = SparseVector.createSparseVector(example,featureMap,false);
160            oldCentre[n] = -1;
161            n++;
162        }
163
164        Map<Integer,List<Integer>> clusterAssignments = new HashMap<>();
165        for (int i = 0; i < centroids; i++) {
166            clusterAssignments.put(i,Collections.synchronizedList(new ArrayList<>()));
167        }
168
169        boolean converged = false;
170
171        for (int i = 0; (i < iterations) && !converged; i++) {
172            //logger.log(Level.INFO,"Beginning iteration " + i);
173            AtomicInteger changeCounter = new AtomicInteger(0);
174
175            for (Entry<Integer,List<Integer>> e : clusterAssignments.entrySet()) {
176                e.getValue().clear();
177            }
178
179            // E step
180            Stream<SparseVector> vecStream = Arrays.stream(data);
181            Stream<Integer> intStream = IntStream.range(0,data.length).boxed();
182            Stream<IntAndVector> eStream;
183            if (numThreads > 1) {
184                eStream = StreamUtil.boundParallelism(StreamUtil.zip(intStream,vecStream,IntAndVector::new).parallel());
185            } else {
186                eStream = StreamUtil.zip(intStream,vecStream,IntAndVector::new);
187            }
188            try {
189                fjp.submit(() -> eStream.forEach((IntAndVector e) -> {
190                    double minDist = Double.POSITIVE_INFINITY;
191                    int clusterID = -1;
192                    int id = e.idx;
193                    SparseVector vector = e.vector;
194                    for (int j = 0; j < centroids; j++) {
195                        DenseVector cluster = centroidVectors[j];
196                        double distance;
197                        switch (distanceType) {
198                            case EUCLIDEAN:
199                                distance = cluster.euclideanDistance(vector);
200                                break;
201                            case COSINE:
202                                distance = cluster.cosineDistance(vector);
203                                break;
204                            case L1:
205                                distance = cluster.l1Distance(vector);
206                                break;
207                            default:
208                                throw new IllegalStateException("Unknown distance " + distanceType);
209                        }
210                        if (distance < minDist) {
211                            minDist = distance;
212                            clusterID = j;
213                        }
214                    }
215
216                    clusterAssignments.get(clusterID).add(id);
217                    if (oldCentre[id] != clusterID) {
218                        // Changed the centroid of this vector.
219                        oldCentre[id] = clusterID;
220                        changeCounter.incrementAndGet();
221                    }
222                })).get();
223            } catch (InterruptedException | ExecutionException e) {
224                throw new RuntimeException("Parallel execution failed",e);
225            }
226            //logger.log(Level.INFO, "E step completed. " + changeCounter.get() + " words updated.");
227
228            mStep(fjp,centroidVectors,clusterAssignments,data,weights);
229
230            logger.log(Level.INFO, "Iteration " + i + " completed. " + changeCounter.get() + " examples updated.");
231
232            if (changeCounter.get() == 0) {
233                converged = true;
234                logger.log(Level.INFO, "K-Means converged at iteration " + i);
235            }
236        }
237
238
239        Map<Integer,MutableLong> counts = new HashMap<>();
240        for (Entry<Integer,List<Integer>> e : clusterAssignments.entrySet()) {
241            counts.put(e.getKey(),new MutableLong(e.getValue().size()));
242        }
243
244        ImmutableOutputInfo<ClusterID> outputMap = new ImmutableClusteringInfo(counts);
245
246        ModelProvenance provenance = new ModelProvenance(KMeansModel.class.getName(), OffsetDateTime.now(), examples.getProvenance(), trainerProvenance, runProvenance);
247
248        return new KMeansModel("",provenance,featureMap,outputMap,centroidVectors, distanceType);
249    }
250
251    @Override
252    public KMeansModel train(Dataset<ClusterID> dataset) {
253        return train(dataset,Collections.emptyMap());
254    }
255
256    @Override
257    public int getInvocationCount() {
258        return trainInvocationCounter;
259    }
260
261    /**
262     * Initialisation method called at the start of each train call.
263     *
264     * Used to allow overriding for kmeans++, kmedoids etc.
265     *
266     * @param centroids The number of centroids to create.
267     * @param examples The dataset to use.
268     * @param featureMap The feature map to use for centroid sampling.
269     * @param rng The RNG to use.
270     * @return A {@link DenseVector} array of centroids.
271     */
272    protected static DenseVector[] initialiseCentroids(int centroids, Dataset<ClusterID> examples, ImmutableFeatureMap featureMap, SplittableRandom rng) {
273        DenseVector[] centroidVectors = new DenseVector[centroids];
274        int numFeatures = featureMap.size();
275        for (int i = 0; i < centroids; i++) {
276            double[] newCentroid = new double[numFeatures];
277
278            for (int j = 0; j < numFeatures; j++) {
279                newCentroid[j] = featureMap.get(j).uniformSample(rng);
280            }
281
282            centroidVectors[i] = DenseVector.createDenseVector(newCentroid);
283        }
284        return centroidVectors;
285    }
286
287    protected void mStep(ForkJoinPool fjp, DenseVector[] centroidVectors, Map<Integer,List<Integer>> clusterAssignments, SparseVector[] data, double[] weights) {
288        // M step
289        Stream<Entry<Integer,List<Integer>>> mStream;
290        if (numThreads > 1) {
291            mStream = StreamUtil.boundParallelism(clusterAssignments.entrySet().stream().parallel());
292        } else {
293            mStream = clusterAssignments.entrySet().stream();
294        }
295        try {
296            fjp.submit(() -> mStream.forEach((e) -> {
297                DenseVector newCentroid = centroidVectors[e.getKey()];
298                newCentroid.fill(0.0);
299
300                int counter = 0;
301                for (Integer idx : e.getValue()) {
302                    newCentroid.intersectAndAddInPlace(data[idx],(double f) -> f * weights[idx]);
303                    counter++;
304                }
305                if (counter > 0) {
306                    newCentroid.scaleInPlace(1.0/counter);
307                }
308            })).get();
309        } catch (InterruptedException | ExecutionException e) {
310            throw new RuntimeException("Parallel execution failed",e);
311        }
312    }
313
314    @Override
315    public String toString() {
316        return "KMeansTrainer(centroids="+centroids+",distanceType="+ distanceType +",seed="+seed+",numThreads="+numThreads+")";
317    }
318
319    @Override
320    public TrainerProvenance getProvenance() {
321        return new TrainerProvenanceImpl(this);
322    }
323
324    /**
325     * Tuple of index and position. One day it'll be a record, but not today.
326     */
327    static class IntAndVector {
328        final int idx;
329        final SparseVector vector;
330
331        public IntAndVector(int idx, SparseVector vector) {
332            this.idx = idx;
333            this.vector = vector;
334        }
335    }
336}