001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.data; 018 019import com.oracle.labs.mlrg.olcut.config.ConfigurationManager; 020import com.oracle.labs.mlrg.olcut.config.Option; 021import com.oracle.labs.mlrg.olcut.config.Options; 022import com.oracle.labs.mlrg.olcut.config.UsageException; 023import com.oracle.labs.mlrg.olcut.util.LabsLogFormatter; 024import org.tribuo.ConfigurableDataSource; 025import org.tribuo.DataSource; 026import org.tribuo.Dataset; 027import org.tribuo.Model; 028import org.tribuo.MutableDataset; 029import org.tribuo.Output; 030import org.tribuo.Trainer; 031import org.tribuo.dataset.MinimumCardinalityDataset; 032import org.tribuo.evaluation.Evaluation; 033import org.tribuo.evaluation.Evaluator; 034import org.tribuo.transform.TransformTrainer; 035import org.tribuo.transform.TransformationMap; 036import org.tribuo.util.Util; 037 038import java.io.FileOutputStream; 039import java.io.IOException; 040import java.io.ObjectOutputStream; 041import java.nio.file.Path; 042import java.util.logging.Level; 043import java.util.logging.Logger; 044 045/** 046 * Build and run a predictor for a standard dataset. 047 */ 048public final class CompletelyConfigurableTrainTest { 049 050 private static final Logger logger = Logger.getLogger(CompletelyConfigurableTrainTest.class.getName()); 051 052 private CompletelyConfigurableTrainTest() {} 053 054 public static class ConfigurableTrainTestOptions implements Options { 055 @Override 056 public String getOptionsDescription() { 057 return "Loads a Trainer and two DataSources from a config file, trains a Model, tests it and optionally saves it to disk."; 058 } 059 060 @Option(charName='f',longName="model-output-path",usage="Path to serialize model to.") 061 public Path outputPath; 062 063 @Option(charName='u',longName="train-source",usage="Load the training DataSource from the config file. Overrides the training path.") 064 public ConfigurableDataSource<?> trainSource; 065 066 @Option(charName='v',longName="test-source",usage="Load the testing DataSource from the config file. Overrides the testing path.") 067 public ConfigurableDataSource<?> testSource; 068 069 @Option(charName='t',longName="trainer",usage="Load a trainer from the config file.") 070 public Trainer<?> trainer; 071 072 @Option(longName="transformer",usage="Load a transformation map from the config file.") 073 public TransformationMap transformationMap; 074 075 @Option(charName='m',longName="minimum-count",usage="Remove features which occur fewer than <int> times.") 076 public int minCount = -1; 077 } 078 079 /** 080 * @param args the command line arguments 081 * @param <T> The {@link Output} subclass. 082 */ 083 @SuppressWarnings("unchecked") 084 public static <T extends Output<T>> void main(String[] args) { 085 086 // 087 // Use the labs format logging. 088 LabsLogFormatter.setAllLogFormatters(); 089 090 ConfigurableTrainTestOptions o = new ConfigurableTrainTestOptions(); 091 ConfigurationManager cm; 092 try { 093 cm = new ConfigurationManager(args,o); 094 } catch (UsageException e) { 095 logger.info(e.getMessage()); 096 return; 097 } 098 099 if (o.trainSource == null || o.testSource == null) { 100 logger.info(cm.usage()); 101 System.exit(1); 102 } else if (o.trainer == null) { 103 logger.warning("No trainer supplied"); 104 logger.info(cm.usage()); 105 System.exit(1); 106 } 107 108 Dataset<T> train = new MutableDataset<>((DataSource<T>)o.trainSource); 109 if (o.minCount > 0) { 110 logger.info("Removing features which occur fewer than " + o.minCount + " times."); 111 train = new MinimumCardinalityDataset<>(train,o.minCount); 112 } 113 Dataset<T> test = new MutableDataset<>((DataSource<T>)o.testSource); 114 115 if (o.transformationMap != null) { 116 o.trainer = new TransformTrainer<>(o.trainer,o.transformationMap); 117 } 118 logger.info("Trainer is " + o.trainer.getProvenance().toString()); 119 120 logger.info("Outputs are " + train.getOutputInfo().toReadableString()); 121 122 logger.info("Number of features: " + train.getFeatureMap().size()); 123 124 final long trainStart = System.currentTimeMillis(); 125 Model<T> model = ((Trainer<T>)o.trainer).train(train); 126 final long trainStop = System.currentTimeMillis(); 127 128 logger.info("Finished training classifier " + Util.formatDuration(trainStart,trainStop)); 129 130 Evaluator<T,? extends Evaluation<T>> evaluator = train.getOutputFactory().getEvaluator(); 131 final long testStart = System.currentTimeMillis(); 132 Evaluation<T> evaluation = evaluator.evaluate(model,test); 133 final long testStop = System.currentTimeMillis(); 134 logger.info("Finished evaluating model " + Util.formatDuration(testStart,testStop)); 135 System.out.println(evaluation.toString()); 136 137 if (o.outputPath != null) { 138 try (ObjectOutputStream oout = new ObjectOutputStream(new FileOutputStream(o.outputPath.toFile()))) { 139 oout.writeObject(model); 140 logger.info("Serialized model to file: " + o.outputPath); 141 } catch (IOException e) { 142 logger.log(Level.SEVERE, "Error writing model", e); 143 } 144 } 145 } 146}