001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.interop.tensorflow;
018
019import com.oracle.labs.mlrg.olcut.util.Pair;
020import org.tribuo.Example;
021import org.tribuo.Excuse;
022import org.tribuo.ImmutableFeatureMap;
023import org.tribuo.ImmutableOutputInfo;
024import org.tribuo.Model;
025import org.tribuo.Output;
026import org.tribuo.Prediction;
027import org.tribuo.math.la.SparseVector;
028import org.tribuo.provenance.ModelProvenance;
029import org.tensorflow.Graph;
030import org.tensorflow.Session;
031import org.tensorflow.Tensor;
032
033import java.io.Closeable;
034import java.io.IOException;
035import java.util.ArrayList;
036import java.util.Collections;
037import java.util.List;
038import java.util.Map;
039import java.util.Optional;
040import java.util.logging.Logger;
041
042/**
043 * This model encapsulates a simple model with a single input tensor (labelled {@link TensorflowModel#INPUT_NAME}),
044 * and produces a single output tensor (labelled {@link TensorflowModel#OUTPUT_NAME}).
045 * <p>
046 * It accepts an {@link ExampleTransformer} that converts an example's features into a {@link Tensor}, and an
047 * {@link OutputTransformer} that converts a {@link Tensor} into a {@link Prediction}.
048 * <p>
049 * The model's serialVersionUID is set to the major Tensorflow version number times 100.
050 * <p>
051 * N.B. Tensorflow support is experimental and may change without a major version bump.
052 */
053public class TensorflowModel<T extends Output<T>> extends Model<T> implements Closeable {
054
055    private static final Logger logger = Logger.getLogger(TensorflowModel.class.getName());
056
057    private static final long serialVersionUID = 100L;
058
059    public static final String INPUT_NAME = "input";
060    public static final String OUTPUT_NAME = "output";
061
062    private transient Graph modelGraph = null;
063
064    private transient Session session = null;
065
066    private int batchSize;
067
068    private final ExampleTransformer<T> exampleTransformer;
069
070    private final OutputTransformer<T> outputTransformer;
071
072    TensorflowModel(String name, ModelProvenance description, ImmutableFeatureMap featureIDMap, ImmutableOutputInfo<T> outputIDMap, byte[] trainedGraphDef, Map<String, Object> tensorMap, int batchSize, ExampleTransformer<T> exampleTransformer, OutputTransformer<T> outputTransformer) {
073        super(name, description, featureIDMap, outputIDMap, outputTransformer.generatesProbabilities());
074        this.exampleTransformer = exampleTransformer;
075        this.outputTransformer = outputTransformer;
076        this.modelGraph = new Graph();
077        this.modelGraph.importGraphDef(trainedGraphDef);
078        this.session = new Session(modelGraph);
079        this.batchSize = batchSize;
080        // Initialises the parameters.
081        session.runner().addTarget(TensorflowTrainer.INIT).run();
082        TensorflowUtil.deserialise(session,tensorMap);
083    }
084
085    @Override
086    public Prediction<T> predict(Example<T> example) {
087        // This adds overhead and triggers lookups for each feature, but is necessary to correctly calculate
088        // the number of features used in this example.
089        SparseVector vec = SparseVector.createSparseVector(example,featureIDMap,false);
090        try (Tensor<?> transformedInput = exampleTransformer.transform(vec);
091             Tensor<?> isTraining = Tensor.create(false);
092             Tensor<?> outputTensor = session.runner()
093                     .feed(INPUT_NAME,transformedInput)
094                     .feed(TensorflowTrainer.IS_TRAINING,isTraining)
095                     .fetch(OUTPUT_NAME).run().get(0)) {
096            // Transform the returned tensor into a Prediction.
097            return outputTransformer.transformToPrediction(outputTensor,outputIDInfo,vec.numActiveElements(),example);
098        }
099    }
100
101    @Override
102    protected List<Prediction<T>> innerPredict(Iterable<Example<T>> examples) {
103        List<Prediction<T>> predictions = new ArrayList<>();
104        List<Example<T>> batchExamples = new ArrayList<>();
105        for (Example<T> example : examples) {
106            batchExamples.add(example);
107            if (batchExamples.size() == batchSize) {
108                predictions.addAll(predictBatch(batchExamples));
109                // clear the batch
110                batchExamples.clear();
111            }
112        }
113
114        if (!batchExamples.isEmpty()) {
115            // send the partial batch
116            predictions.addAll(predictBatch(batchExamples));
117        }
118        return predictions;
119    }
120
121    private List<Prediction<T>> predictBatch(List<Example<T>> batchExamples) {
122        // Convert the batch
123        List<SparseVector> vectors = new ArrayList<>(batchExamples.size());
124        int[] numActiveElements = new int[batchExamples.size()];
125        for (int i = 0; i < batchExamples.size(); i++) {
126            SparseVector vec = SparseVector.createSparseVector(batchExamples.get(i),featureIDMap,false);
127            numActiveElements[i] = vec.numActiveElements();
128            vectors.add(vec);
129        }
130
131        // Send a batch to Tensorflow
132        try (Tensor<?> transformedInput = exampleTransformer.transform(vectors);
133             Tensor<?> isTraining = Tensor.create(false);
134             Tensor<?> outputTensor = session.runner()
135                     .feed(INPUT_NAME,transformedInput)
136                     .feed(TensorflowTrainer.IS_TRAINING,isTraining)
137                     .fetch(OUTPUT_NAME).run().get(0)) {
138            // Transform the returned tensor into a list of Predictions.
139            return outputTransformer.transformToBatchPrediction(outputTensor,outputIDInfo,numActiveElements,batchExamples);
140        }
141    }
142
143    /**
144     * Gets the current testing batch size.
145     * @return The batch size.
146     */
147    public int getBatchSize() {
148        return batchSize;
149    }
150
151    /**
152     * Sets a new batch size.
153     *
154     * Throws {@link IllegalArgumentException} if the batch size isn't positive.
155     * @param batchSize The batch size to use.
156     */
157    public void setBatchSize(int batchSize) {
158        if (batchSize > 0) {
159            this.batchSize = batchSize;
160        } else {
161            throw new IllegalArgumentException("Batch size must be positive, found " + batchSize);
162        }
163    }
164
165    /**
166     * Deep learning models don't do feature rankings. Use an Explainer.
167     * <p>
168     * This method always returns the empty map.
169     * @param n the number of features to return.
170     * @return The empty map.
171     */
172    @Override
173    public Map<String, List<Pair<String, Double>>> getTopFeatures(int n) {
174        return Collections.emptyMap();
175    }
176
177    /**
178     * Deep learning models don't do excuses. Use an Explainer.
179     * <p>
180     * This method always returns {@link Optional#empty}.
181     * @param example The input example.
182     * @return {@link Optional#empty}.
183     */
184    @Override
185    public Optional<Excuse<T>> getExcuse(Example<T> example) {
186        return Optional.empty();
187    }
188
189    @Override
190    protected TensorflowModel<T> copy(String newName, ModelProvenance newProvenance) {
191        return new TensorflowModel<>(newName,newProvenance,featureIDMap,outputIDInfo,modelGraph.toGraphDef(),TensorflowUtil.serialise(modelGraph,session),batchSize,exampleTransformer,outputTransformer);
192    }
193
194    @Override
195    public void close() {
196        if (session != null) {
197            session.close();
198        }
199        if (modelGraph != null) {
200            modelGraph.close();
201        }
202    }
203
204    private void writeObject(java.io.ObjectOutputStream out) throws IOException {
205        out.defaultWriteObject();
206        byte[] modelBytes = modelGraph.toGraphDef();
207        out.writeObject(modelBytes);
208        Map<String,Object> tensorMap = TensorflowUtil.serialise(modelGraph, session);
209        out.writeObject(tensorMap);
210    }
211
212    @SuppressWarnings("unchecked") //deserialising a typed map
213    private void readObject(java.io.ObjectInputStream in) throws IOException, ClassNotFoundException {
214        in.defaultReadObject();
215        byte[] modelBytes = (byte[]) in.readObject();
216        Map<String,Object> tensorMap = (Map<String,Object>) in.readObject();
217        modelGraph = new Graph();
218        modelGraph.importGraphDef(modelBytes);
219        session = new Session(modelGraph);
220        // Initialises the parameters.
221        session.runner().addTarget(TensorflowTrainer.INIT).run();
222        TensorflowUtil.deserialise(session,tensorMap);
223    }
224}