001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.evaluation; 018 019import org.tribuo.Dataset; 020import org.tribuo.Output; 021import org.tribuo.Trainer; 022import org.tribuo.dataset.DatasetView; 023import org.tribuo.util.Util; 024 025import java.util.Arrays; 026import java.util.Iterator; 027import java.util.SplittableRandom; 028import java.util.logging.Logger; 029import java.util.stream.IntStream; 030 031/** 032 * A k-fold splitter to be used in cross-validation. 033 * 034 * @param <T> the type of the examples that make up the dataset to be split 035 */ 036public class KFoldSplitter<T extends Output<T>> { 037 038 private static final Logger logger = Logger.getLogger(KFoldSplitter.class.getName()); 039 040 protected final int nsplits; 041 protected final long seed; 042 protected final SplittableRandom rng; 043 044 /** 045 * Builds a k-fold splitter. 046 * @param nsplits The number of folds. 047 * @param randomSeed The RNG seed. 048 */ 049 public KFoldSplitter(int nsplits, long randomSeed) { 050 if (nsplits < 2) { 051 throw new IllegalArgumentException("nsplits must be at least 2"); 052 } 053 this.nsplits = nsplits; 054 this.seed = randomSeed; 055 this.rng = new SplittableRandom(randomSeed); 056 } 057 058 /** 059 * Builds a k-fold splitter using {@link org.tribuo.Trainer#DEFAULT_SEED} as the seed. 060 * @param nsplits The number of folds. 061 */ 062 public KFoldSplitter(int nsplits) { 063 this(nsplits, Trainer.DEFAULT_SEED); 064 } 065 066 /** 067 * Splits a dataset into k consecutive folds; for each fold, the remaining k-1 folds form the training set. 068 * <p> 069 * Note: the first {@code nsamples % nsplits} folds have size {@code nsamples/nsplits + 1} and the remaining have 070 * size {@code nsamples/nsplits}, where {@code nsamples = dataset.size()}. 071 * @param dataset The full dataset 072 * @param shuffle Whether or not shuffle the dataset before generating folds 073 * @return An iterator over TrainTestFolds 074 */ 075 public Iterator<TrainTestFold<T>> split(Dataset<T> dataset, boolean shuffle) { 076 int nsamples = dataset.size(); 077 if (nsamples == 0) { 078 throw new IllegalArgumentException("empty input data"); 079 } 080 if (nsplits > nsamples) { 081 throw new IllegalArgumentException("cannot have nsplits > nsamples"); 082 } 083 int[] indices; 084 if (shuffle) { 085 indices = Util.randperm(nsamples,rng); 086 } else { 087 indices = IntStream.range(0, nsamples).toArray(); 088 } 089 int[] foldSizes = new int[nsplits]; 090 Arrays.fill(foldSizes, nsamples/nsplits); 091 for (int i = 0; i < (nsamples%nsplits); i++) { 092 foldSizes[i] += 1; 093 } 094 095 return new Iterator<TrainTestFold<T>>() { 096 int foldPtr = 0; 097 int dataPtr = 0; 098 099 @Override 100 public boolean hasNext() { 101 return foldPtr < foldSizes.length; 102 } 103 104 @Override 105 public TrainTestFold<T> next() { 106 int size = foldSizes[foldPtr]; 107 foldPtr++; 108 int start = dataPtr; 109 int stop = dataPtr+size; 110 dataPtr = stop; 111 int[] holdOut = Arrays.copyOfRange(indices, start, stop); 112 int[] rest = new int[indices.length - holdOut.length]; 113 System.arraycopy(indices, 0, rest, 0, start); 114 System.arraycopy(indices, stop, rest, start, nsamples-stop); 115 return new TrainTestFold<>( 116 new DatasetView<>(dataset, rest, "TrainFold(seed="+seed+","+foldPtr+" of " + nsplits+")"), 117 new DatasetView<>(dataset, holdOut, "TestFold(seed="+seed+","+foldPtr+" of " + nsplits+")" ) 118 ); 119 } 120 }; 121 } 122 123 /** 124 * Stores a train/test split for a dataset. 125 * 126 * @see KFoldSplitter#split 127 * 128 * @param <T> the type of the examples that make up the data we've split. 129 */ 130 public static class TrainTestFold<T extends Output<T>> { 131 132 public final DatasetView<T> train; 133 public final DatasetView<T> test; 134 135 TrainTestFold(DatasetView<T> train, DatasetView<T> test) { 136 this.train = train; 137 this.test = test; 138 } 139 140 } 141}