001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.regression.baseline;
018
019import com.oracle.labs.mlrg.olcut.util.Pair;
020import org.tribuo.Example;
021import org.tribuo.Excuse;
022import org.tribuo.ImmutableFeatureMap;
023import org.tribuo.ImmutableOutputInfo;
024import org.tribuo.Model;
025import org.tribuo.Prediction;
026import org.tribuo.Trainer;
027import org.tribuo.provenance.ModelProvenance;
028import org.tribuo.regression.Regressor;
029import org.tribuo.regression.baseline.DummyRegressionTrainer.DummyType;
030
031import java.util.Arrays;
032import java.util.Collections;
033import java.util.List;
034import java.util.Map;
035import java.util.Optional;
036import java.util.Random;
037
038/**
039 * A model which performs dummy regressions (e.g., constant output, gaussian sampled output, mean value, median, quartile).
040 */
041public class DummyRegressionModel extends Model<Regressor> {
042    private static final long serialVersionUID = 2L;
043
044    private final DummyType dummyType;
045
046    private final Regressor output;
047
048    private final long seed;
049
050    private final Random rng;
051
052    private final double[] means;
053
054    private final double[] variances;
055
056    private final String[] dimensionNames;
057
058    DummyRegressionModel(ModelProvenance description, ImmutableFeatureMap featureIDMap, ImmutableOutputInfo<Regressor> outputIDInfo, long seed, double[] means, double[] variances, String[] names) {
059        super("dummy-GAUSSIAN-regression", description, featureIDMap, outputIDInfo, false);
060        this.dummyType = DummyType.GAUSSIAN;
061        this.output = null;
062        this.seed = seed;
063        this.rng = new Random(seed);
064        this.means = Arrays.copyOf(means,means.length);
065        this.variances = Arrays.copyOf(variances,variances.length);
066        this.dimensionNames = Arrays.copyOf(names,names.length);
067    }
068
069    DummyRegressionModel(ModelProvenance description, ImmutableFeatureMap featureIDMap, ImmutableOutputInfo<Regressor> outputIDInfo, DummyType dummyType, Regressor regressor) {
070        super("dummy-"+dummyType+"-regression", description, featureIDMap, outputIDInfo, false);
071        this.dummyType = dummyType;
072        this.output = regressor;
073        this.seed = Trainer.DEFAULT_SEED;
074        this.rng = null;
075        this.means = new double[0];
076        this.variances = new double[0];
077        this.dimensionNames = new String[0];
078    }
079
080    @Override
081    public Prediction<Regressor> predict(Example<Regressor> example) {
082        switch (dummyType) {
083            case CONSTANT:
084            case MEAN:
085            case MEDIAN:
086            case QUARTILE:
087                return new Prediction<>(output,0,example);
088            case GAUSSIAN: {
089                Regressor.DimensionTuple[] dimensions = new Regressor.DimensionTuple[dimensionNames.length];
090                for (int i = 0; i < dimensionNames.length; i++) {
091                    double regressionValue = (rng.nextGaussian() * variances[i]) + means[i];
092                    dimensions[i] = new Regressor.DimensionTuple(dimensionNames[i],regressionValue);
093                }
094                return new Prediction<>(new Regressor(dimensions), 0, example);
095            }
096            default:
097                throw new IllegalStateException("Unknown dummyType " + dummyType);
098        }
099    }
100
101    @Override
102    public Map<String, List<Pair<String, Double>>> getTopFeatures(int n) {
103        if (n != 0) {
104            return Collections.singletonMap(Model.ALL_OUTPUTS, Collections.singletonList(new Pair<>(BIAS_FEATURE, 1.0)));
105        } else {
106            return Collections.emptyMap();
107        }
108    }
109
110    @Override
111    public Optional<Excuse<Regressor>> getExcuse(Example<Regressor> example) {
112        return Optional.of(new Excuse<>(example,predict(example),getTopFeatures(1)));
113    }
114
115    @Override
116    protected Model<Regressor> copy(String newName, ModelProvenance newProvenance) {
117        switch (dummyType) {
118            case GAUSSIAN:
119                return new DummyRegressionModel(newProvenance,featureIDMap,outputIDInfo,seed,means,variances,dimensionNames);
120            case CONSTANT:
121            case MEAN:
122            case MEDIAN:
123            case QUARTILE:
124                return new DummyRegressionModel(newProvenance,featureIDMap,outputIDInfo,dummyType,output.copy());
125            default:
126                throw new IllegalStateException("Unknown dummyType " + dummyType);
127        }
128    }
129}