001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.regression.baseline;
018
019import com.oracle.labs.mlrg.olcut.config.Config;
020import com.oracle.labs.mlrg.olcut.config.PropertyException;
021import com.oracle.labs.mlrg.olcut.provenance.ObjectProvenance;
022import com.oracle.labs.mlrg.olcut.provenance.Provenance;
023import com.oracle.labs.mlrg.olcut.provenance.primitives.DoubleProvenance;
024import com.oracle.labs.mlrg.olcut.provenance.primitives.EnumProvenance;
025import com.oracle.labs.mlrg.olcut.provenance.primitives.LongProvenance;
026import com.oracle.labs.mlrg.olcut.provenance.primitives.StringProvenance;
027import com.oracle.labs.mlrg.olcut.util.Pair;
028import org.tribuo.Dataset;
029import org.tribuo.Example;
030import org.tribuo.ImmutableOutputInfo;
031import org.tribuo.Trainer;
032import org.tribuo.provenance.ModelProvenance;
033import org.tribuo.provenance.TrainerProvenance;
034import org.tribuo.provenance.impl.TrainerProvenanceImpl;
035import org.tribuo.regression.Regressor;
036import org.tribuo.util.Util;
037
038import java.time.OffsetDateTime;
039import java.util.Arrays;
040import java.util.HashMap;
041import java.util.Map;
042import java.util.Objects;
043import java.util.Set;
044
045/**
046 * A trainer for simple baseline regressors. Use this only for comparison purposes, if you can't beat these
047 * baselines, your ML system doesn't work.
048 */
049public final class DummyRegressionTrainer implements Trainer<Regressor> {
050
051    /**
052     * Types of dummy regression model.
053     */
054    public enum DummyType {
055        /**
056         * Returns the mean of the training data outputs.
057         */
058        MEAN,
059        /**
060         * Returns the median of the training data outputs.
061         */
062        MEDIAN,
063        /**
064         * Returns the training data output at the specified fraction of the sorted output.
065         */
066        QUARTILE,
067        /**
068         * Returns the specified constant value.
069         */
070        CONSTANT,
071        /**
072         * Samples from a Gaussian using the means and variances from the training data.
073         */
074        GAUSSIAN
075    }
076
077    @Config(mandatory = true, description="Type of dummy regressor.")
078    private DummyType dummyType;
079
080    @Config(description="Constant value to use for the constant regressor.")
081    private double constantValue = Double.NaN;
082
083    @Config(description="Quartile to use.")
084    private double quartile = Double.NaN;
085
086    @Config(description="The seed for the RNG.")
087    private long seed = 1L;
088
089    private int invocationCount = 0;
090
091    private DummyRegressionTrainer() { }
092
093    /**
094     * Used by the OLCUT configuration system, and should not be called by external code.
095     */
096    @Override
097    public void postConfig() {
098        if ((dummyType == DummyType.CONSTANT) && (Double.isNaN(constantValue))) {
099            throw new PropertyException("","constantValue","Please supply a constant value when using the type CONSTANT.");
100        }
101        if ((dummyType == DummyType.QUARTILE) && ((quartile < 0.) || (quartile > 1.0))) {
102            throw new PropertyException("","quartile","Please supply a quartile between zero and one when using the type QUARTILE.");
103        }
104    }
105
106    @Override
107    public DummyRegressionModel train(Dataset<Regressor> examples, Map<String, Provenance> instanceProvenance) {
108        ModelProvenance provenance = new ModelProvenance(DummyRegressionModel.class.getName(), OffsetDateTime.now(), examples.getProvenance(), getProvenance(), instanceProvenance);
109        invocationCount++;
110        ImmutableOutputInfo<Regressor> outputInfo = examples.getOutputIDInfo();
111        Set<Regressor> domain = outputInfo.getDomain();
112        double[][] outputs = new double[outputInfo.size()][examples.size()];
113        int i = 0;
114        for (Example<Regressor> e : examples) {
115            for (Regressor.DimensionTuple r : e.getOutput()) {
116                int id = outputInfo.getID(r);
117                outputs[id][i] = r.getValue();
118            }
119            i++;
120        }
121        Regressor regressor;
122        switch (dummyType) {
123            case CONSTANT: {
124                Regressor.DimensionTuple[] output = new Regressor.DimensionTuple[outputs.length];
125                for (Regressor r : domain) {
126                    int id = outputInfo.getID(r);
127                    output[id] = new Regressor.DimensionTuple(r.getNames()[0],constantValue);
128                }
129                regressor = new Regressor(output);
130                return new DummyRegressionModel(provenance,examples.getFeatureIDMap(),outputInfo,dummyType,regressor);
131            }
132            case MEAN: {
133                Regressor.DimensionTuple[] output = new Regressor.DimensionTuple[outputs.length];
134                for (Regressor r : domain) {
135                    int id = outputInfo.getID(r);
136                    output[id] = new Regressor.DimensionTuple(r.getNames()[0],Util.mean(outputs[id]));
137                }
138                regressor = new Regressor(output);
139                return new DummyRegressionModel(provenance,examples.getFeatureIDMap(),outputInfo,dummyType,regressor);
140            }
141            case MEDIAN: {
142                Regressor.DimensionTuple[] output = new Regressor.DimensionTuple[outputs.length];
143                for (Regressor r : domain) {
144                    int id = outputInfo.getID(r);
145                    Arrays.sort(outputs[id]);
146                    output[id] = new Regressor.DimensionTuple(r.getNames()[0],outputs[id][outputs[id].length/2]);
147                }
148                regressor = new Regressor(output);
149                return new DummyRegressionModel(provenance,examples.getFeatureIDMap(),outputInfo,dummyType,regressor);
150            }
151            case QUARTILE: {
152                Regressor.DimensionTuple[] output = new Regressor.DimensionTuple[outputs.length];
153                for (Regressor r : domain) {
154                    int id = outputInfo.getID(r);
155                    Arrays.sort(outputs[id]);
156                    output[id] = new Regressor.DimensionTuple(r.getNames()[0],outputs[id][(int) (quartile*outputs[id].length)]);
157                }
158                regressor = new Regressor(output);
159                return new DummyRegressionModel(provenance,examples.getFeatureIDMap(),outputInfo,dummyType,regressor);
160            }
161            case GAUSSIAN: {
162                double[] means = new double[outputs.length];
163                double[] variances = new double[outputs.length];
164                String[] names = new String[outputs.length];
165                for (Regressor r : domain) {
166                    int id = outputInfo.getID(r);
167                    names[id] = r.getNames()[0];
168                    Pair<Double,Double> meanVariance = Util.meanAndVariance(outputs[id]);
169                    means[id] = meanVariance.getA();
170                    variances[id] = meanVariance.getB();
171                }
172                return new DummyRegressionModel(provenance,examples.getFeatureIDMap(),outputInfo,seed,means,variances,names);
173            }
174            default:
175                throw new IllegalStateException("Unknown dummyType " + dummyType);
176        }
177    }
178
179    @Override
180    public String toString() {
181        switch (dummyType) {
182            case CONSTANT:
183                return "DummyRegressionTrainer(dummyType=CONSTANT,constantValue="+constantValue+")";
184            case MEAN:
185                return "DummyRegressionTrainer(dummyType=MEAN)";
186            case MEDIAN:
187                return "DummyRegressionTrainer(dummyType=MEDIAN)";
188            case QUARTILE:
189                return "DummyRegressionTrainer(dummyType=QUARTILE,quartile="+quartile+")";
190            case GAUSSIAN:
191                return "DummyRegressionTrainer(dummyType=GAUSSIAN,seed="+seed+")";
192            default:
193                return "DummyRegressionTrainer(dummyType="+dummyType+")";
194        }
195    }
196
197    @Override
198    public int getInvocationCount() {
199        return invocationCount;
200    }
201
202    @Override
203    public TrainerProvenance getProvenance() {
204        return new TrainerProvenanceImpl(this);
205    }
206
207    /**
208     * Creates a trainer which create models which return a fixed value.
209     * @param value The value to return
210     * @return A regression trainer.
211     */
212    public static DummyRegressionTrainer createConstantTrainer(double value) {
213        DummyRegressionTrainer trainer = new DummyRegressionTrainer();
214        trainer.dummyType = DummyType.CONSTANT;
215        trainer.constantValue = value;
216        return trainer;
217    }
218
219    /**
220     * Creates a trainer which create models which sample the output from a gaussian distribution fit to the training data.
221     * @param seed The RNG seed.
222     * @return A regression trainer.
223     */
224    public static DummyRegressionTrainer createGaussianTrainer(long seed) {
225        DummyRegressionTrainer trainer = new DummyRegressionTrainer();
226        trainer.dummyType = DummyType.GAUSSIAN;
227        trainer.seed = seed;
228        return trainer;
229    }
230
231    /**
232     * Creates a trainer which create models which return the mean of the training data.
233     * @return A regression trainer.
234     */
235    public static DummyRegressionTrainer createMeanTrainer() {
236        DummyRegressionTrainer trainer = new DummyRegressionTrainer();
237        trainer.dummyType = DummyType.MEAN;
238        return trainer;
239    }
240
241    /**
242     * Creates a trainer which create models which return the median of the training data.
243     * @return A regression trainer.
244     */
245    public static DummyRegressionTrainer createMedianTrainer() {
246        DummyRegressionTrainer trainer = new DummyRegressionTrainer();
247        trainer.dummyType = DummyType.MEDIAN;
248        return trainer;
249    }
250
251    /**
252     * Creates a trainer which create models which return the value at the specified fraction of the sorted training data.
253     * @param value The quartile value.
254     * @return A regression trainer.
255     */
256    public static DummyRegressionTrainer createQuartileTrainer(double value) {
257        if (Double.isNaN(value) || value < 0.0 || value > 1.0) {
258            throw new IllegalArgumentException("Please provide an appropriate value between 0.0 and 1.0, found " + value);
259        }
260        DummyRegressionTrainer trainer = new DummyRegressionTrainer();
261        trainer.dummyType = DummyType.QUARTILE;
262        trainer.quartile = value;
263        return trainer;
264    }
265
266    /**
267     * Provenance for {@link DummyRegressionTrainer}.
268     */
269    @Deprecated
270    public final static class DummyRegressionTrainerProvenance implements TrainerProvenance {
271        private static final long serialVersionUID = 1L;
272
273        private final String className;
274        private final DummyType dummyType;
275        private final long seed;
276        private final double constantValue;
277        private final double quartile;
278
279        /**
280         * Constructs a provenance from the host.
281         * @param host The host trainer.
282         */
283        public DummyRegressionTrainerProvenance(DummyRegressionTrainer host) {
284            this.className = host.getClass().getName();
285            this.dummyType = host.dummyType;
286            this.seed = host.seed;
287            this.constantValue = host.constantValue;
288            this.quartile = host.quartile;
289        }
290
291        /**
292         * Constructs a provenance from the marshalled form.
293         * @param map The map of field values.
294         */
295        public DummyRegressionTrainerProvenance(Map<String, Provenance> map) {
296            className = ObjectProvenance.checkAndExtractProvenance(map,CLASS_NAME, StringProvenance.class, DummyRegressionTrainerProvenance.class.getSimpleName()).getValue();
297            dummyType = (DummyType) ObjectProvenance.checkAndExtractProvenance(map,"dummyType", EnumProvenance.class, DummyRegressionTrainerProvenance.class.getSimpleName()).getValue();
298            seed = ObjectProvenance.checkAndExtractProvenance(map,"seed", LongProvenance.class, DummyRegressionTrainerProvenance.class.getSimpleName()).getValue();
299            constantValue = ObjectProvenance.checkAndExtractProvenance(map,"constantValue", DoubleProvenance.class, DummyRegressionTrainerProvenance.class.getSimpleName()).getValue();
300            quartile = ObjectProvenance.checkAndExtractProvenance(map,"quartile", DoubleProvenance.class, DummyRegressionTrainerProvenance.class.getSimpleName()).getValue();
301        }
302
303        @Override
304        public Map<String, Provenance> getConfiguredParameters() {
305            Map<String, Provenance> map = new HashMap<>();
306
307            map.put("dummyType",new EnumProvenance<>("dummyType",dummyType));
308            map.put("constantValue",new DoubleProvenance("constantValue",constantValue));
309            map.put("quartile",new DoubleProvenance("quartile",quartile));
310            map.put("seed",new LongProvenance("seed",seed));
311
312            return map;
313        }
314
315        @Override
316        public String getClassName() {
317            return className;
318        }
319
320        @Override
321        public String toString() {
322            return generateString("Trainer");
323        }
324
325        @Override
326        public boolean equals(Object o) {
327            if (this == o) return true;
328            if (o == null || getClass() != o.getClass()) return false;
329            DummyRegressionTrainerProvenance pairs = (DummyRegressionTrainerProvenance) o;
330            return seed == pairs.seed &&
331                    Double.compare(pairs.constantValue, constantValue) == 0 &&
332                    Double.compare(pairs.quartile, quartile) == 0 &&
333                    className.equals(pairs.className) &&
334                    dummyType == pairs.dummyType;
335        }
336
337        @Override
338        public int hashCode() {
339            return Objects.hash(className, dummyType, seed, constantValue, quartile);
340        }
341    }
342}