001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.interop.tensorflow.sequence; 018 019import com.oracle.labs.mlrg.olcut.config.Configurable; 020import com.oracle.labs.mlrg.olcut.provenance.ConfiguredObjectProvenance; 021import com.oracle.labs.mlrg.olcut.provenance.Provenancable; 022import org.tribuo.ImmutableOutputInfo; 023import org.tribuo.Output; 024import org.tribuo.Prediction; 025import org.tribuo.sequence.SequenceExample; 026import org.tensorflow.Tensor; 027 028import java.io.Serializable; 029import java.util.List; 030import java.util.Map; 031 032/** 033 * Converts a Tensorflow output tensor into a list of predictions, and a Tribuo sequence example into 034 * a Tensorflow tensor suitable for training. 035 */ 036public interface SequenceOutputTransformer<T extends Output<T>> extends Configurable, Provenancable<ConfiguredObjectProvenance>, Serializable { 037 038 /** 039 * Decode a tensor of graph output into a list of predictions for the input sequence. 040 * 041 * @param output graph output 042 * @param input original input sequence example 043 * @param labelMap label domain 044 * @return the model's decoded prediction for the input sequence. 045 */ 046 List<Prediction<T>> decode(Tensor<?> output, SequenceExample<T> input, ImmutableOutputInfo<T> labelMap); 047 048 /** 049 * Decode graph output tensors corresponding to a batch of input sequences. 050 * 051 * @param outputs a tensor corresponding to a batch of outputs. 052 * @param inputBatch the original input batch. 053 * @param labelMap label domain 054 * @return the model's decoded predictions, one for each example in the input batch. 055 */ 056 List<List<Prediction<T>>> decode(Tensor<?> outputs, List<SequenceExample<T>> inputBatch, ImmutableOutputInfo<T> labelMap); 057 058 /** 059 * Encodes an example's label as a feed dict. 060 * 061 * @param example the input example 062 * @param labelMap label domain 063 * @return a map from graph placeholder names to their fed-in values. 064 */ 065 Map<String, Tensor<?>> encode(SequenceExample<T> example, ImmutableOutputInfo<T> labelMap); 066 067 /** 068 * Encodes a batch of labels as a feed dict. 069 * 070 * @param batch a batch of examples. 071 * @param labelMap label domain 072 * @return a map from graph placeholder names to their fed-in values. 073 */ 074 Map<String, Tensor<?>> encode(List<SequenceExample<T>> batch, ImmutableOutputInfo<T> labelMap); 075 076}