Class GaussianClusterDataSource

java.lang.Object
org.tribuo.clustering.example.GaussianClusterDataSource
All Implemented Interfaces:
com.oracle.labs.mlrg.olcut.config.Configurable, com.oracle.labs.mlrg.olcut.provenance.Provenancable<DataSourceProvenance>, Iterable<Example<ClusterID>>, ConfigurableDataSource<ClusterID>, DataSource<ClusterID>

public final class GaussianClusterDataSource extends Object implements ConfigurableDataSource<ClusterID>
Generates a clustering dataset drawn from a mixture of 5 Gaussians.

The Gaussians can be at most 4 dimensional, resulting in 4 features.

By default the Gaussians are 2-dimensional with the following means and variances:

  • N([0.0,0.0], [[1.0,0.0],[0.0,1.0]])
  • N([5.0,5.0], [[1.0,0.0],[0.0,1.0]])
  • N([2.5,2.5], [[1.0,0.5],[0.5,1.0]])
  • N([10.0,0.0], [[0.1,0.0],[0.0,0.1]])
  • N([-1.0,0.0], [[1.0,0.0],[0.0,0.1]])
and the mixing distribution is: [0.1, 0.35, 0.05, 0.25, 0.25].
  • Constructor Details

    • GaussianClusterDataSource

      public GaussianClusterDataSource(int numSamples, long seed)
      Generates a clustering dataset drawn from a mixture of 5 Gaussians.

      The default Gaussians are:

      • N([0.0,0.0], [[1.0,0.0],[0.0,1.0]])
      • N([5.0,5.0], [[1.0,0.0],[0.0,1.0]])
      • N([2.5,2.5], [[1.0,0.5],[0.5,1.0]])
      • N([10.0,0.0], [[0.1,0.0],[0.0,0.1]])
      • N([-1.0,0.0], [[1.0,0.0],[0.0,0.1]])
      and the mixing distribution is: [0.1, 0.35, 0.05, 0.25, 0.25].
      Parameters:
      numSamples - The size of the output dataset.
      seed - The rng seed to use.
    • GaussianClusterDataSource

      public GaussianClusterDataSource(int numSamples, double[] mixingDistribution, double[] firstMean, double[] firstVariance, double[] secondMean, double[] secondVariance, double[] thirdMean, double[] thirdVariance, double[] fourthMean, double[] fourthVariance, double[] fifthMean, double[] fifthVariance, long seed)
      Generates a clustering dataset drawn from a mixture of 5 Gaussians.

      The Gaussians can be at most 4 dimensional, resulting in 4 features.

      Parameters:
      numSamples - The size of the output dataset.
      mixingDistribution - The probability of each cluster.
      firstMean - The mean of the first Gaussian.
      firstVariance - The variance of the first Gaussian, linearised from a row-major matrix.
      secondMean - The mean of the second Gaussian.
      secondVariance - The variance of the second Gaussian, linearised from a row-major matrix.
      thirdMean - The mean of the third Gaussian.
      thirdVariance - The variance of the third Gaussian, linearised from a row-major matrix.
      fourthMean - The mean of the fourth Gaussian.
      fourthVariance - The variance of the fourth Gaussian, linearised from a row-major matrix.
      fifthMean - The mean of the fifth Gaussian.
      fifthVariance - The variance of the fifth Gaussian, linearised from a row-major matrix.
      seed - The rng seed to use.
  • Method Details

    • postConfig

      public void postConfig()
      Used by the OLCUT configuration system, and should not be called by external code.
      Specified by:
      postConfig in interface com.oracle.labs.mlrg.olcut.config.Configurable
    • getOutputFactory

      public OutputFactory<ClusterID> getOutputFactory()
      Description copied from interface: DataSource
      Returns the OutputFactory associated with this Output subclass.
      Specified by:
      getOutputFactory in interface DataSource<ClusterID>
      Returns:
      The output factory.
    • getProvenance

      public DataSourceProvenance getProvenance()
      Specified by:
      getProvenance in interface com.oracle.labs.mlrg.olcut.provenance.Provenancable<DataSourceProvenance>
    • iterator

      public Iterator<Example<ClusterID>> iterator()
      Specified by:
      iterator in interface Iterable<Example<ClusterID>>
    • generateDataset

      public static MutableDataset<ClusterID> generateDataset(int numSamples, double[] mixingDistribution, double[] firstMean, double[] firstVariance, double[] secondMean, double[] secondVariance, double[] thirdMean, double[] thirdVariance, double[] fourthMean, double[] fourthVariance, double[] fifthMean, double[] fifthVariance, long seed)
      Generates a clustering dataset drawn from a mixture of 5 Gaussians.

      The Gaussians can be at most 4 dimensional, resulting in 4 features.

      Parameters:
      numSamples - The size of the output dataset.
      mixingDistribution - The probability of each cluster.
      firstMean - The mean of the first Gaussian.
      firstVariance - The variance of the first Gaussian, linearised from a row-major matrix.
      secondMean - The mean of the second Gaussian.
      secondVariance - The variance of the second Gaussian, linearised from a row-major matrix.
      thirdMean - The mean of the third Gaussian.
      thirdVariance - The variance of the third Gaussian, linearised from a row-major matrix.
      fourthMean - The mean of the fourth Gaussian.
      fourthVariance - The variance of the fourth Gaussian, linearised from a row-major matrix.
      fifthMean - The mean of the fifth Gaussian.
      fifthVariance - The variance of the fifth Gaussian, linearised from a row-major matrix.
      seed - The rng seed to use.
      Returns:
      A dataset drawn from a mixture of Gaussians.