001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.classification.sgd.crf;
018
019import com.oracle.labs.mlrg.olcut.config.ConfigurationManager;
020import com.oracle.labs.mlrg.olcut.config.Option;
021import com.oracle.labs.mlrg.olcut.config.Options;
022import com.oracle.labs.mlrg.olcut.config.UsageException;
023import com.oracle.labs.mlrg.olcut.util.LabsLogFormatter;
024import org.tribuo.classification.Label;
025import org.tribuo.classification.sequence.LabelSequenceEvaluation;
026import org.tribuo.classification.sequence.LabelSequenceEvaluator;
027import org.tribuo.classification.sequence.example.SequenceDataGenerator;
028import org.tribuo.hash.HashCodeHasher;
029import org.tribuo.hash.HashingOptions.ModelHashingType;
030import org.tribuo.hash.MessageDigestHasher;
031import org.tribuo.math.StochasticGradientOptimiser;
032import org.tribuo.math.optimisers.GradientOptimiserOptions;
033import org.tribuo.sequence.HashingSequenceTrainer;
034import org.tribuo.sequence.ImmutableSequenceDataset;
035import org.tribuo.sequence.SequenceDataset;
036import org.tribuo.sequence.SequenceTrainer;
037
038import java.io.BufferedInputStream;
039import java.io.FileInputStream;
040import java.io.FileOutputStream;
041import java.io.IOException;
042import java.io.ObjectInputStream;
043import java.io.ObjectOutputStream;
044import java.nio.file.Path;
045import java.util.logging.Logger;
046
047/**
048 * Build and run a sequence classifier on a generated dataset.
049 */
050public class SeqTest {
051
052    private static final Logger logger = Logger.getLogger(SeqTest.class.getName());
053
054    public static class CRFOptions implements Options {
055        @Override
056        public String getOptionsDescription() {
057            return "Tests a linear chain CRF model on the specified dataset.";
058        }
059
060        public GradientOptimiserOptions gradientOptions;
061        @Option(charName = 'd', longName = "dataset-name", usage = "Name of the example dataset, options are {gorilla}.")
062        public String datasetName = "";
063        @Option(charName = 'f', longName = "output-path", usage = "Path to serialize model to.")
064        public Path outputPath;
065        @Option(charName = 'i', longName = "epochs", usage = "Number of SGD epochs.")
066        public int epochs = 5;
067        @Option(charName = 'o', longName = "print-model", usage = "Print out feature, label and other model details.")
068        public boolean logModel = false;
069        @Option(charName = 'p', longName = "logging-interval", usage = "Log the objective after <int> examples.")
070        public int loggingInterval = 100;
071        @Option(charName = 'r', longName = "seed", usage = "RNG seed.")
072        public long seed = 1;
073        @Option(longName = "shuffle", usage = "Shuffle the data each epoch (default: true).")
074        public boolean shuffle = true;
075        @Option(charName = 'u', longName = "train-dataset", usage = "Path to a serialised SequenceDataset used for training.")
076        public Path trainDataset = null;
077        @Option(charName = 'v', longName = "test-dataset", usage = "Path to a serialised SequenceDataset used for testing.")
078        public Path testDataset = null;
079        @Option(longName = "model-hashing-algorithm", usage = "Hash the model during training. Defaults to no hashing.")
080        public ModelHashingType modelHashingAlgorithm = ModelHashingType.NONE;
081        @Option(longName = "model-hashing-salt", usage = "Salt for hashing the model.")
082        public String modelHashingSalt = "";
083    }
084
085    /**
086     * @param args the command line arguments
087     * @throws ClassNotFoundException if it failed to load the model.
088     * @throws IOException            if there is any error reading the examples.
089     */
090    @SuppressWarnings("unchecked") // deserialising a generic dataset.
091    public static void main(String[] args) throws ClassNotFoundException, IOException {
092
093        //
094        // Use the labs format logging.
095        LabsLogFormatter.setAllLogFormatters();
096
097        CRFOptions o = new CRFOptions();
098        ConfigurationManager cm;
099        try {
100            cm = new ConfigurationManager(args, o);
101        } catch (UsageException e) {
102            logger.info(e.getMessage());
103            return;
104        }
105
106        logger.info("Configuring gradient optimiser");
107        StochasticGradientOptimiser grad = o.gradientOptions.getOptimiser();
108
109        logger.info(String.format("Set logging interval to %d", o.loggingInterval));
110
111        SequenceDataset<Label> train;
112        SequenceDataset<Label> test;
113        switch (o.datasetName) {
114            case "Gorilla":
115            case "gorilla":
116                logger.info("Generating gorilla dataset");
117                train = SequenceDataGenerator.generateGorillaDataset(1);
118                test = SequenceDataGenerator.generateGorillaDataset(1);
119                break;
120            default:
121                if ((o.trainDataset != null) && (o.testDataset != null)) {
122                    logger.info("Loading training data from " + o.trainDataset);
123                    try (ObjectInputStream ois = new ObjectInputStream(new BufferedInputStream(new FileInputStream(o.trainDataset.toFile())));
124                         ObjectInputStream oits = new ObjectInputStream(new BufferedInputStream(new FileInputStream(o.testDataset.toFile())))) {
125                        train = (SequenceDataset<Label>) ois.readObject();
126                        logger.info(String.format("Loaded %d training examples for %s", train.size(), train.getOutputs().toString()));
127                        logger.info("Found " + train.getFeatureIDMap().size() + " features");
128                        logger.info("Loading testing data from " + o.testDataset);
129                        SequenceDataset<Label> deserTest = (SequenceDataset<Label>) oits.readObject();
130                        test = ImmutableSequenceDataset.copyDataset(deserTest, train.getFeatureIDMap(), train.getOutputIDInfo());
131                        logger.info(String.format("Loaded %d testing examples", test.size()));
132                    }
133                } else {
134                    logger.warning("Unknown dataset " + o.datasetName);
135                    logger.info(cm.usage());
136                    return;
137                }
138        }
139
140        SequenceTrainer<Label> trainer = new CRFTrainer(grad, o.epochs, o.loggingInterval, o.seed);
141        ((CRFTrainer) trainer).setShuffle(o.shuffle);
142        switch (o.modelHashingAlgorithm) {
143            case NONE:
144                break;
145            case HC:
146                trainer = new HashingSequenceTrainer<>(trainer, new HashCodeHasher(o.modelHashingSalt));
147                break;
148            case SHA1:
149                trainer = new HashingSequenceTrainer<>(trainer, new MessageDigestHasher("SHA1", o.modelHashingSalt));
150                break;
151            case SHA256:
152                trainer = new HashingSequenceTrainer<>(trainer, new MessageDigestHasher("SHA-256", o.modelHashingSalt));
153                break;
154            default:
155                logger.info("Unknown hasher " + o.modelHashingAlgorithm);
156        }
157
158        logger.info("Training using " + trainer.toString());
159        CRFModel model = (CRFModel) trainer.train(train);
160        logger.info("Finished training");
161
162        if (o.logModel) {
163            System.out.println("FeatureMap = " + model.getFeatureIDMap().toString());
164            System.out.println("LabelMap = " + model.getOutputIDInfo().toString());
165            System.out.println("Features - " + model.generateWeightsString());
166        }
167
168        LabelSequenceEvaluator labelEvaluator = new LabelSequenceEvaluator();
169        LabelSequenceEvaluation evaluation = labelEvaluator.evaluate(model,test);
170        logger.info("Finished evaluating model");
171        System.out.println(evaluation.toString());
172        System.out.println();
173        System.out.println(evaluation.getConfusionMatrix().toString());
174
175        if (o.outputPath != null) {
176            FileOutputStream fout = new FileOutputStream(o.outputPath.toFile());
177            ObjectOutputStream oout = new ObjectOutputStream(fout);
178            oout.writeObject(model);
179            oout.close();
180            fout.close();
181            logger.info("Serialized model to file: " + o.outputPath);
182        }
183    }
184}