001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.common.nearest; 018 019import com.oracle.labs.mlrg.olcut.config.Config; 020import com.oracle.labs.mlrg.olcut.config.PropertyException; 021import com.oracle.labs.mlrg.olcut.provenance.Provenance; 022import com.oracle.labs.mlrg.olcut.util.Pair; 023import org.tribuo.Dataset; 024import org.tribuo.Example; 025import org.tribuo.ImmutableFeatureMap; 026import org.tribuo.ImmutableOutputInfo; 027import org.tribuo.Model; 028import org.tribuo.Output; 029import org.tribuo.Trainer; 030import org.tribuo.common.nearest.KNNModel.Backend; 031import org.tribuo.ensemble.EnsembleCombiner; 032import org.tribuo.math.la.SparseVector; 033import org.tribuo.provenance.ModelProvenance; 034import org.tribuo.provenance.TrainerProvenance; 035import org.tribuo.provenance.impl.TrainerProvenanceImpl; 036 037import java.time.OffsetDateTime; 038import java.util.Map; 039 040/** 041 * A {@link Trainer} for k-nearest neighbour models. 042 */ 043public class KNNTrainer<T extends Output<T>> implements Trainer<T> { 044 045 /** 046 * The available distance functions. 047 */ 048 public enum Distance { 049 /** 050 * L1 (or Manhattan) distance. 051 */ 052 L1, 053 /** 054 * L2 (or Euclidean) distance. 055 */ 056 L2, 057 /** 058 * Cosine similarity used as a distance measure. 059 */ 060 COSINE 061 } 062 063 @Config(mandatory = true, description="The distance function used to measure nearest neighbours.") 064 private Distance distance; 065 066 @Config(mandatory = true, description="The number of nearest neighbours to check.") 067 private int k; 068 069 @Config(mandatory = true, description="The combination function to aggregate the nearest neighbours.") 070 private EnsembleCombiner<T> combiner; 071 072 @Config(description="The number of threads to use for inference.") 073 private int numThreads = 1; 074 075 @Config(description="The threading model to use.") 076 private Backend backend = Backend.THREADPOOL; 077 078 private int invocationCount = 0; 079 080 /** 081 * For olcut. 082 */ 083 private KNNTrainer() {} 084 085 /** 086 * Creates a K-NN trainer using the supplied parameters. 087 * @param k The number of nearest neighbours to consider. 088 * @param distance The distance function. 089 * @param numThreads The number of threads to use. 090 * @param combiner The combination function to aggregate the k predictions. 091 * @param backend The computational backend. 092 */ 093 public KNNTrainer(int k, Distance distance, int numThreads, EnsembleCombiner<T> combiner, Backend backend) { 094 this.k = k; 095 this.distance = distance; 096 this.numThreads = numThreads; 097 this.combiner = combiner; 098 this.backend = backend; 099 postConfig(); 100 } 101 102 /** 103 * Used by the OLCUT configuration system, and should not be called by external code. 104 */ 105 @Override 106 public void postConfig() { 107 if (k < 1) { 108 throw new PropertyException("","k","k must be greater than 0"); 109 } 110 } 111 112 @Override 113 public Model<T> train(Dataset<T> examples, Map<String, Provenance> runProvenance) { 114 ImmutableFeatureMap featureIDMap = examples.getFeatureIDMap(); 115 ImmutableOutputInfo<T> labelIDMap = examples.getOutputIDInfo(); 116 117 @SuppressWarnings("unchecked") // generic array creation 118 Pair<SparseVector,T>[] vectors = new Pair[examples.size()]; 119 120 int i = 0; 121 for (Example<T> e : examples) { 122 vectors[i] = new Pair<>(SparseVector.createSparseVector(e,featureIDMap,false),e.getOutput()); 123 i++; 124 } 125 126 invocationCount++; 127 128 ModelProvenance provenance = new ModelProvenance(KNNModel.class.getName(), OffsetDateTime.now(), examples.getProvenance(), getProvenance(), runProvenance); 129 130 return new KNNModel<>(k+"nn",provenance, featureIDMap, labelIDMap, false, k, distance, numThreads, combiner, vectors, backend); 131 } 132 133 @Override 134 public String toString() { 135 return "KNNTrainer(k="+k+",distance="+distance+",combiner="+combiner.toString()+",numThreads="+numThreads+")"; 136 } 137 138 @Override 139 public int getInvocationCount() { 140 return invocationCount; 141 } 142 143 @Override 144 public TrainerProvenance getProvenance() { 145 return new TrainerProvenanceImpl(this); 146 } 147}