001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.regression.evaluation;
018
019import org.tribuo.Model;
020import org.tribuo.Prediction;
021import org.tribuo.evaluation.AbstractEvaluator;
022import org.tribuo.evaluation.Evaluator;
023import org.tribuo.evaluation.metrics.EvaluationMetric;
024import org.tribuo.evaluation.metrics.MetricID;
025import org.tribuo.evaluation.metrics.MetricTarget;
026import org.tribuo.provenance.EvaluationProvenance;
027import org.tribuo.regression.Regressor;
028
029import java.util.HashSet;
030import java.util.List;
031import java.util.Map;
032import java.util.Set;
033
034/**
035 * A {@link Evaluator} for multi-dimensional regression using {@link Regressor}s.
036 * <p>
037 * If the dataset contains an unknown Regressor (as generated by {@link org.tribuo.regression.RegressionFactory#getUnknownOutput()})
038 * then the evaluate methods will throw {@link IllegalArgumentException} with an appropriate message.
039 */
040public final class RegressionEvaluator extends AbstractEvaluator<Regressor, RegressionMetric.Context, RegressionEvaluation, RegressionMetric> {
041
042    private final boolean useExampleWeights;
043
044    /**
045     * By default, don't use example weights.
046     */
047    public RegressionEvaluator() {
048        this(false);
049    }
050
051    /**
052     * Construct an evaluator.
053     * <p>
054     * Will weight the examples if requested.
055     * @param useExampleWeights Set to true to use the example weights to adjust the importance of the predictions.
056     */
057    public RegressionEvaluator(boolean useExampleWeights) {
058        this.useExampleWeights = useExampleWeights;
059    }
060
061    @Override
062    protected Set<RegressionMetric> createMetrics(Model<Regressor> model) {
063        Set<RegressionMetric> metrics = new HashSet<>();
064        for (Regressor variable : model.getOutputIDInfo().getDomain()) {
065            MetricTarget<Regressor> target = new MetricTarget<>(variable);
066            metrics.add(RegressionMetrics.R2.forTarget(target));
067            metrics.add(RegressionMetrics.RMSE.forTarget(target));
068            metrics.add(RegressionMetrics.MAE.forTarget(target));
069            metrics.add(RegressionMetrics.EV.forTarget(target));
070        }
071        MetricTarget<Regressor> macroAverage = new MetricTarget<>(EvaluationMetric.Average.MACRO);
072        metrics.add(RegressionMetrics.R2.forTarget(macroAverage));
073        metrics.add(RegressionMetrics.RMSE.forTarget(macroAverage));
074        metrics.add(RegressionMetrics.MAE.forTarget(macroAverage));
075        metrics.add(RegressionMetrics.EV.forTarget(macroAverage));
076        return metrics;
077    }
078
079    @Override
080    protected RegressionMetric.Context createContext(Model<Regressor> model, List<Prediction<Regressor>> predictions) {
081        return new RegressionMetric.Context(model, predictions, useExampleWeights);
082    }
083
084    @Override
085    protected RegressionEvaluation createEvaluation(RegressionMetric.Context context,
086                                                    Map<MetricID<Regressor>, Double> results,
087                                                    EvaluationProvenance provenance) {
088        return new RegressionEvaluationImpl(results, context, provenance);
089    }
090}