001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.regression.evaluation;
018
019import org.tribuo.Model;
020import org.tribuo.Prediction;
021import org.tribuo.evaluation.metrics.EvaluationMetric;
022import org.tribuo.evaluation.metrics.MetricContext;
023import org.tribuo.evaluation.metrics.MetricTarget;
024import org.tribuo.regression.Regressor;
025
026import java.util.List;
027import java.util.function.BiFunction;
028import java.util.function.ToDoubleBiFunction;
029
030/**
031 * A {@link EvaluationMetric} for {@link Regressor}s which calculates the metric based on a
032 * the true values and the predicted values.
033 */
034public class RegressionMetric implements EvaluationMetric<Regressor, RegressionMetric.Context> {
035
036    private final MetricTarget<Regressor> tgt;
037    private final String name;
038    private final ToDoubleBiFunction<MetricTarget<Regressor>, Context> impl;
039    private final boolean useExampleWeights;
040
041    /**
042     * Construct a new {@code RegressionMetric} for the supplied metric target,
043     * using the supplied function. This does not use example weights.
044     * @param tgt The metric target.
045     * @param name The name of the metric.
046     * @param impl The implementing function.
047     */
048    public RegressionMetric(MetricTarget<Regressor> tgt,
049                            String name,
050                            ToDoubleBiFunction<MetricTarget<Regressor>, Context> impl) {
051        this(tgt, name, impl, false);
052    }
053
054    /**
055     * Construct a new {@code RegressionMetric} for the supplied metric target,
056     * using the supplied function.
057     * @param tgt The metric target.
058     * @param name The name of the metric.
059     * @param impl The implementing function.
060     * @param useExampleWeights If true then the example weights are used to scale the example importance.
061     */
062    public RegressionMetric(MetricTarget<Regressor> tgt,
063                            String name,
064                            ToDoubleBiFunction<MetricTarget<Regressor>, Context> impl,
065                            boolean useExampleWeights) {
066        this.tgt = tgt;
067        this.name = name;
068        this.impl = impl;
069        this.useExampleWeights = useExampleWeights;
070    }
071
072    @Override
073    public double compute(Context context) {
074        return impl.applyAsDouble(tgt, context);
075    }
076
077    @Override
078    public MetricTarget<Regressor> getTarget() {
079        return tgt;
080    }
081
082    @Override
083    public String getName() {
084        return name;
085    }
086
087    @Override
088    public Context createContext(Model<Regressor> model, List<Prediction<Regressor>> predictions) {
089        return new Context(model, predictions, useExampleWeights);
090    }
091
092    /**
093     * The {@link MetricContext} for a {@link Regressor} is each true value and each predicted value for all dimensions.
094     */
095    static class Context extends MetricContext<Regressor> {
096        private final RegressionSufficientStatistics memo;
097
098        Context(Model<Regressor> model, List<Prediction<Regressor>> predictions, boolean useExampleWeights) {
099            super(model, predictions);
100            this.memo = new RegressionSufficientStatistics(model.getOutputIDInfo(),predictions, useExampleWeights);
101        }
102
103        RegressionSufficientStatistics getMemo() {
104            return memo;
105        }
106    }
107}