001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.classification.explanations.lime; 018 019import com.oracle.labs.mlrg.olcut.util.Pair; 020import org.tribuo.CategoricalInfo; 021import org.tribuo.Example; 022import org.tribuo.Feature; 023import org.tribuo.ImmutableFeatureMap; 024import org.tribuo.Model; 025import org.tribuo.Prediction; 026import org.tribuo.RealInfo; 027import org.tribuo.SparseModel; 028import org.tribuo.SparseTrainer; 029import org.tribuo.VariableIDInfo; 030import org.tribuo.VariableInfo; 031import org.tribuo.classification.Label; 032import org.tribuo.classification.LabelFactory; 033import org.tribuo.classification.explanations.ColumnarExplainer; 034import org.tribuo.data.columnar.ColumnarFeature; 035import org.tribuo.data.columnar.FieldProcessor; 036import org.tribuo.data.columnar.ResponseProcessor; 037import org.tribuo.data.columnar.RowProcessor; 038import org.tribuo.impl.ArrayExample; 039import org.tribuo.impl.ListExample; 040import org.tribuo.math.la.SparseVector; 041import org.tribuo.provenance.SimpleDataSourceProvenance; 042import org.tribuo.regression.Regressor; 043import org.tribuo.regression.evaluation.RegressionEvaluation; 044import org.tribuo.util.Util; 045import org.tribuo.util.tokens.Token; 046import org.tribuo.util.tokens.Tokenizer; 047 048import java.util.ArrayList; 049import java.util.Arrays; 050import java.util.HashMap; 051import java.util.List; 052import java.util.ListIterator; 053import java.util.Map; 054import java.util.Optional; 055import java.util.Random; 056import java.util.SplittableRandom; 057 058/** 059 * Uses the columnar data processing infrastructure to mix text and tabular data. 060 * <p> 061 * If the supplied {@link RowProcessor} doesn't reference any text or binarised fields 062 * then it delegates to {@link LIMEBase#explain}, though it's still more expensive at 063 * construction time. 064 * <p> 065 * See: 066 * <pre> 067 * Ribeiro MT, Singh S, Guestrin C. 068 * "Why should I trust you?: Explaining the predictions of any classifier" 069 * Proceedings of the 22nd ACM SIGKDD International Conference on Knowledge Discovery and Data Mining 2016. 070 * </pre> 071 */ 072public class LIMEColumnar extends LIMEBase implements ColumnarExplainer<Regressor> { 073 074 private final RowProcessor<Label> generator; 075 076 private final Map<String,FieldProcessor> binarisedFields = new HashMap<>(); 077 078 private final Map<String,FieldProcessor> tabularFields = new HashMap<>(); 079 080 private final Map<String,FieldProcessor> textFields = new HashMap<>(); 081 082 private final ResponseProcessor<Label> responseProcessor; 083 084 private final Map<String,List<VariableInfo>> binarisedInfos; 085 086 private final Map<String,double[]> binarisedCDFs; 087 088 private final ImmutableFeatureMap binarisedDomain; 089 090 private final ImmutableFeatureMap tabularDomain; 091 092 private final ImmutableFeatureMap textDomain; 093 094 private final Tokenizer tokenizer; 095 096 private final ThreadLocal<Tokenizer> tokenizerThreadLocal; 097 098 /** 099 * Constructs a LIME explainer for a model which uses the columnar data processing system. 100 * @param rng The rng to use for sampling. 101 * @param innerModel The model to explain. 102 * @param explanationTrainer The trainer for the sparse model used to explain. 103 * @param numSamples The number of samples to generate in each explanation. 104 * @param exampleGenerator The {@link RowProcessor} which converts columnar data into an {@link Example}. 105 * @param tokenizer The tokenizer to use on any text fields. 106 */ 107 public LIMEColumnar(SplittableRandom rng, Model<Label> innerModel, SparseTrainer<Regressor> explanationTrainer, 108 int numSamples, RowProcessor<Label> exampleGenerator, Tokenizer tokenizer) { 109 super(rng, innerModel, explanationTrainer, numSamples); 110 this.generator = exampleGenerator.copy(); 111 this.responseProcessor = generator.getResponseProcessor(); 112 this.tokenizer = tokenizer; 113 this.tokenizerThreadLocal = ThreadLocal.withInitial(() -> {try { return this.tokenizer.clone(); } catch (CloneNotSupportedException e) { throw new IllegalArgumentException("Tokenizer not cloneable",e); }}); 114 if (!this.generator.isConfigured()) { 115 this.generator.expandRegexMapping(innerModel); 116 } 117 this.binarisedInfos = new HashMap<>(); 118 ArrayList<VariableInfo> infos = new ArrayList<>(); 119 for (VariableInfo i : innerModel.getFeatureIDMap()) { 120 infos.add(i); 121 } 122 ArrayList<VariableInfo> allBinarisedInfos = new ArrayList<>(); 123 ArrayList<VariableInfo> tabularInfos = new ArrayList<>(); 124 ArrayList<VariableInfo> textInfos = new ArrayList<>(); 125 for (Map.Entry<String,FieldProcessor> p : generator.getFieldProcessors().entrySet()) { 126 String searchName = p.getKey() + ColumnarFeature.JOINER; 127 switch (p.getValue().getFeatureType()) { 128 case BINARISED_CATEGORICAL: { 129 int numNamespaces = p.getValue().getNumNamespaces(); 130 if (numNamespaces > 1) { 131 for (int i = 0; i < numNamespaces; i++) { 132 String namespace = p.getKey() + FieldProcessor.NAMESPACE + i; 133 String namespaceSearchName = namespace + ColumnarFeature.JOINER; 134 binarisedFields.put(namespace, p.getValue()); 135 List<VariableInfo> binarisedInfoList = this.binarisedInfos.computeIfAbsent(namespace, (k) -> new ArrayList<>()); 136 ListIterator<VariableInfo> li = infos.listIterator(); 137 while (li.hasNext()) { 138 VariableInfo info = li.next(); 139 if (info.getName().startsWith(namespaceSearchName)) { 140 if (((CategoricalInfo) info).getUniqueObservations() != 1) { 141 throw new IllegalStateException("Processor " + p.getKey() + ", should have been binary, but had " + ((CategoricalInfo) info).getUniqueObservations() + " unique values"); 142 } 143 binarisedInfoList.add(info); 144 allBinarisedInfos.add(info); 145 li.remove(); 146 } 147 } 148 } 149 } else { 150 binarisedFields.put(p.getKey(), p.getValue()); 151 List<VariableInfo> binarisedInfoList = this.binarisedInfos.computeIfAbsent(p.getKey(), (k) -> new ArrayList<>()); 152 ListIterator<VariableInfo> li = infos.listIterator(); 153 while (li.hasNext()) { 154 VariableInfo i = li.next(); 155 if (i.getName().startsWith(searchName)) { 156 if (((CategoricalInfo) i).getUniqueObservations() != 1) { 157 throw new IllegalStateException("Processor " + p.getKey() + ", should have been binary, but had " + ((CategoricalInfo) i).getUniqueObservations() + " unique values"); 158 } 159 binarisedInfoList.add(i); 160 allBinarisedInfos.add(i); 161 li.remove(); 162 } 163 } 164 } 165 break; 166 } 167 case CATEGORICAL: 168 case REAL: { 169 tabularFields.put(p.getKey(), p.getValue()); 170 ListIterator<VariableInfo> li = infos.listIterator(); 171 while (li.hasNext()) { 172 VariableInfo i = li.next(); 173 if (i.getName().startsWith(searchName)) { 174 tabularInfos.add(i); 175 li.remove(); 176 } 177 } 178 break; 179 } 180 case TEXT: { 181 textFields.put(p.getKey(), p.getValue()); 182 ListIterator<VariableInfo> li = infos.listIterator(); 183 while (li.hasNext()) { 184 VariableInfo i = li.next(); 185 if (i.getName().startsWith(searchName)) { 186 textInfos.add(i); 187 li.remove(); 188 } 189 } 190 break; 191 } 192 default: 193 throw new IllegalArgumentException("Unsupported feature type " + p.getValue().getFeatureType()); 194 } 195 } 196 if (infos.size() != 0) { 197 throw new IllegalArgumentException("Found " + infos.size() + " unsupported features."); 198 } 199 if (generator.getFeatureProcessors().size() != 0) { 200 throw new IllegalArgumentException("LIMEColumnar does not support FeatureProcessors."); 201 } 202 this.tabularDomain = new ImmutableFeatureMap(tabularInfos); 203 this.textDomain = new ImmutableFeatureMap(textInfos); 204 this.binarisedDomain = new ImmutableFeatureMap(allBinarisedInfos); 205 this.binarisedCDFs = new HashMap<>(); 206 for (Map.Entry<String,List<VariableInfo>> e : binarisedInfos.entrySet()) { 207 long totalCount = 0; 208 long[] counts = new long[e.getValue().size()+1]; 209 int i = 0; 210 for (VariableInfo info : e.getValue()) { 211 long curCount = info.getCount(); 212 counts[i] = curCount; 213 totalCount += curCount; 214 i++; 215 } 216 long zeroCount = numTrainingExamples - totalCount; 217 if (zeroCount < 0) { 218 throw new IllegalStateException("Processor " + e.getKey() + " purports to be a BINARISED_CATEGORICAL, but had overlap in it's elements"); 219 } 220 counts[i] = zeroCount; 221 double[] cdf = Util.generateCDF(counts,numTrainingExamples); 222 binarisedCDFs.put(e.getKey(),cdf); 223 } 224 } 225 226 @Override 227 public LIMEExplanation explain(Map<String, String> input) { 228 return explainWithSamples(input).getA(); 229 } 230 231 protected Pair<LIMEExplanation, List<Example<Regressor>>> explainWithSamples(Map<String, String> input) { 232 Optional<Example<Label>> optExample = generator.generateExample(input,false); 233 if (optExample.isPresent()) { 234 Example<Label> example = optExample.get(); 235 if ((textDomain.size() == 0) && (binarisedCDFs.size() == 0)) { 236 // Short circuit if there are no text or binarised fields. 237 return explainWithSamples(example); 238 } else { 239 Prediction<Label> prediction = innerModel.predict(example); 240 241 // Build the input example with simplified text features 242 ArrayExample<Regressor> labelledExample = new ArrayExample<>(transformOutput(prediction)); 243 244 // Add the tabular features 245 for (Feature f : example) { 246 if (tabularDomain.getID(f.getName()) != -1) { 247 labelledExample.add(f); 248 } 249 } 250 // Extract the tabular features into a SparseVector for later 251 SparseVector tabularVector = SparseVector.createSparseVector(labelledExample,tabularDomain,false); 252 253 // Tokenize the text fields, and generate the perturbed text representation 254 Map<String, String> exampleTextValues = new HashMap<>(); 255 Map<String, List<Token>> exampleTextTokens = new HashMap<>(); 256 for (Map.Entry<String,FieldProcessor> e : textFields.entrySet()) { 257 String value = input.get(e.getKey()); 258 if (value != null) { 259 List<Token> tokens = tokenizerThreadLocal.get().tokenize(value); 260 for (int i = 0; i < tokens.size(); i++) { 261 labelledExample.add(nameFeature(e.getKey(),tokens.get(i).text,i),1.0); 262 } 263 exampleTextValues.put(e.getKey(),value); 264 exampleTextTokens.put(e.getKey(),tokens); 265 } 266 } 267 268 // Sample a dataset. 269 List<Example<Regressor>> sample = sampleData(tabularVector,exampleTextValues,exampleTextTokens); 270 271 // Generate a sparse model on the sampled data. 272 SparseModel<Regressor> model = trainExplainer(labelledExample, sample); 273 274 // Test the sparse model against the predictions of the real model. 275 List<Prediction<Regressor>> predictions = new ArrayList<>(model.predict(sample)); 276 predictions.add(model.predict(labelledExample)); 277 RegressionEvaluation evaluation = evaluator.evaluate(model,predictions,new SimpleDataSourceProvenance("LIMEColumnar sampled data",regressionFactory)); 278 279 return new Pair<>(new LIMEExplanation(model, prediction, evaluation),sample); 280 } 281 } else { 282 throw new IllegalArgumentException("Label not found in input " + input.toString()); 283 } 284 } 285 286 /** 287 * Generate the feature name by combining the word and index. 288 * @param fieldName The name of the column this text feature came from. 289 * @param name The word. 290 * @param idx The index. 291 * @return A string representing both of the inputs. 292 */ 293 protected String nameFeature(String fieldName, String name, int idx) { 294 return fieldName + "@" + name+"@idx"+idx; 295 } 296 297 /** 298 * Samples a dataset based on the provided text, tokens and tabular features. 299 * 300 * The text features are sampled using the {@link LIMEText} sampling approach, 301 * and the tabular features are sampled using the {@link LIMEBase} approach. 302 * 303 * The weight for each example is based on the distance for the tabular features, 304 * combined with the distance for the text features (which is a hamming distance). 305 * These distances are averaged using a weight function representing how many tokens 306 * there are in the text fields, and how many tabular features there are. 307 * 308 * This weight calculation is subject to change, as it's not necessarily optimal. 309 * @param tabularVector The tabular (i.e., non-text) features. 310 * @param text A map from the field names to the field values for the text fields. 311 * @param textTokens A map from the field names to lists of tokens for those fields. 312 * @return A sampled dataset. 313 */ 314 private List<Example<Regressor>> sampleData(SparseVector tabularVector, Map<String,String> text, Map<String,List<Token>> textTokens) { 315 List<Example<Regressor>> output = new ArrayList<>(); 316 317 Random innerRNG = new Random(rng.nextLong()); 318 for (int i = 0; i < numSamples; i++) { 319 // Create the full example 320 ListExample<Label> sampledExample = new ListExample<>(LabelFactory.UNKNOWN_LABEL); 321 322 // Tabular features. 323 List<Feature> tabularFeatures = new ArrayList<>(); 324 // Sample the categorical and real features 325 for (VariableInfo info : tabularDomain) { 326 int id = ((VariableIDInfo) info).getID(); 327 double inputValue = tabularVector.get(id); 328 329 if (info instanceof CategoricalInfo) { 330 // This one is tricksy as categorical info essentially implicitly includes a zero. 331 CategoricalInfo catInfo = (CategoricalInfo) info; 332 double sample = catInfo.frequencyBasedSample(innerRNG,numTrainingExamples); 333 // If we didn't sample zero. 334 if (Math.abs(sample) > 1e-10) { 335 Feature newFeature = new Feature(info.getName(),sample); 336 tabularFeatures.add(newFeature); 337 } 338 } else if (info instanceof RealInfo) { 339 RealInfo realInfo = (RealInfo) info; 340 // As realInfo is sparse we sample from the mixture distribution, 341 // either 0 or N(inputValue,variance). 342 // This assumes realInfo never observed a zero, which is enforced from v2.1 343 // TODO check this makes sense. If the input value is zero do we still want to sample spike and slab? 344 // If it's not zero do we want to? 345 int count = realInfo.getCount(); 346 double threshold = count / ((double)numTrainingExamples); 347 if (innerRNG.nextDouble() < threshold) { 348 double variance = realInfo.getVariance(); 349 double sample = (innerRNG.nextGaussian() * Math.sqrt(variance)) + inputValue; 350 Feature newFeature = new Feature(info.getName(),sample); 351 tabularFeatures.add(newFeature); 352 } 353 } else { 354 throw new IllegalStateException("Unsupported info type, expected CategoricalInfo or RealInfo, found " + info.getClass().getName()); 355 } 356 } 357 // Sample the binarised categorical features 358 for (Map.Entry<String,double[]> e : binarisedCDFs.entrySet()) { 359 // Sample from the CDF 360 int sample = Util.sampleFromCDF(e.getValue(),innerRNG); 361 // If the sample isn't zero (which is defined to be the last value to make the indices work) 362 if (sample != (e.getValue().length-1)) { 363 VariableInfo info = binarisedInfos.get(e.getKey()).get(sample); 364 Feature newFeature = new Feature(info.getName(),1); 365 tabularFeatures.add(newFeature); 366 } 367 } 368 // Add the tabular features to the current example 369 sampledExample.addAll(tabularFeatures); 370 // Calculate tabular distance 371 double tabularDistance = measureDistance(tabularDomain,numTrainingExamples,tabularVector, SparseVector.createSparseVector(sampledExample,tabularDomain,false)); 372 373 // features are the full text features 374 List<Feature> textFeatures = new ArrayList<>(); 375 // Perturbed features are the binarised tokens 376 List<Feature> perturbedFeatures = new ArrayList<>(); 377 378 // Sample the text features 379 double textDistance = 0.0; 380 long numTokens = 0; 381 for (Map.Entry<String, String> e : text.entrySet()) { 382 String curText = e.getValue(); 383 List<Token> tokens = textTokens.get(e.getKey()); 384 numTokens += tokens.size(); 385 386 // Sample a new Example. 387 int[] activeFeatures = new int[tokens.size()]; 388 char[] sampledText = curText.toCharArray(); 389 for (int j = 0; j < activeFeatures.length; j++) { 390 activeFeatures[j] = innerRNG.nextInt(2); 391 if (activeFeatures[j] == 0) { 392 textDistance++; 393 Token curToken = tokens.get(j); 394 Arrays.fill(sampledText, curToken.start, curToken.end, '\0'); 395 } 396 } 397 String sampledString = new String(sampledText); 398 sampledString = sampledString.replace("\0", ""); 399 400 textFeatures.addAll(textFields.get(e.getKey()).process(sampledString)); 401 402 for (int j = 0; j < activeFeatures.length; j++) { 403 perturbedFeatures.add(new Feature(nameFeature(e.getKey(), tokens.get(j).text, j), activeFeatures[j])); 404 } 405 } 406 // Add the text features to the current example 407 sampledExample.addAll(textFeatures); 408 // Calculate text distance 409 double totalTextDistance = textDistance / numTokens; 410 411 // Label it using the full model. 412 Prediction<Label> samplePrediction = innerModel.predict(sampledExample); 413 414 double totalLength = tabularFeatures.size() + perturbedFeatures.size(); 415 // Combine the distances and transform into a weight 416 // Currently this averages the two values based on their relative sizes. 417 double weight = 1.0 - ((tabularFeatures.size()*(kernelDist(tabularDistance,kernelWidth) + perturbedFeatures.size()*totalTextDistance) / totalLength)); 418 419 // Generate the new sample with the appropriate label and weight. 420 ArrayExample<Regressor> labelledSample = new ArrayExample<>(transformOutput(samplePrediction), (float) weight); 421 labelledSample.addAll(tabularFeatures); 422 labelledSample.addAll(perturbedFeatures); 423 output.add(labelledSample); 424 } 425 426 return output; 427 } 428}