001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.regression.example;
018
019import com.oracle.labs.mlrg.olcut.config.Config;
020import com.oracle.labs.mlrg.olcut.config.PropertyException;
021import com.oracle.labs.mlrg.olcut.provenance.ObjectProvenance;
022import com.oracle.labs.mlrg.olcut.provenance.Provenance;
023import com.oracle.labs.mlrg.olcut.provenance.impl.SkeletalConfiguredObjectProvenance;
024import com.oracle.labs.mlrg.olcut.provenance.primitives.StringProvenance;
025import org.tribuo.ConfigurableDataSource;
026import org.tribuo.Dataset;
027import org.tribuo.Example;
028import org.tribuo.MutableDataset;
029import org.tribuo.OutputFactory;
030import org.tribuo.Trainer;
031import org.tribuo.impl.ArrayExample;
032import org.tribuo.provenance.ConfiguredDataSourceProvenance;
033import org.tribuo.provenance.DataSourceProvenance;
034import org.tribuo.regression.RegressionFactory;
035import org.tribuo.regression.Regressor;
036
037import java.util.ArrayList;
038import java.util.Collections;
039import java.util.HashMap;
040import java.util.Iterator;
041import java.util.List;
042import java.util.Map;
043import java.util.Random;
044
045/**
046 * Generates a single dimensional output drawn from N(slope*x + intercept,variance).
047 * <p>
048 * The single feature is drawn from a uniform distribution over the range.
049 * <p>
050 * Set slope to zero to draw from a gaussian.
051 */
052public class GaussianDataSource implements ConfigurableDataSource<Regressor> {
053    @Config(mandatory=true)
054    private int numSamples;
055
056    @Config
057    private float slope;
058
059    @Config
060    private float intercept;
061
062    @Config
063    private float variance = 1.0f;
064
065    @Config(mandatory=true)
066    private float xMin;
067
068    @Config(mandatory=true)
069    private float xMax;
070
071    @Config
072    private long seed = Trainer.DEFAULT_SEED;
073
074    private List<Example<Regressor>> examples;
075
076    private final RegressionFactory factory = new RegressionFactory();
077
078    /**
079     * For OLCUT
080     */
081    private GaussianDataSource() {}
082
083    /**
084     * Generates a single dimensional output drawn from N(slope*x + intercept,variance).
085     * <p>
086     * The single feature is drawn from a uniform distribution over the range.
087     * <p>
088     * Set slope to zero to draw from a gaussian.
089     * @param numSamples The size of the output dataset.
090     * @param slope The slope of the line.
091     * @param intercept The y intercept of the line.
092     * @param variance The variance of the gaussian.
093     * @param xMin The minimum x value (inclusive).
094     * @param xMax The maximum x value (exclusive).
095     * @param seed The rng seed to use.
096     */
097    public GaussianDataSource(int numSamples, float slope, float intercept, float variance, float xMin, float xMax, long seed) {
098        this.numSamples = numSamples;
099        this.slope = slope;
100        this.intercept = intercept;
101        this.variance = variance;
102        this.xMin = xMin;
103        this.xMax = xMax;
104        this.seed = seed;
105        postConfig();
106    }
107
108    /**
109     * Used by the OLCUT configuration system, and should not be called by external code.
110     */
111    @Override
112    public void postConfig() {
113        Random rng = new Random(seed);
114        List<Example<Regressor>> examples = new ArrayList<>(numSamples);
115        if (xMax <= xMin) {
116            throw new PropertyException("","xMax","xMax must be greater than xMin, found xMax = " + xMax + ", xMin = " + xMin);
117        }
118        if (variance <= 0.0) {
119            throw new PropertyException("","variance","Variance must be positive, found variance = " + variance);
120        }
121        double range = xMax - xMin;
122        for (int i = 0; i < numSamples; i++) {
123            double input = (rng.nextDouble() * range) + xMin;
124            Regressor output = new Regressor("Y",(rng.nextGaussian() * variance) + ((slope * input) + intercept));
125            ArrayExample<Regressor> e = new ArrayExample<>(output,new String[]{"X"},new double[]{input});
126            examples.add(e);
127        }
128        this.examples = Collections.unmodifiableList(examples);
129    }
130
131    @Override
132    public OutputFactory<Regressor> getOutputFactory() {
133        return factory;
134    }
135
136    @Override
137    public DataSourceProvenance getProvenance() {
138        return new GaussianDataSourceProvenance(this);
139    }
140
141    @Override
142    public Iterator<Example<Regressor>> iterator() {
143        return examples.iterator();
144    }
145
146    /**
147     * Generates a single dimensional output drawn from N(slope*x + intercept,variance).
148     * <p>
149     * The single feature is drawn from a uniform distribution over the range.
150     * <p>
151     * Set slope to zero to draw from a gaussian.
152     * @param numSamples The size of the output dataset.
153     * @param slope The slope of the line.
154     * @param intercept The y intercept of the line.
155     * @param variance The variance of the gaussian.
156     * @param xMin The minimum x value (inclusive).
157     * @param xMax The maximum x value (exclusive).
158     * @param seed The rng seed to use.
159     * @return A dataset drawn from a gaussian.
160     */
161    public static Dataset<Regressor> generateDataset(int numSamples, float slope, float intercept, float variance, float xMin, float xMax, long seed) {
162        GaussianDataSource source = new GaussianDataSource(numSamples,slope,intercept,variance,xMin,xMax,seed);
163        return new MutableDataset<>(source);
164    }
165
166    /**
167     * Provenance for {@link GaussianDataSource}.
168     */
169    public static class GaussianDataSourceProvenance extends SkeletalConfiguredObjectProvenance implements ConfiguredDataSourceProvenance {
170        private static final long serialVersionUID = 1L;
171
172        /**
173         * Constructs a provenance from the host data source.
174         * @param host The host to read.
175         */
176        GaussianDataSourceProvenance(GaussianDataSource host) {
177            super(host,"DataSource");
178        }
179
180        /**
181         * Constructs a provenance from the marshalled form.
182         * @param map The map of field values.
183         */
184        public GaussianDataSourceProvenance(Map<String, Provenance> map) {
185            this(extractProvenanceInfo(map));
186        }
187
188        private GaussianDataSourceProvenance(ExtractedInfo info) {
189            super(info);
190        }
191
192        /**
193         * Extracts the relevant provenance information fields for this class.
194         * @param map The map to remove values from.
195         * @return The extracted information.
196         */
197        protected static ExtractedInfo extractProvenanceInfo(Map<String,Provenance> map) {
198            Map<String,Provenance> configuredParameters = new HashMap<>(map);
199            String className = ObjectProvenance.checkAndExtractProvenance(configuredParameters,CLASS_NAME, StringProvenance.class, GaussianDataSourceProvenance.class.getSimpleName()).getValue();
200            String hostTypeStringName = ObjectProvenance.checkAndExtractProvenance(configuredParameters,HOST_SHORT_NAME, StringProvenance.class, GaussianDataSourceProvenance.class.getSimpleName()).getValue();
201
202            return new ExtractedInfo(className,hostTypeStringName,configuredParameters,Collections.emptyMap());
203        }
204    }
205}