001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.interop.tensorflow.sequence; 018 019import com.oracle.labs.mlrg.olcut.util.Pair; 020import org.tribuo.ImmutableFeatureMap; 021import org.tribuo.ImmutableOutputInfo; 022import org.tribuo.Output; 023import org.tribuo.Prediction; 024import org.tribuo.interop.tensorflow.TensorflowUtil; 025import org.tribuo.provenance.ModelProvenance; 026import org.tribuo.sequence.SequenceExample; 027import org.tribuo.sequence.SequenceModel; 028import org.tensorflow.Graph; 029import org.tensorflow.Session; 030import org.tensorflow.Tensor; 031 032import java.io.Closeable; 033import java.io.IOException; 034import java.util.Collections; 035import java.util.List; 036import java.util.Map; 037 038/** 039 * A Tensorflow model which implements SequenceModel, suitable for use in sequential prediction tasks. 040 */ 041public class TensorflowSequenceModel<T extends Output<T>> extends SequenceModel<T> implements Closeable { 042 043 private static final long serialVersionUID = 1L; 044 045 private transient Graph modelGraph = null; 046 private transient Session session = null; 047 048 protected final SequenceExampleTransformer<T> exampleTransformer; 049 protected final SequenceOutputTransformer<T> outputTransformer; 050 051 protected final String initOp; 052 protected final String predictOp; 053 054 TensorflowSequenceModel(String name, 055 ModelProvenance description, 056 ImmutableFeatureMap featureIDMap, 057 ImmutableOutputInfo<T> outputIDMap, 058 byte[] graphDef, 059 SequenceExampleTransformer<T> exampleTransformer, 060 SequenceOutputTransformer<T> outputTransformer, 061 String initOp, 062 String predictOp, 063 Map<String, Object> tensorMap 064 ) { 065 super(name, description, featureIDMap, outputIDMap); 066 this.exampleTransformer = exampleTransformer; 067 this.outputTransformer = outputTransformer; 068 this.initOp = initOp; 069 this.predictOp = predictOp; 070 this.modelGraph = new Graph(); 071 this.modelGraph.importGraphDef(graphDef); 072 this.session = new Session(modelGraph); 073 074 // Initialises the parameters. 075 session.runner().addTarget(initOp).run(); 076 TensorflowUtil.deserialise(session, tensorMap); 077 } 078 079 @Override 080 public List<Prediction<T>> predict(SequenceExample<T> example) { 081 Map<String, Tensor<?>> feed = exampleTransformer.encode(example, featureIDMap); 082 Session.Runner runner = session.runner(); 083 for (Map.Entry<String, Tensor<?>> item : feed.entrySet()) { 084 runner.feed(item.getKey(), item.getValue()); 085 } 086 Tensor<?> outputTensor = runner 087 .fetch(predictOp) 088 .run() 089 .get(0); 090 List<Prediction<T>> prediction = outputTransformer.decode(outputTensor, example, outputIDMap); 091 // 092 // Close all the open tensors 093 outputTensor.close(); 094 for (Tensor<?> tensor : feed.values()) { 095 tensor.close(); 096 } 097 return prediction; 098 } 099 100 /** 101 * Returns an empty map, as the top features are not well defined for most Tensorflow models. 102 */ 103 @Override 104 public Map<String, List<Pair<String, Double>>> getTopFeatures(int i) { 105 return Collections.emptyMap(); 106 } 107 108 /** 109 * Close the session and graph if they exist. 110 */ 111 @Override 112 public void close() { 113 if (session != null) { 114 session.close(); 115 } 116 if (modelGraph != null) { 117 modelGraph.close(); 118 } 119 } 120 121 private void writeObject(java.io.ObjectOutputStream out) throws IOException { 122 out.defaultWriteObject(); 123 byte[] modelBytes = modelGraph.toGraphDef(); 124 out.writeObject(modelBytes); 125 Map<String,Object> tensorMap = TensorflowUtil.serialise(modelGraph, session); 126 out.writeObject(tensorMap); 127 } 128 129 @SuppressWarnings("unchecked") //deserialising a typed map 130 private void readObject(java.io.ObjectInputStream in) throws IOException, ClassNotFoundException { 131 in.defaultReadObject(); 132 byte[] modelBytes = (byte[]) in.readObject(); 133 Map<String,Object> tensorMap = (Map<String,Object>) in.readObject(); 134 modelGraph = new Graph(); 135 modelGraph.importGraphDef(modelBytes); 136 session = new Session(modelGraph); 137 // Initialises the parameters. 138 session.runner().addTarget(initOp).run(); 139 TensorflowUtil.deserialise(session,tensorMap); 140 } 141}