001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.regression.evaluation; 018 019import org.tribuo.Model; 020import org.tribuo.Prediction; 021import org.tribuo.evaluation.metrics.EvaluationMetric; 022import org.tribuo.evaluation.metrics.MetricContext; 023import org.tribuo.evaluation.metrics.MetricTarget; 024import org.tribuo.regression.Regressor; 025 026import java.util.List; 027import java.util.function.BiFunction; 028import java.util.function.ToDoubleBiFunction; 029 030/** 031 * A {@link EvaluationMetric} for {@link Regressor}s which calculates the metric based on a 032 * the true values and the predicted values. 033 */ 034public class RegressionMetric implements EvaluationMetric<Regressor, RegressionMetric.Context> { 035 036 private final MetricTarget<Regressor> tgt; 037 private final String name; 038 private final ToDoubleBiFunction<MetricTarget<Regressor>, Context> impl; 039 private final boolean useExampleWeights; 040 041 /** 042 * Construct a new {@code RegressionMetric} for the supplied metric target, 043 * using the supplied function. This does not use example weights. 044 * @param tgt The metric target. 045 * @param name The name of the metric. 046 * @param impl The implementing function. 047 */ 048 public RegressionMetric(MetricTarget<Regressor> tgt, 049 String name, 050 ToDoubleBiFunction<MetricTarget<Regressor>, Context> impl) { 051 this(tgt, name, impl, false); 052 } 053 054 /** 055 * Construct a new {@code RegressionMetric} for the supplied metric target, 056 * using the supplied function. 057 * @param tgt The metric target. 058 * @param name The name of the metric. 059 * @param impl The implementing function. 060 * @param useExampleWeights If true then the example weights are used to scale the example importance. 061 */ 062 public RegressionMetric(MetricTarget<Regressor> tgt, 063 String name, 064 ToDoubleBiFunction<MetricTarget<Regressor>, Context> impl, 065 boolean useExampleWeights) { 066 this.tgt = tgt; 067 this.name = name; 068 this.impl = impl; 069 this.useExampleWeights = useExampleWeights; 070 } 071 072 @Override 073 public double compute(Context context) { 074 return impl.applyAsDouble(tgt, context); 075 } 076 077 @Override 078 public MetricTarget<Regressor> getTarget() { 079 return tgt; 080 } 081 082 @Override 083 public String getName() { 084 return name; 085 } 086 087 @Override 088 public Context createContext(Model<Regressor> model, List<Prediction<Regressor>> predictions) { 089 return new Context(model, predictions, useExampleWeights); 090 } 091 092 /** 093 * The {@link MetricContext} for a {@link Regressor} is each true value and each predicted value for all dimensions. 094 */ 095 static class Context extends MetricContext<Regressor> { 096 private final RegressionSufficientStatistics memo; 097 098 Context(Model<Regressor> model, List<Prediction<Regressor>> predictions, boolean useExampleWeights) { 099 super(model, predictions); 100 this.memo = new RegressionSufficientStatistics(model.getOutputIDInfo(),predictions, useExampleWeights); 101 } 102 103 RegressionSufficientStatistics getMemo() { 104 return memo; 105 } 106 } 107}