Class TrainTestSplitter<T extends Output<T>>

java.lang.Object
org.tribuo.evaluation.TrainTestSplitter<T>
Type Parameters:
T - The output type of the examples in the datasource.

public class TrainTestSplitter<T extends Output<T>> extends Object
Splits data into training and testing sets. Note that this doesn't operate on Dataset, but rather on DataSource.
  • Constructor Details

    • TrainTestSplitter

      public TrainTestSplitter(DataSource<T> data)
      Creates a splitter that splits a dataset 70/30 train and test using a default seed.
      Parameters:
      data - The data to split.
    • TrainTestSplitter

      public TrainTestSplitter(DataSource<T> data, long seed)
      Creates a splitter that splits a dataset 70/30 train and test.
      Parameters:
      data - The data to split.
      seed - The seed for the RNG.
    • TrainTestSplitter

      public TrainTestSplitter(DataSource<T> data, double trainProportion, long seed)
      Creates a splitter that will split the given data set into a training and testing set. The give proportion of the data will be randomly selected for the training set. The remainder will be in the test set.
      Parameters:
      data - the data that we want to split.
      trainProportion - the proportion of the data to select for training. This should be a number between 0 and 1. For example, a value of 0.7 means that 70% of the data should be selected for the training set.
      seed - The seed for the RNG.
  • Method Details

    • totalSize

      public int totalSize()
      The total amount of data in train and test combined.
      Returns:
      The number of examples.
    • getTrain

      public DataSource<T> getTrain()
      Gets the training data source.
      Returns:
      The training data.
    • getTest

      public DataSource<T> getTest()
      Gets the testing datasource.
      Returns:
      The testing data.