001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.evaluation.metrics; 018 019import org.tribuo.Output; 020 021import java.util.Objects; 022import java.util.Optional; 023 024/** 025 * Used by a given {@link EvaluationMetric} to determine whether it should compute its value for a specific {@link Output} value 026 * or whether it should average them. 027 * 028 * @param <T> The {@link Output} type. 029 */ 030public class MetricTarget<T extends Output<T>> { 031 032 private final T target; 033 private final EvaluationMetric.Average avg; 034 // TODO none value? what about cases like balanced error rate / ami? 035 // - rename Average to Aggregate? then return Aggregate "all"/"domainwise" 036 // sometimes? 037 038 /** 039 * Builds a metric target for an output. 040 * @param target The output to target. 041 */ 042 public MetricTarget(T target) { 043 this.target = target; 044 this.avg = null; 045 } 046 047 /** 048 * Builds a metric target for an average. 049 * @param avg The average to compute. 050 */ 051 public MetricTarget(EvaluationMetric.Average avg) { 052 this.target = null; 053 this.avg = avg; 054 } 055 056 /** 057 * Returns the Output this metric targets, or {@link Optional#empty} if it's an average. 058 * @return The output this metric targets, or {@link Optional#empty}. 059 */ 060 public Optional<T> getOutputTarget() { return Optional.ofNullable(target); } 061 062 /** 063 * Returns the average this metric computes, or {@link Optional#empty} if it targets an output. 064 * @return The average this metric computes, or {@link Optional#empty}. 065 */ 066 public Optional<EvaluationMetric.Average> getAverageTarget() { return Optional.ofNullable(avg); } 067 068 @Override 069 public String toString() { 070 if (getOutputTarget().isPresent()) { 071 return String.format("MetricTarget{output=%s}", getOutputTarget().get()); 072 } else { 073 return String.format("MetricTarget{average=%s}", getAverageTarget().get().name()); 074 } 075 } 076 077 @Override 078 public boolean equals(Object o) { 079 if (this == o) return true; 080 if (o == null || getClass() != o.getClass()) return false; 081 MetricTarget<?> that = (MetricTarget<?>) o; 082 return Objects.equals(target, that.target) && 083 avg == that.avg; 084 } 085 086 @Override 087 public int hashCode() { 088 return Objects.hash(target, avg); 089 } 090 091 private static final MetricTarget<?> macroTarget = new MetricTarget<>(EvaluationMetric.Average.MACRO); 092 private static final MetricTarget<?> microTarget = new MetricTarget<>(EvaluationMetric.Average.MICRO); 093 094 /** 095 * Get the singleton {@code MetricTarget} which contains the {@link EvaluationMetric.Average#MACRO}. 096 * 097 * @param <U> The output type of the {@code MetricTarget} 098 * @return The {@code MetricTarget} that provides a macro average. 099 */ 100 @SuppressWarnings("unchecked") 101 public static <U extends Output<U>> MetricTarget<U> macroAverageTarget() { 102 return (MetricTarget<U>) macroTarget; 103 } 104 105 /** 106 * Get the singleton {@code MetricTarget} which contains the {@link EvaluationMetric.Average#MICRO}. 107 * 108 * @param <U> The output type of the {@code MetricTarget} 109 * @return The {@code MetricTarget} that provides a micro average. 110 */ 111 @SuppressWarnings("unchecked") 112 public static <U extends Output<U>> MetricTarget<U> microAverageTarget() { 113 return (MetricTarget<U>) microTarget; 114 } 115}