001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.interop.tensorflow;
018
019import com.oracle.labs.mlrg.olcut.util.Pair;
020import org.tribuo.Example;
021import org.tribuo.Excuse;
022import org.tribuo.ImmutableFeatureMap;
023import org.tribuo.ImmutableOutputInfo;
024import org.tribuo.Model;
025import org.tribuo.Output;
026import org.tribuo.Prediction;
027import org.tribuo.math.la.SparseVector;
028import org.tribuo.provenance.ModelProvenance;
029import org.tensorflow.Graph;
030import org.tensorflow.Session;
031import org.tensorflow.Tensor;
032import org.tensorflow.Tensors;
033
034import java.io.Closeable;
035import java.io.IOException;
036import java.nio.file.Paths;
037import java.util.Collections;
038import java.util.List;
039import java.util.Map;
040import java.util.Optional;
041
042/**
043 * TensorFlow support is experimental, and may change without a major version bump.
044 * <p>
045 * This model encapsulates a simple model with a single input tensor (labelled {@link TensorflowModel#INPUT_NAME}),
046 * and produces a single output tensor (labelled {@link TensorflowModel#OUTPUT_NAME}).
047 * <p>
048 * It accepts an {@link ExampleTransformer} that converts an example's features into a {@link Tensor}, and an
049 * {@link OutputTransformer} that converts a {@link Tensor} into a {@link Prediction}.
050 * <p>
051 * The model's serialVersionUID is set to the major Tensorflow version number times 100.
052 * <p>
053 * N.B. Tensorflow support is experimental and may change without a major version bump.
054 */
055public class TensorflowCheckpointModel<T extends Output<T>> extends Model<T> implements Closeable {
056
057    private static final long serialVersionUID = 100L;
058
059    private transient Graph modelGraph = null;
060
061    private transient Session session = null;
062
063    private final String checkpointDirectory;
064
065    private final ExampleTransformer<T> exampleTransformer;
066
067    private final OutputTransformer<T> outputTransformer;
068
069    TensorflowCheckpointModel(String name, ModelProvenance description, ImmutableFeatureMap featureIDMap, ImmutableOutputInfo<T> outputIDMap, byte[] graphDef, String checkpointDirectory, ExampleTransformer<T> exampleTransformer, OutputTransformer<T> outputTransformer) {
070        super(name, description, featureIDMap, outputIDMap, outputTransformer.generatesProbabilities());
071        this.exampleTransformer = exampleTransformer;
072        this.outputTransformer = outputTransformer;
073        this.checkpointDirectory = checkpointDirectory;
074        this.modelGraph = new Graph();
075        this.modelGraph.importGraphDef(graphDef);
076        this.session = new Session(modelGraph);
077
078        try (Tensor<String> checkpointPrefix = Tensors.create(Paths.get(checkpointDirectory+"/"+TensorflowCheckpointTrainer.MODEL_FILENAME).toString())) {
079            // Initialises the parameters.
080            session.runner().feed("save/Const", checkpointPrefix).addTarget("save/restore_all").run();
081        }
082    }
083
084    @Override
085    public Prediction<T> predict(Example<T> example) {
086        // This adds overhead and triggers lookups for each feature, but is necessary to correctly calculate
087        // the number of features used in this example.
088        SparseVector vec = SparseVector.createSparseVector(example,featureIDMap,false);
089        try (Tensor<?> transformedInput = exampleTransformer.transform(example,featureIDMap);
090             Tensor<?> isTraining = Tensor.create(false);
091             Tensor<?> outputTensor = session.runner()
092                     .feed(TensorflowModel.INPUT_NAME,transformedInput)
093                     .feed(TensorflowTrainer.IS_TRAINING,isTraining)
094                     .fetch(TensorflowModel.OUTPUT_NAME).run().get(0)) {
095            // Transform the returned tensor into a Prediction.
096            return outputTransformer.transformToPrediction(outputTensor,outputIDInfo,vec.numActiveElements(),example);
097        }
098    }
099
100    /**
101     * Deep learning models don't do feature rankings. Use an Explainer.
102     * <p>
103     * This method always returns the empty map.
104     * @param n the number of features to return.
105     * @return The empty map.
106     */
107    @Override
108    public Map<String, List<Pair<String, Double>>> getTopFeatures(int n) {
109        return Collections.emptyMap();
110    }
111
112    /**
113     * Deep learning models don't do excuses. Use an Explainer.
114     * <p>
115     * This method always returns {@link Optional#empty}.
116     * @param example The input example.
117     * @return {@link Optional#empty}.
118     */
119    @Override
120    public Optional<Excuse<T>> getExcuse(Example<T> example) {
121        return Optional.empty();
122    }
123
124    @Override
125    protected TensorflowCheckpointModel<T> copy(String newName, ModelProvenance newProvenance) {
126        return new TensorflowCheckpointModel<>(newName,newProvenance,featureIDMap,outputIDInfo,modelGraph.toGraphDef(),checkpointDirectory,exampleTransformer,outputTransformer);
127    }
128
129    @Override
130    public void close() {
131        if (session != null) {
132            session.close();
133        }
134        if (modelGraph != null) {
135            modelGraph.close();
136        }
137    }
138
139    private void writeObject(java.io.ObjectOutputStream out) throws IOException {
140        out.defaultWriteObject();
141        byte[] modelBytes = modelGraph.toGraphDef();
142        out.writeObject(modelBytes);
143    }
144
145    private void readObject(java.io.ObjectInputStream in) throws IOException, ClassNotFoundException {
146        in.defaultReadObject();
147        byte[] modelBytes = (byte[]) in.readObject();
148        this.modelGraph = new Graph();
149        this.modelGraph.importGraphDef(modelBytes);
150        this.session = new Session(modelGraph);
151
152        try (Tensor<String> checkpointPrefix = Tensors.create(Paths.get(checkpointDirectory+"/"+TensorflowCheckpointTrainer.MODEL_FILENAME).toString())) {
153            // Initialises the parameters.
154            session.runner().feed("save/Const", checkpointPrefix).addTarget("save/restore_all").run();
155        }
156    }
157}