001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.classification.sgd.crf; 018 019import com.oracle.labs.mlrg.olcut.config.ConfigurationManager; 020import com.oracle.labs.mlrg.olcut.config.Option; 021import com.oracle.labs.mlrg.olcut.config.Options; 022import com.oracle.labs.mlrg.olcut.config.UsageException; 023import com.oracle.labs.mlrg.olcut.util.LabsLogFormatter; 024import org.tribuo.classification.Label; 025import org.tribuo.classification.sequence.LabelSequenceEvaluation; 026import org.tribuo.classification.sequence.LabelSequenceEvaluator; 027import org.tribuo.classification.sequence.example.SequenceDataGenerator; 028import org.tribuo.hash.HashCodeHasher; 029import org.tribuo.hash.HashingOptions.ModelHashingType; 030import org.tribuo.hash.MessageDigestHasher; 031import org.tribuo.math.StochasticGradientOptimiser; 032import org.tribuo.math.optimisers.GradientOptimiserOptions; 033import org.tribuo.sequence.HashingSequenceTrainer; 034import org.tribuo.sequence.ImmutableSequenceDataset; 035import org.tribuo.sequence.SequenceDataset; 036import org.tribuo.sequence.SequenceTrainer; 037 038import java.io.BufferedInputStream; 039import java.io.FileInputStream; 040import java.io.FileOutputStream; 041import java.io.IOException; 042import java.io.ObjectInputStream; 043import java.io.ObjectOutputStream; 044import java.nio.file.Path; 045import java.util.logging.Logger; 046 047/** 048 * Build and run a sequence classifier on a generated dataset. 049 */ 050public class SeqTest { 051 052 private static final Logger logger = Logger.getLogger(SeqTest.class.getName()); 053 054 public static class CRFOptions implements Options { 055 @Override 056 public String getOptionsDescription() { 057 return "Tests a linear chain CRF model on the specified dataset."; 058 } 059 060 public GradientOptimiserOptions gradientOptions; 061 @Option(charName = 'd', longName = "dataset-name", usage = "Name of the example dataset, options are {gorilla}.") 062 public String datasetName = ""; 063 @Option(charName = 'f', longName = "output-path", usage = "Path to serialize model to.") 064 public Path outputPath; 065 @Option(charName = 'i', longName = "epochs", usage = "Number of SGD epochs.") 066 public int epochs = 5; 067 @Option(charName = 'o', longName = "print-model", usage = "Print out feature, label and other model details.") 068 public boolean logModel = false; 069 @Option(charName = 'p', longName = "logging-interval", usage = "Log the objective after <int> examples.") 070 public int loggingInterval = 100; 071 @Option(charName = 'r', longName = "seed", usage = "RNG seed.") 072 public long seed = 1; 073 @Option(longName = "shuffle", usage = "Shuffle the data each epoch (default: true).") 074 public boolean shuffle = true; 075 @Option(charName = 'u', longName = "train-dataset", usage = "Path to a serialised SequenceDataset used for training.") 076 public Path trainDataset = null; 077 @Option(charName = 'v', longName = "test-dataset", usage = "Path to a serialised SequenceDataset used for testing.") 078 public Path testDataset = null; 079 @Option(longName = "model-hashing-algorithm", usage = "Hash the model during training. Defaults to no hashing.") 080 public ModelHashingType modelHashingAlgorithm = ModelHashingType.NONE; 081 @Option(longName = "model-hashing-salt", usage = "Salt for hashing the model.") 082 public String modelHashingSalt = ""; 083 } 084 085 /** 086 * @param args the command line arguments 087 * @throws ClassNotFoundException if it failed to load the model. 088 * @throws IOException if there is any error reading the examples. 089 */ 090 @SuppressWarnings("unchecked") // deserialising a generic dataset. 091 public static void main(String[] args) throws ClassNotFoundException, IOException { 092 093 // 094 // Use the labs format logging. 095 LabsLogFormatter.setAllLogFormatters(); 096 097 CRFOptions o = new CRFOptions(); 098 ConfigurationManager cm; 099 try { 100 cm = new ConfigurationManager(args, o); 101 } catch (UsageException e) { 102 logger.info(e.getMessage()); 103 return; 104 } 105 106 logger.info("Configuring gradient optimiser"); 107 StochasticGradientOptimiser grad = o.gradientOptions.getOptimiser(); 108 109 logger.info(String.format("Set logging interval to %d", o.loggingInterval)); 110 111 SequenceDataset<Label> train; 112 SequenceDataset<Label> test; 113 switch (o.datasetName) { 114 case "Gorilla": 115 case "gorilla": 116 logger.info("Generating gorilla dataset"); 117 train = SequenceDataGenerator.generateGorillaDataset(1); 118 test = SequenceDataGenerator.generateGorillaDataset(1); 119 break; 120 default: 121 if ((o.trainDataset != null) && (o.testDataset != null)) { 122 logger.info("Loading training data from " + o.trainDataset); 123 try (ObjectInputStream ois = new ObjectInputStream(new BufferedInputStream(new FileInputStream(o.trainDataset.toFile()))); 124 ObjectInputStream oits = new ObjectInputStream(new BufferedInputStream(new FileInputStream(o.testDataset.toFile())))) { 125 train = (SequenceDataset<Label>) ois.readObject(); 126 logger.info(String.format("Loaded %d training examples for %s", train.size(), train.getOutputs().toString())); 127 logger.info("Found " + train.getFeatureIDMap().size() + " features"); 128 logger.info("Loading testing data from " + o.testDataset); 129 SequenceDataset<Label> deserTest = (SequenceDataset<Label>) oits.readObject(); 130 test = ImmutableSequenceDataset.copyDataset(deserTest, train.getFeatureIDMap(), train.getOutputIDInfo()); 131 logger.info(String.format("Loaded %d testing examples", test.size())); 132 } 133 } else { 134 logger.warning("Unknown dataset " + o.datasetName); 135 logger.info(cm.usage()); 136 return; 137 } 138 } 139 140 SequenceTrainer<Label> trainer = new CRFTrainer(grad, o.epochs, o.loggingInterval, o.seed); 141 ((CRFTrainer) trainer).setShuffle(o.shuffle); 142 switch (o.modelHashingAlgorithm) { 143 case NONE: 144 break; 145 case HC: 146 trainer = new HashingSequenceTrainer<>(trainer, new HashCodeHasher(o.modelHashingSalt)); 147 break; 148 case SHA1: 149 trainer = new HashingSequenceTrainer<>(trainer, new MessageDigestHasher("SHA1", o.modelHashingSalt)); 150 break; 151 case SHA256: 152 trainer = new HashingSequenceTrainer<>(trainer, new MessageDigestHasher("SHA-256", o.modelHashingSalt)); 153 break; 154 default: 155 logger.info("Unknown hasher " + o.modelHashingAlgorithm); 156 } 157 158 logger.info("Training using " + trainer.toString()); 159 CRFModel model = (CRFModel) trainer.train(train); 160 logger.info("Finished training"); 161 162 if (o.logModel) { 163 System.out.println("FeatureMap = " + model.getFeatureIDMap().toString()); 164 System.out.println("LabelMap = " + model.getOutputIDInfo().toString()); 165 System.out.println("Features - " + model.generateWeightsString()); 166 } 167 168 LabelSequenceEvaluator labelEvaluator = new LabelSequenceEvaluator(); 169 LabelSequenceEvaluation evaluation = labelEvaluator.evaluate(model,test); 170 logger.info("Finished evaluating model"); 171 System.out.println(evaluation.toString()); 172 System.out.println(); 173 System.out.println(evaluation.getConfusionMatrix().toString()); 174 175 if (o.outputPath != null) { 176 FileOutputStream fout = new FileOutputStream(o.outputPath.toFile()); 177 ObjectOutputStream oout = new ObjectOutputStream(fout); 178 oout.writeObject(model); 179 oout.close(); 180 fout.close(); 181 logger.info("Serialized model to file: " + o.outputPath); 182 } 183 } 184}