001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo; 018 019import com.oracle.labs.mlrg.olcut.config.Configurable; 020import com.oracle.labs.mlrg.olcut.provenance.Provenancable; 021import org.tribuo.evaluation.Evaluation; 022import org.tribuo.evaluation.Evaluator; 023import org.tribuo.provenance.OutputFactoryProvenance; 024 025import java.io.Serializable; 026import java.util.ArrayList; 027import java.util.HashMap; 028import java.util.List; 029import java.util.Map; 030 031/** 032 * An interface associated with a specific {@link Output}, which can generate the 033 * appropriate Output subclass, and {@link OutputInfo} subclass. 034 * <p> 035 * Must be {@link Configurable} so it can be loaded from an olcut config file. 036 */ 037public interface OutputFactory<T extends Output<T>> extends Configurable, Provenancable<OutputFactoryProvenance>, Serializable { 038 039 /** 040 * Parses the {@code V} and generates the appropriate {@link Output} value. 041 * <p> 042 * Most implementations call toString on the label before parsing it, but this is not required. 043 * 044 * @param label An input value. 045 * @param <V> The type of the input value. 046 * @return The parsed Output as an instance of {@code T}. 047 */ 048 public <V> T generateOutput(V label); 049 050 /** 051 * Returns the singleton unknown output of type T which can be used for prediction time examples. 052 * @return An unknown output. 053 */ 054 public T getUnknownOutput(); 055 056 /** 057 * Generates the appropriate {@link MutableOutputInfo} so the 058 * output values can be tracked by a {@link Dataset} or other 059 * aggregate. 060 * @return The appropriate subclass of {@link MutableOutputInfo} initialised to zero. 061 */ 062 public MutableOutputInfo<T> generateInfo(); 063 064 /** 065 * Creates an {@link ImmutableOutputInfo} from the supplied mapping. 066 * Requires that the mapping is dense in the integers [0,mapping.size()) and 067 * each mapping is unique. 068 * <p> 069 * <b>This call is used to import external models, and should not be used for other purposes.</b> 070 * </p> 071 * @param mapping The mapping to use. 072 * @return The appropriate subclass of {@link ImmutableOutputInfo} with a single observation of each element. 073 */ 074 public ImmutableOutputInfo<T> constructInfoForExternalModel(Map<T,Integer> mapping); 075 076 /** 077 * Gets an {@link Evaluator} suitable for measuring performance of predictions for the Output subclass. 078 * <p> 079 * {@link Evaluator} instances are thread safe and immutable, and commonly this is a singleton 080 * stored in the {@code OutputFactory} implementation. 081 * </p> 082 * @return An evaluator. 083 */ 084 public Evaluator<T,? extends Evaluation<T>> getEvaluator(); 085 086 /** 087 * Generate a list of outputs from the supplied list of inputs. 088 * <p> 089 * Makes inputs.size() calls to {@link OutputFactory#generateOutput}. 090 * @param inputs The list to convert. 091 * @param <V> The type of the inputs 092 * @return A list of outputs. 093 */ 094 default public <V> List<T> generateOutputs(List<V> inputs) { 095 ArrayList<T> outputs = new ArrayList<>(); 096 097 for (V input : inputs) { 098 outputs.add(generateOutput(input)); 099 } 100 101 return outputs; 102 } 103 104 /** 105 * Validates that the mapping can be used as an output info, i.e. 106 * that it is dense in the region [0,mapping.size()) - meaning no duplicate 107 * ids, each id 0 through mapping.size() is used, and there are no negative ids. 108 * @param mapping The mapping to use. 109 * @param <T> The type of the output. 110 */ 111 public static <T extends Output<T>> void validateMapping(Map<T,Integer> mapping) { 112 Map<Integer,T> reverse = new HashMap<>(); 113 for (Map.Entry<T,Integer> e : mapping.entrySet()) { 114 if (e.getValue() < 0 || e.getValue() >= mapping.size()) { 115 throw new IllegalArgumentException("Invalid mapping, expected an integer between 0 and mapping.size(), received " + e.getValue()); 116 } 117 T l = reverse.put(e.getValue(),e.getKey()); 118 if (l != null) { 119 throw new IllegalArgumentException("Invalid mapping, both " + e.getKey() + " and " + l + " map to " + e.getValue()); 120 } 121 } 122 123 if (reverse.size() != mapping.size()) { 124 throw new IllegalArgumentException("The Output<->Integer mapping is not a bijection, reverse mapping had " + reverse.size() + " elements, forward mapping had " + mapping.size() + " elements."); 125 } 126 } 127}