Class KFoldSplitter<T extends Output<T>>

java.lang.Object
org.tribuo.evaluation.KFoldSplitter<T>
Type Parameters:
T - the type of the examples that make up the dataset to be split

public class KFoldSplitter<T extends Output<T>> extends Object
A k-fold splitter to be used in cross-validation.
  • Field Details

    • nsplits

      protected final int nsplits
    • seed

      protected final long seed
    • rng

      protected final SplittableRandom rng
  • Constructor Details

    • KFoldSplitter

      public KFoldSplitter(int nsplits, long randomSeed)
      Builds a k-fold splitter.
      Parameters:
      nsplits - The number of folds.
      randomSeed - The RNG seed.
    • KFoldSplitter

      public KFoldSplitter(int nsplits)
      Builds a k-fold splitter using Trainer.DEFAULT_SEED as the seed.
      Parameters:
      nsplits - The number of folds.
  • Method Details

    • split

      public Iterator<KFoldSplitter.TrainTestFold<T>> split(Dataset<T> dataset, boolean shuffle)
      Splits a dataset into k consecutive folds; for each fold, the remaining k-1 folds form the training set.

      Note: the first nsamples % nsplits folds have size nsamples/nsplits + 1 and the remaining have size nsamples/nsplits, where nsamples = dataset.size().

      Parameters:
      dataset - The full dataset
      shuffle - Whether or not shuffle the dataset before generating folds
      Returns:
      An iterator over TrainTestFolds