001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.interop.tensorflow.sequence;
018
019import com.oracle.labs.mlrg.olcut.util.Pair;
020import org.tribuo.ImmutableFeatureMap;
021import org.tribuo.ImmutableOutputInfo;
022import org.tribuo.Output;
023import org.tribuo.Prediction;
024import org.tribuo.interop.tensorflow.TensorflowUtil;
025import org.tribuo.provenance.ModelProvenance;
026import org.tribuo.sequence.SequenceExample;
027import org.tribuo.sequence.SequenceModel;
028import org.tensorflow.Graph;
029import org.tensorflow.Session;
030import org.tensorflow.Tensor;
031
032import java.io.Closeable;
033import java.io.IOException;
034import java.util.Collections;
035import java.util.List;
036import java.util.Map;
037
038/**
039 * A Tensorflow model which implements SequenceModel, suitable for use in sequential prediction tasks.
040 */
041public class TensorflowSequenceModel<T extends Output<T>> extends SequenceModel<T> implements Closeable {
042
043    private static final long serialVersionUID = 1L;
044
045    private transient Graph modelGraph = null;
046    private transient Session session = null;
047
048    protected final SequenceExampleTransformer<T> exampleTransformer;
049    protected final SequenceOutputTransformer<T> outputTransformer;
050
051    protected final String initOp;
052    protected final String predictOp;
053
054    TensorflowSequenceModel(String name,
055                                   ModelProvenance description,
056                                   ImmutableFeatureMap featureIDMap,
057                                   ImmutableOutputInfo<T> outputIDMap,
058                                   byte[] graphDef,
059                                   SequenceExampleTransformer<T> exampleTransformer,
060                                   SequenceOutputTransformer<T> outputTransformer,
061                                   String initOp,
062                                   String predictOp,
063                                   Map<String, Object> tensorMap
064    ) {
065        super(name, description, featureIDMap, outputIDMap);
066        this.exampleTransformer = exampleTransformer;
067        this.outputTransformer = outputTransformer;
068        this.initOp = initOp;
069        this.predictOp = predictOp;
070        this.modelGraph = new Graph();
071        this.modelGraph.importGraphDef(graphDef);
072        this.session = new Session(modelGraph);
073
074        // Initialises the parameters.
075        session.runner().addTarget(initOp).run();
076        TensorflowUtil.deserialise(session, tensorMap);
077    }
078
079    @Override
080    public List<Prediction<T>> predict(SequenceExample<T> example) {
081        Map<String, Tensor<?>> feed = exampleTransformer.encode(example, featureIDMap);
082        Session.Runner runner = session.runner();
083        for (Map.Entry<String, Tensor<?>> item : feed.entrySet()) {
084            runner.feed(item.getKey(), item.getValue());
085        }
086        Tensor<?> outputTensor = runner
087                .fetch(predictOp)
088                .run()
089                .get(0);
090        List<Prediction<T>> prediction = outputTransformer.decode(outputTensor, example, outputIDMap);
091        //
092        // Close all the open tensors
093        outputTensor.close();
094        for (Tensor<?> tensor : feed.values()) {
095            tensor.close();
096        }
097        return prediction;
098    }
099
100    /**
101     * Returns an empty map, as the top features are not well defined for most Tensorflow models.
102     */
103    @Override
104    public Map<String, List<Pair<String, Double>>> getTopFeatures(int i) {
105        return Collections.emptyMap();
106    }
107
108    /**
109     * Close the session and graph if they exist.
110     */
111    @Override
112    public void close() {
113        if (session != null) {
114            session.close();
115        }
116        if (modelGraph != null) {
117            modelGraph.close();
118        }
119    }
120
121    private void writeObject(java.io.ObjectOutputStream out) throws IOException {
122        out.defaultWriteObject();
123        byte[] modelBytes = modelGraph.toGraphDef();
124        out.writeObject(modelBytes);
125        Map<String,Object> tensorMap = TensorflowUtil.serialise(modelGraph, session);
126        out.writeObject(tensorMap);
127    }
128
129    @SuppressWarnings("unchecked") //deserialising a typed map
130    private void readObject(java.io.ObjectInputStream in) throws IOException, ClassNotFoundException {
131        in.defaultReadObject();
132        byte[] modelBytes = (byte[]) in.readObject();
133        Map<String,Object> tensorMap = (Map<String,Object>) in.readObject();
134        modelGraph = new Graph();
135        modelGraph.importGraphDef(modelBytes);
136        session = new Session(modelGraph);
137        // Initialises the parameters.
138        session.runner().addTarget(initOp).run();
139        TensorflowUtil.deserialise(session,tensorMap);
140    }
141}