001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.interop.onnx; 018 019import ai.onnxruntime.OnnxValue; 020import com.oracle.labs.mlrg.olcut.config.Configurable; 021import com.oracle.labs.mlrg.olcut.provenance.ConfiguredObjectProvenance; 022import com.oracle.labs.mlrg.olcut.provenance.Provenancable; 023import org.tribuo.Example; 024import org.tribuo.ImmutableOutputInfo; 025import org.tribuo.Output; 026import org.tribuo.Prediction; 027 028import java.io.Serializable; 029import java.util.List; 030 031/** 032 * Converts an {@link OnnxValue} into an {@link Output} or a {@link Prediction}. 033 * <p> 034 * N.B. ONNX support is experimental, and may change without a major version bump. 035 */ 036public interface OutputTransformer<T extends Output<T>> extends Configurable, Provenancable<ConfiguredObjectProvenance>, Serializable { 037 038 /** 039 * Converts a {@link OnnxValue} into a {@link Prediction}. 040 * @param value The value to convert. 041 * @param outputIDInfo The output info to use to identify the outputs. 042 * @param numValidFeatures The number of valid features used by the prediction. 043 * @param example The example to insert into the prediction. 044 * @return A prediction object. 045 */ 046 public Prediction<T> transformToPrediction(List<OnnxValue> value, ImmutableOutputInfo<T> outputIDInfo, int numValidFeatures, Example<T> example); 047 048 /** 049 * Converts a {@link OnnxValue} into the specified output type. 050 * @param value The value to convert. 051 * @param outputIDInfo The output info to use to identify the outputs. 052 * @return A output. 053 */ 054 public T transformToOutput(List<OnnxValue> value, ImmutableOutputInfo<T> outputIDInfo); 055 056 /** 057 * Converts a {@link OnnxValue} containing multiple outputs into a list of {@link Prediction}s. 058 * @param value The value to convert. 059 * @param outputIDInfo The output info to use to identify the outputs. 060 * @param numValidFeatures The number of valid features used by the prediction. 061 * @param examples The example to insert into the prediction. 062 * @return A list of predictions. 063 */ 064 public List<Prediction<T>> transformToBatchPrediction(List<OnnxValue> value, ImmutableOutputInfo<T> outputIDInfo, int[] numValidFeatures, List<Example<T>> examples); 065 066 /** 067 * Converts a {@link OnnxValue} containing multiple outputs into a list of {@link Output}s. 068 * @param value The value to convert. 069 * @param outputIDInfo The output info to use to identify the outputs. 070 * @return A list of outputs. 071 */ 072 public List<T> transformToBatchOutput(List<OnnxValue> value, ImmutableOutputInfo<T> outputIDInfo); 073 074 /** 075 * Does this OutputTransformer generate probabilities. 076 * @return True if it produces a probability distribution in the {@link Prediction}. 077 */ 078 public boolean generatesProbabilities(); 079 080}