001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.common.nearest; 018 019import com.oracle.labs.mlrg.olcut.util.Pair; 020import com.oracle.labs.mlrg.olcut.util.StreamUtil; 021import org.tribuo.Example; 022import org.tribuo.Excuse; 023import org.tribuo.ImmutableFeatureMap; 024import org.tribuo.ImmutableOutputInfo; 025import org.tribuo.Model; 026import org.tribuo.Output; 027import org.tribuo.Prediction; 028import org.tribuo.common.nearest.KNNTrainer.Distance; 029import org.tribuo.ensemble.EnsembleCombiner; 030import org.tribuo.math.la.SparseVector; 031import org.tribuo.provenance.ModelProvenance; 032 033import java.io.IOException; 034import java.util.ArrayList; 035import java.util.Collections; 036import java.util.List; 037import java.util.Map; 038import java.util.Objects; 039import java.util.Optional; 040import java.util.PriorityQueue; 041import java.util.concurrent.ExecutionException; 042import java.util.concurrent.ExecutorService; 043import java.util.concurrent.Executors; 044import java.util.concurrent.ForkJoinPool; 045import java.util.concurrent.Future; 046import java.util.function.BiFunction; 047import java.util.function.Function; 048import java.util.logging.Level; 049import java.util.logging.Logger; 050import java.util.stream.Collectors; 051import java.util.stream.Stream; 052 053/** 054 * A k-nearest neighbours model. 055 */ 056public class KNNModel<T extends Output<T>> extends Model<T> { 057 058 private static final Logger logger = Logger.getLogger(KNNModel.class.getName()); 059 060 private static final long serialVersionUID = 1L; 061 062 /** 063 * The parallel backend for batch predictions. 064 */ 065 public enum Backend { 066 /** 067 * Uses the streams API for parallelism when scoring a batch of predictions. 068 */ 069 STREAMS, 070 /** 071 * Uses a thread pool at the outer level (i.e., one thread per prediction). 072 */ 073 THREADPOOL, 074 /** 075 * Uses a thread pool at the inner level (i.e., the whole thread pool works on each prediction). 076 */ 077 INNERTHREADPOOL 078 } 079 080 private final Pair<SparseVector,T>[] vectors; 081 082 private final int k; 083 private final Distance distance; 084 private final int numThreads; 085 086 private final Backend parallelBackend; 087 088 private final EnsembleCombiner<T> combiner; 089 090 KNNModel(String name, ModelProvenance provenance, ImmutableFeatureMap featureIDMap, ImmutableOutputInfo<T> outputIDInfo, 091 boolean generatesProbabilities, int k, Distance distance, int numThreads, EnsembleCombiner<T> combiner, 092 Pair<SparseVector,T>[] vectors, Backend backend) { 093 super(name,provenance,featureIDMap,outputIDInfo,generatesProbabilities); 094 this.k = k; 095 this.distance = distance; 096 this.numThreads = numThreads; 097 this.combiner = combiner; 098 this.parallelBackend = backend; 099 this.vectors = vectors; 100 } 101 102 @Override 103 public Prediction<T> predict(Example<T> example) { 104 SparseVector input = SparseVector.createSparseVector(example,featureIDMap,false); 105 106 Function<Pair<SparseVector,T>, OutputDoublePair<T>> distanceFunc; 107 switch (distance) { 108 case L1: 109 distanceFunc = (a) -> new OutputDoublePair<>(a.getB(),a.getA().l1Distance(input)); 110 break; 111 case L2: 112 distanceFunc = (a) -> new OutputDoublePair<>(a.getB(),a.getA().l2Distance(input)); 113 break; 114 case COSINE: 115 distanceFunc = (a) -> new OutputDoublePair<>(a.getB(),a.getA().cosineDistance(input)); 116 break; 117 default: 118 throw new IllegalStateException("Unknown distance function " + distance); 119 } 120 121 List<Prediction<T>> predictions; 122 Stream<Pair<SparseVector,T>> stream = Stream.of(vectors); 123 if (numThreads > 1) { 124 ForkJoinPool fjp = new ForkJoinPool(numThreads); 125 try { 126 predictions = fjp.submit(()->StreamUtil.boundParallelism(stream.parallel()).map(distanceFunc).sorted().limit(k).map((a) -> new Prediction<>(a.output, input.numActiveElements(), example)).collect(Collectors.toList())).get(); 127 } catch (InterruptedException | ExecutionException e) { 128 logger.log(Level.SEVERE,"Exception when predicting in KNNModel",e); 129 throw new IllegalStateException("Failed to process example in parallel",e); 130 } 131 } else { 132 predictions = stream.map(distanceFunc).sorted().limit(k).map((a) -> new Prediction<>(a.output, input.numActiveElements(), example)).collect(Collectors.toList()); 133 } 134 135 return combiner.combine(outputIDInfo,predictions); 136 } 137 138 /** 139 * Uses the model to predict the output for multiple examples. 140 * @param examples the examples to predict. 141 * @return the results of the prediction, in the same order as the 142 * examples. 143 */ 144 @Override 145 protected List<Prediction<T>> innerPredict(Iterable<Example<T>> examples) { 146 if (numThreads > 1) { 147 return innerPredictMultithreaded(examples); 148 } else { 149 List<Prediction<T>> predictions = new ArrayList<>(); 150 List<Prediction<T>> innerPredictions = new ArrayList<>(); 151 PriorityQueue<OutputDoublePair<T>> queue = new PriorityQueue<>(k, (a,b) -> Double.compare(b.value, a.value)); 152 BiFunction<SparseVector,SparseVector,Double> distanceFunc; 153 switch (distance) { 154 case L1: 155 distanceFunc = (a,b) -> b.l1Distance(a); 156 break; 157 case L2: 158 distanceFunc = (a,b) -> b.l2Distance(a); 159 break; 160 case COSINE: 161 distanceFunc = (a,b) -> b.cosineDistance(a); 162 break; 163 default: 164 throw new IllegalStateException("Unknown distance function " + distance); 165 } 166 167 for (Example<T> example : examples) { 168 queue.clear(); 169 innerPredictions.clear(); 170 SparseVector input = SparseVector.createSparseVector(example, featureIDMap, false); 171 172 for (int i = 0; i < vectors.length; i++) { 173 double curDistance = distanceFunc.apply(input,vectors[i].getA()); 174 175 if (queue.size() < k) { 176 OutputDoublePair<T> newPair = new OutputDoublePair<>(vectors[i].getB(),curDistance); 177 queue.offer(newPair); 178 } else if (Double.compare(curDistance, queue.peek().value) < 0) { 179 OutputDoublePair<T> pair = queue.poll(); 180 pair.output = vectors[i].getB(); 181 pair.value = curDistance; 182 queue.offer(pair); 183 } 184 } 185 186 for (OutputDoublePair<T> pair : queue) { 187 innerPredictions.add(new Prediction<>(pair.output, input.numActiveElements(), example)); 188 } 189 190 predictions.add(combiner.combine(outputIDInfo, innerPredictions)); 191 } 192 return predictions; 193 } 194 } 195 196 /** 197 * Switches between the different multithreaded backends. 198 * @param examples The examples to predict. 199 * @return The predictions. 200 */ 201 private List<Prediction<T>> innerPredictMultithreaded(Iterable<Example<T>> examples) { 202 switch (parallelBackend) { 203 case STREAMS: 204 logger.log(Level.FINE, "Parallel backend - streams"); 205 return innerPredictStreams(examples); 206 case THREADPOOL: 207 logger.log(Level.FINE, "Parallel backend - threadpool"); 208 return innerPredictThreadPool(examples); 209 case INNERTHREADPOOL: 210 logger.log(Level.FINE, "Parallel backend - within example threadpool"); 211 return innerPredictWithinExampleThreadPool(examples); 212 default: 213 throw new IllegalArgumentException("Unknown backend " + parallelBackend); 214 } 215 } 216 217 /** 218 * Predicts using a FJP and the Streams API. 219 * @param examples The examples to predict. 220 * @return The predictions. 221 */ 222 private List<Prediction<T>> innerPredictStreams(Iterable<Example<T>> examples) { 223 List<Prediction<T>> predictions = new ArrayList<>(); 224 List<Prediction<T>> innerPredictions = null; 225 ForkJoinPool fjp = new ForkJoinPool(numThreads); 226 for (Example<T> example : examples) { 227 SparseVector input = SparseVector.createSparseVector(example, featureIDMap, false); 228 229 Function<Pair<SparseVector, T>, OutputDoublePair<T>> distanceFunc; 230 switch (distance) { 231 case L1: 232 distanceFunc = (a) -> new OutputDoublePair<>(a.getB(), a.getA().l1Distance(input)); 233 break; 234 case L2: 235 distanceFunc = (a) -> new OutputDoublePair<>(a.getB(), a.getA().l2Distance(input)); 236 break; 237 case COSINE: 238 distanceFunc = (a) -> new OutputDoublePair<>(a.getB(), a.getA().cosineDistance(input)); 239 break; 240 default: 241 throw new IllegalStateException("Unknown distance function " + distance); 242 } 243 244 Stream<Pair<SparseVector, T>> stream = Stream.of(vectors); 245 try { 246 innerPredictions = fjp.submit(() -> StreamUtil.boundParallelism(stream.parallel()).map(distanceFunc).sorted().limit(k).map((a) -> new Prediction<>(a.output, input.numActiveElements(), example)).collect(Collectors.toList())).get(); 247 } catch (InterruptedException | ExecutionException e) { 248 logger.log(Level.SEVERE, "Exception when predicting in KNNModel", e); 249 } 250 251 predictions.add(combiner.combine(outputIDInfo, innerPredictions)); 252 } 253 254 return predictions; 255 } 256 257 /** 258 * Uses a thread pool, one thread per prediction. 259 * @param examples The examples to predict. 260 * @return The predictions. 261 */ 262 private List<Prediction<T>> innerPredictThreadPool(Iterable<Example<T>> examples) { 263 BiFunction<SparseVector,SparseVector,Double> distanceFunc; 264 switch (distance) { 265 case L1: 266 distanceFunc = (a,b) -> b.l1Distance(a); 267 break; 268 case L2: 269 distanceFunc = (a,b) -> b.l2Distance(a); 270 break; 271 case COSINE: 272 distanceFunc = (a,b) -> b.cosineDistance(a); 273 break; 274 default: 275 throw new IllegalStateException("Unknown distance function " + distance); 276 } 277 278 List<Prediction<T>> predictions = new ArrayList<>(); 279 280 ExecutorService pool = Executors.newFixedThreadPool(numThreads); 281 282 List<Future<Prediction<T>>> futures = new ArrayList<>(); 283 284 ThreadLocal<PriorityQueue<OutputDoublePair<T>>> queuePool = ThreadLocal.withInitial(() -> new PriorityQueue<>(k, (a,b) -> Double.compare(b.value, a.value))); 285 286 for (Example<T> example : examples) { 287 futures.add(pool.submit(() -> innerPredictOne(queuePool,vectors,combiner,distanceFunc,featureIDMap,outputIDInfo,k,example))); 288 } 289 290 try { 291 for (Future<Prediction<T>> f : futures) { 292 predictions.add(f.get()); 293 } 294 } catch (InterruptedException | ExecutionException e) { 295 throw new IllegalStateException("Thread pool went bang",e); 296 } 297 298 pool.shutdown(); 299 300 return predictions; 301 } 302 303 /** 304 * Uses a thread pool where the pool collaborates on each example (best for large training dataset sizes). 305 * @param examples The examples to predict. 306 * @return The predictions. 307 */ 308 private List<Prediction<T>> innerPredictWithinExampleThreadPool(Iterable<Example<T>> examples) { 309 BiFunction<SparseVector,SparseVector,Double> distanceFunc; 310 switch (distance) { 311 case L1: 312 distanceFunc = (a,b) -> b.l1Distance(a); 313 break; 314 case L2: 315 distanceFunc = (a,b) -> b.l2Distance(a); 316 break; 317 case COSINE: 318 distanceFunc = (a,b) -> b.cosineDistance(a); 319 break; 320 default: 321 throw new IllegalStateException("Unknown distance function " + distance); 322 } 323 324 List<Prediction<T>> predictions = new ArrayList<>(); 325 326 ExecutorService pool = Executors.newFixedThreadPool(numThreads); 327 328 ThreadLocal<PriorityQueue<OutputDoublePair<T>>> queuePool = ThreadLocal.withInitial(() -> new PriorityQueue<>(k, (a,b) -> Double.compare(b.value, a.value))); 329 330 for (Example<T> example : examples) { 331 predictions.add(innerPredictThreadPool(pool,queuePool,distanceFunc,example)); 332 } 333 334 pool.shutdown(); 335 336 return predictions; 337 } 338 339 private Prediction<T> innerPredictThreadPool(ExecutorService pool, 340 ThreadLocal<PriorityQueue<OutputDoublePair<T>>> queuePool, 341 BiFunction<SparseVector,SparseVector,Double> distanceFunc, 342 Example<T> example) { 343 SparseVector vector = SparseVector.createSparseVector(example, featureIDMap, false); 344 List<Future<List<OutputDoublePair<T>>>> futures = new ArrayList<>(); 345 for (int i = 0; i < numThreads; i++) { 346 int start = i * (vectors.length / numThreads); 347 int end = (i + 1) * (vectors.length / numThreads); 348 futures.add(pool.submit(() -> innerPredictChunk(queuePool,vectors,start,end,distanceFunc,k,vector))); 349 } 350 351 PriorityQueue<OutputDoublePair<T>> queue = new PriorityQueue<>(k, (a,b) -> Double.compare(b.value, a.value)); 352 try { 353 for (Future<List<OutputDoublePair<T>>> f : futures) { 354 List<OutputDoublePair<T>> chunkOutputs = f.get(); 355 for (OutputDoublePair<T> curOutputPair : chunkOutputs) { 356 if (queue.size() < k) { 357 queue.offer(curOutputPair); 358 } else if (Double.compare(curOutputPair.value, queue.peek().value) < 0) { 359 queue.poll(); 360 queue.offer(curOutputPair); 361 } 362 } 363 } 364 } catch (InterruptedException | ExecutionException e) { 365 throw new IllegalStateException("Thread pool went bang",e); 366 } 367 368 List<Prediction<T>> predictions = new ArrayList<>(); 369 370 for (OutputDoublePair<T> pair : queue) { 371 predictions.add(new Prediction<>(pair.output,vector.numActiveElements(),example)); 372 } 373 374 return combiner.combine(outputIDInfo,predictions); 375 } 376 377 private static <T extends Output<T>> List<OutputDoublePair<T>> innerPredictChunk(ThreadLocal<PriorityQueue<OutputDoublePair<T>>> queuePool, 378 Pair<SparseVector,T>[] vectors, 379 int start, 380 int end, 381 BiFunction<SparseVector,SparseVector,Double> distanceFunc, 382 int k, 383 SparseVector input) { 384 PriorityQueue<OutputDoublePair<T>> queue = queuePool.get(); 385 queue.clear(); 386 387 end = Math.min(end, vectors.length); 388 389 for (int i = start; i < end; i++) { 390 double curDistance = distanceFunc.apply(input,vectors[i].getA()); 391 392 if (queue.size() < k) { 393 OutputDoublePair<T> newPair = new OutputDoublePair<>(vectors[i].getB(),curDistance); 394 queue.offer(newPair); 395 } else if (Double.compare(curDistance, queue.peek().value) < 0) { 396 OutputDoublePair<T> pair = queue.poll(); 397 pair.output = vectors[i].getB(); 398 pair.value = curDistance; 399 queue.offer(pair); 400 } 401 } 402 403 return new ArrayList<>(queue); 404 } 405 406 private static <T extends Output<T>> Prediction<T> innerPredictOne(ThreadLocal<PriorityQueue<OutputDoublePair<T>>> queuePool, 407 Pair<SparseVector,T>[] vectors, 408 EnsembleCombiner<T> combiner, 409 BiFunction<SparseVector,SparseVector,Double> distanceFunc, 410 ImmutableFeatureMap featureIDMap, 411 ImmutableOutputInfo<T> outputIDInfo, 412 int k, 413 Example<T> example) { 414 SparseVector vector = SparseVector.createSparseVector(example, featureIDMap, false); 415 PriorityQueue<OutputDoublePair<T>> queue = queuePool.get(); 416 queue.clear(); 417 418 for (int i = 0; i < vectors.length; i++) { 419 double curDistance = distanceFunc.apply(vector,vectors[i].getA()); 420 421 if (queue.size() < k) { 422 OutputDoublePair<T> newPair = new OutputDoublePair<>(vectors[i].getB(),curDistance); 423 queue.offer(newPair); 424 } else if (Double.compare(curDistance, queue.peek().value) < 0) { 425 OutputDoublePair<T> pair = queue.poll(); 426 pair.output = vectors[i].getB(); 427 pair.value = curDistance; 428 queue.offer(pair); 429 } 430 } 431 432 List<Prediction<T>> localPredictions = new ArrayList<>(); 433 434 for (OutputDoublePair<T> pair : queue) { 435 localPredictions.add(new Prediction<>(pair.output, vector.numActiveElements(), example)); 436 } 437 438 return combiner.combine(outputIDInfo,localPredictions); 439 } 440 441 @Override 442 public Map<String, List<Pair<String, Double>>> getTopFeatures(int n) { 443 return Collections.emptyMap(); 444 } 445 446 @Override 447 public Optional<Excuse<T>> getExcuse(Example<T> example) { 448 return Optional.empty(); 449 } 450 451 @SuppressWarnings("unchecked") // Generic array creation. 452 @Override 453 protected KNNModel<T> copy(String newName, ModelProvenance newProvenance) { 454 Pair<SparseVector,T>[] vectorCopy = new Pair[vectors.length]; 455 for (int i = 0; i < vectors.length; i++) { 456 vectorCopy[i] = new Pair<>(vectors[i].getA().copy(),vectors[i].getB().copy()); 457 } 458 return new KNNModel<>(newName,newProvenance,featureIDMap,outputIDInfo,generatesProbabilities,k,distance,numThreads,combiner,vectorCopy,parallelBackend); 459 } 460 461 private void readObject(java.io.ObjectInputStream in) throws IOException, ClassNotFoundException { 462 in.defaultReadObject(); 463 } 464 465 /** 466 * It's a specialised non-final pair used for buffering and to reduce object creation. 467 * @param <T> The output type. 468 */ 469 private static final class OutputDoublePair<T extends Output<T>> implements Comparable<OutputDoublePair<T>> { 470 T output; 471 double value; 472 473 public OutputDoublePair(T output, double value) { 474 this.output = output; 475 this.value = value; 476 } 477 478 @Override 479 public boolean equals(Object o) { 480 if (this == o) return true; 481 if (o == null || getClass() != o.getClass()) return false; 482 OutputDoublePair<?> that = (OutputDoublePair<?>) o; 483 return Double.compare(that.value, value) == 0 && 484 output.equals(that.output); 485 } 486 487 @Override 488 public int hashCode() { 489 return Objects.hash(output, value); 490 } 491 492 @Override 493 public int compareTo(OutputDoublePair<T> o) { 494 return Double.compare(value, o.value); 495 } 496 } 497 498}