001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.evaluation; 018 019import com.oracle.labs.mlrg.olcut.util.Pair; 020import org.tribuo.Dataset; 021import org.tribuo.Model; 022import org.tribuo.Output; 023import org.tribuo.Trainer; 024 025import java.util.ArrayList; 026import java.util.Iterator; 027import java.util.List; 028import java.util.logging.Level; 029import java.util.logging.Logger; 030 031/** 032 * A class that does k-fold cross-validation. 033 * <p> 034 * This splits the data into k pieces, tests on one of them and trains on the rest. 035 * <p> 036 * It produces a list of {@link Evaluation}s for each of the test sets. 037 */ 038public class CrossValidation<T extends Output<T>, E extends Evaluation<T>> { 039 040 private static final Logger logger = Logger.getLogger(CrossValidation.class.getName()); 041 042 private final Trainer<T> trainer; 043 private final int numFolds; 044 private final Dataset<T> data; 045 private final Evaluator<T, E> evaluator; 046 private final KFoldSplitter<T> splitter; 047 048 /** 049 * Builds a k-fold cross-validation loop. 050 * @param trainer the trainer to use. 051 * @param data the dataset to split. 052 * @param evaluator the evaluator to use. 053 * @param k the number of folds. 054 */ 055 public CrossValidation(Trainer<T> trainer, 056 Dataset<T> data, 057 Evaluator<T, E> evaluator, 058 int k) { 059 this(trainer, data, evaluator, k, Trainer.DEFAULT_SEED); } 060 061 /** 062 * Builds a k-fold cross-validation loop. 063 * @param trainer the trainer to use. 064 * @param data the dataset to split. 065 * @param evaluator the evaluator to use. 066 * @param k the number of folds. 067 * @param seed The RNG seed. 068 */ 069 public CrossValidation(Trainer<T> trainer, 070 Dataset<T> data, 071 Evaluator<T, E> evaluator, 072 int k, 073 long seed) { 074 this.trainer = trainer; 075 this.data = data; 076 this.evaluator = evaluator; 077 this.numFolds = k; 078 this.splitter = new KFoldSplitter<>(k, seed); 079 } 080 081 /** 082 * Returns the number of folds. 083 * @return The number of folds. 084 */ 085 public int getK() { return numFolds; } 086 087 /** 088 * Performs k fold cross validation, returning the k evaluations. 089 * @return The k evaluators one per fold. 090 */ 091 public List<Pair<E, Model<T>>> evaluate() { 092 List<Pair<E, Model<T>>> evals = new ArrayList<>(); 093 Iterator<KFoldSplitter.TrainTestFold<T>> iter = splitter.split(data, true); 094 int ct = 0; 095 while (iter.hasNext()) { 096 logger.log(Level.INFO, "Training for fold " + ct); 097 KFoldSplitter.TrainTestFold<T> fold = iter.next(); 098 Model<T> model = trainer.train(fold.train); 099 evals.add(new Pair<>(evaluator.evaluate(model, fold.test), model)); 100 ct++; 101 } 102 return evals; 103 } 104}