001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.common.nearest;
018
019import com.oracle.labs.mlrg.olcut.config.Config;
020import com.oracle.labs.mlrg.olcut.config.PropertyException;
021import com.oracle.labs.mlrg.olcut.provenance.Provenance;
022import com.oracle.labs.mlrg.olcut.util.Pair;
023import org.tribuo.Dataset;
024import org.tribuo.Example;
025import org.tribuo.ImmutableFeatureMap;
026import org.tribuo.ImmutableOutputInfo;
027import org.tribuo.Model;
028import org.tribuo.Output;
029import org.tribuo.Trainer;
030import org.tribuo.common.nearest.KNNModel.Backend;
031import org.tribuo.ensemble.EnsembleCombiner;
032import org.tribuo.math.la.SparseVector;
033import org.tribuo.provenance.ModelProvenance;
034import org.tribuo.provenance.TrainerProvenance;
035import org.tribuo.provenance.impl.TrainerProvenanceImpl;
036
037import java.time.OffsetDateTime;
038import java.util.Map;
039
040/**
041 * A {@link Trainer} for k-nearest neighbour models.
042 */
043public class KNNTrainer<T extends Output<T>> implements Trainer<T> {
044
045    /**
046     * The available distance functions.
047     */
048    public enum Distance {
049        /**
050         * L1 (or Manhattan) distance.
051         */
052        L1,
053        /**
054         * L2 (or Euclidean) distance.
055         */
056        L2,
057        /**
058         * Cosine similarity used as a distance measure.
059         */
060        COSINE
061    }
062
063    @Config(mandatory = true, description="The distance function used to measure nearest neighbours.")
064    private Distance distance;
065
066    @Config(mandatory = true, description="The number of nearest neighbours to check.")
067    private int k;
068
069    @Config(mandatory = true, description="The combination function to aggregate the nearest neighbours.")
070    private EnsembleCombiner<T> combiner;
071
072    @Config(description="The number of threads to use for inference.")
073    private int numThreads = 1;
074
075    @Config(description="The threading model to use.")
076    private Backend backend = Backend.THREADPOOL;
077
078    private int invocationCount = 0;
079
080    /**
081     * For olcut.
082     */
083    private KNNTrainer() {}
084
085    /**
086     * Creates a K-NN trainer using the supplied parameters.
087     * @param k The number of nearest neighbours to consider.
088     * @param distance The distance function.
089     * @param numThreads The number of threads to use.
090     * @param combiner The combination function to aggregate the k predictions.
091     * @param backend The computational backend.
092     */
093    public KNNTrainer(int k, Distance distance, int numThreads, EnsembleCombiner<T> combiner, Backend backend) {
094        this.k = k;
095        this.distance = distance;
096        this.numThreads = numThreads;
097        this.combiner = combiner;
098        this.backend = backend;
099        postConfig();
100    }
101
102    /**
103     * Used by the OLCUT configuration system, and should not be called by external code.
104     */
105    @Override
106    public void postConfig() {
107        if (k < 1) {
108            throw new PropertyException("","k","k must be greater than 0");
109        }
110    }
111
112    @Override
113    public Model<T> train(Dataset<T> examples, Map<String, Provenance> runProvenance) {
114        ImmutableFeatureMap featureIDMap = examples.getFeatureIDMap();
115        ImmutableOutputInfo<T> labelIDMap = examples.getOutputIDInfo();
116
117        @SuppressWarnings("unchecked") // generic array creation
118        Pair<SparseVector,T>[] vectors = new Pair[examples.size()];
119
120        int i = 0;
121        for (Example<T> e : examples) {
122            vectors[i] = new Pair<>(SparseVector.createSparseVector(e,featureIDMap,false),e.getOutput());
123            i++;
124        }
125
126        invocationCount++;
127
128        ModelProvenance provenance = new ModelProvenance(KNNModel.class.getName(), OffsetDateTime.now(), examples.getProvenance(), getProvenance(), runProvenance);
129
130        return new KNNModel<>(k+"nn",provenance, featureIDMap, labelIDMap, false, k, distance, numThreads, combiner, vectors, backend);
131    }
132
133    @Override
134    public String toString() {
135        return "KNNTrainer(k="+k+",distance="+distance+",combiner="+combiner.toString()+",numThreads="+numThreads+")";
136    }
137
138    @Override
139    public int getInvocationCount() {
140        return invocationCount;
141    }
142
143    @Override
144    public TrainerProvenance getProvenance() {
145        return new TrainerProvenanceImpl(this);
146    }
147}