001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.classification.explanations.lime;
018
019import org.tribuo.Example;
020import org.tribuo.Model;
021import org.tribuo.Prediction;
022import org.tribuo.SparseModel;
023import org.tribuo.SparseTrainer;
024import org.tribuo.classification.Label;
025import org.tribuo.classification.LabelFactory;
026import org.tribuo.classification.explanations.TextExplainer;
027import org.tribuo.data.text.TextFeatureExtractor;
028import org.tribuo.impl.ArrayExample;
029import org.tribuo.provenance.SimpleDataSourceProvenance;
030import org.tribuo.regression.Regressor;
031import org.tribuo.regression.evaluation.RegressionEvaluation;
032import org.tribuo.util.tokens.Token;
033import org.tribuo.util.tokens.Tokenizer;
034
035import java.util.ArrayList;
036import java.util.Arrays;
037import java.util.List;
038import java.util.Random;
039import java.util.SplittableRandom;
040import java.util.logging.Logger;
041
042/**
043 * Uses a Tribuo {@link TextFeatureExtractor} to explain the prediction for a given piece of text.
044 * <p>
045 * LIME uses a naive sampling procedure which blanks out words and trains the linear model on
046 * a set of binary features, each of which is the presence of a word+position combination. Words
047 * are not permuted, and new words are not added (so it's only explaining when the absence of a
048 * word would change the prediction).
049 * <p>
050 * See:
051 * <pre>
052 * Ribeiro MT, Singh S, Guestrin C.
053 * "Why should I trust you?: Explaining the predictions of any classifier"
054 * Proceedings of the 22nd ACM SIGKDD International Conference on Knowledge Discovery and Data Mining 2016.
055 * </pre>
056 */
057public class LIMEText extends LIMEBase implements TextExplainer<Regressor> {
058
059    private static final Logger logger = Logger.getLogger(LIMEText.class.getName());
060
061    private final TextFeatureExtractor<Label> extractor;
062
063    private final Tokenizer tokenizer;
064
065    private final ThreadLocal<Tokenizer> tokenizerThreadLocal;
066
067    /**
068     * Constructs a LIME explainer for a model which uses text data.
069     * @param rng The rng to use for sampling.
070     * @param innerModel The model to explain.
071     * @param explanationTrainer The sparse trainer to use to generate explanations.
072     * @param numSamples The number of samples to generate for each explanation.
073     * @param extractor The {@link TextFeatureExtractor} used to generate text features from a string.
074     * @param tokenizer The tokenizer used to tokenize the examples.
075     */
076    public LIMEText(SplittableRandom rng, Model<Label> innerModel, SparseTrainer<Regressor> explanationTrainer, int numSamples, TextFeatureExtractor<Label> extractor, Tokenizer tokenizer) {
077        super(rng, innerModel, explanationTrainer, numSamples);
078        this.extractor = extractor;
079        this.tokenizer = tokenizer;
080        this.tokenizerThreadLocal = ThreadLocal.withInitial(() -> {try { return this.tokenizer.clone(); } catch (CloneNotSupportedException e) { throw new IllegalArgumentException("Tokenizer not cloneable",e); }});
081    }
082
083    @Override
084    public LIMEExplanation explain(String inputText) {
085        Example<Label> trueExample = extractor.extract(LabelFactory.UNKNOWN_LABEL, inputText);
086        Prediction<Label> prediction = innerModel.predict(trueExample);
087
088        ArrayExample<Regressor> bowExample = new ArrayExample<>(transformOutput(prediction));
089        List<Token> tokens = tokenizerThreadLocal.get().tokenize(inputText);
090        for (int i = 0; i < tokens.size(); i++) {
091            bowExample.add(nameFeature(tokens.get(i).text,i),1.0);
092        }
093
094        // Sample a dataset.
095        List<Example<Regressor>> sample = sampleData(inputText,tokens);
096
097        // Generate a sparse model on the sampled data.
098        SparseModel<Regressor> model = trainExplainer(bowExample, sample);
099
100        // Test the sparse model against the predictions of the real model.
101        List<Prediction<Regressor>> predictions = new ArrayList<>(model.predict(sample));
102        predictions.add(model.predict(bowExample));
103        RegressionEvaluation evaluation = evaluator.evaluate(model,predictions,new SimpleDataSourceProvenance("LIMEText sampled data",regressionFactory));
104
105        return new LIMEExplanation(model, prediction, evaluation);
106    }
107
108    /**
109     * Generate the feature name by combining the word and index.
110     * @param name The word.
111     * @param idx The index.
112     * @return A string representing both of the inputs.
113     */
114    protected String nameFeature(String name, int idx) {
115        return name+"@idx"+idx;
116    }
117
118    /**
119     * Samples a new dataset from the input text. Uses the tokenized representation,
120     * removes words by blanking them out. Only removes words to generate a new sentence,
121     * and does not generate the empty sentence.
122     * @param inputText The input text.
123     * @param tokens The tokenized representation of the input text.
124     * @return A list of samples from the input text.
125     */
126    protected List<Example<Regressor>> sampleData(String inputText, List<Token> tokens) {
127        List<Example<Regressor>> output = new ArrayList<>();
128
129        Random innerRNG = new Random(rng.nextLong());
130        for (int i = 0; i < numSamples; i++) {
131            // Sample a new Example.
132            double distance = 0.0;
133            int[] activeFeatures = new int[tokens.size()];
134            char[] sampledText = inputText.toCharArray();
135            for (int j = 0; j < activeFeatures.length; j++) {
136                activeFeatures[j] = innerRNG.nextInt(2);
137                if (activeFeatures[j] == 0) {
138                    distance++;
139                    Token curToken = tokens.get(j);
140                    Arrays.fill(sampledText,curToken.start,curToken.end,'\0');
141                }
142            }
143            String sampledString = new String(sampledText);
144            sampledString = sampledString.replace("\0","");
145
146            Example<Label> sample = extractor.extract(LabelFactory.UNKNOWN_LABEL,sampledString);
147
148            // If the sample has features.
149            if (sample.size() > 0) {
150                // Label it using the full model.
151                Prediction<Label> samplePrediction = innerModel.predict(sample);
152
153                // Transform distance into a weight.
154                double weight = 1.0 - (distance / tokens.size());
155
156                // Generate the new sample with the appropriate label and weight.
157                ArrayExample<Regressor> labelledSample = new ArrayExample<>(transformOutput(samplePrediction), (float) weight);
158                for (int j = 0; j < activeFeatures.length; j++) {
159                    labelledSample.add(nameFeature(tokens.get(j).text, j), activeFeatures[j]);
160                }
161                output.add(labelledSample);
162            }
163        }
164
165        return output;
166    }
167}