001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.regression.example; 018 019import com.oracle.labs.mlrg.olcut.config.Config; 020import com.oracle.labs.mlrg.olcut.config.PropertyException; 021import com.oracle.labs.mlrg.olcut.provenance.ObjectProvenance; 022import com.oracle.labs.mlrg.olcut.provenance.Provenance; 023import com.oracle.labs.mlrg.olcut.provenance.impl.SkeletalConfiguredObjectProvenance; 024import com.oracle.labs.mlrg.olcut.provenance.primitives.StringProvenance; 025import org.tribuo.ConfigurableDataSource; 026import org.tribuo.Dataset; 027import org.tribuo.Example; 028import org.tribuo.MutableDataset; 029import org.tribuo.OutputFactory; 030import org.tribuo.Trainer; 031import org.tribuo.impl.ArrayExample; 032import org.tribuo.provenance.ConfiguredDataSourceProvenance; 033import org.tribuo.provenance.DataSourceProvenance; 034import org.tribuo.regression.RegressionFactory; 035import org.tribuo.regression.Regressor; 036 037import java.util.ArrayList; 038import java.util.Collections; 039import java.util.HashMap; 040import java.util.Iterator; 041import java.util.List; 042import java.util.Map; 043import java.util.Random; 044 045/** 046 * Generates a single dimensional output drawn from N(slope*x + intercept,variance). 047 * <p> 048 * The single feature is drawn from a uniform distribution over the range. 049 * <p> 050 * Set slope to zero to draw from a gaussian. 051 */ 052public class GaussianDataSource implements ConfigurableDataSource<Regressor> { 053 @Config(mandatory=true) 054 private int numSamples; 055 056 @Config 057 private float slope; 058 059 @Config 060 private float intercept; 061 062 @Config 063 private float variance = 1.0f; 064 065 @Config(mandatory=true) 066 private float xMin; 067 068 @Config(mandatory=true) 069 private float xMax; 070 071 @Config 072 private long seed = Trainer.DEFAULT_SEED; 073 074 private List<Example<Regressor>> examples; 075 076 private final RegressionFactory factory = new RegressionFactory(); 077 078 /** 079 * For OLCUT 080 */ 081 private GaussianDataSource() {} 082 083 /** 084 * Generates a single dimensional output drawn from N(slope*x + intercept,variance). 085 * <p> 086 * The single feature is drawn from a uniform distribution over the range. 087 * <p> 088 * Set slope to zero to draw from a gaussian. 089 * @param numSamples The size of the output dataset. 090 * @param slope The slope of the line. 091 * @param intercept The y intercept of the line. 092 * @param variance The variance of the gaussian. 093 * @param xMin The minimum x value (inclusive). 094 * @param xMax The maximum x value (exclusive). 095 * @param seed The rng seed to use. 096 */ 097 public GaussianDataSource(int numSamples, float slope, float intercept, float variance, float xMin, float xMax, long seed) { 098 this.numSamples = numSamples; 099 this.slope = slope; 100 this.intercept = intercept; 101 this.variance = variance; 102 this.xMin = xMin; 103 this.xMax = xMax; 104 this.seed = seed; 105 postConfig(); 106 } 107 108 /** 109 * Used by the OLCUT configuration system, and should not be called by external code. 110 */ 111 @Override 112 public void postConfig() { 113 Random rng = new Random(seed); 114 List<Example<Regressor>> examples = new ArrayList<>(numSamples); 115 if (xMax <= xMin) { 116 throw new PropertyException("","xMax","xMax must be greater than xMin, found xMax = " + xMax + ", xMin = " + xMin); 117 } 118 if (variance <= 0.0) { 119 throw new PropertyException("","variance","Variance must be positive, found variance = " + variance); 120 } 121 double range = xMax - xMin; 122 for (int i = 0; i < numSamples; i++) { 123 double input = (rng.nextDouble() * range) + xMin; 124 Regressor output = new Regressor("Y",(rng.nextGaussian() * variance) + ((slope * input) + intercept)); 125 ArrayExample<Regressor> e = new ArrayExample<>(output,new String[]{"X"},new double[]{input}); 126 examples.add(e); 127 } 128 this.examples = Collections.unmodifiableList(examples); 129 } 130 131 @Override 132 public OutputFactory<Regressor> getOutputFactory() { 133 return factory; 134 } 135 136 @Override 137 public DataSourceProvenance getProvenance() { 138 return new GaussianDataSourceProvenance(this); 139 } 140 141 @Override 142 public Iterator<Example<Regressor>> iterator() { 143 return examples.iterator(); 144 } 145 146 /** 147 * Generates a single dimensional output drawn from N(slope*x + intercept,variance). 148 * <p> 149 * The single feature is drawn from a uniform distribution over the range. 150 * <p> 151 * Set slope to zero to draw from a gaussian. 152 * @param numSamples The size of the output dataset. 153 * @param slope The slope of the line. 154 * @param intercept The y intercept of the line. 155 * @param variance The variance of the gaussian. 156 * @param xMin The minimum x value (inclusive). 157 * @param xMax The maximum x value (exclusive). 158 * @param seed The rng seed to use. 159 * @return A dataset drawn from a gaussian. 160 */ 161 public static Dataset<Regressor> generateDataset(int numSamples, float slope, float intercept, float variance, float xMin, float xMax, long seed) { 162 GaussianDataSource source = new GaussianDataSource(numSamples,slope,intercept,variance,xMin,xMax,seed); 163 return new MutableDataset<>(source); 164 } 165 166 /** 167 * Provenance for {@link GaussianDataSource}. 168 */ 169 public static class GaussianDataSourceProvenance extends SkeletalConfiguredObjectProvenance implements ConfiguredDataSourceProvenance { 170 private static final long serialVersionUID = 1L; 171 172 /** 173 * Constructs a provenance from the host data source. 174 * @param host The host to read. 175 */ 176 GaussianDataSourceProvenance(GaussianDataSource host) { 177 super(host,"DataSource"); 178 } 179 180 /** 181 * Constructs a provenance from the marshalled form. 182 * @param map The map of field values. 183 */ 184 public GaussianDataSourceProvenance(Map<String, Provenance> map) { 185 this(extractProvenanceInfo(map)); 186 } 187 188 private GaussianDataSourceProvenance(ExtractedInfo info) { 189 super(info); 190 } 191 192 /** 193 * Extracts the relevant provenance information fields for this class. 194 * @param map The map to remove values from. 195 * @return The extracted information. 196 */ 197 protected static ExtractedInfo extractProvenanceInfo(Map<String,Provenance> map) { 198 Map<String,Provenance> configuredParameters = new HashMap<>(map); 199 String className = ObjectProvenance.checkAndExtractProvenance(configuredParameters,CLASS_NAME, StringProvenance.class, GaussianDataSourceProvenance.class.getSimpleName()).getValue(); 200 String hostTypeStringName = ObjectProvenance.checkAndExtractProvenance(configuredParameters,HOST_SHORT_NAME, StringProvenance.class, GaussianDataSourceProvenance.class.getSimpleName()).getValue(); 201 202 return new ExtractedInfo(className,hostTypeStringName,configuredParameters,Collections.emptyMap()); 203 } 204 } 205}