001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.interop.tensorflow; 018 019import com.oracle.labs.mlrg.olcut.config.ConfigurationManager; 020import com.oracle.labs.mlrg.olcut.config.Option; 021import com.oracle.labs.mlrg.olcut.config.Options; 022import com.oracle.labs.mlrg.olcut.config.UsageException; 023import com.oracle.labs.mlrg.olcut.util.LabsLogFormatter; 024import com.oracle.labs.mlrg.olcut.util.Pair; 025import org.tribuo.Dataset; 026import org.tribuo.ImmutableDataset; 027import org.tribuo.Model; 028import org.tribuo.MutableDataset; 029import org.tribuo.OutputFactory; 030import org.tribuo.Trainer; 031import org.tribuo.classification.Label; 032import org.tribuo.classification.LabelFactory; 033import org.tribuo.classification.evaluation.LabelEvaluation; 034import org.tribuo.classification.evaluation.LabelEvaluator; 035import org.tribuo.datasource.LibSVMDataSource; 036import org.tribuo.util.Util; 037 038import java.io.FileOutputStream; 039import java.io.IOException; 040import java.io.ObjectOutputStream; 041import java.nio.file.Path; 042import java.util.logging.Logger; 043 044/** 045 * Build and run a Tensorflow multi-class classifier for a standard dataset. 046 */ 047public class TrainTest { 048 049 private static final Logger logger = Logger.getLogger(TrainTest.class.getName()); 050 051 public enum InputType { DENSE, IMAGE } 052 053 public static Pair<Dataset<Label>,Dataset<Label>> load(Path trainingPath, Path testingPath, OutputFactory<Label> outputFactory) throws IOException { 054 logger.info(String.format("Loading data from %s", trainingPath)); 055 Dataset<Label> train; 056 Dataset<Label> test; 057 // 058 // Load the libsvm text-based training data format. 059 LibSVMDataSource<Label> trainSource = new LibSVMDataSource<>(trainingPath,outputFactory); 060 train = new MutableDataset<>(trainSource); 061 boolean zeroIndexed = trainSource.isZeroIndexed(); 062 int maxFeatureID = trainSource.getMaxFeatureID(); 063 logger.info(String.format("Loaded %d training examples for %s", train.size(), train.getOutputs().toString())); 064 logger.info("Found " + train.getFeatureIDMap().size() + " features"); 065 test = new ImmutableDataset<>(new LibSVMDataSource<>(testingPath,outputFactory,zeroIndexed,maxFeatureID),train.getFeatureIDMap(),train.getOutputIDInfo(),false); 066 logger.info(String.format("Loaded %d testing examples", test.size())); 067 return new Pair<>(train,test); 068 } 069 070 public static void saveModel(Path outputPath, Model<Label> model) throws IOException { 071 FileOutputStream fout = new FileOutputStream(outputPath.toFile()); 072 ObjectOutputStream oout = new ObjectOutputStream(fout); 073 oout.writeObject(model); 074 oout.close(); 075 fout.close(); 076 logger.info("Serialized model to file: " + outputPath); 077 } 078 079 public static class TensorflowOptions implements Options { 080 @Override 081 public String getOptionsDescription() { 082 return "Trains and tests a Tensorflow model."; 083 } 084 @Option(charName='f',longName="model-output-path",usage="Path to serialize model to.") 085 public Path outputPath; 086 @Option(charName='u',longName="training-file",usage="Path to the libsvm format training file.") 087 public Path trainingPath; 088 @Option(charName='v',longName="testing-file",usage="Path to the libsvm format testing file.") 089 public Path testingPath; 090 091 @Option(charName='b',longName="batch-size",usage="Test time minibatch size.") 092 public int testBatchSize = 16; 093 094 @Option(charName='b',longName="batch-size",usage="Minibatch size.") 095 public int batchSize = 128; 096 @Option(charName='e',longName="num-epochs",usage="Number of gradient descent epochs.") 097 public int epochs = 5; 098 @Option(charName='i',longName="image-format",usage="Image format, in [W,H,C]. Defaults to MNIST.") 099 public String imageFormat = "28,28,1"; 100 @Option(charName='t',longName="input-type",usage="Input type.") 101 public InputType inputType = InputType.IMAGE; 102 @Option(charName='m',longName="model-protobuf",usage="Path to the protobuf containing the network description.") 103 public Path protobufPath; 104 @Option(charName='p',longName="checkpoint-dir",usage="Path to the checkpoint base directory.") 105 public Path checkpointPath; 106 } 107 108 /** 109 * @param args the command line arguments 110 * @throws IOException if there is any error reading the examples. 111 */ 112 public static void main(String[] args) throws IOException { 113 // 114 // Use the labs format logging. 115 LabsLogFormatter.setAllLogFormatters(); 116 117 TensorflowOptions o = new TensorflowOptions(); 118 ConfigurationManager cm; 119 try { 120 cm = new ConfigurationManager(args,o); 121 } catch (UsageException e) { 122 logger.info(e.getMessage()); 123 return; 124 } 125 126 if (o.trainingPath == null || o.testingPath == null) { 127 logger.info(cm.usage()); 128 return; 129 } 130 131 Pair<Dataset<Label>,Dataset<Label>> data = load(o.trainingPath, o.testingPath, new LabelFactory()); 132 Dataset<Label> train = data.getA(); 133 Dataset<Label> test = data.getB(); 134 135 ExampleTransformer<Label> inputTransformer; 136 switch (o.inputType) { 137 case IMAGE: 138 String[] splitFormat = o.imageFormat.split(","); 139 if (splitFormat.length != 3) { 140 logger.info(cm.usage()); 141 logger.info("Invalid image format specified. Found " + o.imageFormat); 142 return; 143 } 144 int width = Integer.parseInt(splitFormat[0]); 145 int height = Integer.parseInt(splitFormat[1]); 146 int channels = Integer.parseInt(splitFormat[2]); 147 inputTransformer = new ImageTransformer<>(width,height,channels); 148 break; 149 case DENSE: 150 inputTransformer = new DenseTransformer<>(); 151 break; 152 default: 153 logger.info(cm.usage()); 154 logger.info("Unknown input type. Found " + o.inputType); 155 return; 156 } 157 OutputTransformer<Label> labelTransformer = new LabelTransformer(); 158 159 //public TensorflowTrainer(Path graphPath, ExampleTransformer<T> exampleTransformer, OutputTransformer<T> outputTransformer, int batchSize, int numEpochs) throws IOException { 160 Trainer<Label> trainer; 161 if (o.checkpointPath == null) { 162 logger.info("Using TensorflowTrainer"); 163 trainer = new TensorflowTrainer<>(o.protobufPath, inputTransformer, labelTransformer, o.batchSize, o.epochs, o.testBatchSize); 164 } else { 165 logger.info("Using TensorflowCheckpointTrainer, writing to path " + o.checkpointPath); 166 trainer = new TensorflowCheckpointTrainer<>(o.protobufPath, o.checkpointPath, inputTransformer, labelTransformer, o.batchSize, o.epochs); 167 } 168 logger.info("Training using " + trainer.toString()); 169 final long trainStart = System.currentTimeMillis(); 170 Model<Label> model = trainer.train(train); 171 final long trainStop = System.currentTimeMillis(); 172 logger.info("Finished training classifier " + Util.formatDuration(trainStart, trainStop)); 173 final long testStart = System.currentTimeMillis(); 174 LabelEvaluator evaluator = new LabelEvaluator(); 175 LabelEvaluation evaluation = evaluator.evaluate(model, test); 176 final long testStop = System.currentTimeMillis(); 177 logger.info("Finished evaluating model " + Util.formatDuration(testStart, testStop)); 178 179 if (model.generatesProbabilities()) { 180 logger.info("Average AUC = " + evaluation.averageAUCROC(false)); 181 logger.info("Average weighted AUC = " + evaluation.averageAUCROC(true)); 182 } 183 184 logger.info(evaluation.toString()); 185 186 logger.info(evaluation.getConfusionMatrix().toString()); 187 188 if (o.outputPath != null) { 189 saveModel(o.outputPath, model); 190 } 191 192 if (o.checkpointPath == null) { 193 ((TensorflowModel<?>) model).close(); 194 } else { 195 ((TensorflowCheckpointModel<?>) model).close(); 196 } 197 } 198}