001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.interop.tensorflow; 018 019import com.oracle.labs.mlrg.olcut.util.Pair; 020import org.tribuo.Example; 021import org.tribuo.Excuse; 022import org.tribuo.ImmutableFeatureMap; 023import org.tribuo.ImmutableOutputInfo; 024import org.tribuo.Model; 025import org.tribuo.Output; 026import org.tribuo.Prediction; 027import org.tribuo.math.la.SparseVector; 028import org.tribuo.provenance.ModelProvenance; 029import org.tensorflow.Graph; 030import org.tensorflow.Session; 031import org.tensorflow.Tensor; 032import org.tensorflow.Tensors; 033 034import java.io.Closeable; 035import java.io.IOException; 036import java.nio.file.Paths; 037import java.util.Collections; 038import java.util.List; 039import java.util.Map; 040import java.util.Optional; 041 042/** 043 * TensorFlow support is experimental, and may change without a major version bump. 044 * <p> 045 * This model encapsulates a simple model with a single input tensor (labelled {@link TensorflowModel#INPUT_NAME}), 046 * and produces a single output tensor (labelled {@link TensorflowModel#OUTPUT_NAME}). 047 * <p> 048 * It accepts an {@link ExampleTransformer} that converts an example's features into a {@link Tensor}, and an 049 * {@link OutputTransformer} that converts a {@link Tensor} into a {@link Prediction}. 050 * <p> 051 * The model's serialVersionUID is set to the major Tensorflow version number times 100. 052 * <p> 053 * N.B. Tensorflow support is experimental and may change without a major version bump. 054 */ 055public class TensorflowCheckpointModel<T extends Output<T>> extends Model<T> implements Closeable { 056 057 private static final long serialVersionUID = 100L; 058 059 private transient Graph modelGraph = null; 060 061 private transient Session session = null; 062 063 private final String checkpointDirectory; 064 065 private final ExampleTransformer<T> exampleTransformer; 066 067 private final OutputTransformer<T> outputTransformer; 068 069 TensorflowCheckpointModel(String name, ModelProvenance description, ImmutableFeatureMap featureIDMap, ImmutableOutputInfo<T> outputIDMap, byte[] graphDef, String checkpointDirectory, ExampleTransformer<T> exampleTransformer, OutputTransformer<T> outputTransformer) { 070 super(name, description, featureIDMap, outputIDMap, outputTransformer.generatesProbabilities()); 071 this.exampleTransformer = exampleTransformer; 072 this.outputTransformer = outputTransformer; 073 this.checkpointDirectory = checkpointDirectory; 074 this.modelGraph = new Graph(); 075 this.modelGraph.importGraphDef(graphDef); 076 this.session = new Session(modelGraph); 077 078 try (Tensor<String> checkpointPrefix = Tensors.create(Paths.get(checkpointDirectory+"/"+TensorflowCheckpointTrainer.MODEL_FILENAME).toString())) { 079 // Initialises the parameters. 080 session.runner().feed("save/Const", checkpointPrefix).addTarget("save/restore_all").run(); 081 } 082 } 083 084 @Override 085 public Prediction<T> predict(Example<T> example) { 086 // This adds overhead and triggers lookups for each feature, but is necessary to correctly calculate 087 // the number of features used in this example. 088 SparseVector vec = SparseVector.createSparseVector(example,featureIDMap,false); 089 try (Tensor<?> transformedInput = exampleTransformer.transform(example,featureIDMap); 090 Tensor<?> isTraining = Tensor.create(false); 091 Tensor<?> outputTensor = session.runner() 092 .feed(TensorflowModel.INPUT_NAME,transformedInput) 093 .feed(TensorflowTrainer.IS_TRAINING,isTraining) 094 .fetch(TensorflowModel.OUTPUT_NAME).run().get(0)) { 095 // Transform the returned tensor into a Prediction. 096 return outputTransformer.transformToPrediction(outputTensor,outputIDInfo,vec.numActiveElements(),example); 097 } 098 } 099 100 /** 101 * Deep learning models don't do feature rankings. Use an Explainer. 102 * <p> 103 * This method always returns the empty map. 104 * @param n the number of features to return. 105 * @return The empty map. 106 */ 107 @Override 108 public Map<String, List<Pair<String, Double>>> getTopFeatures(int n) { 109 return Collections.emptyMap(); 110 } 111 112 /** 113 * Deep learning models don't do excuses. Use an Explainer. 114 * <p> 115 * This method always returns {@link Optional#empty}. 116 * @param example The input example. 117 * @return {@link Optional#empty}. 118 */ 119 @Override 120 public Optional<Excuse<T>> getExcuse(Example<T> example) { 121 return Optional.empty(); 122 } 123 124 @Override 125 protected TensorflowCheckpointModel<T> copy(String newName, ModelProvenance newProvenance) { 126 return new TensorflowCheckpointModel<>(newName,newProvenance,featureIDMap,outputIDInfo,modelGraph.toGraphDef(),checkpointDirectory,exampleTransformer,outputTransformer); 127 } 128 129 @Override 130 public void close() { 131 if (session != null) { 132 session.close(); 133 } 134 if (modelGraph != null) { 135 modelGraph.close(); 136 } 137 } 138 139 private void writeObject(java.io.ObjectOutputStream out) throws IOException { 140 out.defaultWriteObject(); 141 byte[] modelBytes = modelGraph.toGraphDef(); 142 out.writeObject(modelBytes); 143 } 144 145 private void readObject(java.io.ObjectInputStream in) throws IOException, ClassNotFoundException { 146 in.defaultReadObject(); 147 byte[] modelBytes = (byte[]) in.readObject(); 148 this.modelGraph = new Graph(); 149 this.modelGraph.importGraphDef(modelBytes); 150 this.session = new Session(modelGraph); 151 152 try (Tensor<String> checkpointPrefix = Tensors.create(Paths.get(checkpointDirectory+"/"+TensorflowCheckpointTrainer.MODEL_FILENAME).toString())) { 153 // Initialises the parameters. 154 session.runner().feed("save/Const", checkpointPrefix).addTarget("save/restore_all").run(); 155 } 156 } 157}