001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.classification.explanations.lime; 018 019import com.oracle.labs.mlrg.olcut.command.Command; 020import com.oracle.labs.mlrg.olcut.command.CommandGroup; 021import com.oracle.labs.mlrg.olcut.command.CommandInterpreter; 022import com.oracle.labs.mlrg.olcut.config.ConfigurationManager; 023import com.oracle.labs.mlrg.olcut.config.Option; 024import com.oracle.labs.mlrg.olcut.config.Options; 025import com.oracle.labs.mlrg.olcut.config.UsageException; 026import org.tribuo.Model; 027import org.tribuo.Prediction; 028import org.tribuo.SparseModel; 029import org.tribuo.SparseTrainer; 030import org.tribuo.VariableInfo; 031import org.tribuo.classification.Label; 032import org.tribuo.classification.LabelFactory; 033import org.tribuo.data.text.TextFeatureExtractor; 034import org.tribuo.data.text.impl.BasicPipeline; 035import org.tribuo.data.text.impl.TextFeatureExtractorImpl; 036import org.tribuo.regression.Regressor; 037import org.tribuo.regression.rtree.CARTJointRegressionTrainer; 038import org.jline.builtins.Completers; 039import org.jline.reader.Completer; 040import org.jline.reader.impl.completer.NullCompleter; 041import org.tribuo.util.tokens.Tokenizer; 042import org.tribuo.util.tokens.universal.UniversalTokenizer; 043 044import java.io.BufferedInputStream; 045import java.io.File; 046import java.io.FileInputStream; 047import java.io.FileNotFoundException; 048import java.io.IOException; 049import java.io.ObjectInputStream; 050import java.util.SplittableRandom; 051import java.util.logging.Level; 052import java.util.logging.Logger; 053 054/** 055 * A CLI for interacting with {@link LIMEText}. Uses a simple tokenisation and text extraction pipeline. 056 */ 057public class LIMETextCLI implements CommandGroup { 058 private static final Logger logger = Logger.getLogger(LIMETextCLI.class.getName()); 059 060 protected CommandInterpreter shell; 061 062 private Model<Label> model; 063 064 private int numSamples = 100; 065 066 private int numFeatures = 10; 067 068 //private SparseTrainer<Regressor> limeTrainer = new LARSLassoTrainer(numFeatures); 069 private SparseTrainer<Regressor> limeTrainer = new CARTJointRegressionTrainer((int)Math.log(numFeatures),true); 070 071 private Tokenizer tokenizer = new UniversalTokenizer(); 072 073 private TextFeatureExtractor<Label> extractor = new TextFeatureExtractorImpl<>(new BasicPipeline(tokenizer,2)); 074 075 private LIMEText limeText = null; 076 077 public LIMETextCLI() { 078 shell = new CommandInterpreter(); 079 shell.setPrompt("lime-text sh% "); 080 } 081 082 @Override 083 public String getName() { 084 return "LIME Text CLI"; 085 } 086 087 @Override 088 public String getDescription() { 089 return "Commands for experimenting with LIME Text."; 090 } 091 092 public Completer[] fileCompleter() { 093 return new Completer[]{ 094 new Completers.FileNameCompleter(), 095 new NullCompleter() 096 }; 097 } 098 099 /** 100 * Start the command shell 101 */ 102 public void startShell() { 103 shell.add(this); 104 shell.start(); 105 } 106 107 @Command(usage = "<filename> - Load a model from disk.", completers="fileCompleter") 108 public String loadModel(CommandInterpreter ci, File path) { 109 try (ObjectInputStream ois = new ObjectInputStream(new BufferedInputStream(new FileInputStream(path)))) { 110 @SuppressWarnings("unchecked") // deserialising generically typed model. 111 Model<Label> m = (Model<Label>) ois.readObject(); 112 model = m; 113 } catch (ClassNotFoundException e) { 114 logger.log(Level.SEVERE,"Failed to load class from stream " + path.getAbsolutePath(),e); 115 return "Failed to load model"; 116 } catch (FileNotFoundException e) { 117 logger.log(Level.SEVERE,"Failed to open file " + path.getAbsolutePath(),e); 118 return "Failed to load model"; 119 } catch (IOException e) { 120 logger.log(Level.SEVERE,"IOException when reading from file " + path.getAbsolutePath(),e); 121 return "Failed to load model"; 122 } 123 124 limeText = new LIMEText(new SplittableRandom(1),model,limeTrainer,numSamples,extractor,tokenizer); 125 126 return "Loaded model from path " + path.toString(); 127 } 128 129 @Command(usage="Does the model generate probabilities") 130 public String generatesProbabilities(CommandInterpreter ci) { 131 return ""+model.generatesProbabilities(); 132 } 133 134 @Command(usage="Shows the model description") 135 public String modelDescription(CommandInterpreter ci) { 136 return model.toString(); 137 } 138 139 @Command(usage="Shows the information on a particular feature") 140 public String featureInfo(CommandInterpreter ci, String featureName) { 141 VariableInfo f = model.getFeatureIDMap().get(featureName); 142 if (f != null) { 143 return "" + f.toString(); 144 } else { 145 return "Feature " + featureName + " not found."; 146 } 147 } 148 149 @Command(usage="Shows the output information.") 150 public String outputInfo(CommandInterpreter ci) { 151 return model.getOutputIDInfo().toReadableString(); 152 } 153 154 @Command(usage="<int> - Shows the top N features in the model") 155 public String topFeatures(CommandInterpreter ci, int numFeatures) { 156 return ""+ model.getTopFeatures(numFeatures); 157 } 158 159 @Command(usage="Shows the number of features in the model") 160 public String numFeatures(CommandInterpreter ci) { 161 return ""+ model.getFeatureIDMap().size(); 162 } 163 164 @Command(usage="<min count> - Shows the number of features that occurred more than min count times.") 165 public String minCount(CommandInterpreter ci, int minCount) { 166 int counter = 0; 167 for (VariableInfo f : model.getFeatureIDMap()) { 168 if (f.getCount() > minCount) { 169 counter++; 170 } 171 } 172 return counter + " features occurred more than " + minCount + " times."; 173 } 174 175 @Command(usage="Shows the output statistics") 176 public String showLabelStats(CommandInterpreter ci) { 177 return "Label histogram : \n" + model.getOutputIDInfo().toReadableString(); 178 } 179 180 @Command(usage="Sets the number of samples to use in LIME") 181 public String setNumSamples(CommandInterpreter ci, int newNumSamples) { 182 numSamples = newNumSamples; 183 return "Set number of samples to " + numSamples; 184 } 185 186 @Command(usage="Explain a text classification") 187 public String explain(CommandInterpreter ci, String[] tokens) { 188 String text = String.join(" ",tokens); 189 190 LIMEExplanation explanation = limeText.explain(text); 191 192 SparseModel<Regressor> model = explanation.getModel(); 193 194 ci.out.println("Active features of the predicted class = " + model.getActiveFeatures().get(explanation.getPrediction().getOutput().getLabel())); 195 196 return "Explanation = " + explanation.toString(); 197 } 198 199 @Command(usage="Sets the number of features LIME should use in an explanation") 200 public String setNumFeatures(CommandInterpreter ci, int newNumFeatures) { 201 numFeatures = newNumFeatures; 202 //limeTrainer = new LARSLassoTrainer(numFeatures); 203 limeTrainer = new CARTJointRegressionTrainer((int)Math.log(numFeatures),true); 204 limeText = new LIMEText(new SplittableRandom(1),model,limeTrainer,numSamples,extractor, tokenizer); 205 return "Set the number of features in LIME to " + numFeatures; 206 } 207 208 @Command(usage="Make a prediction") 209 public String predict(CommandInterpreter ci, String[] tokens) { 210 String text = String.join(" ",tokens); 211 212 Prediction<Label> prediction = model.predict(extractor.extract(LabelFactory.UNKNOWN_LABEL,text)); 213 214 return "Prediction = " + prediction.toString(); 215 } 216 217 public static class LIMETextCLIOptions implements Options { 218 @Option(charName='f',longName="filename",usage="Model file to load. Optional.") 219 public String modelFilename; 220 } 221 222 public static void main(String[] args) { 223 LIMETextCLI.LIMETextCLIOptions options = new LIMETextCLI.LIMETextCLIOptions(); 224 try { 225 ConfigurationManager cm = new ConfigurationManager(args, options, false); 226 LIMETextCLI driver = new LIMETextCLI(); 227 if (options.modelFilename != null) { 228 logger.log(Level.INFO, driver.loadModel(driver.shell, new File(options.modelFilename))); 229 } 230 driver.startShell(); 231 } catch (UsageException e) { 232 System.out.println("Usage: " + e.getUsage()); 233 } 234 } 235}