001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.interop;
018
019import org.tribuo.CategoricalInfo;
020import org.tribuo.Example;
021import org.tribuo.Excuse;
022import org.tribuo.ImmutableFeatureMap;
023import org.tribuo.ImmutableOutputInfo;
024import org.tribuo.Model;
025import org.tribuo.MutableFeatureMap;
026import org.tribuo.Output;
027import org.tribuo.OutputFactory;
028import org.tribuo.Prediction;
029import org.tribuo.math.la.SparseVector;
030import org.tribuo.math.la.VectorTuple;
031import org.tribuo.provenance.ModelProvenance;
032
033import java.util.ArrayList;
034import java.util.Arrays;
035import java.util.List;
036import java.util.Map;
037import java.util.Optional;
038import java.util.Set;
039
040/**
041 * This is the base class for third party models which are trained externally and
042 * loaded into Tribuo for prediction.
043 * <p>
044 * Batch size defaults to {@link ExternalModel#DEFAULT_BATCH_SIZE}
045 * @param <T> The output subclass that this model operates on.
046 * @param <U> The internal representation of features.
047 * @param <V> The internal representation of outputs.
048 */
049public abstract class ExternalModel<T extends Output<T>,U,V> extends Model<T> {
050    private static final long serialVersionUID = 1L;
051    /**
052     * Default batch size for external model batch predictions.
053     */
054    public static final int DEFAULT_BATCH_SIZE = 16;
055
056    protected final int[] featureForwardMapping;
057    protected final int[] featureBackwardMapping;
058
059    private int batchSize = DEFAULT_BATCH_SIZE;
060
061    protected ExternalModel(String name, ModelProvenance provenance,
062                            ImmutableFeatureMap featureIDMap, ImmutableOutputInfo<T> outputIDInfo,
063                            boolean generatesProbabilities, Map<String,Integer> featureMapping) {
064        super(name, provenance, featureIDMap, outputIDInfo, generatesProbabilities);
065
066        if (featureIDMap.size() != featureMapping.size()) {
067            throw new IllegalArgumentException("The featureMapping must be the same size as the featureIDMap, found featureMapping.size()="+featureMapping.size()+", featureIDMap.size()="+featureIDMap.size());
068        }
069
070        this.featureForwardMapping = new int[featureIDMap.size()];
071        this.featureBackwardMapping = new int[featureIDMap.size()];
072        Arrays.fill(featureForwardMapping,-1);
073        Arrays.fill(featureBackwardMapping,-1);
074
075        for (Map.Entry<String,Integer> e : featureMapping.entrySet()) {
076            int tribuoID = featureIDMap.getID(e.getKey());
077            int mappingID = e.getValue();
078            if (tribuoID == -1) {
079                throw new IllegalArgumentException("Found invalid feature name in mapping " + e);
080            } else if (mappingID >= featureForwardMapping.length) {
081                throw new IllegalArgumentException("Found invalid feature id in mapping " + e);
082            } else if (featureBackwardMapping[mappingID] != -1) {
083                throw new IllegalArgumentException("Mapping for " + e + " already exists as feature " + featureIDMap.get(featureBackwardMapping[mappingID]));
084            }
085
086            featureForwardMapping[tribuoID] = mappingID;
087            featureBackwardMapping[mappingID] = tribuoID;
088        }
089    }
090
091    protected ExternalModel(String name, ModelProvenance provenance, ImmutableFeatureMap featureIDMap, ImmutableOutputInfo<T> outputIDInfo, int[] featureForwardMapping, int[] featureBackwardMapping, boolean generatesProbabilities) {
092        super(name,provenance,featureIDMap,outputIDInfo,generatesProbabilities);
093        this.featureBackwardMapping = Arrays.copyOf(featureBackwardMapping,featureBackwardMapping.length);
094        this.featureForwardMapping = Arrays.copyOf(featureForwardMapping,featureForwardMapping.length);
095    }
096
097    @Override
098    public Prediction<T> predict(Example<T> example) {
099        SparseVector features = SparseVector.createSparseVector(example,featureIDMap,false);
100        SparseVector renumberedFeatures = renumberFeatureIndices(features);
101        U transformedFeatures = convertFeatures(renumberedFeatures);
102        V output = externalPrediction(transformedFeatures);
103        return convertOutput(output,features.numActiveElements(),example);
104    }
105
106    @Override
107    protected List<Prediction<T>> innerPredict(Iterable<Example<T>> examples) {
108        List<Prediction<T>> predictions = new ArrayList<>();
109        List<Example<T>> batchExamples = new ArrayList<>();
110        for (Example<T> example : examples) {
111            batchExamples.add(example);
112            if (batchExamples.size() == batchSize) {
113                predictions.addAll(predictBatch(batchExamples));
114                // clear the batch
115                batchExamples.clear();
116            }
117        }
118
119        if (!batchExamples.isEmpty()) {
120            // send the partial batch
121            predictions.addAll(predictBatch(batchExamples));
122        }
123        return predictions;
124    }
125
126    private List<Prediction<T>> predictBatch(List<Example<T>> batch) {
127        List<SparseVector> vectors = new ArrayList<>();
128        int[] numValidFeatures = new int[batch.size()];
129        for (int i = 0; i < batch.size(); i++) {
130            SparseVector features = SparseVector.createSparseVector(batch.get(i),featureIDMap,false);
131            vectors.add(renumberFeatureIndices(features));
132            numValidFeatures[i] = features.numActiveElements();
133        }
134        U transformedFeatures = convertFeaturesList(vectors);
135        V output = externalPrediction(transformedFeatures);
136        List<Prediction<T>> predictions = convertOutput(output,numValidFeatures,batch);
137        if (predictions.size() != vectors.size()) {
138            throw new IllegalStateException("Unexpected number of predictions received from external model batch, found " + predictions.size() + ", expected " + vectors.size() + ".");
139        } else {
140            return predictions;
141        }
142    }
143
144    /**
145     * Renumbers the indices in a {@link SparseVector} switching from
146     * Tribuo's internal indices to the external ones for this model.
147     * @param input The features using internal indices.
148     * @return The features using external indices.
149     */
150    private SparseVector renumberFeatureIndices(SparseVector input) {
151        int inputSize = input.numActiveElements();
152        int[] newIndices = new int[inputSize];
153        double[] newValues = new double[inputSize];
154
155        int i = 0;
156        for (VectorTuple t : input) {
157            int tribuoIdx = t.index;
158            double value = t.value;
159            newIndices[i] = featureForwardMapping[tribuoIdx];
160            newValues[i] = value;
161            i++;
162        }
163
164        return SparseVector.createSparseVector(input.size(),newIndices,newValues);
165    }
166
167    /**
168     * Converts from a SparseVector using the external model's indices into
169     * the ingestion format for the external model.
170     * @param input The features using external indices.
171     * @return The ingestion format for the external model.
172     */
173    protected abstract U convertFeatures(SparseVector input);
174
175    /**
176     * Converts from a list of SparseVector using the external model's indices
177     * into the ingestion format for the external model.
178     * @param input The features using external indices.
179     * @return The ingestion format for the external model.
180     */
181    protected abstract U convertFeaturesList(List<SparseVector> input);
182
183    /**
184     * Runs the external model's prediction function.
185     * @param input The input in the external model's format.
186     * @return The output in the external model's format.
187     */
188    protected abstract V externalPrediction(U input);
189
190    /**
191     * Converts the output of the external model into a {@link Prediction}.
192     * @param output The output of the external model.
193     * @param numValidFeatures The number of valid features in the input.
194     * @param example The input example, used to construct the Prediction.
195     * @return A Tribuo Prediction.
196     */
197    protected abstract Prediction<T> convertOutput(V output, int numValidFeatures, Example<T> example);
198
199    /**
200     * Converts the output of the external model into a list of {@link Prediction}s.
201     * @param output The output of the external model.
202     * @param numValidFeatures An array with the number of valid features in each example.
203     * @param examples The input examples, used to construct the Predictions.
204     * @return A list of Tribuo Predictions.
205     */
206    protected abstract List<Prediction<T>> convertOutput(V output, int[] numValidFeatures, List<Example<T>> examples);
207
208    /**
209     * By default third party models don't return excuses.
210     * @param example The input example.
211     * @return Optional.empty.
212     */
213    @Override
214    public Optional<Excuse<T>> getExcuse(Example<T> example) {
215        return Optional.empty();
216    }
217
218    /**
219     * Gets the current testing batch size.
220     * @return The batch size.
221     */
222    public int getBatchSize() {
223        return batchSize;
224    }
225
226    /**
227     * Sets a new batch size.
228     * <p>
229     * Throws {@link IllegalArgumentException} if the batch size isn't positive.
230     * @param batchSize The batch size to use.
231     */
232    public void setBatchSize(int batchSize) {
233        if (batchSize > 0) {
234            this.batchSize = batchSize;
235        } else {
236            throw new IllegalArgumentException("Batch size must be positive, found " + batchSize);
237        }
238    }
239
240    /**
241     * Creates an immutable feature map from a set of feature names.
242     * <p>
243     * Each feature is unobserved.
244     * @param featureNames The names of the features to create.
245     * @return A feature map representing the feature names.
246     */
247    protected static ImmutableFeatureMap createFeatureMap(Set<String> featureNames) {
248        MutableFeatureMap featureMap = new MutableFeatureMap();
249
250        for (String name : featureNames) {
251            featureMap.put(new CategoricalInfo(name));
252        }
253
254        return new ImmutableFeatureMap(featureMap);
255    }
256
257    /**
258     * Creates an output info from a set of outputs.
259     * @param factory The output factory to use.
260     * @param outputs The outputs and ids to observe.
261     * @param <T> The type of the outputs.
262     * @return An immutable output info representing the outputs.
263     */
264    protected static <T extends Output<T>> ImmutableOutputInfo<T> createOutputInfo(OutputFactory<T> factory, Map<T,Integer> outputs) {
265        return factory.constructInfoForExternalModel(outputs);
266    }
267
268}