001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.common.nearest;
018
019import com.oracle.labs.mlrg.olcut.util.Pair;
020import com.oracle.labs.mlrg.olcut.util.StreamUtil;
021import org.tribuo.Example;
022import org.tribuo.Excuse;
023import org.tribuo.ImmutableFeatureMap;
024import org.tribuo.ImmutableOutputInfo;
025import org.tribuo.Model;
026import org.tribuo.Output;
027import org.tribuo.Prediction;
028import org.tribuo.common.nearest.KNNTrainer.Distance;
029import org.tribuo.ensemble.EnsembleCombiner;
030import org.tribuo.math.la.SparseVector;
031import org.tribuo.provenance.ModelProvenance;
032
033import java.io.IOException;
034import java.util.ArrayList;
035import java.util.Collections;
036import java.util.List;
037import java.util.Map;
038import java.util.Objects;
039import java.util.Optional;
040import java.util.PriorityQueue;
041import java.util.concurrent.ExecutionException;
042import java.util.concurrent.ExecutorService;
043import java.util.concurrent.Executors;
044import java.util.concurrent.ForkJoinPool;
045import java.util.concurrent.Future;
046import java.util.function.BiFunction;
047import java.util.function.Function;
048import java.util.logging.Level;
049import java.util.logging.Logger;
050import java.util.stream.Collectors;
051import java.util.stream.Stream;
052
053/**
054 * A k-nearest neighbours model.
055 */
056public class KNNModel<T extends Output<T>> extends Model<T> {
057
058    private static final Logger logger = Logger.getLogger(KNNModel.class.getName());
059
060    private static final long serialVersionUID = 1L;
061
062    /**
063     * The parallel backend for batch predictions.
064     */
065    public enum Backend {
066        /**
067         * Uses the streams API for parallelism when scoring a batch of predictions.
068         */
069        STREAMS,
070        /**
071         * Uses a thread pool at the outer level (i.e., one thread per prediction).
072         */
073        THREADPOOL,
074        /**
075         * Uses a thread pool at the inner level (i.e., the whole thread pool works on each prediction).
076         */
077        INNERTHREADPOOL
078    }
079
080    private final Pair<SparseVector,T>[] vectors;
081
082    private final int k;
083    private final Distance distance;
084    private final int numThreads;
085
086    private final Backend parallelBackend;
087
088    private final EnsembleCombiner<T> combiner;
089
090    KNNModel(String name, ModelProvenance provenance, ImmutableFeatureMap featureIDMap, ImmutableOutputInfo<T> outputIDInfo,
091                    boolean generatesProbabilities, int k, Distance distance, int numThreads, EnsembleCombiner<T> combiner,
092                    Pair<SparseVector,T>[] vectors, Backend backend) {
093        super(name,provenance,featureIDMap,outputIDInfo,generatesProbabilities);
094        this.k = k;
095        this.distance = distance;
096        this.numThreads = numThreads;
097        this.combiner = combiner;
098        this.parallelBackend = backend;
099        this.vectors = vectors;
100    }
101
102    @Override
103    public Prediction<T> predict(Example<T> example) {
104        SparseVector input = SparseVector.createSparseVector(example,featureIDMap,false);
105
106        Function<Pair<SparseVector,T>, OutputDoublePair<T>> distanceFunc;
107        switch (distance) {
108            case L1:
109                distanceFunc = (a) -> new OutputDoublePair<>(a.getB(),a.getA().l1Distance(input));
110                break;
111            case L2:
112                distanceFunc = (a) -> new OutputDoublePair<>(a.getB(),a.getA().l2Distance(input));
113                break;
114            case COSINE:
115                distanceFunc = (a) -> new OutputDoublePair<>(a.getB(),a.getA().cosineDistance(input));
116                break;
117            default:
118                throw new IllegalStateException("Unknown distance function " + distance);
119        }
120
121        List<Prediction<T>> predictions;
122        Stream<Pair<SparseVector,T>> stream = Stream.of(vectors);
123        if (numThreads > 1) {
124            ForkJoinPool fjp = new ForkJoinPool(numThreads);
125            try {
126                predictions = fjp.submit(()->StreamUtil.boundParallelism(stream.parallel()).map(distanceFunc).sorted().limit(k).map((a) -> new Prediction<>(a.output, input.numActiveElements(), example)).collect(Collectors.toList())).get();
127            } catch (InterruptedException | ExecutionException e) {
128                logger.log(Level.SEVERE,"Exception when predicting in KNNModel",e);
129                throw new IllegalStateException("Failed to process example in parallel",e);
130            }
131        } else {
132            predictions = stream.map(distanceFunc).sorted().limit(k).map((a) -> new Prediction<>(a.output, input.numActiveElements(), example)).collect(Collectors.toList());
133        }
134
135        return combiner.combine(outputIDInfo,predictions);
136    }
137
138    /**
139     * Uses the model to predict the output for multiple examples.
140     * @param examples the examples to predict.
141     * @return the results of the prediction, in the same order as the
142     * examples.
143     */
144    @Override
145    protected List<Prediction<T>> innerPredict(Iterable<Example<T>> examples) {
146        if (numThreads > 1) {
147            return innerPredictMultithreaded(examples);
148        } else {
149            List<Prediction<T>> predictions = new ArrayList<>();
150            List<Prediction<T>> innerPredictions = new ArrayList<>();
151            PriorityQueue<OutputDoublePair<T>> queue = new PriorityQueue<>(k, (a,b) -> Double.compare(b.value, a.value));
152            BiFunction<SparseVector,SparseVector,Double> distanceFunc;
153            switch (distance) {
154                case L1:
155                    distanceFunc = (a,b) -> b.l1Distance(a);
156                    break;
157                case L2:
158                    distanceFunc = (a,b) -> b.l2Distance(a);
159                    break;
160                case COSINE:
161                    distanceFunc = (a,b) -> b.cosineDistance(a);
162                    break;
163                default:
164                    throw new IllegalStateException("Unknown distance function " + distance);
165            }
166
167            for (Example<T> example : examples) {
168                queue.clear();
169                innerPredictions.clear();
170                SparseVector input = SparseVector.createSparseVector(example, featureIDMap, false);
171
172                for (int i = 0; i < vectors.length; i++) {
173                    double curDistance = distanceFunc.apply(input,vectors[i].getA());
174
175                    if (queue.size() < k) {
176                        OutputDoublePair<T> newPair = new OutputDoublePair<>(vectors[i].getB(),curDistance);
177                        queue.offer(newPair);
178                    } else if (Double.compare(curDistance, queue.peek().value) < 0) {
179                        OutputDoublePair<T> pair = queue.poll();
180                        pair.output = vectors[i].getB();
181                        pair.value = curDistance;
182                        queue.offer(pair);
183                    }
184                }
185
186                for (OutputDoublePair<T> pair : queue) {
187                    innerPredictions.add(new Prediction<>(pair.output, input.numActiveElements(), example));
188                }
189
190                predictions.add(combiner.combine(outputIDInfo, innerPredictions));
191            }
192            return predictions;
193        }
194    }
195
196    /**
197     * Switches between the different multithreaded backends.
198     * @param examples The examples to predict.
199     * @return The predictions.
200     */
201    private List<Prediction<T>> innerPredictMultithreaded(Iterable<Example<T>> examples) {
202        switch (parallelBackend) {
203            case STREAMS:
204                logger.log(Level.FINE, "Parallel backend - streams");
205                return innerPredictStreams(examples);
206            case THREADPOOL:
207                logger.log(Level.FINE, "Parallel backend - threadpool");
208                return innerPredictThreadPool(examples);
209            case INNERTHREADPOOL:
210                logger.log(Level.FINE, "Parallel backend - within example threadpool");
211                return innerPredictWithinExampleThreadPool(examples);
212            default:
213                throw new IllegalArgumentException("Unknown backend " + parallelBackend);
214        }
215    }
216
217    /**
218     * Predicts using a FJP and the Streams API.
219     * @param examples The examples to predict.
220     * @return The predictions.
221     */
222    private List<Prediction<T>> innerPredictStreams(Iterable<Example<T>> examples) {
223        List<Prediction<T>> predictions = new ArrayList<>();
224        List<Prediction<T>> innerPredictions = null;
225        ForkJoinPool fjp = new ForkJoinPool(numThreads);
226        for (Example<T> example : examples) {
227            SparseVector input = SparseVector.createSparseVector(example, featureIDMap, false);
228
229            Function<Pair<SparseVector, T>, OutputDoublePair<T>> distanceFunc;
230            switch (distance) {
231                case L1:
232                    distanceFunc = (a) -> new OutputDoublePair<>(a.getB(), a.getA().l1Distance(input));
233                    break;
234                case L2:
235                    distanceFunc = (a) -> new OutputDoublePair<>(a.getB(), a.getA().l2Distance(input));
236                    break;
237                case COSINE:
238                    distanceFunc = (a) -> new OutputDoublePair<>(a.getB(), a.getA().cosineDistance(input));
239                    break;
240                default:
241                    throw new IllegalStateException("Unknown distance function " + distance);
242            }
243
244            Stream<Pair<SparseVector, T>> stream = Stream.of(vectors);
245            try {
246                innerPredictions = fjp.submit(() -> StreamUtil.boundParallelism(stream.parallel()).map(distanceFunc).sorted().limit(k).map((a) -> new Prediction<>(a.output, input.numActiveElements(), example)).collect(Collectors.toList())).get();
247            } catch (InterruptedException | ExecutionException e) {
248                logger.log(Level.SEVERE, "Exception when predicting in KNNModel", e);
249            }
250
251            predictions.add(combiner.combine(outputIDInfo, innerPredictions));
252        }
253
254        return predictions;
255    }
256
257    /**
258     * Uses a thread pool, one thread per prediction.
259     * @param examples The examples to predict.
260     * @return The predictions.
261     */
262    private List<Prediction<T>> innerPredictThreadPool(Iterable<Example<T>> examples) {
263        BiFunction<SparseVector,SparseVector,Double> distanceFunc;
264        switch (distance) {
265            case L1:
266                distanceFunc = (a,b) -> b.l1Distance(a);
267                break;
268            case L2:
269                distanceFunc = (a,b) -> b.l2Distance(a);
270                break;
271            case COSINE:
272                distanceFunc = (a,b) -> b.cosineDistance(a);
273                break;
274            default:
275                throw new IllegalStateException("Unknown distance function " + distance);
276        }
277
278        List<Prediction<T>> predictions = new ArrayList<>();
279
280        ExecutorService pool = Executors.newFixedThreadPool(numThreads);
281
282        List<Future<Prediction<T>>> futures = new ArrayList<>();
283
284        ThreadLocal<PriorityQueue<OutputDoublePair<T>>> queuePool = ThreadLocal.withInitial(() -> new PriorityQueue<>(k, (a,b) -> Double.compare(b.value, a.value)));
285
286        for (Example<T> example : examples) {
287            futures.add(pool.submit(() -> innerPredictOne(queuePool,vectors,combiner,distanceFunc,featureIDMap,outputIDInfo,k,example)));
288        }
289
290        try {
291            for (Future<Prediction<T>> f : futures) {
292                predictions.add(f.get());
293            }
294        } catch (InterruptedException | ExecutionException e) {
295            throw new IllegalStateException("Thread pool went bang",e);
296        }
297
298        pool.shutdown();
299
300        return predictions;
301    }
302
303    /**
304     * Uses a thread pool where the pool collaborates on each example (best for large training dataset sizes).
305     * @param examples The examples to predict.
306     * @return The predictions.
307     */
308    private List<Prediction<T>> innerPredictWithinExampleThreadPool(Iterable<Example<T>> examples) {
309        BiFunction<SparseVector,SparseVector,Double> distanceFunc;
310        switch (distance) {
311            case L1:
312                distanceFunc = (a,b) -> b.l1Distance(a);
313                break;
314            case L2:
315                distanceFunc = (a,b) -> b.l2Distance(a);
316                break;
317            case COSINE:
318                distanceFunc = (a,b) -> b.cosineDistance(a);
319                break;
320            default:
321                throw new IllegalStateException("Unknown distance function " + distance);
322        }
323
324        List<Prediction<T>> predictions = new ArrayList<>();
325
326        ExecutorService pool = Executors.newFixedThreadPool(numThreads);
327
328        ThreadLocal<PriorityQueue<OutputDoublePair<T>>> queuePool = ThreadLocal.withInitial(() -> new PriorityQueue<>(k, (a,b) -> Double.compare(b.value, a.value)));
329
330        for (Example<T> example : examples) {
331            predictions.add(innerPredictThreadPool(pool,queuePool,distanceFunc,example));
332        }
333
334        pool.shutdown();
335
336        return predictions;
337    }
338
339    private Prediction<T> innerPredictThreadPool(ExecutorService pool,
340                                                 ThreadLocal<PriorityQueue<OutputDoublePair<T>>> queuePool,
341                                                 BiFunction<SparseVector,SparseVector,Double> distanceFunc,
342                                                 Example<T> example) {
343        SparseVector vector = SparseVector.createSparseVector(example, featureIDMap, false);
344        List<Future<List<OutputDoublePair<T>>>> futures = new ArrayList<>();
345        for (int i = 0; i < numThreads; i++) {
346            int start = i * (vectors.length / numThreads);
347            int end = (i + 1) * (vectors.length / numThreads);
348            futures.add(pool.submit(() -> innerPredictChunk(queuePool,vectors,start,end,distanceFunc,k,vector)));
349        }
350
351        PriorityQueue<OutputDoublePair<T>> queue = new PriorityQueue<>(k, (a,b) -> Double.compare(b.value, a.value));
352        try {
353            for (Future<List<OutputDoublePair<T>>> f : futures) {
354                List<OutputDoublePair<T>> chunkOutputs = f.get();
355                for (OutputDoublePair<T> curOutputPair : chunkOutputs) {
356                    if (queue.size() < k) {
357                        queue.offer(curOutputPair);
358                    } else if (Double.compare(curOutputPair.value, queue.peek().value) < 0) {
359                        queue.poll();
360                        queue.offer(curOutputPair);
361                    }
362                }
363            }
364        } catch (InterruptedException | ExecutionException e) {
365            throw new IllegalStateException("Thread pool went bang",e);
366        }
367
368        List<Prediction<T>> predictions = new ArrayList<>();
369
370        for (OutputDoublePair<T> pair : queue) {
371            predictions.add(new Prediction<>(pair.output,vector.numActiveElements(),example));
372        }
373
374        return combiner.combine(outputIDInfo,predictions);
375    }
376
377    private static <T extends Output<T>> List<OutputDoublePair<T>> innerPredictChunk(ThreadLocal<PriorityQueue<OutputDoublePair<T>>> queuePool,
378                                                                            Pair<SparseVector,T>[] vectors,
379                                                                            int start,
380                                                                            int end,
381                                                                            BiFunction<SparseVector,SparseVector,Double> distanceFunc,
382                                                                            int k,
383                                                                            SparseVector input) {
384        PriorityQueue<OutputDoublePair<T>> queue = queuePool.get();
385        queue.clear();
386
387        end = Math.min(end, vectors.length);
388
389        for (int i = start; i < end; i++) {
390            double curDistance = distanceFunc.apply(input,vectors[i].getA());
391
392            if (queue.size() < k) {
393                OutputDoublePair<T> newPair = new OutputDoublePair<>(vectors[i].getB(),curDistance);
394                queue.offer(newPair);
395            } else if (Double.compare(curDistance, queue.peek().value) < 0) {
396                OutputDoublePair<T> pair = queue.poll();
397                pair.output = vectors[i].getB();
398                pair.value = curDistance;
399                queue.offer(pair);
400            }
401        }
402
403        return new ArrayList<>(queue);
404    }
405
406    private static <T extends Output<T>> Prediction<T> innerPredictOne(ThreadLocal<PriorityQueue<OutputDoublePair<T>>> queuePool,
407                                                                    Pair<SparseVector,T>[] vectors,
408                                                                    EnsembleCombiner<T> combiner,
409                                                                    BiFunction<SparseVector,SparseVector,Double> distanceFunc,
410                                                                    ImmutableFeatureMap featureIDMap,
411                                                                    ImmutableOutputInfo<T> outputIDInfo,
412                                                                    int k,
413                                                                    Example<T> example) {
414        SparseVector vector = SparseVector.createSparseVector(example, featureIDMap, false);
415        PriorityQueue<OutputDoublePair<T>> queue = queuePool.get();
416        queue.clear();
417
418        for (int i = 0; i < vectors.length; i++) {
419            double curDistance = distanceFunc.apply(vector,vectors[i].getA());
420
421            if (queue.size() < k) {
422                OutputDoublePair<T> newPair = new OutputDoublePair<>(vectors[i].getB(),curDistance);
423                queue.offer(newPair);
424            } else if (Double.compare(curDistance, queue.peek().value) < 0) {
425                OutputDoublePair<T> pair = queue.poll();
426                pair.output = vectors[i].getB();
427                pair.value = curDistance;
428                queue.offer(pair);
429            }
430        }
431
432        List<Prediction<T>> localPredictions = new ArrayList<>();
433
434        for (OutputDoublePair<T> pair : queue) {
435            localPredictions.add(new Prediction<>(pair.output, vector.numActiveElements(), example));
436        }
437
438        return combiner.combine(outputIDInfo,localPredictions);
439    }
440
441    @Override
442    public Map<String, List<Pair<String, Double>>> getTopFeatures(int n) {
443        return Collections.emptyMap();
444    }
445
446    @Override
447    public Optional<Excuse<T>> getExcuse(Example<T> example) {
448        return Optional.empty();
449    }
450
451    @SuppressWarnings("unchecked") // Generic array creation.
452    @Override
453    protected KNNModel<T> copy(String newName, ModelProvenance newProvenance) {
454        Pair<SparseVector,T>[] vectorCopy = new Pair[vectors.length];
455        for (int i = 0; i < vectors.length; i++) {
456            vectorCopy[i] = new Pair<>(vectors[i].getA().copy(),vectors[i].getB().copy());
457        }
458        return new KNNModel<>(newName,newProvenance,featureIDMap,outputIDInfo,generatesProbabilities,k,distance,numThreads,combiner,vectorCopy,parallelBackend);
459    }
460
461    private void readObject(java.io.ObjectInputStream in) throws IOException, ClassNotFoundException {
462        in.defaultReadObject();
463    }
464
465    /**
466     * It's a specialised non-final pair used for buffering and to reduce object creation.
467     * @param <T> The output type.
468     */
469    private static final class OutputDoublePair<T extends Output<T>> implements Comparable<OutputDoublePair<T>> {
470        T output;
471        double value;
472
473        public OutputDoublePair(T output, double value) {
474            this.output = output;
475            this.value = value;
476        }
477
478        @Override
479        public boolean equals(Object o) {
480            if (this == o) return true;
481            if (o == null || getClass() != o.getClass()) return false;
482            OutputDoublePair<?> that = (OutputDoublePair<?>) o;
483            return Double.compare(that.value, value) == 0 &&
484                    output.equals(that.output);
485        }
486
487        @Override
488        public int hashCode() {
489            return Objects.hash(output, value);
490        }
491
492        @Override
493        public int compareTo(OutputDoublePair<T> o) {
494            return Double.compare(value, o.value);
495        }
496    }
497
498}