001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.interop.tensorflow; 018 019import com.oracle.labs.mlrg.olcut.util.Pair; 020import org.tribuo.Example; 021import org.tribuo.Excuse; 022import org.tribuo.ImmutableFeatureMap; 023import org.tribuo.ImmutableOutputInfo; 024import org.tribuo.Model; 025import org.tribuo.Output; 026import org.tribuo.Prediction; 027import org.tribuo.math.la.SparseVector; 028import org.tribuo.provenance.ModelProvenance; 029import org.tensorflow.Graph; 030import org.tensorflow.Session; 031import org.tensorflow.Tensor; 032 033import java.io.Closeable; 034import java.io.IOException; 035import java.util.ArrayList; 036import java.util.Collections; 037import java.util.List; 038import java.util.Map; 039import java.util.Optional; 040import java.util.logging.Logger; 041 042/** 043 * This model encapsulates a simple model with a single input tensor (labelled {@link TensorflowModel#INPUT_NAME}), 044 * and produces a single output tensor (labelled {@link TensorflowModel#OUTPUT_NAME}). 045 * <p> 046 * It accepts an {@link ExampleTransformer} that converts an example's features into a {@link Tensor}, and an 047 * {@link OutputTransformer} that converts a {@link Tensor} into a {@link Prediction}. 048 * <p> 049 * The model's serialVersionUID is set to the major Tensorflow version number times 100. 050 * <p> 051 * N.B. Tensorflow support is experimental and may change without a major version bump. 052 */ 053public class TensorflowModel<T extends Output<T>> extends Model<T> implements Closeable { 054 055 private static final Logger logger = Logger.getLogger(TensorflowModel.class.getName()); 056 057 private static final long serialVersionUID = 100L; 058 059 public static final String INPUT_NAME = "input"; 060 public static final String OUTPUT_NAME = "output"; 061 062 private transient Graph modelGraph = null; 063 064 private transient Session session = null; 065 066 private int batchSize; 067 068 private final ExampleTransformer<T> exampleTransformer; 069 070 private final OutputTransformer<T> outputTransformer; 071 072 TensorflowModel(String name, ModelProvenance description, ImmutableFeatureMap featureIDMap, ImmutableOutputInfo<T> outputIDMap, byte[] trainedGraphDef, Map<String, Object> tensorMap, int batchSize, ExampleTransformer<T> exampleTransformer, OutputTransformer<T> outputTransformer) { 073 super(name, description, featureIDMap, outputIDMap, outputTransformer.generatesProbabilities()); 074 this.exampleTransformer = exampleTransformer; 075 this.outputTransformer = outputTransformer; 076 this.modelGraph = new Graph(); 077 this.modelGraph.importGraphDef(trainedGraphDef); 078 this.session = new Session(modelGraph); 079 this.batchSize = batchSize; 080 // Initialises the parameters. 081 session.runner().addTarget(TensorflowTrainer.INIT).run(); 082 TensorflowUtil.deserialise(session,tensorMap); 083 } 084 085 @Override 086 public Prediction<T> predict(Example<T> example) { 087 // This adds overhead and triggers lookups for each feature, but is necessary to correctly calculate 088 // the number of features used in this example. 089 SparseVector vec = SparseVector.createSparseVector(example,featureIDMap,false); 090 try (Tensor<?> transformedInput = exampleTransformer.transform(vec); 091 Tensor<?> isTraining = Tensor.create(false); 092 Tensor<?> outputTensor = session.runner() 093 .feed(INPUT_NAME,transformedInput) 094 .feed(TensorflowTrainer.IS_TRAINING,isTraining) 095 .fetch(OUTPUT_NAME).run().get(0)) { 096 // Transform the returned tensor into a Prediction. 097 return outputTransformer.transformToPrediction(outputTensor,outputIDInfo,vec.numActiveElements(),example); 098 } 099 } 100 101 @Override 102 protected List<Prediction<T>> innerPredict(Iterable<Example<T>> examples) { 103 List<Prediction<T>> predictions = new ArrayList<>(); 104 List<Example<T>> batchExamples = new ArrayList<>(); 105 for (Example<T> example : examples) { 106 batchExamples.add(example); 107 if (batchExamples.size() == batchSize) { 108 predictions.addAll(predictBatch(batchExamples)); 109 // clear the batch 110 batchExamples.clear(); 111 } 112 } 113 114 if (!batchExamples.isEmpty()) { 115 // send the partial batch 116 predictions.addAll(predictBatch(batchExamples)); 117 } 118 return predictions; 119 } 120 121 private List<Prediction<T>> predictBatch(List<Example<T>> batchExamples) { 122 // Convert the batch 123 List<SparseVector> vectors = new ArrayList<>(batchExamples.size()); 124 int[] numActiveElements = new int[batchExamples.size()]; 125 for (int i = 0; i < batchExamples.size(); i++) { 126 SparseVector vec = SparseVector.createSparseVector(batchExamples.get(i),featureIDMap,false); 127 numActiveElements[i] = vec.numActiveElements(); 128 vectors.add(vec); 129 } 130 131 // Send a batch to Tensorflow 132 try (Tensor<?> transformedInput = exampleTransformer.transform(vectors); 133 Tensor<?> isTraining = Tensor.create(false); 134 Tensor<?> outputTensor = session.runner() 135 .feed(INPUT_NAME,transformedInput) 136 .feed(TensorflowTrainer.IS_TRAINING,isTraining) 137 .fetch(OUTPUT_NAME).run().get(0)) { 138 // Transform the returned tensor into a list of Predictions. 139 return outputTransformer.transformToBatchPrediction(outputTensor,outputIDInfo,numActiveElements,batchExamples); 140 } 141 } 142 143 /** 144 * Gets the current testing batch size. 145 * @return The batch size. 146 */ 147 public int getBatchSize() { 148 return batchSize; 149 } 150 151 /** 152 * Sets a new batch size. 153 * 154 * Throws {@link IllegalArgumentException} if the batch size isn't positive. 155 * @param batchSize The batch size to use. 156 */ 157 public void setBatchSize(int batchSize) { 158 if (batchSize > 0) { 159 this.batchSize = batchSize; 160 } else { 161 throw new IllegalArgumentException("Batch size must be positive, found " + batchSize); 162 } 163 } 164 165 /** 166 * Deep learning models don't do feature rankings. Use an Explainer. 167 * <p> 168 * This method always returns the empty map. 169 * @param n the number of features to return. 170 * @return The empty map. 171 */ 172 @Override 173 public Map<String, List<Pair<String, Double>>> getTopFeatures(int n) { 174 return Collections.emptyMap(); 175 } 176 177 /** 178 * Deep learning models don't do excuses. Use an Explainer. 179 * <p> 180 * This method always returns {@link Optional#empty}. 181 * @param example The input example. 182 * @return {@link Optional#empty}. 183 */ 184 @Override 185 public Optional<Excuse<T>> getExcuse(Example<T> example) { 186 return Optional.empty(); 187 } 188 189 @Override 190 protected TensorflowModel<T> copy(String newName, ModelProvenance newProvenance) { 191 return new TensorflowModel<>(newName,newProvenance,featureIDMap,outputIDInfo,modelGraph.toGraphDef(),TensorflowUtil.serialise(modelGraph,session),batchSize,exampleTransformer,outputTransformer); 192 } 193 194 @Override 195 public void close() { 196 if (session != null) { 197 session.close(); 198 } 199 if (modelGraph != null) { 200 modelGraph.close(); 201 } 202 } 203 204 private void writeObject(java.io.ObjectOutputStream out) throws IOException { 205 out.defaultWriteObject(); 206 byte[] modelBytes = modelGraph.toGraphDef(); 207 out.writeObject(modelBytes); 208 Map<String,Object> tensorMap = TensorflowUtil.serialise(modelGraph, session); 209 out.writeObject(tensorMap); 210 } 211 212 @SuppressWarnings("unchecked") //deserialising a typed map 213 private void readObject(java.io.ObjectInputStream in) throws IOException, ClassNotFoundException { 214 in.defaultReadObject(); 215 byte[] modelBytes = (byte[]) in.readObject(); 216 Map<String,Object> tensorMap = (Map<String,Object>) in.readObject(); 217 modelGraph = new Graph(); 218 modelGraph.importGraphDef(modelBytes); 219 session = new Session(modelGraph); 220 // Initialises the parameters. 221 session.runner().addTarget(TensorflowTrainer.INIT).run(); 222 TensorflowUtil.deserialise(session,tensorMap); 223 } 224}