001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.interop.tensorflow;
018
019import com.oracle.labs.mlrg.olcut.provenance.ConfiguredObjectProvenance;
020import com.oracle.labs.mlrg.olcut.provenance.impl.ConfiguredObjectProvenanceImpl;
021import org.tribuo.Example;
022import org.tribuo.ImmutableOutputInfo;
023import org.tribuo.Prediction;
024import org.tribuo.classification.Label;
025import org.tensorflow.Tensor;
026
027import java.util.ArrayList;
028import java.util.Arrays;
029import java.util.HashMap;
030import java.util.List;
031import java.util.Map;
032import java.util.logging.Logger;
033
034/**
035 * Can convert a {@link Label} into a {@link Tensor} containing a 32-bit integer and
036 * can convert a vector of 32-bit floats into a {@link Prediction} or a {@link Label}.
037 */
038public class LabelTransformer implements OutputTransformer<Label> {
039    private static final long serialVersionUID = 1L;
040    private static final Logger logger = Logger.getLogger(LabelTransformer.class.getName());
041
042    @Override
043    public Prediction<Label> transformToPrediction(Tensor<?> tensor, ImmutableOutputInfo<Label> outputIDInfo, int numValidFeatures, Example<Label> example) {
044        float[][] predictions = getBatchPredictions(tensor,outputIDInfo);
045        if (predictions.length != 1) {
046            throw new IllegalArgumentException("Supplied tensor has too many results, predictions.length = " + predictions.length);
047        }
048        return generatePrediction(predictions[0],outputIDInfo,numValidFeatures,example);
049    }
050
051    private Prediction<Label> generatePrediction(float[] predictions, ImmutableOutputInfo<Label> outputIDInfo, int numUsed, Example<Label> example) {
052        Label max = null;
053        Map<String,Label> map = new HashMap<>();
054        for (int i = 0; i < predictions.length; i++) {
055            Label current = new Label(outputIDInfo.getOutput(i).getLabel(),predictions[i]);
056            map.put(current.getLabel(),current);
057            if ((max == null) || (current.getScore() > max.getScore())) {
058                max = current;
059            }
060        }
061        return new Prediction<>(max,map,numUsed,example,true);
062    }
063
064    @Override
065    public Label transformToOutput(Tensor<?> tensor, ImmutableOutputInfo<Label> outputIDInfo) {
066        float[][] predictions = getBatchPredictions(tensor,outputIDInfo);
067        if (predictions.length != 1) {
068            throw new IllegalArgumentException("Supplied tensor has too many results, predictions.length = " + predictions.length);
069        }
070        return generateLabel(predictions[0],outputIDInfo);
071    }
072
073    private Label generateLabel(float[] predictions, ImmutableOutputInfo<Label> outputIDInfo) {
074        int maxIdx = 0;
075        float max = Float.NEGATIVE_INFINITY;
076        for (int i = 0; i < predictions.length; i++) {
077            if (predictions[i] > max) {
078                maxIdx = i;
079                max = predictions[i];
080            }
081        }
082        return new Label(outputIDInfo.getOutput(maxIdx).getLabel(),max);
083    }
084
085    private float[][] getBatchPredictions(Tensor<?> tensor, ImmutableOutputInfo<Label> outputIDInfo) {
086        long[] shape = tensor.shape();
087        if (shape.length != 2) {
088            throw new IllegalArgumentException("Supplied tensor has the wrong number of dimensions, shape = " + Arrays.toString(shape));
089        }
090        int numValues = (int) shape[1];
091        if (numValues != outputIDInfo.size()) {
092            throw new IllegalArgumentException("Supplied tensor has too many elements, tensor.length = " + numValues + ", outputIDInfo.size() = " + outputIDInfo.size());
093        }
094        int batchSize = (int) shape[0];
095        Tensor<Float> converted = tensor.expect(Float.class);
096        return converted.copyTo(new float[batchSize][numValues]);
097    }
098
099    @Override
100    public List<Prediction<Label>> transformToBatchPrediction(Tensor<?> tensor, ImmutableOutputInfo<Label> outputIDInfo, int[] numValidFeatures, List<Example<Label>> examples) {
101        float[][] predictions = getBatchPredictions(tensor,outputIDInfo);
102        List<Prediction<Label>> output = new ArrayList<>();
103
104        if ((predictions.length != examples.size()) || (predictions.length != numValidFeatures.length)) {
105            throw new IllegalArgumentException("Invalid number of predictions received from Tensorflow, expected " + numValidFeatures.length + ", received " + predictions.length);
106        }
107
108        for (int i = 0; i < predictions.length; i++) {
109            output.add(generatePrediction(predictions[i],outputIDInfo,numValidFeatures[i],examples.get(i)));
110        }
111
112        return output;
113    }
114
115    @Override
116    public List<Label> transformToBatchOutput(Tensor<?> tensor, ImmutableOutputInfo<Label> outputIDInfo) {
117        float[][] predictions = getBatchPredictions(tensor,outputIDInfo);
118        List<Label> output = new ArrayList<>();
119
120        for (int i = 0; i < predictions.length; i++) {
121            output.add(generateLabel(predictions[i],outputIDInfo));
122        }
123
124        return output;
125    }
126
127    private int innerTransform(Label label, ImmutableOutputInfo<Label> outputIDInfo) {
128        int id = outputIDInfo.getID(label);
129        if (id == -1) {
130            throw new IllegalArgumentException("Label " + label + " isn't known by the supplied outputIDInfo, " + outputIDInfo.toString());
131        }
132        return id;
133    }
134
135    @Override
136    public Tensor<?> transform(Label example, ImmutableOutputInfo<Label> outputIDInfo) {
137        int[] output = new int[1];
138        output[0] = innerTransform(example, outputIDInfo);
139        return Tensor.create(output);
140    }
141
142    @Override
143    public Tensor<?> transform(List<Example<Label>> examples, ImmutableOutputInfo<Label> outputIDInfo) {
144        int[] output = new int[examples.size()];
145        int i = 0;
146        for (Example<Label> e : examples) {
147            output[i] = innerTransform(e.getOutput(), outputIDInfo);
148            i++;
149        }
150        return Tensor.create(output);
151    }
152
153    @Override
154    public boolean generatesProbabilities() {
155        return true;
156    }
157
158    @Override
159    public String toString() {
160        return "LabelTransformer()";
161    }
162
163    @Override
164    public ConfiguredObjectProvenance getProvenance() {
165        return new ConfiguredObjectProvenanceImpl(this,"OutputTransformer");
166    }
167}