001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.classification.baseline;
018
019import com.oracle.labs.mlrg.olcut.config.Config;
020import com.oracle.labs.mlrg.olcut.config.PropertyException;
021import com.oracle.labs.mlrg.olcut.provenance.Provenance;
022import org.tribuo.Dataset;
023import org.tribuo.ImmutableFeatureMap;
024import org.tribuo.ImmutableOutputInfo;
025import org.tribuo.Model;
026import org.tribuo.MutableOutputInfo;
027import org.tribuo.Trainer;
028import org.tribuo.classification.Label;
029import org.tribuo.provenance.ModelProvenance;
030import org.tribuo.provenance.TrainerProvenance;
031import org.tribuo.provenance.impl.TrainerProvenanceImpl;
032
033import java.time.OffsetDateTime;
034import java.util.Map;
035
036/**
037 * A trainer for simple baseline classifiers. Use this only for comparison purposes, if you can't beat these
038 * baselines, your ML system doesn't work.
039 */
040public final class DummyClassifierTrainer implements Trainer<Label> {
041
042    /**
043     * Types of dummy classifier.
044     */
045    public enum DummyType {
046        /**
047         * Samples the label proprotional to the training label frequencies.
048         */
049        STRATIFIED,
050        /**
051         * Returns the most frequent training label.
052         */
053        MOST_FREQUENT,
054        /**
055         * Samples uniformly from the label domain.
056         */
057        UNIFORM,
058        /**
059         * Returns the supplied label for all inputs.
060         */
061        CONSTANT
062    }
063
064    @Config(mandatory = true,description="Type of dummy classifier.")
065    private DummyType dummyType;
066
067    @Config(description="Label to use for the constant classifier.")
068    private String constantLabel;
069
070    @Config(description="Seed for the RNG.")
071    private long seed = 1L;
072
073    private int invocationCount = 0;
074
075    private DummyClassifierTrainer() {}
076
077    /**
078     * Used by the OLCUT configuration system, and should not be called by external code.
079     */
080    @Override
081    public void postConfig() {
082        if ((dummyType == DummyType.CONSTANT) && (constantLabel == null)) {
083            throw new PropertyException("","constantLabel","Please supply a label string when using the type CONSTANT.");
084        }
085    }
086
087    @Override
088    public Model<Label> train(Dataset<Label> examples, Map<String, Provenance> instanceProvenance) {
089        ModelProvenance provenance = new ModelProvenance(DummyClassifierModel.class.getName(), OffsetDateTime.now(), examples.getProvenance(), getProvenance(), instanceProvenance);
090        ImmutableFeatureMap featureMap = examples.getFeatureIDMap();
091        invocationCount++;
092        switch (dummyType) {
093            case CONSTANT:
094                MutableOutputInfo<Label> labelInfo = examples.getOutputInfo().generateMutableOutputInfo();
095                Label constLabel = new Label(constantLabel);
096                labelInfo.observe(constLabel);
097                return new DummyClassifierModel(provenance,featureMap,labelInfo.generateImmutableOutputInfo(),constLabel);
098            case MOST_FREQUENT: {
099                ImmutableOutputInfo<Label> immutableLabelInfo = examples.getOutputIDInfo();
100                return new DummyClassifierModel(provenance, featureMap, immutableLabelInfo);
101            }
102            case UNIFORM:
103            case STRATIFIED: {
104                ImmutableOutputInfo<Label> immutableLabelInfo = examples.getOutputIDInfo();
105                return new DummyClassifierModel(provenance, featureMap, immutableLabelInfo, dummyType, seed);
106            }
107            default:
108                throw new IllegalStateException("Unknown dummyType " + dummyType);
109        }
110    }
111
112    @Override
113    public int getInvocationCount() {
114        return invocationCount;
115    }
116
117    @Override
118    public String toString() {
119        switch (dummyType) {
120            case CONSTANT:
121                return "DummyClassifierTrainer(dummyType="+dummyType+",constantLabel="+constantLabel+")";
122            case MOST_FREQUENT: {
123                return "DummyClassifierTrainer(dummyType="+dummyType+")";
124            }
125            case UNIFORM:
126            case STRATIFIED: {
127                return "DummyClassifierTrainer(dummyType="+dummyType+",seed="+seed+")";
128            }
129            default:
130                return "DummyClassifierTrainer(dummyType="+dummyType+")";
131        }
132    }
133
134    @Override
135    public TrainerProvenance getProvenance() {
136        return new TrainerProvenanceImpl(this);
137    }
138
139    /**
140     * Creates a trainer which creates models which return random labels sampled from the training label distribution.
141     * @param seed The RNG seed to use.
142     * @return A classification trainer.
143     */
144    public static DummyClassifierTrainer createStratifiedTrainer(long seed) {
145        DummyClassifierTrainer trainer = new DummyClassifierTrainer();
146        trainer.dummyType = DummyType.STRATIFIED;
147        trainer.seed = seed;
148        return trainer;
149    }
150
151    /**
152     * Creates a trainer which creates models which return a fixed label.
153     * @param constantLabel The label to return.
154     * @return A classification trainer.
155     */
156    public static DummyClassifierTrainer createConstantTrainer(String constantLabel) {
157        DummyClassifierTrainer trainer = new DummyClassifierTrainer();
158        trainer.dummyType = DummyType.CONSTANT;
159        trainer.constantLabel = constantLabel;
160        return trainer;
161    }
162
163    /**
164     * Creates a trainer which creates models which return random labels sampled uniformly from the labels seen at training time.
165     * @param seed The RNG seed to use.
166     * @return A classification trainer.
167     */
168    public static DummyClassifierTrainer createUniformTrainer(long seed) {
169        DummyClassifierTrainer trainer = new DummyClassifierTrainer();
170        trainer.dummyType = DummyType.UNIFORM;
171        trainer.seed = seed;
172        return trainer;
173    }
174
175    /**
176     * Creates a trainer which creates models which return a fixed label, the one which was most frequent in the training data.
177     * @return A classification trainer.
178     */
179    public static DummyClassifierTrainer createMostFrequentTrainer() {
180        DummyClassifierTrainer trainer = new DummyClassifierTrainer();
181        trainer.dummyType = DummyType.MOST_FREQUENT;
182        return trainer;
183    }
184}