001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.classification.evaluation;
018
019import org.tribuo.Model;
020import org.tribuo.Prediction;
021import org.tribuo.classification.Label;
022import org.tribuo.evaluation.metrics.EvaluationMetric;
023import org.tribuo.evaluation.metrics.MetricContext;
024import org.tribuo.evaluation.metrics.MetricTarget;
025import org.tribuo.sequence.SequenceModel;
026
027import java.util.List;
028import java.util.Objects;
029import java.util.function.ToDoubleBiFunction;
030
031/**
032 * A {@link EvaluationMetric} for {@link Label}s which calculates the value based on a
033 * {@link ConfusionMatrix}.
034 */
035public class LabelMetric implements EvaluationMetric<Label, LabelMetric.Context> {
036
037    private final MetricTarget<Label> tgt;
038    private final String name;
039    private final ToDoubleBiFunction<MetricTarget<Label>, Context> impl;
040
041    /**
042     * Construct a new {@code LabelMetric} for the supplied metric target,
043     * using the supplied function.
044     * @param tgt The metric target.
045     * @param name The name of the metric.
046     * @param impl The implementing function.
047     */
048    public LabelMetric(MetricTarget<Label> tgt, String name,
049                       ToDoubleBiFunction<MetricTarget<Label>, Context> impl) {
050        this.tgt = tgt;
051        this.name = name;
052        this.impl = impl;
053    }
054
055    @Override
056    public double compute(LabelMetric.Context context) {
057        return impl.applyAsDouble(tgt, context);
058    }
059
060    @Override
061    public MetricTarget<Label> getTarget() {
062        return tgt;
063    }
064
065    @Override
066    public String getName() {
067        return name;
068    }
069
070    @Override
071    public boolean equals(Object o) {
072        if (this == o) return true;
073        if (o == null || getClass() != o.getClass()) return false;
074        LabelMetric that = (LabelMetric) o;
075        return Objects.equals(tgt, that.tgt) &&
076                Objects.equals(name, that.name) &&
077                Objects.equals(impl, that.impl);
078    }
079
080    @Override
081    public int hashCode() {
082        return Objects.hash(tgt, name, impl);
083    }
084
085    @Override
086    public String toString() {
087        return "LabelMetric{" +
088                "target=" + tgt +
089                ", name='" + name +
090                '}';
091    }
092
093    @Override
094    public Context createContext(Model<Label> model, List<Prediction<Label>> predictions) {
095        return new Context(model, predictions);
096    }
097
098    /**
099     * The context for a {@link LabelMetric} is a {@link ConfusionMatrix}.
100     */
101    public static final class Context extends MetricContext<Label> {
102
103        private final ConfusionMatrix<Label> cm;
104
105        public Context(Model<Label> model, List<Prediction<Label>> predictions) {
106            super(model, predictions);
107            this.cm = new LabelConfusionMatrix(model.getOutputIDInfo(), predictions);
108        }
109
110        public Context(SequenceModel<Label> model, List<Prediction<Label>> predictions) {
111            super(model, predictions);
112            this.cm = new LabelConfusionMatrix(model.getOutputIDInfo(), predictions);
113        }
114
115        public ConfusionMatrix<Label> getCM() {
116            return cm;
117        }
118    }
119}