001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.classification.experiments; 018 019import com.oracle.labs.mlrg.olcut.config.ConfigurationManager; 020import com.oracle.labs.mlrg.olcut.config.Option; 021import com.oracle.labs.mlrg.olcut.config.Options; 022import com.oracle.labs.mlrg.olcut.config.UsageException; 023import com.oracle.labs.mlrg.olcut.util.LabsLogFormatter; 024import com.oracle.labs.mlrg.olcut.util.Pair; 025import org.tribuo.Dataset; 026import org.tribuo.Model; 027import org.tribuo.Trainer; 028import org.tribuo.classification.Label; 029import org.tribuo.classification.LabelFactory; 030import org.tribuo.classification.evaluation.ConfusionMatrix; 031import org.tribuo.classification.evaluation.LabelEvaluation; 032import org.tribuo.classification.evaluation.LabelEvaluator; 033import org.tribuo.data.DataOptions; 034 035import java.io.BufferedWriter; 036import java.io.File; 037import java.io.FileOutputStream; 038import java.io.FileWriter; 039import java.io.IOException; 040import java.io.ObjectOutputStream; 041import java.io.OutputStreamWriter; 042import java.io.PrintWriter; 043import java.nio.charset.StandardCharsets; 044import java.util.HashMap; 045import java.util.List; 046import java.util.Map; 047import java.util.logging.Level; 048import java.util.logging.Logger; 049 050/** 051 * Trains and tests a model using the supplied data, for each trainer inside a configuration file. 052 */ 053public class RunAll { 054 private static final Logger logger = Logger.getLogger(RunAll.class.getName()); 055 056 public static class RunAllOptions implements Options { 057 @Override 058 public String getOptionsDescription() { 059 return "Performs the same training and test experiment on all Trainers in the supplied configuration file."; 060 } 061 public DataOptions general; 062 063 @Option(charName='d',longName="output-directory",usage="Directory to write out the models and test reports.") 064 public File directory; 065 } 066 067 public static void main(String[] args) throws IOException { 068 LabsLogFormatter.setAllLogFormatters(); 069 070 RunAllOptions o = new RunAllOptions(); 071 ConfigurationManager cm; 072 try { 073 cm = new ConfigurationManager(args,o); 074 } catch (UsageException e) { 075 logger.info(e.getMessage()); 076 return; 077 } 078 079 if (o.general.trainingPath == null || o.general.testingPath == null || o.directory == null) { 080 logger.info(cm.usage()); 081 System.exit(1); 082 } 083 Pair<Dataset<Label>,Dataset<Label>> data = null; 084 try { 085 data = o.general.load(new LabelFactory()); 086 } catch (IOException e) { 087 logger.log(Level.SEVERE, "Failed to load data", e); 088 System.exit(1); 089 } 090 Dataset<Label> train = data.getA(); 091 Dataset<Label> test = data.getB(); 092 093 logger.info("Creating directory - " + o.directory.toString()); 094 if (!o.directory.exists() && !o.directory.mkdirs()) { 095 logger.warning("Failed to create directory."); 096 } 097 098 Map<String,Double> performances = new HashMap<>(); 099 List<Trainer> trainers = cm.lookupAll(Trainer.class); 100 for (Trainer<?> t : trainers) { 101 String name = t.getClass().getSimpleName(); 102 logger.info("Training model using " + t.toString()); 103 @SuppressWarnings("unchecked") // configuration system cast. 104 Model<Label> curModel = ((Trainer<Label>)t).train(train); 105 LabelEvaluator evaluator = new LabelEvaluator(); 106 LabelEvaluation evaluation = evaluator.evaluate(curModel,test); 107 Double old = performances.put(name,evaluation.microAveragedF1()); 108 if (old != null) { 109 logger.info("Found two trainers with the name " + name); 110 } 111 String outputPath = o.directory.toString()+"/"+name; 112 try (ObjectOutputStream oos = new ObjectOutputStream(new FileOutputStream(outputPath+".model"))) { 113 oos.writeObject(curModel); 114 } 115 try (PrintWriter writer = new PrintWriter(new OutputStreamWriter(new FileOutputStream(outputPath+".output"), StandardCharsets.UTF_8))) { 116 writer.println("Model = " + name); 117 writer.println("Provenance = " + curModel.toString()); 118 writer.println(); 119 ConfusionMatrix<Label> matrix = evaluation.getConfusionMatrix(); 120 writer.println("ConfusionMatrix:\n" + matrix.toString()); 121 writer.println(); 122 writer.println("Evaluation:\n" + evaluation.toString()); 123 } 124 } 125 126 for (Map.Entry<String,Double> e : performances.entrySet()) { 127 logger.info("Trainer = " + e.getKey() + ", F1 = " + e.getValue()); 128 } 129 130 } 131 132}