001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo;
018
019import com.oracle.labs.mlrg.olcut.config.Configurable;
020import com.oracle.labs.mlrg.olcut.provenance.Provenancable;
021import org.tribuo.evaluation.Evaluation;
022import org.tribuo.evaluation.Evaluator;
023import org.tribuo.provenance.OutputFactoryProvenance;
024
025import java.io.Serializable;
026import java.util.ArrayList;
027import java.util.HashMap;
028import java.util.List;
029import java.util.Map;
030
031/**
032 * An interface associated with a specific {@link Output}, which can generate the
033 * appropriate Output subclass, and {@link OutputInfo} subclass.
034 * <p>
035 * Must be {@link Configurable} so it can be loaded from an olcut config file.
036 */
037public interface OutputFactory<T extends Output<T>> extends Configurable, Provenancable<OutputFactoryProvenance>, Serializable {
038
039    /**
040     * Parses the {@code V} and generates the appropriate {@link Output} value.
041     * <p>
042     * Most implementations call toString on the label before parsing it, but this is not required.
043     *
044     * @param label An input value.
045     * @param <V> The type of the input value.
046     * @return The parsed Output as an instance of {@code T}.
047     */
048    public <V> T generateOutput(V label);
049
050    /**
051     * Returns the singleton unknown output of type T which can be used for prediction time examples.
052     * @return An unknown output.
053     */
054    public T getUnknownOutput();
055
056    /**
057     * Generates the appropriate {@link MutableOutputInfo} so the
058     * output values can be tracked by a {@link Dataset} or other
059     * aggregate.
060     * @return The appropriate subclass of {@link MutableOutputInfo} initialised to zero.
061     */
062    public MutableOutputInfo<T> generateInfo();
063
064    /**
065     * Creates an {@link ImmutableOutputInfo} from the supplied mapping.
066     * Requires that the mapping is dense in the integers [0,mapping.size()) and
067     * each mapping is unique.
068     * <p>
069     *     <b>This call is used to import external models, and should not be used for other purposes.</b>
070     * </p>
071     * @param mapping The mapping to use.
072     * @return The appropriate subclass of {@link ImmutableOutputInfo} with a single observation of each element.
073     */
074    public ImmutableOutputInfo<T> constructInfoForExternalModel(Map<T,Integer> mapping);
075
076    /**
077     * Gets an {@link Evaluator} suitable for measuring performance of predictions for the Output subclass.
078     * <p>
079     * {@link Evaluator} instances are thread safe and immutable, and commonly this is a singleton
080     * stored in the {@code OutputFactory} implementation.
081     * </p>
082     * @return An evaluator.
083     */
084    public Evaluator<T,? extends Evaluation<T>> getEvaluator();
085
086    /**
087     * Generate a list of outputs from the supplied list of inputs.
088     * <p>
089     * Makes inputs.size() calls to {@link OutputFactory#generateOutput}.
090     * @param inputs The list to convert.
091     * @param <V> The type of the inputs
092     * @return A list of outputs.
093     */
094    default public <V> List<T> generateOutputs(List<V> inputs) {
095        ArrayList<T> outputs = new ArrayList<>();
096
097        for (V input : inputs) {
098            outputs.add(generateOutput(input));
099        }
100
101        return outputs;
102    }
103
104    /**
105     * Validates that the mapping can be used as an output info, i.e.
106     * that it is dense in the region [0,mapping.size()) - meaning no duplicate
107     * ids, each id 0 through mapping.size() is used, and there are no negative ids.
108     * @param mapping The mapping to use.
109     * @param <T> The type of the output.
110     */
111    public static <T extends Output<T>> void validateMapping(Map<T,Integer> mapping) {
112        Map<Integer,T> reverse = new HashMap<>();
113        for (Map.Entry<T,Integer> e : mapping.entrySet()) {
114            if (e.getValue() < 0 || e.getValue() >= mapping.size()) {
115                throw new IllegalArgumentException("Invalid mapping, expected an integer between 0 and mapping.size(), received " + e.getValue());
116            }
117            T l = reverse.put(e.getValue(),e.getKey());
118            if (l != null) {
119                throw new IllegalArgumentException("Invalid mapping, both " + e.getKey() + " and " + l + " map to " + e.getValue());
120            }
121        }
122
123        if (reverse.size() != mapping.size()) {
124            throw new IllegalArgumentException("The Output<->Integer mapping is not a bijection, reverse mapping had " + reverse.size() + " elements, forward mapping had " + mapping.size() + " elements.");
125        }
126    }
127}