001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.regression.baseline; 018 019import com.oracle.labs.mlrg.olcut.config.Config; 020import com.oracle.labs.mlrg.olcut.config.PropertyException; 021import com.oracle.labs.mlrg.olcut.provenance.ObjectProvenance; 022import com.oracle.labs.mlrg.olcut.provenance.Provenance; 023import com.oracle.labs.mlrg.olcut.provenance.primitives.DoubleProvenance; 024import com.oracle.labs.mlrg.olcut.provenance.primitives.EnumProvenance; 025import com.oracle.labs.mlrg.olcut.provenance.primitives.LongProvenance; 026import com.oracle.labs.mlrg.olcut.provenance.primitives.StringProvenance; 027import com.oracle.labs.mlrg.olcut.util.Pair; 028import org.tribuo.Dataset; 029import org.tribuo.Example; 030import org.tribuo.ImmutableOutputInfo; 031import org.tribuo.Trainer; 032import org.tribuo.provenance.ModelProvenance; 033import org.tribuo.provenance.TrainerProvenance; 034import org.tribuo.provenance.impl.TrainerProvenanceImpl; 035import org.tribuo.regression.Regressor; 036import org.tribuo.util.Util; 037 038import java.time.OffsetDateTime; 039import java.util.Arrays; 040import java.util.HashMap; 041import java.util.Map; 042import java.util.Objects; 043import java.util.Set; 044 045/** 046 * A trainer for simple baseline regressors. Use this only for comparison purposes, if you can't beat these 047 * baselines, your ML system doesn't work. 048 */ 049public final class DummyRegressionTrainer implements Trainer<Regressor> { 050 051 /** 052 * Types of dummy regression model. 053 */ 054 public enum DummyType { 055 /** 056 * Returns the mean of the training data outputs. 057 */ 058 MEAN, 059 /** 060 * Returns the median of the training data outputs. 061 */ 062 MEDIAN, 063 /** 064 * Returns the training data output at the specified fraction of the sorted output. 065 */ 066 QUARTILE, 067 /** 068 * Returns the specified constant value. 069 */ 070 CONSTANT, 071 /** 072 * Samples from a Gaussian using the means and variances from the training data. 073 */ 074 GAUSSIAN 075 } 076 077 @Config(mandatory = true, description="Type of dummy regressor.") 078 private DummyType dummyType; 079 080 @Config(description="Constant value to use for the constant regressor.") 081 private double constantValue = Double.NaN; 082 083 @Config(description="Quartile to use.") 084 private double quartile = Double.NaN; 085 086 @Config(description="The seed for the RNG.") 087 private long seed = 1L; 088 089 private int invocationCount = 0; 090 091 private DummyRegressionTrainer() { } 092 093 /** 094 * Used by the OLCUT configuration system, and should not be called by external code. 095 */ 096 @Override 097 public void postConfig() { 098 if ((dummyType == DummyType.CONSTANT) && (Double.isNaN(constantValue))) { 099 throw new PropertyException("","constantValue","Please supply a constant value when using the type CONSTANT."); 100 } 101 if ((dummyType == DummyType.QUARTILE) && ((quartile < 0.) || (quartile > 1.0))) { 102 throw new PropertyException("","quartile","Please supply a quartile between zero and one when using the type QUARTILE."); 103 } 104 } 105 106 @Override 107 public DummyRegressionModel train(Dataset<Regressor> examples, Map<String, Provenance> instanceProvenance) { 108 ModelProvenance provenance = new ModelProvenance(DummyRegressionModel.class.getName(), OffsetDateTime.now(), examples.getProvenance(), getProvenance(), instanceProvenance); 109 invocationCount++; 110 ImmutableOutputInfo<Regressor> outputInfo = examples.getOutputIDInfo(); 111 Set<Regressor> domain = outputInfo.getDomain(); 112 double[][] outputs = new double[outputInfo.size()][examples.size()]; 113 int i = 0; 114 for (Example<Regressor> e : examples) { 115 for (Regressor.DimensionTuple r : e.getOutput()) { 116 int id = outputInfo.getID(r); 117 outputs[id][i] = r.getValue(); 118 } 119 i++; 120 } 121 Regressor regressor; 122 switch (dummyType) { 123 case CONSTANT: { 124 Regressor.DimensionTuple[] output = new Regressor.DimensionTuple[outputs.length]; 125 for (Regressor r : domain) { 126 int id = outputInfo.getID(r); 127 output[id] = new Regressor.DimensionTuple(r.getNames()[0],constantValue); 128 } 129 regressor = new Regressor(output); 130 return new DummyRegressionModel(provenance,examples.getFeatureIDMap(),outputInfo,dummyType,regressor); 131 } 132 case MEAN: { 133 Regressor.DimensionTuple[] output = new Regressor.DimensionTuple[outputs.length]; 134 for (Regressor r : domain) { 135 int id = outputInfo.getID(r); 136 output[id] = new Regressor.DimensionTuple(r.getNames()[0],Util.mean(outputs[id])); 137 } 138 regressor = new Regressor(output); 139 return new DummyRegressionModel(provenance,examples.getFeatureIDMap(),outputInfo,dummyType,regressor); 140 } 141 case MEDIAN: { 142 Regressor.DimensionTuple[] output = new Regressor.DimensionTuple[outputs.length]; 143 for (Regressor r : domain) { 144 int id = outputInfo.getID(r); 145 Arrays.sort(outputs[id]); 146 output[id] = new Regressor.DimensionTuple(r.getNames()[0],outputs[id][outputs[id].length/2]); 147 } 148 regressor = new Regressor(output); 149 return new DummyRegressionModel(provenance,examples.getFeatureIDMap(),outputInfo,dummyType,regressor); 150 } 151 case QUARTILE: { 152 Regressor.DimensionTuple[] output = new Regressor.DimensionTuple[outputs.length]; 153 for (Regressor r : domain) { 154 int id = outputInfo.getID(r); 155 Arrays.sort(outputs[id]); 156 output[id] = new Regressor.DimensionTuple(r.getNames()[0],outputs[id][(int) (quartile*outputs[id].length)]); 157 } 158 regressor = new Regressor(output); 159 return new DummyRegressionModel(provenance,examples.getFeatureIDMap(),outputInfo,dummyType,regressor); 160 } 161 case GAUSSIAN: { 162 double[] means = new double[outputs.length]; 163 double[] variances = new double[outputs.length]; 164 String[] names = new String[outputs.length]; 165 for (Regressor r : domain) { 166 int id = outputInfo.getID(r); 167 names[id] = r.getNames()[0]; 168 Pair<Double,Double> meanVariance = Util.meanAndVariance(outputs[id]); 169 means[id] = meanVariance.getA(); 170 variances[id] = meanVariance.getB(); 171 } 172 return new DummyRegressionModel(provenance,examples.getFeatureIDMap(),outputInfo,seed,means,variances,names); 173 } 174 default: 175 throw new IllegalStateException("Unknown dummyType " + dummyType); 176 } 177 } 178 179 @Override 180 public String toString() { 181 switch (dummyType) { 182 case CONSTANT: 183 return "DummyRegressionTrainer(dummyType=CONSTANT,constantValue="+constantValue+")"; 184 case MEAN: 185 return "DummyRegressionTrainer(dummyType=MEAN)"; 186 case MEDIAN: 187 return "DummyRegressionTrainer(dummyType=MEDIAN)"; 188 case QUARTILE: 189 return "DummyRegressionTrainer(dummyType=QUARTILE,quartile="+quartile+")"; 190 case GAUSSIAN: 191 return "DummyRegressionTrainer(dummyType=GAUSSIAN,seed="+seed+")"; 192 default: 193 return "DummyRegressionTrainer(dummyType="+dummyType+")"; 194 } 195 } 196 197 @Override 198 public int getInvocationCount() { 199 return invocationCount; 200 } 201 202 @Override 203 public TrainerProvenance getProvenance() { 204 return new TrainerProvenanceImpl(this); 205 } 206 207 /** 208 * Creates a trainer which create models which return a fixed value. 209 * @param value The value to return 210 * @return A regression trainer. 211 */ 212 public static DummyRegressionTrainer createConstantTrainer(double value) { 213 DummyRegressionTrainer trainer = new DummyRegressionTrainer(); 214 trainer.dummyType = DummyType.CONSTANT; 215 trainer.constantValue = value; 216 return trainer; 217 } 218 219 /** 220 * Creates a trainer which create models which sample the output from a gaussian distribution fit to the training data. 221 * @param seed The RNG seed. 222 * @return A regression trainer. 223 */ 224 public static DummyRegressionTrainer createGaussianTrainer(long seed) { 225 DummyRegressionTrainer trainer = new DummyRegressionTrainer(); 226 trainer.dummyType = DummyType.GAUSSIAN; 227 trainer.seed = seed; 228 return trainer; 229 } 230 231 /** 232 * Creates a trainer which create models which return the mean of the training data. 233 * @return A regression trainer. 234 */ 235 public static DummyRegressionTrainer createMeanTrainer() { 236 DummyRegressionTrainer trainer = new DummyRegressionTrainer(); 237 trainer.dummyType = DummyType.MEAN; 238 return trainer; 239 } 240 241 /** 242 * Creates a trainer which create models which return the median of the training data. 243 * @return A regression trainer. 244 */ 245 public static DummyRegressionTrainer createMedianTrainer() { 246 DummyRegressionTrainer trainer = new DummyRegressionTrainer(); 247 trainer.dummyType = DummyType.MEDIAN; 248 return trainer; 249 } 250 251 /** 252 * Creates a trainer which create models which return the value at the specified fraction of the sorted training data. 253 * @param value The quartile value. 254 * @return A regression trainer. 255 */ 256 public static DummyRegressionTrainer createQuartileTrainer(double value) { 257 if (Double.isNaN(value) || value < 0.0 || value > 1.0) { 258 throw new IllegalArgumentException("Please provide an appropriate value between 0.0 and 1.0, found " + value); 259 } 260 DummyRegressionTrainer trainer = new DummyRegressionTrainer(); 261 trainer.dummyType = DummyType.QUARTILE; 262 trainer.quartile = value; 263 return trainer; 264 } 265 266 /** 267 * Provenance for {@link DummyRegressionTrainer}. 268 */ 269 @Deprecated 270 public final static class DummyRegressionTrainerProvenance implements TrainerProvenance { 271 private static final long serialVersionUID = 1L; 272 273 private final String className; 274 private final DummyType dummyType; 275 private final long seed; 276 private final double constantValue; 277 private final double quartile; 278 279 /** 280 * Constructs a provenance from the host. 281 * @param host The host trainer. 282 */ 283 public DummyRegressionTrainerProvenance(DummyRegressionTrainer host) { 284 this.className = host.getClass().getName(); 285 this.dummyType = host.dummyType; 286 this.seed = host.seed; 287 this.constantValue = host.constantValue; 288 this.quartile = host.quartile; 289 } 290 291 /** 292 * Constructs a provenance from the marshalled form. 293 * @param map The map of field values. 294 */ 295 public DummyRegressionTrainerProvenance(Map<String, Provenance> map) { 296 className = ObjectProvenance.checkAndExtractProvenance(map,CLASS_NAME, StringProvenance.class, DummyRegressionTrainerProvenance.class.getSimpleName()).getValue(); 297 dummyType = (DummyType) ObjectProvenance.checkAndExtractProvenance(map,"dummyType", EnumProvenance.class, DummyRegressionTrainerProvenance.class.getSimpleName()).getValue(); 298 seed = ObjectProvenance.checkAndExtractProvenance(map,"seed", LongProvenance.class, DummyRegressionTrainerProvenance.class.getSimpleName()).getValue(); 299 constantValue = ObjectProvenance.checkAndExtractProvenance(map,"constantValue", DoubleProvenance.class, DummyRegressionTrainerProvenance.class.getSimpleName()).getValue(); 300 quartile = ObjectProvenance.checkAndExtractProvenance(map,"quartile", DoubleProvenance.class, DummyRegressionTrainerProvenance.class.getSimpleName()).getValue(); 301 } 302 303 @Override 304 public Map<String, Provenance> getConfiguredParameters() { 305 Map<String, Provenance> map = new HashMap<>(); 306 307 map.put("dummyType",new EnumProvenance<>("dummyType",dummyType)); 308 map.put("constantValue",new DoubleProvenance("constantValue",constantValue)); 309 map.put("quartile",new DoubleProvenance("quartile",quartile)); 310 map.put("seed",new LongProvenance("seed",seed)); 311 312 return map; 313 } 314 315 @Override 316 public String getClassName() { 317 return className; 318 } 319 320 @Override 321 public String toString() { 322 return generateString("Trainer"); 323 } 324 325 @Override 326 public boolean equals(Object o) { 327 if (this == o) return true; 328 if (o == null || getClass() != o.getClass()) return false; 329 DummyRegressionTrainerProvenance pairs = (DummyRegressionTrainerProvenance) o; 330 return seed == pairs.seed && 331 Double.compare(pairs.constantValue, constantValue) == 0 && 332 Double.compare(pairs.quartile, quartile) == 0 && 333 className.equals(pairs.className) && 334 dummyType == pairs.dummyType; 335 } 336 337 @Override 338 public int hashCode() { 339 return Objects.hash(className, dummyType, seed, constantValue, quartile); 340 } 341 } 342}