001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016package org.tribuo.clustering.kmeans; 017 018import com.oracle.labs.mlrg.olcut.config.Config; 019import com.oracle.labs.mlrg.olcut.provenance.Provenance; 020import com.oracle.labs.mlrg.olcut.util.MutableLong; 021import com.oracle.labs.mlrg.olcut.util.StreamUtil; 022import org.tribuo.Dataset; 023import org.tribuo.Example; 024import org.tribuo.ImmutableFeatureMap; 025import org.tribuo.ImmutableOutputInfo; 026import org.tribuo.Trainer; 027import org.tribuo.clustering.ClusterID; 028import org.tribuo.clustering.ImmutableClusteringInfo; 029import org.tribuo.math.la.DenseVector; 030import org.tribuo.math.la.SparseVector; 031import org.tribuo.provenance.ModelProvenance; 032import org.tribuo.provenance.TrainerProvenance; 033import org.tribuo.provenance.impl.TrainerProvenanceImpl; 034 035import java.time.OffsetDateTime; 036import java.util.ArrayList; 037import java.util.Arrays; 038import java.util.Collections; 039import java.util.HashMap; 040import java.util.List; 041import java.util.Map; 042import java.util.Map.Entry; 043import java.util.SplittableRandom; 044import java.util.concurrent.ExecutionException; 045import java.util.concurrent.ForkJoinPool; 046import java.util.concurrent.atomic.AtomicInteger; 047import java.util.logging.Level; 048import java.util.logging.Logger; 049import java.util.stream.IntStream; 050import java.util.stream.Stream; 051 052/** 053 * A K-Means trainer, which generates a K-means clustering of the supplied 054 * data. The model finds the centres, and then predict needs to be 055 * called to infer the centre assignments for the input data. 056 * <p> 057 * It's slightly contorted to fit the Tribuo Trainer and Model API, as the cluster assignments 058 * can only be retrieved from the model after training, and require re-evaluating each example. 059 * <p> 060 * The Trainer has a parameterised distance function, and a selectable number 061 * of threads used in the training step. The thread pool is local to an invocation of train, 062 * so there can be multiple concurrent trainings. 063 * <p> 064 * See: 065 * <pre> 066 * J. Friedman, T. Hastie, & R. Tibshirani. 067 * "The Elements of Statistical Learning" 068 * Springer 2001. <a href="http://web.stanford.edu/~hastie/ElemStatLearn/">PDF</a> 069 * </pre> 070 */ 071public class KMeansTrainer implements Trainer<ClusterID> { 072 private static final Logger logger = Logger.getLogger(KMeansTrainer.class.getName()); 073 074 /** 075 * Possible distance functions. 076 */ 077 public enum Distance { 078 /** 079 * Euclidean (or l2) distance. 080 */ 081 EUCLIDEAN, 082 /** 083 * Cosine similarity as a distance measure. 084 */ 085 COSINE, 086 /** 087 * L1 (or Manhattan) distance. 088 */ 089 L1 090 } 091 092 @Config(mandatory = true, description = "Number of centroids (i.e., the \"k\" in k-means).") 093 private int centroids; 094 095 @Config(mandatory = true,description="The number of iterations to run.") 096 private int iterations; 097 098 @Config(mandatory = true,description="The distance function to use.") 099 private Distance distanceType; 100 101 @Config(description="The number of threads to use for training.") 102 private int numThreads = 1; 103 104 @Config(mandatory = true,description="The seed to use for the RNG.") 105 private long seed; 106 107 private SplittableRandom rng; 108 109 private int trainInvocationCounter; 110 111 /** 112 * for olcut. 113 */ 114 private KMeansTrainer() {} 115 116 /** 117 * Constructs a K-Means trainer using the supplied parameters. 118 * @param centroids The number of centroids to use. 119 * @param iterations The maximum number of iterations. 120 * @param distanceType The distance function. 121 * @param numThreads The number of threads. 122 * @param seed The random seed. 123 */ 124 public KMeansTrainer(int centroids, int iterations, Distance distanceType, int numThreads, long seed) { 125 this.centroids = centroids; 126 this.iterations = iterations; 127 this.distanceType = distanceType; 128 this.numThreads = numThreads; 129 this.seed = seed; 130 postConfig(); 131 } 132 133 @Override 134 public synchronized void postConfig() { 135 this.rng = new SplittableRandom(seed); 136 } 137 138 @Override 139 public KMeansModel train(Dataset<ClusterID> examples, Map<String, Provenance> runProvenance) { 140 // Creates a new local RNG and adds one to the invocation count. 141 TrainerProvenance trainerProvenance; 142 SplittableRandom localRNG; 143 synchronized(this) { 144 localRNG = rng.split(); 145 trainerProvenance = getProvenance(); 146 trainInvocationCounter++; 147 } 148 ImmutableFeatureMap featureMap = examples.getFeatureIDMap(); 149 DenseVector[] centroidVectors = initialiseCentroids(centroids,examples,featureMap,localRNG); 150 151 ForkJoinPool fjp = new ForkJoinPool(numThreads); 152 153 int[] oldCentre = new int[examples.size()]; 154 SparseVector[] data = new SparseVector[examples.size()]; 155 double[] weights = new double[examples.size()]; 156 int n = 0; 157 for (Example<ClusterID> example : examples) { 158 weights[n] = example.getWeight(); 159 data[n] = SparseVector.createSparseVector(example,featureMap,false); 160 oldCentre[n] = -1; 161 n++; 162 } 163 164 Map<Integer,List<Integer>> clusterAssignments = new HashMap<>(); 165 for (int i = 0; i < centroids; i++) { 166 clusterAssignments.put(i,Collections.synchronizedList(new ArrayList<>())); 167 } 168 169 boolean converged = false; 170 171 for (int i = 0; (i < iterations) && !converged; i++) { 172 //logger.log(Level.INFO,"Beginning iteration " + i); 173 AtomicInteger changeCounter = new AtomicInteger(0); 174 175 for (Entry<Integer,List<Integer>> e : clusterAssignments.entrySet()) { 176 e.getValue().clear(); 177 } 178 179 // E step 180 Stream<SparseVector> vecStream = Arrays.stream(data); 181 Stream<Integer> intStream = IntStream.range(0,data.length).boxed(); 182 Stream<IntAndVector> eStream; 183 if (numThreads > 1) { 184 eStream = StreamUtil.boundParallelism(StreamUtil.zip(intStream,vecStream,IntAndVector::new).parallel()); 185 } else { 186 eStream = StreamUtil.zip(intStream,vecStream,IntAndVector::new); 187 } 188 try { 189 fjp.submit(() -> eStream.forEach((IntAndVector e) -> { 190 double minDist = Double.POSITIVE_INFINITY; 191 int clusterID = -1; 192 int id = e.idx; 193 SparseVector vector = e.vector; 194 for (int j = 0; j < centroids; j++) { 195 DenseVector cluster = centroidVectors[j]; 196 double distance; 197 switch (distanceType) { 198 case EUCLIDEAN: 199 distance = cluster.euclideanDistance(vector); 200 break; 201 case COSINE: 202 distance = cluster.cosineDistance(vector); 203 break; 204 case L1: 205 distance = cluster.l1Distance(vector); 206 break; 207 default: 208 throw new IllegalStateException("Unknown distance " + distanceType); 209 } 210 if (distance < minDist) { 211 minDist = distance; 212 clusterID = j; 213 } 214 } 215 216 clusterAssignments.get(clusterID).add(id); 217 if (oldCentre[id] != clusterID) { 218 // Changed the centroid of this vector. 219 oldCentre[id] = clusterID; 220 changeCounter.incrementAndGet(); 221 } 222 })).get(); 223 } catch (InterruptedException | ExecutionException e) { 224 throw new RuntimeException("Parallel execution failed",e); 225 } 226 //logger.log(Level.INFO, "E step completed. " + changeCounter.get() + " words updated."); 227 228 mStep(fjp,centroidVectors,clusterAssignments,data,weights); 229 230 logger.log(Level.INFO, "Iteration " + i + " completed. " + changeCounter.get() + " examples updated."); 231 232 if (changeCounter.get() == 0) { 233 converged = true; 234 logger.log(Level.INFO, "K-Means converged at iteration " + i); 235 } 236 } 237 238 239 Map<Integer,MutableLong> counts = new HashMap<>(); 240 for (Entry<Integer,List<Integer>> e : clusterAssignments.entrySet()) { 241 counts.put(e.getKey(),new MutableLong(e.getValue().size())); 242 } 243 244 ImmutableOutputInfo<ClusterID> outputMap = new ImmutableClusteringInfo(counts); 245 246 ModelProvenance provenance = new ModelProvenance(KMeansModel.class.getName(), OffsetDateTime.now(), examples.getProvenance(), trainerProvenance, runProvenance); 247 248 return new KMeansModel("",provenance,featureMap,outputMap,centroidVectors, distanceType); 249 } 250 251 @Override 252 public KMeansModel train(Dataset<ClusterID> dataset) { 253 return train(dataset,Collections.emptyMap()); 254 } 255 256 @Override 257 public int getInvocationCount() { 258 return trainInvocationCounter; 259 } 260 261 /** 262 * Initialisation method called at the start of each train call. 263 * 264 * Used to allow overriding for kmeans++, kmedoids etc. 265 * 266 * @param centroids The number of centroids to create. 267 * @param examples The dataset to use. 268 * @param featureMap The feature map to use for centroid sampling. 269 * @param rng The RNG to use. 270 * @return A {@link DenseVector} array of centroids. 271 */ 272 protected static DenseVector[] initialiseCentroids(int centroids, Dataset<ClusterID> examples, ImmutableFeatureMap featureMap, SplittableRandom rng) { 273 DenseVector[] centroidVectors = new DenseVector[centroids]; 274 int numFeatures = featureMap.size(); 275 for (int i = 0; i < centroids; i++) { 276 double[] newCentroid = new double[numFeatures]; 277 278 for (int j = 0; j < numFeatures; j++) { 279 newCentroid[j] = featureMap.get(j).uniformSample(rng); 280 } 281 282 centroidVectors[i] = DenseVector.createDenseVector(newCentroid); 283 } 284 return centroidVectors; 285 } 286 287 protected void mStep(ForkJoinPool fjp, DenseVector[] centroidVectors, Map<Integer,List<Integer>> clusterAssignments, SparseVector[] data, double[] weights) { 288 // M step 289 Stream<Entry<Integer,List<Integer>>> mStream; 290 if (numThreads > 1) { 291 mStream = StreamUtil.boundParallelism(clusterAssignments.entrySet().stream().parallel()); 292 } else { 293 mStream = clusterAssignments.entrySet().stream(); 294 } 295 try { 296 fjp.submit(() -> mStream.forEach((e) -> { 297 DenseVector newCentroid = centroidVectors[e.getKey()]; 298 newCentroid.fill(0.0); 299 300 int counter = 0; 301 for (Integer idx : e.getValue()) { 302 newCentroid.intersectAndAddInPlace(data[idx],(double f) -> f * weights[idx]); 303 counter++; 304 } 305 if (counter > 0) { 306 newCentroid.scaleInPlace(1.0/counter); 307 } 308 })).get(); 309 } catch (InterruptedException | ExecutionException e) { 310 throw new RuntimeException("Parallel execution failed",e); 311 } 312 } 313 314 @Override 315 public String toString() { 316 return "KMeansTrainer(centroids="+centroids+",distanceType="+ distanceType +",seed="+seed+",numThreads="+numThreads+")"; 317 } 318 319 @Override 320 public TrainerProvenance getProvenance() { 321 return new TrainerProvenanceImpl(this); 322 } 323 324 /** 325 * Tuple of index and position. One day it'll be a record, but not today. 326 */ 327 static class IntAndVector { 328 final int idx; 329 final SparseVector vector; 330 331 public IntAndVector(int idx, SparseVector vector) { 332 this.idx = idx; 333 this.vector = vector; 334 } 335 } 336}