Package org.tribuo.multilabel.example
Class MultiLabelGaussianDataSource
java.lang.Object
org.tribuo.multilabel.example.MultiLabelGaussianDataSource
- All Implemented Interfaces:
com.oracle.labs.mlrg.olcut.config.Configurable
,com.oracle.labs.mlrg.olcut.provenance.Provenancable<DataSourceProvenance>
,Iterable<Example<MultiLabel>>
,ConfigurableDataSource<MultiLabel>
,DataSource<MultiLabel>
public final class MultiLabelGaussianDataSource
extends Object
implements ConfigurableDataSource<MultiLabel>
Generates a multi label output drawn from a series of functions.
The functions are:
- y_0 is positive if N(w_00*x_0 + w_01*x_1 + w_02*x_1*x_0 + w_03*x_1*x_1*x_1,variance) > threshold_0.
- y_1 is positive if N(w_10*x_0 + w_11*x_1 + w_12*x_1*x_0 + w_13*x_1*x_1*x_1,variance) < threshold_1.
- y_2 is positive if N(w_20*x_0 + w_21*x_2 + w_22*x_1*x_0 + w_23*x_1*x_2*x_2,variance) > threshold_2.
By default y_1 is the inverse of y_0, and y_2 shares the same weights w_0 and w_2.
- y_0 weights = [1.0,1.0,1.0,1.0]
- y_1 weights = [1.0,1.0,1.0,1.0]
- y_2 weights = [1.0,-3.0,1.0,3.0]
- threshold = [0.0,0.0,2.0
The features are drawn from a uniform distribution over the range.
-
Nested Class Summary
-
Constructor Summary
ConstructorDescriptionMultiLabelGaussianDataSource
(int numSamples, float[] yZeroWeights, float[] yOneWeights, float[] yTwoWeights, float[] threshold, boolean[] negate, float variance, float[] xMin, float[] xMax, long seed) Generates a multi-label output drawn from three gaussian functions. -
Method Summary
Modifier and TypeMethodDescriptionstatic MutableDataset<MultiLabel>
generateDataset
(int numSamples, float[] yZeroWeights, float[] yOneWeights, float[] yTwoWeights, float[] threshold, boolean[] negate, float variance, float[] xMin, float[] xMax, long seed) Generates a multi-label output drawn from three gaussian functions.Returns the OutputFactory associated with this Output subclass.iterator()
static MultiLabelGaussianDataSource
makeDefaultSource
(int numSamples, long seed) Generates a multi label output drawn from a series of functions.void
Used by the OLCUT configuration system, and should not be called by external code.Methods inherited from class java.lang.Object
clone, equals, finalize, getClass, hashCode, notify, notifyAll, toString, wait, wait, wait
Methods inherited from interface java.lang.Iterable
forEach, spliterator
-
Constructor Details
-
MultiLabelGaussianDataSource
public MultiLabelGaussianDataSource(int numSamples, float[] yZeroWeights, float[] yOneWeights, float[] yTwoWeights, float[] threshold, boolean[] negate, float variance, float[] xMin, float[] xMax, long seed) Generates a multi-label output drawn from three gaussian functions.- N(w_00*x_0 + w_01*x_1 + w_02*x_1*x_0 + w_03*x_1*x_1*x_1,variance)
- N(w_10*x_0 + w_11*x_1 + w_12*x_1*x_0 + w_13*x_1*x_1*x_1,variance)
- N(w_20*x_0 + w_21*x_2 + w_22*x_1*x_0 + w_23*x_1*x_2*x_2,variance)
The features are drawn from a uniform distribution over the range.
- Parameters:
numSamples
- The size of the output dataset.yZeroWeights
- The feature weights for label y_0.yOneWeights
- The feature weights for label y_1.yTwoWeights
- The feature weights for label y_2.threshold
- The y threshold of each label.negate
- Should the computed value be negated before thresholding?variance
- The variance of the gaussian.xMin
- The minimum feature values (inclusive).xMax
- The maximum feature values (exclusive).seed
- The rng seed to use.
-
-
Method Details
-
postConfig
public void postConfig()Used by the OLCUT configuration system, and should not be called by external code.- Specified by:
postConfig
in interfacecom.oracle.labs.mlrg.olcut.config.Configurable
-
getOutputFactory
Description copied from interface:DataSource
Returns the OutputFactory associated with this Output subclass.- Specified by:
getOutputFactory
in interfaceDataSource<MultiLabel>
- Returns:
- The output factory.
-
getProvenance
- Specified by:
getProvenance
in interfacecom.oracle.labs.mlrg.olcut.provenance.Provenancable<DataSourceProvenance>
-
iterator
- Specified by:
iterator
in interfaceIterable<Example<MultiLabel>>
-
generateDataset
public static MutableDataset<MultiLabel> generateDataset(int numSamples, float[] yZeroWeights, float[] yOneWeights, float[] yTwoWeights, float[] threshold, boolean[] negate, float variance, float[] xMin, float[] xMax, long seed) Generates a multi-label output drawn from three gaussian functions.- N(w_00*x_0 + w_01*x_1 + w_02*x_1*x_0 + w_03*x_1*x_1*x_1,variance)
- N(w_10*x_0 + w_11*x_1 + w_12*x_1*x_0 + w_13*x_1*x_1*x_1,variance)
- N(w_20*x_0 + w_21*x_2 + w_22*x_1*x_0 + w_23*x_1*x_2*x_2,variance)
The features are drawn from a uniform distribution over the range.
- Parameters:
numSamples
- The size of the output dataset.yZeroWeights
- The feature weights for label y_0.yOneWeights
- The feature weights for label y_1.yTwoWeights
- The feature weights for label y_2.threshold
- The y threshold of each label.negate
- Should the computed value be negated before thresholding?variance
- The variance of the gaussian.xMin
- The minimum feature values (inclusive).xMax
- The maximum feature values (exclusive).seed
- The rng seed to use.- Returns:
- A dataset drawn from several gaussian generated labels.
-
makeDefaultSource
Generates a multi label output drawn from a series of functions.The functions are:
- y_0 is positive if N(w_00*x_0 + w_01*x_1 + w_02*x_1*x_0 + w_03*x_1*x_1*x_1,variance) > threshold_0.
- y_1 is positive if N(w_10*x_0 + w_11*x_1 + w_12*x_1*x_0 + w_13*x_1*x_1*x_1,variance) < threshold_1.
- y_2 is positive if N(w_20*x_0 + w_21*x_2 + w_22*x_1*x_0 + w_23*x_1*x_2*x_2,variance) > threshold_2.
By default y_1 is the inverse of y_0, and y_2 shares the same weights w_0 and w_2.
- y_0 weights = [1.0,1.0,1.0,1.0]
- y_1 weights = [1.0,1.0,1.0,1.0]
- y_2 weights = [1.0,-3.0,1.0,3.0]
- threshold = [0.0,0.0,2.0
The features are drawn from a uniform distribution over the range.
- Parameters:
numSamples
- The number of samples to draw.seed
- The RNG seed.- Returns:
- A dataset drawn from multiple Gaussians.
-