001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.evaluation;
018
019import com.oracle.labs.mlrg.olcut.util.Pair;
020import org.tribuo.Dataset;
021import org.tribuo.Model;
022import org.tribuo.Output;
023import org.tribuo.Trainer;
024
025import java.util.ArrayList;
026import java.util.Iterator;
027import java.util.List;
028import java.util.logging.Level;
029import java.util.logging.Logger;
030
031/**
032 * A class that does k-fold cross-validation.
033 * <p>
034 * This splits the data into k pieces, tests on one of them and trains on the rest.
035 * <p>
036 * It produces a list of {@link Evaluation}s for each of the test sets.
037 */
038public class CrossValidation<T extends Output<T>, E extends Evaluation<T>> {
039
040    private static final Logger logger = Logger.getLogger(CrossValidation.class.getName());
041
042    private final Trainer<T> trainer;
043    private final int numFolds;
044    private final Dataset<T> data;
045    private final Evaluator<T, E> evaluator;
046    private final KFoldSplitter<T> splitter;
047
048    /**
049     * Builds a k-fold cross-validation loop.
050     * @param trainer the trainer to use.
051     * @param data the dataset to split.
052     * @param evaluator the evaluator to use.
053     * @param k the number of folds.
054     */
055    public CrossValidation(Trainer<T> trainer,
056                           Dataset<T> data,
057                           Evaluator<T, E> evaluator,
058                           int k) {
059        this(trainer, data, evaluator, k, Trainer.DEFAULT_SEED); }
060
061    /**
062     * Builds a k-fold cross-validation loop.
063     * @param trainer the trainer to use.
064     * @param data the dataset to split.
065     * @param evaluator the evaluator to use.
066     * @param k the number of folds.
067     * @param seed The RNG seed.
068     */
069    public CrossValidation(Trainer<T> trainer,
070                           Dataset<T> data,
071                           Evaluator<T, E> evaluator,
072                           int k,
073                           long seed) {
074        this.trainer = trainer;
075        this.data = data;
076        this.evaluator = evaluator;
077        this.numFolds = k;
078        this.splitter = new KFoldSplitter<>(k, seed);
079    }
080
081    /**
082     * Returns the number of folds.
083     * @return The number of folds.
084     */
085    public int getK() { return numFolds; }
086
087    /**
088     * Performs k fold cross validation, returning the k evaluations.
089     * @return The k evaluators one per fold.
090     */
091    public List<Pair<E, Model<T>>> evaluate() {
092        List<Pair<E, Model<T>>> evals = new ArrayList<>();
093        Iterator<KFoldSplitter.TrainTestFold<T>> iter = splitter.split(data, true);
094        int ct = 0;
095        while (iter.hasNext()) {
096            logger.log(Level.INFO, "Training for fold " + ct);
097            KFoldSplitter.TrainTestFold<T> fold = iter.next();
098            Model<T> model = trainer.train(fold.train);
099            evals.add(new Pair<>(evaluator.evaluate(model, fold.test), model));
100            ct++;
101        }
102        return evals;
103    }
104}