001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.multilabel.evaluation;
018
019import org.tribuo.classification.evaluation.ConfusionMetrics;
020import org.tribuo.evaluation.metrics.MetricTarget;
021import org.tribuo.multilabel.MultiLabel;
022
023import java.util.function.BiFunction;
024
025/**
026 * An enum of the default {@link MultiLabelMetric}s supported by the multi-label classification
027 * evaluation package.
028 */
029public enum MultiLabelMetrics {
030
031    /**
032     * The number of true positives.
033     */
034    TP((tgt, ctx) -> ConfusionMetrics.tp(tgt, ctx.getCM())),
035    /**
036     * The number of false positives.
037     */
038    FP((tgt, ctx) -> ConfusionMetrics.fp(tgt, ctx.getCM())),
039    /**
040     * The number of true negatives.
041     */
042    TN((tgt, ctx) -> ConfusionMetrics.tn(tgt, ctx.getCM())),
043    /**
044     * The number of false negatives.
045     */
046    FN((tgt, ctx) -> ConfusionMetrics.fn(tgt, ctx.getCM())),
047    /**
048     * The precision, i.e., the number of true positives divided by the number of predicted positives.
049     */
050    PRECISION((tgt, ctx) -> ConfusionMetrics.precision(tgt, ctx.getCM())),
051    /**
052     * The recall, i.e., the number of true positives divided by the number of ground truth positives.
053     */
054    RECALL((tgt, ctx) -> ConfusionMetrics.recall(tgt, ctx.getCM())),
055    /**
056     * The F_1 score, i.e., the harmonic mean of the precision and the recall.
057     */
058    F1((tgt, ctx) -> ConfusionMetrics.f1(tgt, ctx.getCM())),
059    /**
060     * The balanced error rate, i.e., the mean of the per class recalls.
061     */
062    BALANCED_ERROR_RATE((tgt, ctx) -> ConfusionMetrics.balancedErrorRate(ctx.getCM()));
063
064    private final BiFunction<MetricTarget<MultiLabel>, MultiLabelMetric.Context, Double> impl;
065
066    MultiLabelMetrics(BiFunction<MetricTarget<MultiLabel>, MultiLabelMetric.Context, Double> impl) {
067        this.impl = impl;
068    }
069
070    public BiFunction<MetricTarget<MultiLabel>, MultiLabelMetric.Context, Double> getImpl() {
071        return impl;
072    }
073
074    public MultiLabelMetric forTarget(MetricTarget<MultiLabel> tgt) {
075        return new MultiLabelMetric(tgt, this.name(), this.getImpl());
076    }
077}