001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.classification.explanations.lime; 018 019import org.tribuo.Example; 020import org.tribuo.Model; 021import org.tribuo.Prediction; 022import org.tribuo.SparseModel; 023import org.tribuo.SparseTrainer; 024import org.tribuo.classification.Label; 025import org.tribuo.classification.LabelFactory; 026import org.tribuo.classification.explanations.TextExplainer; 027import org.tribuo.data.text.TextFeatureExtractor; 028import org.tribuo.impl.ArrayExample; 029import org.tribuo.provenance.SimpleDataSourceProvenance; 030import org.tribuo.regression.Regressor; 031import org.tribuo.regression.evaluation.RegressionEvaluation; 032import org.tribuo.util.tokens.Token; 033import org.tribuo.util.tokens.Tokenizer; 034 035import java.util.ArrayList; 036import java.util.Arrays; 037import java.util.List; 038import java.util.Random; 039import java.util.SplittableRandom; 040import java.util.logging.Logger; 041 042/** 043 * Uses a Tribuo {@link TextFeatureExtractor} to explain the prediction for a given piece of text. 044 * <p> 045 * LIME uses a naive sampling procedure which blanks out words and trains the linear model on 046 * a set of binary features, each of which is the presence of a word+position combination. Words 047 * are not permuted, and new words are not added (so it's only explaining when the absence of a 048 * word would change the prediction). 049 * <p> 050 * See: 051 * <pre> 052 * Ribeiro MT, Singh S, Guestrin C. 053 * "Why should I trust you?: Explaining the predictions of any classifier" 054 * Proceedings of the 22nd ACM SIGKDD International Conference on Knowledge Discovery and Data Mining 2016. 055 * </pre> 056 */ 057public class LIMEText extends LIMEBase implements TextExplainer<Regressor> { 058 059 private static final Logger logger = Logger.getLogger(LIMEText.class.getName()); 060 061 private final TextFeatureExtractor<Label> extractor; 062 063 private final Tokenizer tokenizer; 064 065 private final ThreadLocal<Tokenizer> tokenizerThreadLocal; 066 067 /** 068 * Constructs a LIME explainer for a model which uses text data. 069 * @param rng The rng to use for sampling. 070 * @param innerModel The model to explain. 071 * @param explanationTrainer The sparse trainer to use to generate explanations. 072 * @param numSamples The number of samples to generate for each explanation. 073 * @param extractor The {@link TextFeatureExtractor} used to generate text features from a string. 074 * @param tokenizer The tokenizer used to tokenize the examples. 075 */ 076 public LIMEText(SplittableRandom rng, Model<Label> innerModel, SparseTrainer<Regressor> explanationTrainer, int numSamples, TextFeatureExtractor<Label> extractor, Tokenizer tokenizer) { 077 super(rng, innerModel, explanationTrainer, numSamples); 078 this.extractor = extractor; 079 this.tokenizer = tokenizer; 080 this.tokenizerThreadLocal = ThreadLocal.withInitial(() -> {try { return this.tokenizer.clone(); } catch (CloneNotSupportedException e) { throw new IllegalArgumentException("Tokenizer not cloneable",e); }}); 081 } 082 083 @Override 084 public LIMEExplanation explain(String inputText) { 085 Example<Label> trueExample = extractor.extract(LabelFactory.UNKNOWN_LABEL, inputText); 086 Prediction<Label> prediction = innerModel.predict(trueExample); 087 088 ArrayExample<Regressor> bowExample = new ArrayExample<>(transformOutput(prediction)); 089 List<Token> tokens = tokenizerThreadLocal.get().tokenize(inputText); 090 for (int i = 0; i < tokens.size(); i++) { 091 bowExample.add(nameFeature(tokens.get(i).text,i),1.0); 092 } 093 094 // Sample a dataset. 095 List<Example<Regressor>> sample = sampleData(inputText,tokens); 096 097 // Generate a sparse model on the sampled data. 098 SparseModel<Regressor> model = trainExplainer(bowExample, sample); 099 100 // Test the sparse model against the predictions of the real model. 101 List<Prediction<Regressor>> predictions = new ArrayList<>(model.predict(sample)); 102 predictions.add(model.predict(bowExample)); 103 RegressionEvaluation evaluation = evaluator.evaluate(model,predictions,new SimpleDataSourceProvenance("LIMEText sampled data",regressionFactory)); 104 105 return new LIMEExplanation(model, prediction, evaluation); 106 } 107 108 /** 109 * Generate the feature name by combining the word and index. 110 * @param name The word. 111 * @param idx The index. 112 * @return A string representing both of the inputs. 113 */ 114 protected String nameFeature(String name, int idx) { 115 return name+"@idx"+idx; 116 } 117 118 /** 119 * Samples a new dataset from the input text. Uses the tokenized representation, 120 * removes words by blanking them out. Only removes words to generate a new sentence, 121 * and does not generate the empty sentence. 122 * @param inputText The input text. 123 * @param tokens The tokenized representation of the input text. 124 * @return A list of samples from the input text. 125 */ 126 protected List<Example<Regressor>> sampleData(String inputText, List<Token> tokens) { 127 List<Example<Regressor>> output = new ArrayList<>(); 128 129 Random innerRNG = new Random(rng.nextLong()); 130 for (int i = 0; i < numSamples; i++) { 131 // Sample a new Example. 132 double distance = 0.0; 133 int[] activeFeatures = new int[tokens.size()]; 134 char[] sampledText = inputText.toCharArray(); 135 for (int j = 0; j < activeFeatures.length; j++) { 136 activeFeatures[j] = innerRNG.nextInt(2); 137 if (activeFeatures[j] == 0) { 138 distance++; 139 Token curToken = tokens.get(j); 140 Arrays.fill(sampledText,curToken.start,curToken.end,'\0'); 141 } 142 } 143 String sampledString = new String(sampledText); 144 sampledString = sampledString.replace("\0",""); 145 146 Example<Label> sample = extractor.extract(LabelFactory.UNKNOWN_LABEL,sampledString); 147 148 // If the sample has features. 149 if (sample.size() > 0) { 150 // Label it using the full model. 151 Prediction<Label> samplePrediction = innerModel.predict(sample); 152 153 // Transform distance into a weight. 154 double weight = 1.0 - (distance / tokens.size()); 155 156 // Generate the new sample with the appropriate label and weight. 157 ArrayExample<Regressor> labelledSample = new ArrayExample<>(transformOutput(samplePrediction), (float) weight); 158 for (int j = 0; j < activeFeatures.length; j++) { 159 labelledSample.add(nameFeature(tokens.get(j).text, j), activeFeatures[j]); 160 } 161 output.add(labelledSample); 162 } 163 } 164 165 return output; 166 } 167}