001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.data;
018
019import com.oracle.labs.mlrg.olcut.config.ConfigurationManager;
020import com.oracle.labs.mlrg.olcut.config.Option;
021import com.oracle.labs.mlrg.olcut.config.Options;
022import com.oracle.labs.mlrg.olcut.config.UsageException;
023import com.oracle.labs.mlrg.olcut.util.LabsLogFormatter;
024import org.tribuo.ConfigurableDataSource;
025import org.tribuo.DataSource;
026import org.tribuo.Dataset;
027import org.tribuo.Model;
028import org.tribuo.MutableDataset;
029import org.tribuo.Output;
030import org.tribuo.Trainer;
031import org.tribuo.dataset.MinimumCardinalityDataset;
032import org.tribuo.evaluation.Evaluation;
033import org.tribuo.evaluation.Evaluator;
034import org.tribuo.transform.TransformTrainer;
035import org.tribuo.transform.TransformationMap;
036import org.tribuo.util.Util;
037
038import java.io.FileOutputStream;
039import java.io.IOException;
040import java.io.ObjectOutputStream;
041import java.nio.file.Path;
042import java.util.logging.Level;
043import java.util.logging.Logger;
044
045/**
046 * Build and run a predictor for a standard dataset.
047 */
048public final class CompletelyConfigurableTrainTest {
049
050    private static final Logger logger = Logger.getLogger(CompletelyConfigurableTrainTest.class.getName());
051
052    private CompletelyConfigurableTrainTest() {}
053
054    public static class ConfigurableTrainTestOptions implements Options {
055        @Override
056        public String getOptionsDescription() {
057            return "Loads a Trainer and two DataSources from a config file, trains a Model, tests it and optionally saves it to disk.";
058        }
059
060        @Option(charName='f',longName="model-output-path",usage="Path to serialize model to.")
061        public Path outputPath;
062
063        @Option(charName='u',longName="train-source",usage="Load the training DataSource from the config file. Overrides the training path.")
064        public ConfigurableDataSource<?> trainSource;
065
066        @Option(charName='v',longName="test-source",usage="Load the testing DataSource from the config file. Overrides the testing path.")
067        public ConfigurableDataSource<?> testSource;
068
069        @Option(charName='t',longName="trainer",usage="Load a trainer from the config file.")
070        public Trainer<?> trainer;
071
072        @Option(longName="transformer",usage="Load a transformation map from the config file.")
073        public TransformationMap transformationMap;
074
075        @Option(charName='m',longName="minimum-count",usage="Remove features which occur fewer than <int> times.")
076        public int minCount = -1;
077    }
078
079    /**
080     * @param args the command line arguments
081     * @param <T> The {@link Output} subclass.
082     */
083    @SuppressWarnings("unchecked")
084    public static <T extends Output<T>> void main(String[] args) {
085
086        //
087        // Use the labs format logging.
088        LabsLogFormatter.setAllLogFormatters();
089
090        ConfigurableTrainTestOptions o = new ConfigurableTrainTestOptions();
091        ConfigurationManager cm;
092        try {
093            cm = new ConfigurationManager(args,o);
094        } catch (UsageException e) {
095            logger.info(e.getMessage());
096            return;
097        }
098
099        if (o.trainSource == null || o.testSource == null) {
100            logger.info(cm.usage());
101            System.exit(1);
102        } else if (o.trainer == null) {
103            logger.warning("No trainer supplied");
104            logger.info(cm.usage());
105            System.exit(1);
106        }
107
108        Dataset<T> train = new MutableDataset<>((DataSource<T>)o.trainSource);
109        if (o.minCount > 0) {
110            logger.info("Removing features which occur fewer than " + o.minCount + " times.");
111            train = new MinimumCardinalityDataset<>(train,o.minCount);
112        }
113        Dataset<T> test = new MutableDataset<>((DataSource<T>)o.testSource);
114
115        if (o.transformationMap != null) {
116            o.trainer = new TransformTrainer<>(o.trainer,o.transformationMap);
117        }
118        logger.info("Trainer is " + o.trainer.getProvenance().toString());
119
120        logger.info("Outputs are " + train.getOutputInfo().toReadableString());
121
122        logger.info("Number of features: " + train.getFeatureMap().size());
123
124        final long trainStart = System.currentTimeMillis();
125        Model<T> model = ((Trainer<T>)o.trainer).train(train);
126        final long trainStop = System.currentTimeMillis();
127                
128        logger.info("Finished training classifier " + Util.formatDuration(trainStart,trainStop));
129
130        Evaluator<T,? extends Evaluation<T>> evaluator = train.getOutputFactory().getEvaluator();
131        final long testStart = System.currentTimeMillis();
132        Evaluation<T> evaluation = evaluator.evaluate(model,test);
133        final long testStop = System.currentTimeMillis();
134        logger.info("Finished evaluating model " + Util.formatDuration(testStart,testStop));
135        System.out.println(evaluation.toString());
136
137        if (o.outputPath != null) {
138            try (ObjectOutputStream oout = new ObjectOutputStream(new FileOutputStream(o.outputPath.toFile()))) {
139                oout.writeObject(model);
140                logger.info("Serialized model to file: " + o.outputPath);
141            } catch (IOException e) {
142                logger.log(Level.SEVERE, "Error writing model", e);
143            }
144        }
145    }
146}