001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.classification.explanations.lime;
018
019import com.oracle.labs.mlrg.olcut.util.Pair;
020import org.tribuo.CategoricalInfo;
021import org.tribuo.Example;
022import org.tribuo.Feature;
023import org.tribuo.ImmutableFeatureMap;
024import org.tribuo.Model;
025import org.tribuo.Prediction;
026import org.tribuo.RealInfo;
027import org.tribuo.SparseModel;
028import org.tribuo.SparseTrainer;
029import org.tribuo.VariableIDInfo;
030import org.tribuo.VariableInfo;
031import org.tribuo.classification.Label;
032import org.tribuo.classification.LabelFactory;
033import org.tribuo.classification.explanations.ColumnarExplainer;
034import org.tribuo.data.columnar.ColumnarFeature;
035import org.tribuo.data.columnar.FieldProcessor;
036import org.tribuo.data.columnar.ResponseProcessor;
037import org.tribuo.data.columnar.RowProcessor;
038import org.tribuo.impl.ArrayExample;
039import org.tribuo.impl.ListExample;
040import org.tribuo.math.la.SparseVector;
041import org.tribuo.provenance.SimpleDataSourceProvenance;
042import org.tribuo.regression.Regressor;
043import org.tribuo.regression.evaluation.RegressionEvaluation;
044import org.tribuo.util.Util;
045import org.tribuo.util.tokens.Token;
046import org.tribuo.util.tokens.Tokenizer;
047
048import java.util.ArrayList;
049import java.util.Arrays;
050import java.util.HashMap;
051import java.util.List;
052import java.util.ListIterator;
053import java.util.Map;
054import java.util.Optional;
055import java.util.Random;
056import java.util.SplittableRandom;
057
058/**
059 * Uses the columnar data processing infrastructure to mix text and tabular data.
060 * <p>
061 * If the supplied {@link RowProcessor} doesn't reference any text or binarised fields
062 * then it delegates to {@link LIMEBase#explain}, though it's still more expensive at
063 * construction time.
064 * <p>
065 * See:
066 * <pre>
067 * Ribeiro MT, Singh S, Guestrin C.
068 * "Why should I trust you?: Explaining the predictions of any classifier"
069 * Proceedings of the 22nd ACM SIGKDD International Conference on Knowledge Discovery and Data Mining 2016.
070 * </pre>
071 */
072public class LIMEColumnar extends LIMEBase implements ColumnarExplainer<Regressor> {
073
074    private final RowProcessor<Label> generator;
075
076    private final Map<String,FieldProcessor> binarisedFields = new HashMap<>();
077
078    private final Map<String,FieldProcessor> tabularFields = new HashMap<>();
079
080    private final Map<String,FieldProcessor> textFields = new HashMap<>();
081
082    private final ResponseProcessor<Label> responseProcessor;
083
084    private final Map<String,List<VariableInfo>> binarisedInfos;
085
086    private final Map<String,double[]> binarisedCDFs;
087
088    private final ImmutableFeatureMap binarisedDomain;
089
090    private final ImmutableFeatureMap tabularDomain;
091
092    private final ImmutableFeatureMap textDomain;
093
094    private final Tokenizer tokenizer;
095
096    private final ThreadLocal<Tokenizer> tokenizerThreadLocal;
097
098    /**
099     * Constructs a LIME explainer for a model which uses the columnar data processing system.
100     * @param rng The rng to use for sampling.
101     * @param innerModel The model to explain.
102     * @param explanationTrainer The trainer for the sparse model used to explain.
103     * @param numSamples The number of samples to generate in each explanation.
104     * @param exampleGenerator The {@link RowProcessor} which converts columnar data into an {@link Example}.
105     * @param tokenizer The tokenizer to use on any text fields.
106     */
107    public LIMEColumnar(SplittableRandom rng, Model<Label> innerModel, SparseTrainer<Regressor> explanationTrainer,
108                        int numSamples, RowProcessor<Label> exampleGenerator, Tokenizer tokenizer) {
109        super(rng, innerModel, explanationTrainer, numSamples);
110        this.generator = exampleGenerator.copy();
111        this.responseProcessor = generator.getResponseProcessor();
112        this.tokenizer = tokenizer;
113        this.tokenizerThreadLocal = ThreadLocal.withInitial(() -> {try { return this.tokenizer.clone(); } catch (CloneNotSupportedException e) { throw new IllegalArgumentException("Tokenizer not cloneable",e); }});
114        if (!this.generator.isConfigured()) {
115            this.generator.expandRegexMapping(innerModel);
116        }
117        this.binarisedInfos = new HashMap<>();
118        ArrayList<VariableInfo> infos = new ArrayList<>();
119        for (VariableInfo i : innerModel.getFeatureIDMap()) {
120            infos.add(i);
121        }
122        ArrayList<VariableInfo> allBinarisedInfos = new ArrayList<>();
123        ArrayList<VariableInfo> tabularInfos = new ArrayList<>();
124        ArrayList<VariableInfo> textInfos = new ArrayList<>();
125        for (Map.Entry<String,FieldProcessor> p : generator.getFieldProcessors().entrySet()) {
126            String searchName = p.getKey() + ColumnarFeature.JOINER;
127            switch (p.getValue().getFeatureType()) {
128                case BINARISED_CATEGORICAL: {
129                    int numNamespaces = p.getValue().getNumNamespaces();
130                    if (numNamespaces > 1) {
131                        for (int i = 0; i < numNamespaces; i++) {
132                            String namespace = p.getKey() + FieldProcessor.NAMESPACE + i;
133                            String namespaceSearchName = namespace + ColumnarFeature.JOINER;
134                            binarisedFields.put(namespace, p.getValue());
135                            List<VariableInfo> binarisedInfoList = this.binarisedInfos.computeIfAbsent(namespace, (k) -> new ArrayList<>());
136                            ListIterator<VariableInfo> li = infos.listIterator();
137                            while (li.hasNext()) {
138                                VariableInfo info = li.next();
139                                if (info.getName().startsWith(namespaceSearchName)) {
140                                    if (((CategoricalInfo) info).getUniqueObservations() != 1) {
141                                        throw new IllegalStateException("Processor " + p.getKey() + ", should have been binary, but had " + ((CategoricalInfo) info).getUniqueObservations() + " unique values");
142                                    }
143                                    binarisedInfoList.add(info);
144                                    allBinarisedInfos.add(info);
145                                    li.remove();
146                                }
147                            }
148                        }
149                    } else {
150                        binarisedFields.put(p.getKey(), p.getValue());
151                        List<VariableInfo> binarisedInfoList = this.binarisedInfos.computeIfAbsent(p.getKey(), (k) -> new ArrayList<>());
152                        ListIterator<VariableInfo> li = infos.listIterator();
153                        while (li.hasNext()) {
154                            VariableInfo i = li.next();
155                            if (i.getName().startsWith(searchName)) {
156                                if (((CategoricalInfo) i).getUniqueObservations() != 1) {
157                                    throw new IllegalStateException("Processor " + p.getKey() + ", should have been binary, but had " + ((CategoricalInfo) i).getUniqueObservations() + " unique values");
158                                }
159                                binarisedInfoList.add(i);
160                                allBinarisedInfos.add(i);
161                                li.remove();
162                            }
163                        }
164                    }
165                    break;
166                }
167                case CATEGORICAL:
168                case REAL: {
169                    tabularFields.put(p.getKey(), p.getValue());
170                    ListIterator<VariableInfo> li = infos.listIterator();
171                    while (li.hasNext()) {
172                        VariableInfo i = li.next();
173                        if (i.getName().startsWith(searchName)) {
174                            tabularInfos.add(i);
175                            li.remove();
176                        }
177                    }
178                    break;
179                }
180                case TEXT: {
181                    textFields.put(p.getKey(), p.getValue());
182                    ListIterator<VariableInfo> li = infos.listIterator();
183                    while (li.hasNext()) {
184                        VariableInfo i = li.next();
185                        if (i.getName().startsWith(searchName)) {
186                            textInfos.add(i);
187                            li.remove();
188                        }
189                    }
190                    break;
191                }
192                default:
193                    throw new IllegalArgumentException("Unsupported feature type " + p.getValue().getFeatureType());
194            }
195        }
196        if (infos.size() != 0) {
197            throw new IllegalArgumentException("Found " + infos.size() + " unsupported features.");
198        }
199        if (generator.getFeatureProcessors().size() != 0) {
200            throw new IllegalArgumentException("LIMEColumnar does not support FeatureProcessors.");
201        }
202        this.tabularDomain = new ImmutableFeatureMap(tabularInfos);
203        this.textDomain = new ImmutableFeatureMap(textInfos);
204        this.binarisedDomain = new ImmutableFeatureMap(allBinarisedInfos);
205        this.binarisedCDFs = new HashMap<>();
206        for (Map.Entry<String,List<VariableInfo>> e : binarisedInfos.entrySet()) {
207            long totalCount = 0;
208            long[] counts = new long[e.getValue().size()+1];
209            int i = 0;
210            for (VariableInfo info : e.getValue()) {
211                long curCount = info.getCount();
212                counts[i] = curCount;
213                totalCount += curCount;
214                i++;
215            }
216            long zeroCount = numTrainingExamples - totalCount;
217            if (zeroCount < 0) {
218                throw new IllegalStateException("Processor " + e.getKey() + " purports to be a BINARISED_CATEGORICAL, but had overlap in it's elements");
219            }
220            counts[i] = zeroCount;
221            double[] cdf = Util.generateCDF(counts,numTrainingExamples);
222            binarisedCDFs.put(e.getKey(),cdf);
223        }
224    }
225
226    @Override
227    public LIMEExplanation explain(Map<String, String> input) {
228        return explainWithSamples(input).getA();
229    }
230
231    protected Pair<LIMEExplanation, List<Example<Regressor>>> explainWithSamples(Map<String, String> input) {
232        Optional<Example<Label>> optExample = generator.generateExample(input,false);
233        if (optExample.isPresent()) {
234            Example<Label> example = optExample.get();
235            if ((textDomain.size() == 0) && (binarisedCDFs.size() == 0)) {
236                // Short circuit if there are no text or binarised fields.
237                return explainWithSamples(example);
238            } else {
239                Prediction<Label> prediction = innerModel.predict(example);
240
241                // Build the input example with simplified text features
242                ArrayExample<Regressor> labelledExample = new ArrayExample<>(transformOutput(prediction));
243
244                // Add the tabular features
245                for (Feature f : example) {
246                    if (tabularDomain.getID(f.getName()) != -1) {
247                        labelledExample.add(f);
248                    }
249                }
250                // Extract the tabular features into a SparseVector for later
251                SparseVector tabularVector = SparseVector.createSparseVector(labelledExample,tabularDomain,false);
252
253                // Tokenize the text fields, and generate the perturbed text representation
254                Map<String, String> exampleTextValues = new HashMap<>();
255                Map<String, List<Token>> exampleTextTokens = new HashMap<>();
256                for (Map.Entry<String,FieldProcessor> e : textFields.entrySet()) {
257                    String value = input.get(e.getKey());
258                    if (value != null) {
259                        List<Token> tokens = tokenizerThreadLocal.get().tokenize(value);
260                        for (int i = 0; i < tokens.size(); i++) {
261                            labelledExample.add(nameFeature(e.getKey(),tokens.get(i).text,i),1.0);
262                        }
263                        exampleTextValues.put(e.getKey(),value);
264                        exampleTextTokens.put(e.getKey(),tokens);
265                    }
266                }
267
268                // Sample a dataset.
269                List<Example<Regressor>> sample = sampleData(tabularVector,exampleTextValues,exampleTextTokens);
270
271                // Generate a sparse model on the sampled data.
272                SparseModel<Regressor> model = trainExplainer(labelledExample, sample);
273
274                // Test the sparse model against the predictions of the real model.
275                List<Prediction<Regressor>> predictions = new ArrayList<>(model.predict(sample));
276                predictions.add(model.predict(labelledExample));
277                RegressionEvaluation evaluation = evaluator.evaluate(model,predictions,new SimpleDataSourceProvenance("LIMEColumnar sampled data",regressionFactory));
278
279                return new Pair<>(new LIMEExplanation(model, prediction, evaluation),sample);
280            }
281        } else {
282            throw new IllegalArgumentException("Label not found in input " + input.toString());
283        }
284    }
285
286    /**
287     * Generate the feature name by combining the word and index.
288     * @param fieldName The name of the column this text feature came from.
289     * @param name The word.
290     * @param idx The index.
291     * @return A string representing both of the inputs.
292     */
293    protected String nameFeature(String fieldName, String name, int idx) {
294        return fieldName + "@" + name+"@idx"+idx;
295    }
296
297    /**
298     * Samples a dataset based on the provided text, tokens and tabular features.
299     *
300     * The text features are sampled using the {@link LIMEText} sampling approach,
301     * and the tabular features are sampled using the {@link LIMEBase} approach.
302     *
303     * The weight for each example is based on the distance for the tabular features,
304     * combined with the distance for the text features (which is a hamming distance).
305     * These distances are averaged using a weight function representing how many tokens
306     * there are in the text fields, and how many tabular features there are.
307     *
308     * This weight calculation is subject to change, as it's not necessarily optimal.
309     * @param tabularVector The tabular (i.e., non-text) features.
310     * @param text A map from the field names to the field values for the text fields.
311     * @param textTokens A map from the field names to lists of tokens for those fields.
312     * @return A sampled dataset.
313     */
314    private List<Example<Regressor>> sampleData(SparseVector tabularVector, Map<String,String> text, Map<String,List<Token>> textTokens) {
315        List<Example<Regressor>> output = new ArrayList<>();
316
317        Random innerRNG = new Random(rng.nextLong());
318        for (int i = 0; i < numSamples; i++) {
319            // Create the full example
320            ListExample<Label> sampledExample = new ListExample<>(LabelFactory.UNKNOWN_LABEL);
321
322            // Tabular features.
323            List<Feature> tabularFeatures = new ArrayList<>();
324            // Sample the categorical and real features
325            for (VariableInfo info : tabularDomain) {
326                int id = ((VariableIDInfo) info).getID();
327                double inputValue = tabularVector.get(id);
328
329                if (info instanceof CategoricalInfo) {
330                    // This one is tricksy as categorical info essentially implicitly includes a zero.
331                    CategoricalInfo catInfo = (CategoricalInfo) info;
332                    double sample = catInfo.frequencyBasedSample(innerRNG,numTrainingExamples);
333                    // If we didn't sample zero.
334                    if (Math.abs(sample) > 1e-10) {
335                        Feature newFeature = new Feature(info.getName(),sample);
336                        tabularFeatures.add(newFeature);
337                    }
338                } else if (info instanceof RealInfo) {
339                    RealInfo realInfo = (RealInfo) info;
340                    // As realInfo is sparse we sample from the mixture distribution,
341                    // either 0 or N(inputValue,variance).
342                    // This assumes realInfo never observed a zero, which is enforced from v2.1
343                    // TODO check this makes sense. If the input value is zero do we still want to sample spike and slab?
344                    // If it's not zero do we want to?
345                    int count = realInfo.getCount();
346                    double threshold = count / ((double)numTrainingExamples);
347                    if (innerRNG.nextDouble() < threshold) {
348                        double variance = realInfo.getVariance();
349                        double sample = (innerRNG.nextGaussian() * Math.sqrt(variance)) + inputValue;
350                        Feature newFeature = new Feature(info.getName(),sample);
351                        tabularFeatures.add(newFeature);
352                    }
353                } else {
354                    throw new IllegalStateException("Unsupported info type, expected CategoricalInfo or RealInfo, found " + info.getClass().getName());
355                }
356            }
357            // Sample the binarised categorical features
358            for (Map.Entry<String,double[]> e : binarisedCDFs.entrySet()) {
359                // Sample from the CDF
360                int sample = Util.sampleFromCDF(e.getValue(),innerRNG);
361                // If the sample isn't zero (which is defined to be the last value to make the indices work)
362                if (sample != (e.getValue().length-1)) {
363                    VariableInfo info = binarisedInfos.get(e.getKey()).get(sample);
364                    Feature newFeature = new Feature(info.getName(),1);
365                    tabularFeatures.add(newFeature);
366                }
367            }
368            // Add the tabular features to the current example
369            sampledExample.addAll(tabularFeatures);
370            // Calculate tabular distance
371            double tabularDistance = measureDistance(tabularDomain,numTrainingExamples,tabularVector, SparseVector.createSparseVector(sampledExample,tabularDomain,false));
372
373            // features are the full text features
374            List<Feature> textFeatures = new ArrayList<>();
375            // Perturbed features are the binarised tokens
376            List<Feature> perturbedFeatures = new ArrayList<>();
377
378            // Sample the text features
379            double textDistance = 0.0;
380            long numTokens = 0;
381            for (Map.Entry<String, String> e : text.entrySet()) {
382                String curText = e.getValue();
383                List<Token> tokens = textTokens.get(e.getKey());
384                numTokens += tokens.size();
385
386                // Sample a new Example.
387                int[] activeFeatures = new int[tokens.size()];
388                char[] sampledText = curText.toCharArray();
389                for (int j = 0; j < activeFeatures.length; j++) {
390                    activeFeatures[j] = innerRNG.nextInt(2);
391                    if (activeFeatures[j] == 0) {
392                        textDistance++;
393                        Token curToken = tokens.get(j);
394                        Arrays.fill(sampledText, curToken.start, curToken.end, '\0');
395                    }
396                }
397                String sampledString = new String(sampledText);
398                sampledString = sampledString.replace("\0", "");
399
400                textFeatures.addAll(textFields.get(e.getKey()).process(sampledString));
401
402                for (int j = 0; j < activeFeatures.length; j++) {
403                    perturbedFeatures.add(new Feature(nameFeature(e.getKey(), tokens.get(j).text, j), activeFeatures[j]));
404                }
405            }
406            // Add the text features to the current example
407            sampledExample.addAll(textFeatures);
408            // Calculate text distance
409            double totalTextDistance = textDistance / numTokens;
410
411            // Label it using the full model.
412            Prediction<Label> samplePrediction = innerModel.predict(sampledExample);
413
414            double totalLength = tabularFeatures.size() + perturbedFeatures.size();
415            // Combine the distances and transform into a weight
416            // Currently this averages the two values based on their relative sizes.
417            double weight = 1.0 - ((tabularFeatures.size()*(kernelDist(tabularDistance,kernelWidth) + perturbedFeatures.size()*totalTextDistance) / totalLength));
418
419            // Generate the new sample with the appropriate label and weight.
420            ArrayExample<Regressor> labelledSample = new ArrayExample<>(transformOutput(samplePrediction), (float) weight);
421            labelledSample.addAll(tabularFeatures);
422            labelledSample.addAll(perturbedFeatures);
423            output.add(labelledSample);
424        }
425
426        return output;
427    }
428}