001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.interop.tensorflow; 018 019import com.oracle.labs.mlrg.olcut.util.Pair; 020import org.tribuo.Example; 021import org.tribuo.ImmutableFeatureMap; 022import org.tribuo.ImmutableOutputInfo; 023import org.tribuo.Model; 024import org.tribuo.Output; 025import org.tribuo.OutputFactory; 026import org.tribuo.Prediction; 027import org.tribuo.interop.ExternalDatasetProvenance; 028import org.tribuo.interop.ExternalModel; 029import org.tribuo.interop.ExternalTrainerProvenance; 030import org.tribuo.math.la.SparseVector; 031import org.tribuo.provenance.DatasetProvenance; 032import org.tribuo.provenance.ModelProvenance; 033import org.tensorflow.Graph; 034import org.tensorflow.Session; 035import org.tensorflow.Tensor; 036 037import java.io.Closeable; 038import java.io.IOException; 039import java.net.URL; 040import java.nio.file.Files; 041import java.nio.file.Path; 042import java.nio.file.Paths; 043import java.time.OffsetDateTime; 044import java.util.Collections; 045import java.util.List; 046import java.util.Map; 047 048/** 049 * A Tribuo wrapper around a Tensorflow frozen model. 050 * <p> 051 * The model's serialVersionUID is set to the major Tensorflow version number times 100. 052 * <p> 053 * N.B. Tensorflow support is experimental and may change without a major version bump. 054 */ 055public final class TensorflowExternalModel<T extends Output<T>> extends ExternalModel<T, Tensor<?>, Tensor<?>> implements Closeable { 056 private static final long serialVersionUID = 100L; 057 058 private transient Graph model; 059 060 private transient Session session; 061 062 private final ExampleTransformer<T> featureTransformer; 063 064 private final OutputTransformer<T> outputTransformer; 065 066 private final String inputName; 067 068 private final String outputName; 069 070 private TensorflowExternalModel(String name, ModelProvenance provenance, 071 ImmutableFeatureMap featureIDMap, ImmutableOutputInfo<T> outputIDInfo, 072 Map<String, Integer> featureMapping, 073 Graph model, String inputName, String outputName, 074 ExampleTransformer<T> featureTransformer, OutputTransformer<T> outputTransformer) { 075 super(name, provenance, featureIDMap, outputIDInfo, outputTransformer.generatesProbabilities(), featureMapping); 076 this.model = model; 077 this.session = new Session(model); 078 this.inputName = inputName; 079 this.outputName = outputName; 080 this.featureTransformer = featureTransformer; 081 this.outputTransformer = outputTransformer; 082 } 083 084 private TensorflowExternalModel(String name, ModelProvenance provenance, 085 ImmutableFeatureMap featureIDMap, ImmutableOutputInfo<T> outputIDInfo, 086 int[] featureForwardMapping, int[] featureBackwardMapping, 087 Graph model, String inputName, String outputName, 088 ExampleTransformer<T> featureTransformer, OutputTransformer<T> outputTransformer) { 089 super(name,provenance,featureIDMap,outputIDInfo,featureForwardMapping,featureBackwardMapping, 090 outputTransformer.generatesProbabilities()); 091 this.model = model; 092 this.session = new Session(model); 093 this.inputName = inputName; 094 this.outputName = outputName; 095 this.featureTransformer = featureTransformer; 096 this.outputTransformer = outputTransformer; 097 } 098 099 @Override 100 protected Tensor<?> convertFeatures(SparseVector input) { 101 return featureTransformer.transform(input); 102 } 103 104 @Override 105 protected Tensor<?> convertFeaturesList(List<SparseVector> input) { 106 return featureTransformer.transform(input); 107 } 108 109 /** 110 * Runs the session to make a prediction. 111 * 112 * Closes the input tensor after the prediction has been made. 113 * @param input The input in the external model's format. 114 * @return A tensor representing the output. 115 */ 116 @Override 117 protected Tensor<?> externalPrediction(Tensor<?> input) { 118 Tensor<?> output = session.runner().feed(inputName,input).fetch(outputName).run().get(0); 119 input.close(); 120 return output; 121 } 122 123 /** 124 * Converts a tensor into a prediction. 125 * Closes the output tensor after it's been converted. 126 * @param output The output of the external model. 127 * @param numValidFeatures The number of valid features in the input. 128 * @param example The input example, used to construct the Prediction. 129 * @return A {@link Prediction} representing this tensor output. 130 */ 131 @Override 132 protected Prediction<T> convertOutput(Tensor<?> output, int numValidFeatures, Example<T> example) { 133 Prediction<T> pred = outputTransformer.transformToPrediction(output,outputIDInfo,numValidFeatures,example); 134 output.close(); 135 return pred; 136 } 137 138 /** 139 * Converts a tensor into a prediction. 140 * Closes the output tensor after it's been converted. 141 * @param output The output of the external model. 142 * @param numValidFeatures An array with the number of valid features in each example. 143 * @param examples The input examples, used to construct the Predictions. 144 * @return A list of {@link Prediction} representing this tensor output. 145 */ 146 @Override 147 protected List<Prediction<T>> convertOutput(Tensor<?> output, int[] numValidFeatures, List<Example<T>> examples) { 148 List<Prediction<T>> predictions = outputTransformer.transformToBatchPrediction(output,outputIDInfo,numValidFeatures,examples); 149 output.close(); 150 return predictions; 151 } 152 153 @Override 154 public Map<String, List<Pair<String, Double>>> getTopFeatures(int n) { 155 return Collections.emptyMap(); 156 } 157 158 @Override 159 protected Model<T> copy(String newName, ModelProvenance newProvenance) { 160 byte[] modelBytes = model.toGraphDef(); 161 Graph newGraph = new Graph(); 162 newGraph.importGraphDef(modelBytes); 163 return new TensorflowExternalModel<>(newName,newProvenance,featureIDMap,outputIDInfo, 164 featureForwardMapping,featureBackwardMapping, 165 newGraph,inputName,outputName,featureTransformer,outputTransformer); 166 } 167 168 @Override 169 public void close() { 170 if (session != null) { 171 session.close(); 172 } 173 if (model != null) { 174 model.close(); 175 } 176 } 177 178 /** 179 * Creates a TensorflowExternalModel by loading in a frozen graph. 180 * @param factory The output factory. 181 * @param featureMapping The feature mapping between Tribuo's names and the TF integer ids. 182 * @param outputMapping The output mapping between Tribuo's names and the TF integer ids. 183 * @param inputName The name of the input placeholder. 184 * @param outputName The name of the output tensor. 185 * @param featureTransformer The feature transformation function. 186 * @param outputTransformer The output transformation function. 187 * @param filename The filename to load the graph from. 188 * @param <T> The type of the output. 189 * @return The TF model wrapped in a Tribuo ExternalModel. 190 */ 191 public static <T extends Output<T>> TensorflowExternalModel<T> createTensorflowModel(OutputFactory<T> factory, 192 Map<String, Integer> featureMapping, 193 Map<T,Integer> outputMapping, 194 String inputName, 195 String outputName, 196 ExampleTransformer<T> featureTransformer, 197 OutputTransformer<T> outputTransformer, 198 String filename) { 199 try { 200 Path path = Paths.get(filename); 201 byte[] model = Files.readAllBytes(path); 202 Graph graph = new Graph(); 203 graph.importGraphDef(model); 204 URL provenanceLocation = path.toUri().toURL(); 205 ImmutableFeatureMap featureMap = ExternalModel.createFeatureMap(featureMapping.keySet()); 206 ImmutableOutputInfo<T> outputInfo = ExternalModel.createOutputInfo(factory,outputMapping); 207 OffsetDateTime now = OffsetDateTime.now(); 208 ExternalTrainerProvenance trainerProvenance = new ExternalTrainerProvenance(provenanceLocation); 209 DatasetProvenance datasetProvenance = new ExternalDatasetProvenance("unknown-external-data",factory,false,featureMapping.size(),outputMapping.size()); 210 ModelProvenance provenance = new ModelProvenance(TensorflowExternalModel.class.getName(),now,datasetProvenance,trainerProvenance); 211 return new TensorflowExternalModel<>("external-model",provenance,featureMap,outputInfo, 212 featureMapping,graph,inputName,outputName,featureTransformer,outputTransformer); 213 } catch (IOException e) { 214 throw new IllegalArgumentException("Unable to load model from path " + filename, e); 215 } 216 } 217 218 private void writeObject(java.io.ObjectOutputStream out) throws IOException { 219 out.defaultWriteObject(); 220 byte[] modelBytes = model.toGraphDef(); 221 out.writeObject(modelBytes); 222 } 223 224 private void readObject(java.io.ObjectInputStream in) throws IOException, ClassNotFoundException { 225 in.defaultReadObject(); 226 byte[] modelBytes = (byte[]) in.readObject(); 227 model = new Graph(); 228 model.importGraphDef(modelBytes); 229 session = new Session(model); 230 } 231 232}