001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.clustering.evaluation;
018
019import org.tribuo.Model;
020import org.tribuo.Prediction;
021import org.tribuo.clustering.ClusterID;
022import org.tribuo.clustering.ClusteringFactory;
023import org.tribuo.evaluation.metrics.EvaluationMetric;
024import org.tribuo.evaluation.metrics.MetricContext;
025import org.tribuo.evaluation.metrics.MetricTarget;
026
027import java.util.ArrayList;
028import java.util.List;
029import java.util.function.BiFunction;
030
031/**
032 * A metric for evaluating clustering problems. The sufficient statistics are the cluster
033 * ids assigned to every point, along with the "true" ids.
034 */
035public class ClusteringMetric implements EvaluationMetric<ClusterID, ClusteringMetric.Context> {
036
037    private final MetricTarget<ClusterID> target;
038    private final String name;
039    private final BiFunction<MetricTarget<ClusterID>, Context, Double> impl;
040
041    public ClusteringMetric(MetricTarget<ClusterID> target, String name, BiFunction<MetricTarget<ClusterID>, Context, Double> impl) {
042        this.target = target;
043        this.name = name;
044        this.impl = impl;
045    }
046
047    @Override
048    public double compute(Context context) {
049        return impl.apply(target, context);
050    }
051
052    @Override
053    public MetricTarget<ClusterID> getTarget() {
054        return target;
055    }
056
057    @Override
058    public String getName() {
059        return name;
060    }
061
062    @Override
063    public Context createContext(Model<ClusterID> model, List<Prediction<ClusterID>> predictions) {
064        return buildContext(model, predictions);
065    }
066
067    @Override
068    public String toString() {
069        return "ClusteringMetric(" +
070                "target=" + target +
071                ",name='" + name + '\'' +
072                ')';
073    }
074
075    static final class Context extends MetricContext<ClusterID> {
076
077        private final ArrayList<Integer> predictedIDs = new ArrayList<>();
078        private final ArrayList<Integer> trueIDs = new ArrayList<>();
079
080        Context(Model<ClusterID> model, List<Prediction<ClusterID>> predictions) {
081            super(model, predictions);
082            int i = 0;
083            for (Prediction<ClusterID> pred : predictions) {
084                if (pred.getOutput().equals(ClusteringFactory.UNASSIGNED_CLUSTER_ID)) {
085                    throw new IllegalArgumentException("The sentinel unassigned cluster id was used as a ground truth output at prediction number " + i);
086                } else if (pred.getExample().getOutput().equals(ClusteringFactory.UNASSIGNED_CLUSTER_ID)) {
087                    throw new IllegalArgumentException("The sentinel unassigned cluster id was predicted by the model at prediction number " + i);
088                }
089                predictedIDs.add(pred.getOutput().getID());
090                trueIDs.add(pred.getExample().getOutput().getID());
091                i++;
092            }
093        }
094
095        public ArrayList<Integer> getPredictedIDs() {
096            return predictedIDs;
097        }
098
099        public ArrayList<Integer> getTrueIDs() {
100            return trueIDs;
101        }
102    }
103
104    static Context buildContext(Model<ClusterID> model, List<Prediction<ClusterID>> predictions) {
105        return new Context(model, predictions);
106    }
107
108}