001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.regression.evaluation; 018 019import org.tribuo.Model; 020import org.tribuo.Prediction; 021import org.tribuo.evaluation.AbstractEvaluator; 022import org.tribuo.evaluation.Evaluator; 023import org.tribuo.evaluation.metrics.EvaluationMetric; 024import org.tribuo.evaluation.metrics.MetricID; 025import org.tribuo.evaluation.metrics.MetricTarget; 026import org.tribuo.provenance.EvaluationProvenance; 027import org.tribuo.regression.Regressor; 028 029import java.util.HashSet; 030import java.util.List; 031import java.util.Map; 032import java.util.Set; 033 034/** 035 * A {@link Evaluator} for multi-dimensional regression using {@link Regressor}s. 036 * <p> 037 * If the dataset contains an unknown Regressor (as generated by {@link org.tribuo.regression.RegressionFactory#getUnknownOutput()}) 038 * then the evaluate methods will throw {@link IllegalArgumentException} with an appropriate message. 039 */ 040public final class RegressionEvaluator extends AbstractEvaluator<Regressor, RegressionMetric.Context, RegressionEvaluation, RegressionMetric> { 041 042 private final boolean useExampleWeights; 043 044 /** 045 * By default, don't use example weights. 046 */ 047 public RegressionEvaluator() { 048 this(false); 049 } 050 051 /** 052 * Construct an evaluator. 053 * <p> 054 * Will weight the examples if requested. 055 * @param useExampleWeights Set to true to use the example weights to adjust the importance of the predictions. 056 */ 057 public RegressionEvaluator(boolean useExampleWeights) { 058 this.useExampleWeights = useExampleWeights; 059 } 060 061 @Override 062 protected Set<RegressionMetric> createMetrics(Model<Regressor> model) { 063 Set<RegressionMetric> metrics = new HashSet<>(); 064 for (Regressor variable : model.getOutputIDInfo().getDomain()) { 065 MetricTarget<Regressor> target = new MetricTarget<>(variable); 066 metrics.add(RegressionMetrics.R2.forTarget(target)); 067 metrics.add(RegressionMetrics.RMSE.forTarget(target)); 068 metrics.add(RegressionMetrics.MAE.forTarget(target)); 069 metrics.add(RegressionMetrics.EV.forTarget(target)); 070 } 071 MetricTarget<Regressor> macroAverage = new MetricTarget<>(EvaluationMetric.Average.MACRO); 072 metrics.add(RegressionMetrics.R2.forTarget(macroAverage)); 073 metrics.add(RegressionMetrics.RMSE.forTarget(macroAverage)); 074 metrics.add(RegressionMetrics.MAE.forTarget(macroAverage)); 075 metrics.add(RegressionMetrics.EV.forTarget(macroAverage)); 076 return metrics; 077 } 078 079 @Override 080 protected RegressionMetric.Context createContext(Model<Regressor> model, List<Prediction<Regressor>> predictions) { 081 return new RegressionMetric.Context(model, predictions, useExampleWeights); 082 } 083 084 @Override 085 protected RegressionEvaluation createEvaluation(RegressionMetric.Context context, 086 Map<MetricID<Regressor>, Double> results, 087 EvaluationProvenance provenance) { 088 return new RegressionEvaluationImpl(results, context, provenance); 089 } 090}