001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.multilabel.evaluation;
018
019import org.tribuo.Model;
020import org.tribuo.Prediction;
021import org.tribuo.classification.evaluation.ConfusionMatrix;
022import org.tribuo.evaluation.metrics.EvaluationMetric;
023import org.tribuo.evaluation.metrics.MetricContext;
024import org.tribuo.evaluation.metrics.MetricTarget;
025import org.tribuo.multilabel.MultiLabel;
026
027import java.util.List;
028import java.util.Objects;
029import java.util.function.BiFunction;
030
031/**
032 * A {@link EvaluationMetric} for evaluating {@link MultiLabel} problems.
033 * The sufficient statistics used must be held in a {@link ConfusionMatrix}.
034 */
035public class MultiLabelMetric implements EvaluationMetric<MultiLabel, MultiLabelMetric.Context> {
036
037    private final MetricTarget<MultiLabel> target;
038    private final String name;
039    private final BiFunction<MetricTarget<MultiLabel>, Context, Double> impl;
040
041    public MultiLabelMetric(MetricTarget<MultiLabel> target, String name, BiFunction<MetricTarget<MultiLabel>, Context, Double> impl) {
042        this.target = target;
043        this.name = name;
044        this.impl = impl;
045    }
046
047    @Override
048    public MetricTarget<MultiLabel> getTarget() {
049        return target;
050    }
051
052    @Override
053    public String getName() {
054        return name;
055    }
056
057    @Override
058    public double compute(Context context) {
059        return impl.apply(target, context);
060    }
061
062    @Override
063    public String toString() {
064        return "MultiLabelMetric{" +
065                "target=" + target +
066                ", name='" + name + '\'' +
067                ", impl=" + impl +
068                '}';
069    }
070
071    @Override
072    public boolean equals(Object o) {
073        if (this == o) return true;
074        if (o == null || getClass() != o.getClass()) return false;
075        MultiLabelMetric that = (MultiLabelMetric) o;
076        return Objects.equals(target, that.target) &&
077                Objects.equals(name, that.name) &&
078                Objects.equals(impl, that.impl);
079    }
080
081    @Override
082    public int hashCode() {
083        return Objects.hash(target, name, impl);
084    }
085
086    @Override
087    public Context createContext(Model<MultiLabel> model, List<Prediction<MultiLabel>> predictions) {
088        return buildContext(model, predictions);
089    }
090
091    static final class Context extends MetricContext<MultiLabel> {
092        private final ConfusionMatrix<MultiLabel> cm;
093
094        Context(Model<MultiLabel> model, List<Prediction<MultiLabel>> predictions, ConfusionMatrix<MultiLabel> cm) {
095            super(model, predictions);
096            this.cm = cm;
097        }
098
099        ConfusionMatrix<MultiLabel> getCM() {
100            return cm;
101        }
102    }
103
104    static Context buildContext(Model<MultiLabel> model, List<Prediction<MultiLabel>> predictions) {
105        ConfusionMatrix<MultiLabel> cm = new MultiLabelConfusionMatrix(model, predictions);
106        return new Context(model, predictions, cm);
107    }
108
109}