001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.evaluation.metrics;
018
019import org.tribuo.Output;
020
021import java.util.Objects;
022import java.util.Optional;
023
024/**
025 * Used by a given {@link EvaluationMetric} to determine whether it should compute its value for a specific {@link Output} value
026 * or whether it should average them.
027 *
028 * @param <T> The {@link Output} type.
029 */
030public class MetricTarget<T extends Output<T>> {
031
032    private final T target;
033    private final EvaluationMetric.Average avg;
034    // TODO none value? what about cases like balanced error rate / ami?
035    // - rename Average to Aggregate? then return Aggregate "all"/"domainwise"
036    //   sometimes?
037
038    /**
039     * Builds a metric target for an output.
040     * @param target The output to target.
041     */
042    public MetricTarget(T target) {
043        this.target = target;
044        this.avg = null;
045    }
046
047    /**
048     * Builds a metric target for an average.
049     * @param avg The average to compute.
050     */
051    public MetricTarget(EvaluationMetric.Average avg) {
052        this.target = null;
053        this.avg = avg;
054    }
055
056    /**
057     * Returns the Output this metric targets, or {@link Optional#empty} if it's an average.
058     * @return The output this metric targets, or {@link Optional#empty}.
059     */
060    public Optional<T> getOutputTarget() { return Optional.ofNullable(target); }
061
062    /**
063     * Returns the average this metric computes, or {@link Optional#empty} if it targets an output.
064     * @return The average this metric computes, or {@link Optional#empty}.
065     */
066    public Optional<EvaluationMetric.Average> getAverageTarget() { return Optional.ofNullable(avg); }
067
068    @Override
069    public String toString() {
070        if (getOutputTarget().isPresent()) {
071            return String.format("MetricTarget{output=%s}", getOutputTarget().get());
072        } else {
073            return String.format("MetricTarget{average=%s}", getAverageTarget().get().name());
074        }
075    }
076
077    @Override
078    public boolean equals(Object o) {
079        if (this == o) return true;
080        if (o == null || getClass() != o.getClass()) return false;
081        MetricTarget<?> that = (MetricTarget<?>) o;
082        return Objects.equals(target, that.target) &&
083                avg == that.avg;
084    }
085
086    @Override
087    public int hashCode() {
088        return Objects.hash(target, avg);
089    }
090
091    private static final MetricTarget<?> macroTarget = new MetricTarget<>(EvaluationMetric.Average.MACRO);
092    private static final MetricTarget<?> microTarget = new MetricTarget<>(EvaluationMetric.Average.MICRO);
093
094    /**
095     * Get the singleton {@code MetricTarget} which contains the {@link EvaluationMetric.Average#MACRO}.
096     *
097     * @param <U> The output type of the {@code MetricTarget}
098     * @return The {@code MetricTarget} that provides a macro average.
099     */
100    @SuppressWarnings("unchecked")
101    public static <U extends Output<U>> MetricTarget<U> macroAverageTarget() {
102        return (MetricTarget<U>) macroTarget;
103    }
104
105    /**
106     * Get the singleton {@code MetricTarget} which contains the {@link EvaluationMetric.Average#MICRO}.
107     *
108     * @param <U> The output type of the {@code MetricTarget}
109     * @return The {@code MetricTarget} that provides a micro average.
110     */
111    @SuppressWarnings("unchecked")
112    public static <U extends Output<U>> MetricTarget<U> microAverageTarget() {
113        return (MetricTarget<U>) microTarget;
114    }
115}