001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.evaluation;
018
019import org.tribuo.Dataset;
020import org.tribuo.Output;
021import org.tribuo.Trainer;
022import org.tribuo.dataset.DatasetView;
023import org.tribuo.util.Util;
024
025import java.util.Arrays;
026import java.util.Iterator;
027import java.util.SplittableRandom;
028import java.util.logging.Logger;
029import java.util.stream.IntStream;
030
031/**
032 * A k-fold splitter to be used in cross-validation.
033 *
034 * @param <T> the type of the examples that make up the dataset to be split
035 */
036public class KFoldSplitter<T extends Output<T>> {
037
038    private static final Logger logger = Logger.getLogger(KFoldSplitter.class.getName());
039
040    protected final int nsplits;
041    protected final long seed;
042    protected final SplittableRandom rng;
043
044    /**
045     * Builds a k-fold splitter.
046     * @param nsplits The number of folds.
047     * @param randomSeed The RNG seed.
048     */
049    public KFoldSplitter(int nsplits, long randomSeed) {
050        if (nsplits < 2) {
051            throw new IllegalArgumentException("nsplits must be at least 2");
052        }
053        this.nsplits = nsplits;
054        this.seed = randomSeed;
055        this.rng = new SplittableRandom(randomSeed);
056    }
057
058    /**
059     * Builds a k-fold splitter using {@link org.tribuo.Trainer#DEFAULT_SEED} as the seed.
060     * @param nsplits The number of folds.
061     */
062    public KFoldSplitter(int nsplits) {
063        this(nsplits, Trainer.DEFAULT_SEED);
064    }
065
066    /**
067     * Splits a dataset into k consecutive folds; for each fold, the remaining k-1 folds form the training set.
068     * <p>
069     * Note: the first {@code nsamples % nsplits} folds have size {@code nsamples/nsplits + 1} and the remaining have
070     * size {@code nsamples/nsplits}, where {@code nsamples = dataset.size()}.
071     * @param dataset The full dataset
072     * @param shuffle Whether or not shuffle the dataset before generating folds
073     * @return An iterator over TrainTestFolds
074     */
075    public Iterator<TrainTestFold<T>> split(Dataset<T> dataset, boolean shuffle) {
076        int nsamples = dataset.size();
077        if (nsamples == 0) {
078            throw new IllegalArgumentException("empty input data");
079        }
080        if (nsplits > nsamples) {
081            throw new IllegalArgumentException("cannot have nsplits > nsamples");
082        }
083        int[] indices;
084        if (shuffle) {
085            indices = Util.randperm(nsamples,rng);
086        } else {
087            indices = IntStream.range(0, nsamples).toArray();
088        }
089        int[] foldSizes = new int[nsplits];
090        Arrays.fill(foldSizes, nsamples/nsplits);
091        for (int i = 0; i < (nsamples%nsplits); i++) {
092            foldSizes[i] += 1;
093        }
094
095        return new Iterator<TrainTestFold<T>>() {
096            int foldPtr = 0;
097            int dataPtr = 0;
098
099            @Override
100            public boolean hasNext() {
101                return foldPtr < foldSizes.length;
102            }
103
104            @Override
105            public TrainTestFold<T> next() {
106                int size = foldSizes[foldPtr];
107                foldPtr++;
108                int start = dataPtr;
109                int stop = dataPtr+size;
110                dataPtr = stop;
111                int[] holdOut = Arrays.copyOfRange(indices, start, stop);
112                int[] rest = new int[indices.length - holdOut.length];
113                System.arraycopy(indices, 0, rest, 0, start);
114                System.arraycopy(indices, stop, rest, start, nsamples-stop);
115                return new TrainTestFold<>(
116                        new DatasetView<>(dataset, rest, "TrainFold(seed="+seed+","+foldPtr+" of " + nsplits+")"),
117                        new DatasetView<>(dataset, holdOut, "TestFold(seed="+seed+","+foldPtr+" of " + nsplits+")" )
118                );
119            }
120        };
121    }
122
123    /**
124     * Stores a train/test split for a dataset.
125     *
126     * @see KFoldSplitter#split
127     *
128     * @param <T> the type of the examples that make up the data we've split.
129     */
130    public static class TrainTestFold<T extends Output<T>> {
131
132        public final DatasetView<T> train;
133        public final DatasetView<T> test;
134
135        TrainTestFold(DatasetView<T> train, DatasetView<T> test) {
136            this.train = train;
137            this.test = test;
138        }
139
140    }
141}