Package org.tribuo.clustering.example
Class GaussianClusterDataSource
java.lang.Object
org.tribuo.clustering.example.GaussianClusterDataSource
- All Implemented Interfaces:
com.oracle.labs.mlrg.olcut.config.Configurable
,com.oracle.labs.mlrg.olcut.provenance.Provenancable<DataSourceProvenance>
,Iterable<Example<ClusterID>>
,ConfigurableDataSource<ClusterID>
,DataSource<ClusterID>
public final class GaussianClusterDataSource
extends Object
implements ConfigurableDataSource<ClusterID>
Generates a clustering dataset drawn from a mixture of 5 Gaussians.
The Gaussians can be at most 4 dimensional, resulting in 4 features.
By default the Gaussians are 2-dimensional with the following means and variances:
- N([0.0,0.0], [[1.0,0.0],[0.0,1.0]])
- N([5.0,5.0], [[1.0,0.0],[0.0,1.0]])
- N([2.5,2.5], [[1.0,0.5],[0.5,1.0]])
- N([10.0,0.0], [[0.1,0.0],[0.0,0.1]])
- N([-1.0,0.0], [[1.0,0.0],[0.0,0.1]])
-
Nested Class Summary
-
Constructor Summary
ConstructorDescriptionGaussianClusterDataSource
(int numSamples, double[] mixingDistribution, double[] firstMean, double[] firstVariance, double[] secondMean, double[] secondVariance, double[] thirdMean, double[] thirdVariance, double[] fourthMean, double[] fourthVariance, double[] fifthMean, double[] fifthVariance, long seed) Generates a clustering dataset drawn from a mixture of 5 Gaussians.GaussianClusterDataSource
(int numSamples, long seed) Generates a clustering dataset drawn from a mixture of 5 Gaussians. -
Method Summary
Modifier and TypeMethodDescriptionstatic MutableDataset<ClusterID>
generateDataset
(int numSamples, double[] mixingDistribution, double[] firstMean, double[] firstVariance, double[] secondMean, double[] secondVariance, double[] thirdMean, double[] thirdVariance, double[] fourthMean, double[] fourthVariance, double[] fifthMean, double[] fifthVariance, long seed) Generates a clustering dataset drawn from a mixture of 5 Gaussians.Returns the OutputFactory associated with this Output subclass.iterator()
void
Used by the OLCUT configuration system, and should not be called by external code.Methods inherited from class java.lang.Object
clone, equals, finalize, getClass, hashCode, notify, notifyAll, toString, wait, wait, wait
Methods inherited from interface java.lang.Iterable
forEach, spliterator
-
Constructor Details
-
GaussianClusterDataSource
public GaussianClusterDataSource(int numSamples, long seed) Generates a clustering dataset drawn from a mixture of 5 Gaussians.The default Gaussians are:
- N([0.0,0.0], [[1.0,0.0],[0.0,1.0]])
- N([5.0,5.0], [[1.0,0.0],[0.0,1.0]])
- N([2.5,2.5], [[1.0,0.5],[0.5,1.0]])
- N([10.0,0.0], [[0.1,0.0],[0.0,0.1]])
- N([-1.0,0.0], [[1.0,0.0],[0.0,0.1]])
- Parameters:
numSamples
- The size of the output dataset.seed
- The rng seed to use.
-
GaussianClusterDataSource
public GaussianClusterDataSource(int numSamples, double[] mixingDistribution, double[] firstMean, double[] firstVariance, double[] secondMean, double[] secondVariance, double[] thirdMean, double[] thirdVariance, double[] fourthMean, double[] fourthVariance, double[] fifthMean, double[] fifthVariance, long seed) Generates a clustering dataset drawn from a mixture of 5 Gaussians.The Gaussians can be at most 4 dimensional, resulting in 4 features.
- Parameters:
numSamples
- The size of the output dataset.mixingDistribution
- The probability of each cluster.firstMean
- The mean of the first Gaussian.firstVariance
- The variance of the first Gaussian, linearised from a row-major matrix.secondMean
- The mean of the second Gaussian.secondVariance
- The variance of the second Gaussian, linearised from a row-major matrix.thirdMean
- The mean of the third Gaussian.thirdVariance
- The variance of the third Gaussian, linearised from a row-major matrix.fourthMean
- The mean of the fourth Gaussian.fourthVariance
- The variance of the fourth Gaussian, linearised from a row-major matrix.fifthMean
- The mean of the fifth Gaussian.fifthVariance
- The variance of the fifth Gaussian, linearised from a row-major matrix.seed
- The rng seed to use.
-
-
Method Details
-
postConfig
public void postConfig()Used by the OLCUT configuration system, and should not be called by external code.- Specified by:
postConfig
in interfacecom.oracle.labs.mlrg.olcut.config.Configurable
-
getOutputFactory
Description copied from interface:DataSource
Returns the OutputFactory associated with this Output subclass.- Specified by:
getOutputFactory
in interfaceDataSource<ClusterID>
- Returns:
- The output factory.
-
getProvenance
- Specified by:
getProvenance
in interfacecom.oracle.labs.mlrg.olcut.provenance.Provenancable<DataSourceProvenance>
-
iterator
-
generateDataset
public static MutableDataset<ClusterID> generateDataset(int numSamples, double[] mixingDistribution, double[] firstMean, double[] firstVariance, double[] secondMean, double[] secondVariance, double[] thirdMean, double[] thirdVariance, double[] fourthMean, double[] fourthVariance, double[] fifthMean, double[] fifthVariance, long seed) Generates a clustering dataset drawn from a mixture of 5 Gaussians.The Gaussians can be at most 4 dimensional, resulting in 4 features.
- Parameters:
numSamples
- The size of the output dataset.mixingDistribution
- The probability of each cluster.firstMean
- The mean of the first Gaussian.firstVariance
- The variance of the first Gaussian, linearised from a row-major matrix.secondMean
- The mean of the second Gaussian.secondVariance
- The variance of the second Gaussian, linearised from a row-major matrix.thirdMean
- The mean of the third Gaussian.thirdVariance
- The variance of the third Gaussian, linearised from a row-major matrix.fourthMean
- The mean of the fourth Gaussian.fourthVariance
- The variance of the fourth Gaussian, linearised from a row-major matrix.fifthMean
- The mean of the fifth Gaussian.fifthVariance
- The variance of the fifth Gaussian, linearised from a row-major matrix.seed
- The rng seed to use.- Returns:
- A dataset drawn from a mixture of Gaussians.
-