Class CrossValidation<T extends Output<T>,E extends Evaluation<T>>

java.lang.Object
org.tribuo.evaluation.CrossValidation<T,E>

public class CrossValidation<T extends Output<T>,E extends Evaluation<T>> extends Object
A class that does k-fold cross-validation.

This splits the data into k pieces, tests on one of them and trains on the rest.

It produces a list of Evaluations for each of the test sets.

  • Constructor Details

    • CrossValidation

      public CrossValidation(Trainer<T> trainer, Dataset<T> data, Evaluator<T,E> evaluator, int k)
      Builds a k-fold cross-validation loop.
      Parameters:
      trainer - the trainer to use.
      data - the dataset to split.
      evaluator - the evaluator to use.
      k - the number of folds.
    • CrossValidation

      public CrossValidation(Trainer<T> trainer, Dataset<T> data, Evaluator<T,E> evaluator, int k, long seed)
      Builds a k-fold cross-validation loop.
      Parameters:
      trainer - the trainer to use.
      data - the dataset to split.
      evaluator - the evaluator to use.
      k - the number of folds.
      seed - The RNG seed.
  • Method Details

    • getK

      public int getK()
      Returns the number of folds.
      Returns:
      The number of folds.
    • evaluate

      public List<com.oracle.labs.mlrg.olcut.util.Pair<E,Model<T>>> evaluate()
      Performs k fold cross validation, returning the k evaluations.
      Returns:
      The k evaluators one per fold.