001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.anomaly.evaluation;
018
019import org.tribuo.Example;
020import org.tribuo.Model;
021import org.tribuo.Prediction;
022import org.tribuo.anomaly.Event;
023import org.tribuo.anomaly.Event.EventType;
024import org.tribuo.evaluation.metrics.EvaluationMetric;
025import org.tribuo.evaluation.metrics.MetricContext;
026import org.tribuo.evaluation.metrics.MetricTarget;
027
028import java.util.List;
029import java.util.function.ToDoubleBiFunction;
030
031/**
032 * A metric for evaluating anomaly detection problems. The sufficient statistics
033 * must be encoded in the number of true positives, false positives, true negatives
034 * and false negatives.
035 */
036public class AnomalyMetric implements EvaluationMetric<Event, AnomalyMetric.Context> {
037
038    private final MetricTarget<Event> target;
039    private final String name;
040
041    private final ToDoubleBiFunction<MetricTarget<Event>, Context> impl;
042
043    /**
044     * Creates an anomaly detection metric, with a specific name, using the supplied evaluation function.
045     * @param target The target of the metric (i.e., the event type or an average).
046     * @param name The name of the metric.
047     * @param impl The implementation function.
048     */
049    public AnomalyMetric(MetricTarget<Event> target, String name, ToDoubleBiFunction<MetricTarget<Event>, Context> impl) {
050        this.target = target;
051        this.name = name;
052        this.impl = impl;
053    }
054
055    @Override
056    public double compute(Context context) {
057        return impl.applyAsDouble(target, context);
058    }
059
060    @Override
061    public MetricTarget<Event> getTarget() {
062        return target;
063    }
064
065    @Override
066    public String getName() {
067        return name;
068    }
069
070    @Override
071    public Context createContext(Model<Event> model, List<Prediction<Event>> predictions) {
072        return buildContext(model, predictions);
073    }
074
075    static Context buildContext(Model<Event> model, List<Prediction<Event>> predictions) {
076        return new Context(model, predictions);
077    }
078
079    /**
080     * The context for anomaly detection is the tp,fp,tn,fn statistics.
081     */
082    static final class Context extends MetricContext<Event> {
083
084        // predicted anomalous, actually anomalous
085        private final long truePositive;
086        // predicted anomalous, actually expected
087        private final long falsePositive;
088        // predicted expected, actually expected
089        private final long trueNegative;
090        // predicted expected, actually anomalous
091        private final long falseNegative;
092
093        Context(Model<Event> model, List<Prediction<Event>> predictions) {
094            super(model, predictions);
095            PredictionStatistics tab = tabulate(predictions);
096            truePositive = tab.truePositive;
097            falsePositive = tab.falsePositive;
098            trueNegative = tab.trueNegative;
099            falseNegative = tab.falseNegative;
100        }
101
102        long getTruePositive() {
103            return truePositive;
104        }
105
106        long getFalsePositive() {
107            return falsePositive;
108        }
109
110        long getTrueNegative() {
111            return trueNegative;
112        }
113
114        long getFalseNegative() {
115            return falseNegative;
116        }
117
118        private static PredictionStatistics tabulate(List<Prediction<Event>> predictions) {
119            // predicted anomalous, actually anomalous
120            long truePositive = 0;
121            // predicted anomalous, actually expected
122            long falsePositive = 0;
123            // predicted expected, actually expected
124            long trueNegative = 0;
125            // predicted expected, actually anomalous
126            long falseNegative = 0;
127
128            for (Prediction<Event> prediction : predictions) {
129                Example<Event> example = prediction.getExample();
130                Event.EventType truth = example.getOutput().getType();
131                Event.EventType predicted = prediction.getOutput().getType();
132
133                if (truth == EventType.ANOMALOUS) {
134                    if (predicted == EventType.ANOMALOUS) {
135                        truePositive++;
136                    } else if (predicted == EventType.EXPECTED) {
137                        falseNegative++;
138                    } else {
139                        //unknown predicted
140                    }
141                } else if (truth == EventType.EXPECTED) {
142                    if (predicted == EventType.ANOMALOUS) {
143                        falsePositive++;
144                    } else if (predicted == EventType.EXPECTED) {
145                        trueNegative++;
146                    } else {
147                        //unknown predicted
148                    }
149                } else {
150                    // truth unknown
151                    throw new IllegalArgumentException("Evaluation data contained EventType.UNKNOWN as the ground truth output.");
152                }
153            }
154            return new PredictionStatistics(truePositive, falsePositive, trueNegative, falseNegative);
155        }
156    }
157
158    /**
159     * One day it will be a record. Not today though.
160     */
161    private static final class PredictionStatistics {
162        private final long truePositive;
163        private final long falsePositive;
164        private final long trueNegative;
165        private final long falseNegative;
166        PredictionStatistics(long truePositive, long falsePositive, long trueNegative, long falseNegative) {
167            this.truePositive = truePositive;
168            this.falsePositive = falsePositive;
169            this.trueNegative = trueNegative;
170            this.falseNegative = falseNegative;
171        }
172    }
173
174}