001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.interop.tensorflow.sequence;
018
019import com.oracle.labs.mlrg.olcut.config.Configurable;
020import com.oracle.labs.mlrg.olcut.provenance.ConfiguredObjectProvenance;
021import com.oracle.labs.mlrg.olcut.provenance.Provenancable;
022import org.tribuo.ImmutableOutputInfo;
023import org.tribuo.Output;
024import org.tribuo.Prediction;
025import org.tribuo.sequence.SequenceExample;
026import org.tensorflow.Tensor;
027
028import java.io.Serializable;
029import java.util.List;
030import java.util.Map;
031
032/**
033 * Converts a Tensorflow output tensor into a list of predictions, and a Tribuo sequence example into
034 * a Tensorflow tensor suitable for training.
035 */
036public interface SequenceOutputTransformer<T extends Output<T>> extends Configurable, Provenancable<ConfiguredObjectProvenance>, Serializable {
037
038    /**
039     * Decode a tensor of graph output into a list of predictions for the input sequence.
040     *
041     * @param output graph output
042     * @param input original input sequence example
043     * @param labelMap label domain
044     * @return the model's decoded prediction for the input sequence.
045     */
046    List<Prediction<T>> decode(Tensor<?> output, SequenceExample<T> input, ImmutableOutputInfo<T> labelMap);
047
048    /**
049     * Decode graph output tensors corresponding to a batch of input sequences.
050     *
051     * @param outputs a tensor corresponding to a batch of outputs.
052     * @param inputBatch the original input batch.
053     * @param labelMap label domain
054     * @return the model's decoded predictions, one for each example in the input batch.
055     */
056    List<List<Prediction<T>>> decode(Tensor<?> outputs, List<SequenceExample<T>> inputBatch, ImmutableOutputInfo<T> labelMap);
057
058    /**
059     * Encodes an example's label as a feed dict.
060     *
061     * @param example the input example
062     * @param labelMap label domain
063     * @return a map from graph placeholder names to their fed-in values.
064     */
065    Map<String, Tensor<?>> encode(SequenceExample<T> example, ImmutableOutputInfo<T> labelMap);
066
067    /**
068     * Encodes a batch of labels as a feed dict.
069     *
070     * @param batch a batch of examples.
071     * @param labelMap label domain
072     * @return a map from graph placeholder names to their fed-in values.
073     */
074    Map<String, Tensor<?>> encode(List<SequenceExample<T>> batch, ImmutableOutputInfo<T> labelMap);
075
076}