001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.classification.baseline; 018 019import com.oracle.labs.mlrg.olcut.config.Config; 020import com.oracle.labs.mlrg.olcut.config.PropertyException; 021import com.oracle.labs.mlrg.olcut.provenance.Provenance; 022import org.tribuo.Dataset; 023import org.tribuo.ImmutableFeatureMap; 024import org.tribuo.ImmutableOutputInfo; 025import org.tribuo.Model; 026import org.tribuo.MutableOutputInfo; 027import org.tribuo.Trainer; 028import org.tribuo.classification.Label; 029import org.tribuo.provenance.ModelProvenance; 030import org.tribuo.provenance.TrainerProvenance; 031import org.tribuo.provenance.impl.TrainerProvenanceImpl; 032 033import java.time.OffsetDateTime; 034import java.util.Map; 035 036/** 037 * A trainer for simple baseline classifiers. Use this only for comparison purposes, if you can't beat these 038 * baselines, your ML system doesn't work. 039 */ 040public final class DummyClassifierTrainer implements Trainer<Label> { 041 042 /** 043 * Types of dummy classifier. 044 */ 045 public enum DummyType { 046 /** 047 * Samples the label proprotional to the training label frequencies. 048 */ 049 STRATIFIED, 050 /** 051 * Returns the most frequent training label. 052 */ 053 MOST_FREQUENT, 054 /** 055 * Samples uniformly from the label domain. 056 */ 057 UNIFORM, 058 /** 059 * Returns the supplied label for all inputs. 060 */ 061 CONSTANT 062 } 063 064 @Config(mandatory = true,description="Type of dummy classifier.") 065 private DummyType dummyType; 066 067 @Config(description="Label to use for the constant classifier.") 068 private String constantLabel; 069 070 @Config(description="Seed for the RNG.") 071 private long seed = 1L; 072 073 private int invocationCount = 0; 074 075 private DummyClassifierTrainer() {} 076 077 /** 078 * Used by the OLCUT configuration system, and should not be called by external code. 079 */ 080 @Override 081 public void postConfig() { 082 if ((dummyType == DummyType.CONSTANT) && (constantLabel == null)) { 083 throw new PropertyException("","constantLabel","Please supply a label string when using the type CONSTANT."); 084 } 085 } 086 087 @Override 088 public Model<Label> train(Dataset<Label> examples, Map<String, Provenance> instanceProvenance) { 089 ModelProvenance provenance = new ModelProvenance(DummyClassifierModel.class.getName(), OffsetDateTime.now(), examples.getProvenance(), getProvenance(), instanceProvenance); 090 ImmutableFeatureMap featureMap = examples.getFeatureIDMap(); 091 invocationCount++; 092 switch (dummyType) { 093 case CONSTANT: 094 MutableOutputInfo<Label> labelInfo = examples.getOutputInfo().generateMutableOutputInfo(); 095 Label constLabel = new Label(constantLabel); 096 labelInfo.observe(constLabel); 097 return new DummyClassifierModel(provenance,featureMap,labelInfo.generateImmutableOutputInfo(),constLabel); 098 case MOST_FREQUENT: { 099 ImmutableOutputInfo<Label> immutableLabelInfo = examples.getOutputIDInfo(); 100 return new DummyClassifierModel(provenance, featureMap, immutableLabelInfo); 101 } 102 case UNIFORM: 103 case STRATIFIED: { 104 ImmutableOutputInfo<Label> immutableLabelInfo = examples.getOutputIDInfo(); 105 return new DummyClassifierModel(provenance, featureMap, immutableLabelInfo, dummyType, seed); 106 } 107 default: 108 throw new IllegalStateException("Unknown dummyType " + dummyType); 109 } 110 } 111 112 @Override 113 public int getInvocationCount() { 114 return invocationCount; 115 } 116 117 @Override 118 public String toString() { 119 switch (dummyType) { 120 case CONSTANT: 121 return "DummyClassifierTrainer(dummyType="+dummyType+",constantLabel="+constantLabel+")"; 122 case MOST_FREQUENT: { 123 return "DummyClassifierTrainer(dummyType="+dummyType+")"; 124 } 125 case UNIFORM: 126 case STRATIFIED: { 127 return "DummyClassifierTrainer(dummyType="+dummyType+",seed="+seed+")"; 128 } 129 default: 130 return "DummyClassifierTrainer(dummyType="+dummyType+")"; 131 } 132 } 133 134 @Override 135 public TrainerProvenance getProvenance() { 136 return new TrainerProvenanceImpl(this); 137 } 138 139 /** 140 * Creates a trainer which creates models which return random labels sampled from the training label distribution. 141 * @param seed The RNG seed to use. 142 * @return A classification trainer. 143 */ 144 public static DummyClassifierTrainer createStratifiedTrainer(long seed) { 145 DummyClassifierTrainer trainer = new DummyClassifierTrainer(); 146 trainer.dummyType = DummyType.STRATIFIED; 147 trainer.seed = seed; 148 return trainer; 149 } 150 151 /** 152 * Creates a trainer which creates models which return a fixed label. 153 * @param constantLabel The label to return. 154 * @return A classification trainer. 155 */ 156 public static DummyClassifierTrainer createConstantTrainer(String constantLabel) { 157 DummyClassifierTrainer trainer = new DummyClassifierTrainer(); 158 trainer.dummyType = DummyType.CONSTANT; 159 trainer.constantLabel = constantLabel; 160 return trainer; 161 } 162 163 /** 164 * Creates a trainer which creates models which return random labels sampled uniformly from the labels seen at training time. 165 * @param seed The RNG seed to use. 166 * @return A classification trainer. 167 */ 168 public static DummyClassifierTrainer createUniformTrainer(long seed) { 169 DummyClassifierTrainer trainer = new DummyClassifierTrainer(); 170 trainer.dummyType = DummyType.UNIFORM; 171 trainer.seed = seed; 172 return trainer; 173 } 174 175 /** 176 * Creates a trainer which creates models which return a fixed label, the one which was most frequent in the training data. 177 * @return A classification trainer. 178 */ 179 public static DummyClassifierTrainer createMostFrequentTrainer() { 180 DummyClassifierTrainer trainer = new DummyClassifierTrainer(); 181 trainer.dummyType = DummyType.MOST_FREQUENT; 182 return trainer; 183 } 184}