001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.interop.tensorflow; 018 019import com.oracle.labs.mlrg.olcut.config.Configurable; 020import com.oracle.labs.mlrg.olcut.provenance.ConfiguredObjectProvenance; 021import com.oracle.labs.mlrg.olcut.provenance.Provenancable; 022import org.tribuo.Example; 023import org.tribuo.ImmutableOutputInfo; 024import org.tribuo.Output; 025import org.tribuo.Prediction; 026import org.tensorflow.Tensor; 027 028import java.io.Serializable; 029import java.util.List; 030 031/** 032 * TensorFlow support is experimental, and may change without a major version bump. 033 * <p> 034 * Converts the {@link Output} into a {@link Tensor} and vice versa. 035 */ 036public interface OutputTransformer<T extends Output<T>> extends Configurable, Provenancable<ConfiguredObjectProvenance>, Serializable { 037 038 /** 039 * Converts a {@link Tensor} into a {@link Prediction}. 040 * @param tensor The tensor to convert. 041 * @param outputIDInfo The output info to use to identify the outputs. 042 * @param numValidFeatures The number of valid features used by the prediction. 043 * @param example The example to insert into the prediction. 044 * @return A prediction object. 045 */ 046 public Prediction<T> transformToPrediction(Tensor<?> tensor, ImmutableOutputInfo<T> outputIDInfo, int numValidFeatures, Example<T> example); 047 048 /** 049 * Converts a {@link Tensor} into the specified output type. 050 * @param tensor The tensor to convert. 051 * @param outputIDInfo The output info to use to identify the outputs. 052 * @return A output. 053 */ 054 public T transformToOutput(Tensor<?> tensor, ImmutableOutputInfo<T> outputIDInfo); 055 056 /** 057 * Converts a {@link Tensor} containing multiple outputs into a list of {@link Prediction}s. 058 * @param tensor The tensor to convert. 059 * @param outputIDInfo The output info to use to identify the outputs. 060 * @param numValidFeatures The number of valid features used by the prediction. 061 * @param examples The example to insert into the prediction. 062 * @return A list of predictions. 063 */ 064 public List<Prediction<T>> transformToBatchPrediction(Tensor<?> tensor, ImmutableOutputInfo<T> outputIDInfo, int[] numValidFeatures, List<Example<T>> examples); 065 066 /** 067 * Converts a {@link Tensor} containing multiple outputs into a list of {@link Output}s. 068 * @param tensor The tensor to convert. 069 * @param outputIDInfo The output info to use to identify the outputs. 070 * @return A list of outputs. 071 */ 072 public List<T> transformToBatchOutput(Tensor<?> tensor, ImmutableOutputInfo<T> outputIDInfo); 073 074 /** 075 * Converts an {@link Output} into a {@link Tensor} representing it's output. 076 * @param output The output to convert. 077 * @param outputIDInfo The output info to use to identify the outputs. 078 * @return A Tensor representing this output. 079 */ 080 public Tensor<?> transform(T output, ImmutableOutputInfo<T> outputIDInfo); 081 082 /** 083 * Converts a list of {@link Example} into a {@link Tensor} representing all the outputs 084 * in the list. It accepts a list of Example rather than a list of Output for efficiency reasons. 085 * @param examples The examples to convert. 086 * @param outputIDInfo The output info to use to identify the outputs. 087 * @return A Tensor representing all the supplied Outputs. 088 */ 089 public Tensor<?> transform(List<Example<T>> examples, ImmutableOutputInfo<T> outputIDInfo); 090 091 /** 092 * Does this OutputTransformer generate probabilities. 093 * @return True if it produces a probability distribution in the {@link Prediction}. 094 */ 095 public boolean generatesProbabilities(); 096 097}