001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.multilabel.evaluation; 018 019import org.tribuo.Model; 020import org.tribuo.Prediction; 021import org.tribuo.classification.evaluation.ConfusionMatrix; 022import org.tribuo.evaluation.metrics.EvaluationMetric; 023import org.tribuo.evaluation.metrics.MetricContext; 024import org.tribuo.evaluation.metrics.MetricTarget; 025import org.tribuo.multilabel.MultiLabel; 026 027import java.util.List; 028import java.util.Objects; 029import java.util.function.BiFunction; 030 031/** 032 * A {@link EvaluationMetric} for evaluating {@link MultiLabel} problems. 033 * The sufficient statistics used must be held in a {@link ConfusionMatrix}. 034 */ 035public class MultiLabelMetric implements EvaluationMetric<MultiLabel, MultiLabelMetric.Context> { 036 037 private final MetricTarget<MultiLabel> target; 038 private final String name; 039 private final BiFunction<MetricTarget<MultiLabel>, Context, Double> impl; 040 041 public MultiLabelMetric(MetricTarget<MultiLabel> target, String name, BiFunction<MetricTarget<MultiLabel>, Context, Double> impl) { 042 this.target = target; 043 this.name = name; 044 this.impl = impl; 045 } 046 047 @Override 048 public MetricTarget<MultiLabel> getTarget() { 049 return target; 050 } 051 052 @Override 053 public String getName() { 054 return name; 055 } 056 057 @Override 058 public double compute(Context context) { 059 return impl.apply(target, context); 060 } 061 062 @Override 063 public String toString() { 064 return "MultiLabelMetric{" + 065 "target=" + target + 066 ", name='" + name + '\'' + 067 ", impl=" + impl + 068 '}'; 069 } 070 071 @Override 072 public boolean equals(Object o) { 073 if (this == o) return true; 074 if (o == null || getClass() != o.getClass()) return false; 075 MultiLabelMetric that = (MultiLabelMetric) o; 076 return Objects.equals(target, that.target) && 077 Objects.equals(name, that.name) && 078 Objects.equals(impl, that.impl); 079 } 080 081 @Override 082 public int hashCode() { 083 return Objects.hash(target, name, impl); 084 } 085 086 @Override 087 public Context createContext(Model<MultiLabel> model, List<Prediction<MultiLabel>> predictions) { 088 return buildContext(model, predictions); 089 } 090 091 static final class Context extends MetricContext<MultiLabel> { 092 private final ConfusionMatrix<MultiLabel> cm; 093 094 Context(Model<MultiLabel> model, List<Prediction<MultiLabel>> predictions, ConfusionMatrix<MultiLabel> cm) { 095 super(model, predictions); 096 this.cm = cm; 097 } 098 099 ConfusionMatrix<MultiLabel> getCM() { 100 return cm; 101 } 102 } 103 104 static Context buildContext(Model<MultiLabel> model, List<Prediction<MultiLabel>> predictions) { 105 ConfusionMatrix<MultiLabel> cm = new MultiLabelConfusionMatrix(model, predictions); 106 return new Context(model, predictions, cm); 107 } 108 109}