001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.interop.tensorflow; 018 019import com.oracle.labs.mlrg.olcut.provenance.ConfiguredObjectProvenance; 020import com.oracle.labs.mlrg.olcut.provenance.impl.ConfiguredObjectProvenanceImpl; 021import org.tribuo.Example; 022import org.tribuo.ImmutableOutputInfo; 023import org.tribuo.Prediction; 024import org.tribuo.classification.Label; 025import org.tensorflow.Tensor; 026 027import java.util.ArrayList; 028import java.util.Arrays; 029import java.util.HashMap; 030import java.util.List; 031import java.util.Map; 032import java.util.logging.Logger; 033 034/** 035 * Can convert a {@link Label} into a {@link Tensor} containing a 32-bit integer and 036 * can convert a vector of 32-bit floats into a {@link Prediction} or a {@link Label}. 037 */ 038public class LabelTransformer implements OutputTransformer<Label> { 039 private static final long serialVersionUID = 1L; 040 private static final Logger logger = Logger.getLogger(LabelTransformer.class.getName()); 041 042 @Override 043 public Prediction<Label> transformToPrediction(Tensor<?> tensor, ImmutableOutputInfo<Label> outputIDInfo, int numValidFeatures, Example<Label> example) { 044 float[][] predictions = getBatchPredictions(tensor,outputIDInfo); 045 if (predictions.length != 1) { 046 throw new IllegalArgumentException("Supplied tensor has too many results, predictions.length = " + predictions.length); 047 } 048 return generatePrediction(predictions[0],outputIDInfo,numValidFeatures,example); 049 } 050 051 private Prediction<Label> generatePrediction(float[] predictions, ImmutableOutputInfo<Label> outputIDInfo, int numUsed, Example<Label> example) { 052 Label max = null; 053 Map<String,Label> map = new HashMap<>(); 054 for (int i = 0; i < predictions.length; i++) { 055 Label current = new Label(outputIDInfo.getOutput(i).getLabel(),predictions[i]); 056 map.put(current.getLabel(),current); 057 if ((max == null) || (current.getScore() > max.getScore())) { 058 max = current; 059 } 060 } 061 return new Prediction<>(max,map,numUsed,example,true); 062 } 063 064 @Override 065 public Label transformToOutput(Tensor<?> tensor, ImmutableOutputInfo<Label> outputIDInfo) { 066 float[][] predictions = getBatchPredictions(tensor,outputIDInfo); 067 if (predictions.length != 1) { 068 throw new IllegalArgumentException("Supplied tensor has too many results, predictions.length = " + predictions.length); 069 } 070 return generateLabel(predictions[0],outputIDInfo); 071 } 072 073 private Label generateLabel(float[] predictions, ImmutableOutputInfo<Label> outputIDInfo) { 074 int maxIdx = 0; 075 float max = Float.NEGATIVE_INFINITY; 076 for (int i = 0; i < predictions.length; i++) { 077 if (predictions[i] > max) { 078 maxIdx = i; 079 max = predictions[i]; 080 } 081 } 082 return new Label(outputIDInfo.getOutput(maxIdx).getLabel(),max); 083 } 084 085 private float[][] getBatchPredictions(Tensor<?> tensor, ImmutableOutputInfo<Label> outputIDInfo) { 086 long[] shape = tensor.shape(); 087 if (shape.length != 2) { 088 throw new IllegalArgumentException("Supplied tensor has the wrong number of dimensions, shape = " + Arrays.toString(shape)); 089 } 090 int numValues = (int) shape[1]; 091 if (numValues != outputIDInfo.size()) { 092 throw new IllegalArgumentException("Supplied tensor has too many elements, tensor.length = " + numValues + ", outputIDInfo.size() = " + outputIDInfo.size()); 093 } 094 int batchSize = (int) shape[0]; 095 Tensor<Float> converted = tensor.expect(Float.class); 096 return converted.copyTo(new float[batchSize][numValues]); 097 } 098 099 @Override 100 public List<Prediction<Label>> transformToBatchPrediction(Tensor<?> tensor, ImmutableOutputInfo<Label> outputIDInfo, int[] numValidFeatures, List<Example<Label>> examples) { 101 float[][] predictions = getBatchPredictions(tensor,outputIDInfo); 102 List<Prediction<Label>> output = new ArrayList<>(); 103 104 if ((predictions.length != examples.size()) || (predictions.length != numValidFeatures.length)) { 105 throw new IllegalArgumentException("Invalid number of predictions received from Tensorflow, expected " + numValidFeatures.length + ", received " + predictions.length); 106 } 107 108 for (int i = 0; i < predictions.length; i++) { 109 output.add(generatePrediction(predictions[i],outputIDInfo,numValidFeatures[i],examples.get(i))); 110 } 111 112 return output; 113 } 114 115 @Override 116 public List<Label> transformToBatchOutput(Tensor<?> tensor, ImmutableOutputInfo<Label> outputIDInfo) { 117 float[][] predictions = getBatchPredictions(tensor,outputIDInfo); 118 List<Label> output = new ArrayList<>(); 119 120 for (int i = 0; i < predictions.length; i++) { 121 output.add(generateLabel(predictions[i],outputIDInfo)); 122 } 123 124 return output; 125 } 126 127 private int innerTransform(Label label, ImmutableOutputInfo<Label> outputIDInfo) { 128 int id = outputIDInfo.getID(label); 129 if (id == -1) { 130 throw new IllegalArgumentException("Label " + label + " isn't known by the supplied outputIDInfo, " + outputIDInfo.toString()); 131 } 132 return id; 133 } 134 135 @Override 136 public Tensor<?> transform(Label example, ImmutableOutputInfo<Label> outputIDInfo) { 137 int[] output = new int[1]; 138 output[0] = innerTransform(example, outputIDInfo); 139 return Tensor.create(output); 140 } 141 142 @Override 143 public Tensor<?> transform(List<Example<Label>> examples, ImmutableOutputInfo<Label> outputIDInfo) { 144 int[] output = new int[examples.size()]; 145 int i = 0; 146 for (Example<Label> e : examples) { 147 output[i] = innerTransform(e.getOutput(), outputIDInfo); 148 i++; 149 } 150 return Tensor.create(output); 151 } 152 153 @Override 154 public boolean generatesProbabilities() { 155 return true; 156 } 157 158 @Override 159 public String toString() { 160 return "LabelTransformer()"; 161 } 162 163 @Override 164 public ConfiguredObjectProvenance getProvenance() { 165 return new ConfiguredObjectProvenanceImpl(this,"OutputTransformer"); 166 } 167}