001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.classification.explanations.lime;
018
019import com.oracle.labs.mlrg.olcut.command.Command;
020import com.oracle.labs.mlrg.olcut.command.CommandGroup;
021import com.oracle.labs.mlrg.olcut.command.CommandInterpreter;
022import com.oracle.labs.mlrg.olcut.config.ConfigurationManager;
023import com.oracle.labs.mlrg.olcut.config.Option;
024import com.oracle.labs.mlrg.olcut.config.Options;
025import com.oracle.labs.mlrg.olcut.config.UsageException;
026import org.tribuo.Model;
027import org.tribuo.Prediction;
028import org.tribuo.SparseModel;
029import org.tribuo.SparseTrainer;
030import org.tribuo.VariableInfo;
031import org.tribuo.classification.Label;
032import org.tribuo.classification.LabelFactory;
033import org.tribuo.data.text.TextFeatureExtractor;
034import org.tribuo.data.text.impl.BasicPipeline;
035import org.tribuo.data.text.impl.TextFeatureExtractorImpl;
036import org.tribuo.regression.Regressor;
037import org.tribuo.regression.rtree.CARTJointRegressionTrainer;
038import org.jline.builtins.Completers;
039import org.jline.reader.Completer;
040import org.jline.reader.impl.completer.NullCompleter;
041import org.tribuo.util.tokens.Tokenizer;
042import org.tribuo.util.tokens.universal.UniversalTokenizer;
043
044import java.io.BufferedInputStream;
045import java.io.File;
046import java.io.FileInputStream;
047import java.io.FileNotFoundException;
048import java.io.IOException;
049import java.io.ObjectInputStream;
050import java.util.SplittableRandom;
051import java.util.logging.Level;
052import java.util.logging.Logger;
053
054/**
055 * A CLI for interacting with {@link LIMEText}. Uses a simple tokenisation and text extraction pipeline.
056 */
057public class LIMETextCLI implements CommandGroup {
058    private static final Logger logger = Logger.getLogger(LIMETextCLI.class.getName());
059
060    protected CommandInterpreter shell;
061
062    private Model<Label> model;
063
064    private int numSamples = 100;
065
066    private int numFeatures = 10;
067
068    //private SparseTrainer<Regressor> limeTrainer = new LARSLassoTrainer(numFeatures);
069    private SparseTrainer<Regressor> limeTrainer = new CARTJointRegressionTrainer((int)Math.log(numFeatures),true);
070
071    private Tokenizer tokenizer = new UniversalTokenizer();
072
073    private TextFeatureExtractor<Label> extractor = new TextFeatureExtractorImpl<>(new BasicPipeline(tokenizer,2));
074
075    private LIMEText limeText = null;
076
077    public LIMETextCLI() {
078        shell = new CommandInterpreter();
079        shell.setPrompt("lime-text sh% ");
080    }
081
082    @Override
083    public String getName() {
084        return "LIME Text CLI";
085    }
086
087    @Override
088    public String getDescription() {
089        return "Commands for experimenting with LIME Text.";
090    }
091
092    public Completer[] fileCompleter() {
093        return new Completer[]{
094                new Completers.FileNameCompleter(),
095                new NullCompleter()
096        };
097    }
098
099    /**
100     * Start the command shell
101     */
102    public void startShell() {
103        shell.add(this);
104        shell.start();
105    }
106
107    @Command(usage = "<filename> - Load a model from disk.", completers="fileCompleter")
108    public String loadModel(CommandInterpreter ci, File path) {
109        try (ObjectInputStream ois = new ObjectInputStream(new BufferedInputStream(new FileInputStream(path)))) {
110            @SuppressWarnings("unchecked") // deserialising generically typed model.
111            Model<Label> m = (Model<Label>) ois.readObject();
112            model = m;
113        } catch (ClassNotFoundException e) {
114            logger.log(Level.SEVERE,"Failed to load class from stream " + path.getAbsolutePath(),e);
115            return "Failed to load model";
116        } catch (FileNotFoundException e) {
117            logger.log(Level.SEVERE,"Failed to open file " + path.getAbsolutePath(),e);
118            return "Failed to load model";
119        } catch (IOException e) {
120            logger.log(Level.SEVERE,"IOException when reading from file " + path.getAbsolutePath(),e);
121            return "Failed to load model";
122        }
123
124        limeText = new LIMEText(new SplittableRandom(1),model,limeTrainer,numSamples,extractor,tokenizer);
125
126        return "Loaded model from path " + path.toString();
127    }
128
129    @Command(usage="Does the model generate probabilities")
130    public String generatesProbabilities(CommandInterpreter ci) {
131        return ""+model.generatesProbabilities();
132    }
133
134    @Command(usage="Shows the model description")
135    public String modelDescription(CommandInterpreter ci) {
136        return model.toString();
137    }
138
139    @Command(usage="Shows the information on a particular feature")
140    public String featureInfo(CommandInterpreter ci, String featureName) {
141        VariableInfo f = model.getFeatureIDMap().get(featureName);
142        if (f != null) {
143            return "" + f.toString();
144        } else {
145            return "Feature " + featureName + " not found.";
146        }
147    }
148
149    @Command(usage="Shows the output information.")
150    public String outputInfo(CommandInterpreter ci) {
151        return model.getOutputIDInfo().toReadableString();
152    }
153
154    @Command(usage="<int> - Shows the top N features in the model")
155    public String topFeatures(CommandInterpreter ci, int numFeatures) {
156        return ""+ model.getTopFeatures(numFeatures);
157    }
158
159    @Command(usage="Shows the number of features in the model")
160    public String numFeatures(CommandInterpreter ci) {
161        return ""+ model.getFeatureIDMap().size();
162    }
163
164    @Command(usage="<min count> - Shows the number of features that occurred more than min count times.")
165    public String minCount(CommandInterpreter ci, int minCount) {
166        int counter = 0;
167        for (VariableInfo f : model.getFeatureIDMap()) {
168            if (f.getCount() > minCount) {
169                counter++;
170            }
171        }
172        return counter + " features occurred more than " + minCount + " times.";
173    }
174
175    @Command(usage="Shows the output statistics")
176    public String showLabelStats(CommandInterpreter ci) {
177        return "Label histogram : \n" + model.getOutputIDInfo().toReadableString();
178    }
179
180    @Command(usage="Sets the number of samples to use in LIME")
181    public String setNumSamples(CommandInterpreter ci, int newNumSamples) {
182        numSamples = newNumSamples;
183        return "Set number of samples to " + numSamples;
184    }
185
186    @Command(usage="Explain a text classification")
187    public String explain(CommandInterpreter ci, String[] tokens) {
188        String text = String.join(" ",tokens);
189
190        LIMEExplanation explanation = limeText.explain(text);
191
192        SparseModel<Regressor> model = explanation.getModel();
193
194        ci.out.println("Active features of the predicted class = " + model.getActiveFeatures().get(explanation.getPrediction().getOutput().getLabel()));
195
196        return "Explanation = " + explanation.toString();
197    }
198
199    @Command(usage="Sets the number of features LIME should use in an explanation")
200    public String setNumFeatures(CommandInterpreter ci, int newNumFeatures) {
201        numFeatures = newNumFeatures;
202        //limeTrainer = new LARSLassoTrainer(numFeatures);
203        limeTrainer = new CARTJointRegressionTrainer((int)Math.log(numFeatures),true);
204        limeText = new LIMEText(new SplittableRandom(1),model,limeTrainer,numSamples,extractor, tokenizer);
205        return "Set the number of features in LIME to " + numFeatures;
206    }
207
208    @Command(usage="Make a prediction")
209    public String predict(CommandInterpreter ci, String[] tokens) {
210        String text = String.join(" ",tokens);
211
212        Prediction<Label> prediction = model.predict(extractor.extract(LabelFactory.UNKNOWN_LABEL,text));
213
214        return "Prediction = " + prediction.toString();
215    }
216
217    public static class LIMETextCLIOptions implements Options {
218        @Option(charName='f',longName="filename",usage="Model file to load. Optional.")
219        public String modelFilename;
220    }
221
222    public static void main(String[] args) {
223        LIMETextCLI.LIMETextCLIOptions options = new LIMETextCLI.LIMETextCLIOptions();
224        try {
225            ConfigurationManager cm = new ConfigurationManager(args, options, false);
226            LIMETextCLI driver = new LIMETextCLI();
227            if (options.modelFilename != null) {
228                logger.log(Level.INFO, driver.loadModel(driver.shell, new File(options.modelFilename)));
229            }
230            driver.startShell();
231        } catch (UsageException e) {
232            System.out.println("Usage: " + e.getUsage());
233        }
234    }
235}