001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.interop; 018 019import org.tribuo.CategoricalInfo; 020import org.tribuo.Example; 021import org.tribuo.Excuse; 022import org.tribuo.ImmutableFeatureMap; 023import org.tribuo.ImmutableOutputInfo; 024import org.tribuo.Model; 025import org.tribuo.MutableFeatureMap; 026import org.tribuo.Output; 027import org.tribuo.OutputFactory; 028import org.tribuo.Prediction; 029import org.tribuo.math.la.SparseVector; 030import org.tribuo.math.la.VectorTuple; 031import org.tribuo.provenance.ModelProvenance; 032 033import java.util.ArrayList; 034import java.util.Arrays; 035import java.util.List; 036import java.util.Map; 037import java.util.Optional; 038import java.util.Set; 039 040/** 041 * This is the base class for third party models which are trained externally and 042 * loaded into Tribuo for prediction. 043 * <p> 044 * Batch size defaults to {@link ExternalModel#DEFAULT_BATCH_SIZE} 045 * @param <T> The output subclass that this model operates on. 046 * @param <U> The internal representation of features. 047 * @param <V> The internal representation of outputs. 048 */ 049public abstract class ExternalModel<T extends Output<T>,U,V> extends Model<T> { 050 private static final long serialVersionUID = 1L; 051 /** 052 * Default batch size for external model batch predictions. 053 */ 054 public static final int DEFAULT_BATCH_SIZE = 16; 055 056 protected final int[] featureForwardMapping; 057 protected final int[] featureBackwardMapping; 058 059 private int batchSize = DEFAULT_BATCH_SIZE; 060 061 protected ExternalModel(String name, ModelProvenance provenance, 062 ImmutableFeatureMap featureIDMap, ImmutableOutputInfo<T> outputIDInfo, 063 boolean generatesProbabilities, Map<String,Integer> featureMapping) { 064 super(name, provenance, featureIDMap, outputIDInfo, generatesProbabilities); 065 066 if (featureIDMap.size() != featureMapping.size()) { 067 throw new IllegalArgumentException("The featureMapping must be the same size as the featureIDMap, found featureMapping.size()="+featureMapping.size()+", featureIDMap.size()="+featureIDMap.size()); 068 } 069 070 this.featureForwardMapping = new int[featureIDMap.size()]; 071 this.featureBackwardMapping = new int[featureIDMap.size()]; 072 Arrays.fill(featureForwardMapping,-1); 073 Arrays.fill(featureBackwardMapping,-1); 074 075 for (Map.Entry<String,Integer> e : featureMapping.entrySet()) { 076 int tribuoID = featureIDMap.getID(e.getKey()); 077 int mappingID = e.getValue(); 078 if (tribuoID == -1) { 079 throw new IllegalArgumentException("Found invalid feature name in mapping " + e); 080 } else if (mappingID >= featureForwardMapping.length) { 081 throw new IllegalArgumentException("Found invalid feature id in mapping " + e); 082 } else if (featureBackwardMapping[mappingID] != -1) { 083 throw new IllegalArgumentException("Mapping for " + e + " already exists as feature " + featureIDMap.get(featureBackwardMapping[mappingID])); 084 } 085 086 featureForwardMapping[tribuoID] = mappingID; 087 featureBackwardMapping[mappingID] = tribuoID; 088 } 089 } 090 091 protected ExternalModel(String name, ModelProvenance provenance, ImmutableFeatureMap featureIDMap, ImmutableOutputInfo<T> outputIDInfo, int[] featureForwardMapping, int[] featureBackwardMapping, boolean generatesProbabilities) { 092 super(name,provenance,featureIDMap,outputIDInfo,generatesProbabilities); 093 this.featureBackwardMapping = Arrays.copyOf(featureBackwardMapping,featureBackwardMapping.length); 094 this.featureForwardMapping = Arrays.copyOf(featureForwardMapping,featureForwardMapping.length); 095 } 096 097 @Override 098 public Prediction<T> predict(Example<T> example) { 099 SparseVector features = SparseVector.createSparseVector(example,featureIDMap,false); 100 SparseVector renumberedFeatures = renumberFeatureIndices(features); 101 U transformedFeatures = convertFeatures(renumberedFeatures); 102 V output = externalPrediction(transformedFeatures); 103 return convertOutput(output,features.numActiveElements(),example); 104 } 105 106 @Override 107 protected List<Prediction<T>> innerPredict(Iterable<Example<T>> examples) { 108 List<Prediction<T>> predictions = new ArrayList<>(); 109 List<Example<T>> batchExamples = new ArrayList<>(); 110 for (Example<T> example : examples) { 111 batchExamples.add(example); 112 if (batchExamples.size() == batchSize) { 113 predictions.addAll(predictBatch(batchExamples)); 114 // clear the batch 115 batchExamples.clear(); 116 } 117 } 118 119 if (!batchExamples.isEmpty()) { 120 // send the partial batch 121 predictions.addAll(predictBatch(batchExamples)); 122 } 123 return predictions; 124 } 125 126 private List<Prediction<T>> predictBatch(List<Example<T>> batch) { 127 List<SparseVector> vectors = new ArrayList<>(); 128 int[] numValidFeatures = new int[batch.size()]; 129 for (int i = 0; i < batch.size(); i++) { 130 SparseVector features = SparseVector.createSparseVector(batch.get(i),featureIDMap,false); 131 vectors.add(renumberFeatureIndices(features)); 132 numValidFeatures[i] = features.numActiveElements(); 133 } 134 U transformedFeatures = convertFeaturesList(vectors); 135 V output = externalPrediction(transformedFeatures); 136 List<Prediction<T>> predictions = convertOutput(output,numValidFeatures,batch); 137 if (predictions.size() != vectors.size()) { 138 throw new IllegalStateException("Unexpected number of predictions received from external model batch, found " + predictions.size() + ", expected " + vectors.size() + "."); 139 } else { 140 return predictions; 141 } 142 } 143 144 /** 145 * Renumbers the indices in a {@link SparseVector} switching from 146 * Tribuo's internal indices to the external ones for this model. 147 * @param input The features using internal indices. 148 * @return The features using external indices. 149 */ 150 private SparseVector renumberFeatureIndices(SparseVector input) { 151 int inputSize = input.numActiveElements(); 152 int[] newIndices = new int[inputSize]; 153 double[] newValues = new double[inputSize]; 154 155 int i = 0; 156 for (VectorTuple t : input) { 157 int tribuoIdx = t.index; 158 double value = t.value; 159 newIndices[i] = featureForwardMapping[tribuoIdx]; 160 newValues[i] = value; 161 i++; 162 } 163 164 return SparseVector.createSparseVector(input.size(),newIndices,newValues); 165 } 166 167 /** 168 * Converts from a SparseVector using the external model's indices into 169 * the ingestion format for the external model. 170 * @param input The features using external indices. 171 * @return The ingestion format for the external model. 172 */ 173 protected abstract U convertFeatures(SparseVector input); 174 175 /** 176 * Converts from a list of SparseVector using the external model's indices 177 * into the ingestion format for the external model. 178 * @param input The features using external indices. 179 * @return The ingestion format for the external model. 180 */ 181 protected abstract U convertFeaturesList(List<SparseVector> input); 182 183 /** 184 * Runs the external model's prediction function. 185 * @param input The input in the external model's format. 186 * @return The output in the external model's format. 187 */ 188 protected abstract V externalPrediction(U input); 189 190 /** 191 * Converts the output of the external model into a {@link Prediction}. 192 * @param output The output of the external model. 193 * @param numValidFeatures The number of valid features in the input. 194 * @param example The input example, used to construct the Prediction. 195 * @return A Tribuo Prediction. 196 */ 197 protected abstract Prediction<T> convertOutput(V output, int numValidFeatures, Example<T> example); 198 199 /** 200 * Converts the output of the external model into a list of {@link Prediction}s. 201 * @param output The output of the external model. 202 * @param numValidFeatures An array with the number of valid features in each example. 203 * @param examples The input examples, used to construct the Predictions. 204 * @return A list of Tribuo Predictions. 205 */ 206 protected abstract List<Prediction<T>> convertOutput(V output, int[] numValidFeatures, List<Example<T>> examples); 207 208 /** 209 * By default third party models don't return excuses. 210 * @param example The input example. 211 * @return Optional.empty. 212 */ 213 @Override 214 public Optional<Excuse<T>> getExcuse(Example<T> example) { 215 return Optional.empty(); 216 } 217 218 /** 219 * Gets the current testing batch size. 220 * @return The batch size. 221 */ 222 public int getBatchSize() { 223 return batchSize; 224 } 225 226 /** 227 * Sets a new batch size. 228 * <p> 229 * Throws {@link IllegalArgumentException} if the batch size isn't positive. 230 * @param batchSize The batch size to use. 231 */ 232 public void setBatchSize(int batchSize) { 233 if (batchSize > 0) { 234 this.batchSize = batchSize; 235 } else { 236 throw new IllegalArgumentException("Batch size must be positive, found " + batchSize); 237 } 238 } 239 240 /** 241 * Creates an immutable feature map from a set of feature names. 242 * <p> 243 * Each feature is unobserved. 244 * @param featureNames The names of the features to create. 245 * @return A feature map representing the feature names. 246 */ 247 protected static ImmutableFeatureMap createFeatureMap(Set<String> featureNames) { 248 MutableFeatureMap featureMap = new MutableFeatureMap(); 249 250 for (String name : featureNames) { 251 featureMap.put(new CategoricalInfo(name)); 252 } 253 254 return new ImmutableFeatureMap(featureMap); 255 } 256 257 /** 258 * Creates an output info from a set of outputs. 259 * @param factory The output factory to use. 260 * @param outputs The outputs and ids to observe. 261 * @param <T> The type of the outputs. 262 * @return An immutable output info representing the outputs. 263 */ 264 protected static <T extends Output<T>> ImmutableOutputInfo<T> createOutputInfo(OutputFactory<T> factory, Map<T,Integer> outputs) { 265 return factory.constructInfoForExternalModel(outputs); 266 } 267 268}