001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.interop.tensorflow;
018
019import com.oracle.labs.mlrg.olcut.util.Pair;
020import org.tribuo.Example;
021import org.tribuo.ImmutableFeatureMap;
022import org.tribuo.ImmutableOutputInfo;
023import org.tribuo.Model;
024import org.tribuo.Output;
025import org.tribuo.OutputFactory;
026import org.tribuo.Prediction;
027import org.tribuo.interop.ExternalDatasetProvenance;
028import org.tribuo.interop.ExternalModel;
029import org.tribuo.interop.ExternalTrainerProvenance;
030import org.tribuo.math.la.SparseVector;
031import org.tribuo.provenance.DatasetProvenance;
032import org.tribuo.provenance.ModelProvenance;
033import org.tensorflow.Graph;
034import org.tensorflow.Session;
035import org.tensorflow.Tensor;
036
037import java.io.Closeable;
038import java.io.IOException;
039import java.net.URL;
040import java.nio.file.Files;
041import java.nio.file.Path;
042import java.nio.file.Paths;
043import java.time.OffsetDateTime;
044import java.util.Collections;
045import java.util.List;
046import java.util.Map;
047
048/**
049 * A Tribuo wrapper around a Tensorflow frozen model.
050 * <p>
051 * The model's serialVersionUID is set to the major Tensorflow version number times 100.
052 * <p>
053 * N.B. Tensorflow support is experimental and may change without a major version bump.
054 */
055public final class TensorflowExternalModel<T extends Output<T>> extends ExternalModel<T, Tensor<?>, Tensor<?>> implements Closeable {
056    private static final long serialVersionUID = 100L;
057
058    private transient Graph model;
059
060    private transient Session session;
061
062    private final ExampleTransformer<T> featureTransformer;
063
064    private final OutputTransformer<T> outputTransformer;
065
066    private final String inputName;
067
068    private final String outputName;
069
070    private TensorflowExternalModel(String name, ModelProvenance provenance,
071                                 ImmutableFeatureMap featureIDMap, ImmutableOutputInfo<T> outputIDInfo,
072                                 Map<String, Integer> featureMapping,
073                                 Graph model, String inputName, String outputName,
074                                 ExampleTransformer<T> featureTransformer, OutputTransformer<T> outputTransformer) {
075        super(name, provenance, featureIDMap, outputIDInfo, outputTransformer.generatesProbabilities(), featureMapping);
076        this.model = model;
077        this.session = new Session(model);
078        this.inputName = inputName;
079        this.outputName = outputName;
080        this.featureTransformer = featureTransformer;
081        this.outputTransformer = outputTransformer;
082    }
083
084    private TensorflowExternalModel(String name, ModelProvenance provenance,
085                                    ImmutableFeatureMap featureIDMap, ImmutableOutputInfo<T> outputIDInfo,
086                                    int[] featureForwardMapping, int[] featureBackwardMapping,
087                                    Graph model, String inputName, String outputName,
088                                    ExampleTransformer<T> featureTransformer, OutputTransformer<T> outputTransformer) {
089        super(name,provenance,featureIDMap,outputIDInfo,featureForwardMapping,featureBackwardMapping,
090                outputTransformer.generatesProbabilities());
091        this.model = model;
092        this.session = new Session(model);
093        this.inputName = inputName;
094        this.outputName = outputName;
095        this.featureTransformer = featureTransformer;
096        this.outputTransformer = outputTransformer;
097    }
098
099    @Override
100    protected Tensor<?> convertFeatures(SparseVector input) {
101        return featureTransformer.transform(input);
102    }
103
104    @Override
105    protected Tensor<?> convertFeaturesList(List<SparseVector> input) {
106        return featureTransformer.transform(input);
107    }
108
109    /**
110     * Runs the session to make a prediction.
111     *
112     * Closes the input tensor after the prediction has been made.
113     * @param input The input in the external model's format.
114     * @return A tensor representing the output.
115     */
116    @Override
117    protected Tensor<?> externalPrediction(Tensor<?> input) {
118        Tensor<?> output = session.runner().feed(inputName,input).fetch(outputName).run().get(0);
119        input.close();
120        return output;
121    }
122
123    /**
124     * Converts a tensor into a prediction.
125     * Closes the output tensor after it's been converted.
126     * @param output The output of the external model.
127     * @param numValidFeatures The number of valid features in the input.
128     * @param example The input example, used to construct the Prediction.
129     * @return A {@link Prediction} representing this tensor output.
130     */
131    @Override
132    protected Prediction<T> convertOutput(Tensor<?> output, int numValidFeatures, Example<T> example) {
133        Prediction<T> pred = outputTransformer.transformToPrediction(output,outputIDInfo,numValidFeatures,example);
134        output.close();
135        return pred;
136    }
137
138    /**
139     * Converts a tensor into a prediction.
140     * Closes the output tensor after it's been converted.
141     * @param output The output of the external model.
142     * @param numValidFeatures An array with the number of valid features in each example.
143     * @param examples The input examples, used to construct the Predictions.
144     * @return A list of {@link Prediction} representing this tensor output.
145     */
146    @Override
147    protected List<Prediction<T>> convertOutput(Tensor<?> output, int[] numValidFeatures, List<Example<T>> examples) {
148        List<Prediction<T>> predictions = outputTransformer.transformToBatchPrediction(output,outputIDInfo,numValidFeatures,examples);
149        output.close();
150        return predictions;
151    }
152
153    @Override
154    public Map<String, List<Pair<String, Double>>> getTopFeatures(int n) {
155        return Collections.emptyMap();
156    }
157
158    @Override
159    protected Model<T> copy(String newName, ModelProvenance newProvenance) {
160        byte[] modelBytes = model.toGraphDef();
161        Graph newGraph = new Graph();
162        newGraph.importGraphDef(modelBytes);
163        return new TensorflowExternalModel<>(newName,newProvenance,featureIDMap,outputIDInfo,
164                featureForwardMapping,featureBackwardMapping,
165                newGraph,inputName,outputName,featureTransformer,outputTransformer);
166    }
167
168    @Override
169    public void close() {
170        if (session != null) {
171            session.close();
172        }
173        if (model != null) {
174            model.close();
175        }
176    }
177
178    /**
179     * Creates a TensorflowExternalModel by loading in a frozen graph.
180     * @param factory The output factory.
181     * @param featureMapping The feature mapping between Tribuo's names and the TF integer ids.
182     * @param outputMapping The output mapping between Tribuo's names and the TF integer ids.
183     * @param inputName The name of the input placeholder.
184     * @param outputName The name of the output tensor.
185     * @param featureTransformer The feature transformation function.
186     * @param outputTransformer The output transformation function.
187     * @param filename The filename to load the graph from.
188     * @param <T> The type of the output.
189     * @return The TF model wrapped in a Tribuo ExternalModel.
190     */
191    public static <T extends Output<T>> TensorflowExternalModel<T> createTensorflowModel(OutputFactory<T> factory,
192                                                                                      Map<String, Integer> featureMapping,
193                                                                                      Map<T,Integer> outputMapping,
194                                                                                      String inputName,
195                                                                                      String outputName,
196                                                                                      ExampleTransformer<T> featureTransformer,
197                                                                                      OutputTransformer<T> outputTransformer,
198                                                                                      String filename) {
199        try {
200            Path path = Paths.get(filename);
201            byte[] model = Files.readAllBytes(path);
202            Graph graph = new Graph();
203            graph.importGraphDef(model);
204            URL provenanceLocation = path.toUri().toURL();
205            ImmutableFeatureMap featureMap = ExternalModel.createFeatureMap(featureMapping.keySet());
206            ImmutableOutputInfo<T> outputInfo = ExternalModel.createOutputInfo(factory,outputMapping);
207            OffsetDateTime now = OffsetDateTime.now();
208            ExternalTrainerProvenance trainerProvenance = new ExternalTrainerProvenance(provenanceLocation);
209            DatasetProvenance datasetProvenance = new ExternalDatasetProvenance("unknown-external-data",factory,false,featureMapping.size(),outputMapping.size());
210            ModelProvenance provenance = new ModelProvenance(TensorflowExternalModel.class.getName(),now,datasetProvenance,trainerProvenance);
211            return new TensorflowExternalModel<>("external-model",provenance,featureMap,outputInfo,
212                    featureMapping,graph,inputName,outputName,featureTransformer,outputTransformer);
213        } catch (IOException e) {
214            throw new IllegalArgumentException("Unable to load model from path " + filename, e);
215        }
216    }
217
218    private void writeObject(java.io.ObjectOutputStream out) throws IOException {
219        out.defaultWriteObject();
220        byte[] modelBytes = model.toGraphDef();
221        out.writeObject(modelBytes);
222    }
223
224    private void readObject(java.io.ObjectInputStream in) throws IOException, ClassNotFoundException {
225        in.defaultReadObject();
226        byte[] modelBytes = (byte[]) in.readObject();
227        model = new Graph();
228        model.importGraphDef(modelBytes);
229        session = new Session(model);
230    }
231
232}