001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.json;
018
019import com.fasterxml.jackson.databind.ObjectMapper;
020import com.fasterxml.jackson.databind.SerializationFeature;
021import com.oracle.labs.mlrg.olcut.config.ConfigurationManager;
022import com.oracle.labs.mlrg.olcut.config.Option;
023import com.oracle.labs.mlrg.olcut.config.Options;
024import com.oracle.labs.mlrg.olcut.config.UsageException;
025import com.oracle.labs.mlrg.olcut.config.json.JsonProvenanceModule;
026import com.oracle.labs.mlrg.olcut.provenance.ListProvenance;
027import com.oracle.labs.mlrg.olcut.provenance.ObjectProvenance;
028import com.oracle.labs.mlrg.olcut.provenance.Provenance;
029import com.oracle.labs.mlrg.olcut.provenance.ProvenanceUtil;
030import com.oracle.labs.mlrg.olcut.provenance.io.ObjectMarshalledProvenance;
031import com.oracle.labs.mlrg.olcut.provenance.primitives.HashProvenance;
032import com.oracle.labs.mlrg.olcut.util.IOUtil;
033import com.oracle.labs.mlrg.olcut.util.LabsLogFormatter;
034import org.tribuo.Model;
035import org.tribuo.Output;
036import org.tribuo.ensemble.EnsembleModel;
037import org.tribuo.provenance.DatasetProvenance;
038import org.tribuo.provenance.EnsembleModelProvenance;
039import org.tribuo.provenance.ModelProvenance;
040import org.tribuo.provenance.TrainerProvenance;
041import org.tribuo.provenance.impl.EmptyDatasetProvenance;
042import org.tribuo.provenance.impl.EmptyTrainerProvenance;
043
044import java.io.File;
045import java.io.FileNotFoundException;
046import java.io.FileOutputStream;
047import java.io.IOException;
048import java.io.ObjectInputStream;
049import java.io.ObjectOutputStream;
050import java.io.OutputStreamWriter;
051import java.io.PrintWriter;
052import java.io.UnsupportedEncodingException;
053import java.lang.reflect.InvocationTargetException;
054import java.lang.reflect.Method;
055import java.nio.charset.StandardCharsets;
056import java.security.MessageDigest;
057import java.time.OffsetDateTime;
058import java.util.ArrayList;
059import java.util.EnumSet;
060import java.util.HashMap;
061import java.util.List;
062import java.util.Map;
063import java.util.logging.Level;
064import java.util.logging.Logger;
065
066import static org.tribuo.json.StripProvenance.ProvenanceTypes.ALL;
067import static org.tribuo.json.StripProvenance.ProvenanceTypes.DATASET;
068import static org.tribuo.json.StripProvenance.ProvenanceTypes.INSTANCE;
069import static org.tribuo.json.StripProvenance.ProvenanceTypes.TRAINER;
070
071/**
072 * A main class for stripping out and storing provenance from a model.
073 * <p>
074 * Provenance stripping is useful for deploying models where others may
075 * be able to inspect the model metadata and discover things about the model's
076 * training procedure.
077 */
078public final class StripProvenance {
079    private static final Logger logger = Logger.getLogger(StripProvenance.class.getName());
080
081    private StripProvenance() { }
082
083    /**
084     * Types of provenance that can be removed.
085     */
086    public enum ProvenanceTypes {
087        /**
088         * Select the dataset provenance.
089         */
090        DATASET,
091        /**
092         * Select the trainer provenance.
093         */
094        TRAINER,
095        /**
096         * Select any instance provenance from the specific training run that created this model.
097         */
098        INSTANCE,
099        /**
100         * Selects all provenance stored in the model.
101         */
102        ALL
103    }
104
105    /**
106     * Creates a new model provenance with the requested provenances stripped out.
107     * @param old The old model provenance.
108     * @param provenanceHash The hash of the provenance (if requested it can be written into the new provenance for tracking).
109     * @param opt The program options.
110     * @return A new model provenance.
111     */
112    private static ModelProvenance cleanProvenance(ModelProvenance old, String provenanceHash, StripProvenanceOptions opt) {
113        // Dataset provenance
114        DatasetProvenance datasetProvenance;
115        if (opt.removeProvenances.contains(ALL) || opt.removeProvenances.contains(DATASET)) {
116            datasetProvenance = new EmptyDatasetProvenance();
117        } else {
118            datasetProvenance = old.getDatasetProvenance();
119        }
120        // Trainer provenance
121        TrainerProvenance trainerProvenance;
122        if (opt.removeProvenances.contains(ALL) || opt.removeProvenances.contains(TRAINER)) {
123            trainerProvenance = new EmptyTrainerProvenance();
124        } else {
125            trainerProvenance = old.getTrainerProvenance();
126        }
127        // Instance provenance
128        OffsetDateTime time;
129        Map<String, Provenance> instanceProvenance;
130        if (opt.removeProvenances.contains(ALL) || opt.removeProvenances.contains(INSTANCE)) {
131            instanceProvenance = new HashMap<>();
132            time = OffsetDateTime.MIN;
133        } else {
134            instanceProvenance = new HashMap<>(old.getInstanceProvenance().getMap());
135            time = old.getTrainingTime();
136        }
137        if (opt.storeHash) {
138            logger.info("Writing provenance hash into instance map.");
139            instanceProvenance.put("original-provenance-hash",new HashProvenance(opt.hashType,"original-provenance-hash",provenanceHash));
140        }
141
142        return new ModelProvenance(old.getClassName(),time,datasetProvenance,trainerProvenance,instanceProvenance);
143    }
144
145    /**
146     * Creates a new ensemble provenance with the requested information removed.
147     * @param old The old ensemble provenance.
148     * @param memberProvenance The new member provenances.
149     * @param provenanceHash The old ensemble provenance hash.
150     * @param opt The program options.
151     * @return The new ensemble provenance with the requested fields removed.
152     */
153    private static EnsembleModelProvenance cleanEnsembleProvenance(EnsembleModelProvenance old, ListProvenance<ModelProvenance> memberProvenance, String provenanceHash, StripProvenanceOptions opt) {
154        // Dataset provenance
155        DatasetProvenance datasetProvenance;
156        if (opt.removeProvenances.contains(ALL) || opt.removeProvenances.contains(DATASET)) {
157            datasetProvenance = new EmptyDatasetProvenance();
158        } else {
159            datasetProvenance = old.getDatasetProvenance();
160        }
161        // Trainer provenance
162        TrainerProvenance trainerProvenance;
163        if (opt.removeProvenances.contains(ALL) || opt.removeProvenances.contains(TRAINER)) {
164            trainerProvenance = new EmptyTrainerProvenance();
165        } else {
166            trainerProvenance = old.getTrainerProvenance();
167        }
168        // Instance provenance
169        OffsetDateTime time;
170        Map<String, Provenance> instanceProvenance;
171        if (opt.removeProvenances.contains(ALL) || opt.removeProvenances.contains(INSTANCE)) {
172            instanceProvenance = new HashMap<>();
173            time = OffsetDateTime.MIN;
174        } else {
175            instanceProvenance = new HashMap<>(old.getInstanceProvenance().getMap());
176            time = old.getTrainingTime();
177        }
178        if (opt.storeHash) {
179            logger.info("Writing provenance hash into instance map.");
180            instanceProvenance.put("original-provenance-hash",new HashProvenance(opt.hashType,"original-provenance-hash",provenanceHash));
181        }
182        return new EnsembleModelProvenance(old.getClassName(),time,datasetProvenance,trainerProvenance,instanceProvenance,memberProvenance);
183    }
184
185    /**
186     * Creates a copy of the old model with the requested provenance removed.
187     * @param oldModel The model to remove provenance from.
188     * @param provenanceHash A hash of the old provenance.
189     * @param opt The program options.
190     * @param <T> The output type.
191     * @return A copy of the model with redacted provenance.
192     * @throws InvocationTargetException If the model doesn't expose a copy method (all models should do).
193     * @throws IllegalAccessException If the model's copy method is not accessible.
194     * @throws NoSuchMethodException If the model's copy method isn't present.
195     */
196    @SuppressWarnings("unchecked") // cast of model after call to copy which returns model.
197    private static <T extends Output<T>> ModelTuple<T> convertModel(Model<T> oldModel, String provenanceHash, StripProvenanceOptions opt) throws InvocationTargetException, IllegalAccessException, NoSuchMethodException {
198        if (oldModel instanceof EnsembleModel) {
199            EnsembleModelProvenance oldProvenance = ((EnsembleModel<T>) oldModel).getProvenance();
200            List<ModelProvenance> newProvenances = new ArrayList<>();
201            List<Model<T>> newModels = new ArrayList<>();
202            for (Model<T> e : ((EnsembleModel<T>) oldModel).getModels()) {
203                ModelTuple<T> tuple = convertModel(e,provenanceHash,opt);
204                newProvenances.add(tuple.provenance);
205                newModels.add(tuple.model);
206            }
207            ListProvenance<ModelProvenance> listProv = new ListProvenance<>(newProvenances);
208            EnsembleModelProvenance cleanedProvenance = cleanEnsembleProvenance(oldProvenance,listProv,provenanceHash,opt);
209            Class<? extends Model> clazz = oldModel.getClass();
210            Method copyMethod = clazz.getDeclaredMethod("copy", String.class, ModelProvenance.class, List.class);
211            boolean accessible = copyMethod.isAccessible();
212            copyMethod.setAccessible(true);
213            String newName = oldModel.getName().isEmpty() ? "deprovenanced" : oldModel.getName() + "-deprovenanced";
214            EnsembleModel<T> output = (EnsembleModel<T>) copyMethod.invoke(oldModel, newName, cleanedProvenance, newModels);
215            copyMethod.setAccessible(accessible);
216            return new ModelTuple<>(output, cleanedProvenance);
217        } else {
218            ModelProvenance oldProvenance = oldModel.getProvenance();
219            ModelProvenance cleanedProvenance = cleanProvenance(oldProvenance, provenanceHash, opt);
220            Class<? extends Model> clazz = oldModel.getClass();
221            Method copyMethod = clazz.getDeclaredMethod("copy", String.class, ModelProvenance.class);
222            boolean accessible = copyMethod.isAccessible();
223            copyMethod.setAccessible(true);
224            String newName = oldModel.getName().isEmpty() ? "deprovenanced" : oldModel.getName() + "-deprovenanced";
225            Model<T> output = (Model<T>) copyMethod.invoke(oldModel, newName, cleanedProvenance);
226            copyMethod.setAccessible(accessible);
227            return new ModelTuple<>(output, cleanedProvenance);
228        }
229    }
230
231    public static class StripProvenanceOptions implements Options {
232        @Override
233        public String getOptionsDescription() {
234            return "A program for removing Provenance information from a Tribuo Model or SequenceModel.";
235        }
236
237        @Option(charName = 'h', longName = "store-provenance-hash", usage = "Stores a hash of the model provenance in the stripped model.")
238        public boolean storeHash;
239        @Option(charName = 'i', longName = "input-model-path", usage = "The model to load.")
240        public File inputModel;
241        @Option(charName = 'o', longName = "output-model-path", usage = "The location to write out the stripped model.")
242        public File outputModel;
243        @Option(charName = 'p', longName = "provenance-path", usage = "Write out the stripped provenance as json.")
244        public File provenanceFile;
245        @Option(charName = 'r', longName = "remove-provenances", usage = "The provenances to remove")
246        public EnumSet<ProvenanceTypes> removeProvenances = EnumSet.noneOf(ProvenanceTypes.class);
247        @Option(charName = 't', longName = "hash-type", usage = "The hash type to use.")
248        public ProvenanceUtil.HashType hashType = ObjectProvenance.DEFAULT_HASH_TYPE;
249    }
250
251    /**
252     * @param args the command line arguments
253     * @param <T>  The {@link Output} subclass.
254     */
255    @SuppressWarnings("unchecked")
256    public static <T extends Output<T>> void main(String[] args) {
257
258        //
259        // Use the labs format logging.
260        LabsLogFormatter.setAllLogFormatters();
261
262        StripProvenanceOptions o = new StripProvenanceOptions();
263        ConfigurationManager cm;
264        try {
265            cm = new ConfigurationManager(args, o);
266        } catch (UsageException e) {
267            logger.info(e.getMessage());
268            return;
269        }
270
271        if (o.inputModel == null || o.outputModel == null) {
272            logger.info(cm.usage());
273            System.exit(1);
274        }
275
276        try (ObjectInputStream ois = IOUtil.getObjectInputStream(o.inputModel);
277             ObjectOutputStream oos = new ObjectOutputStream(new FileOutputStream(o.outputModel))) {
278            logger.info("Loading model from " + o.inputModel);
279            Model<T> input = (Model<T>) ois.readObject();
280
281            ModelProvenance oldProvenance = input.getProvenance();
282
283            logger.info("Marshalling provenance and creating JSON.");
284            List<ObjectMarshalledProvenance> list = ProvenanceUtil.marshalProvenance(oldProvenance);
285            ObjectMapper mapper = new ObjectMapper();
286            mapper.registerModule(new JsonProvenanceModule());
287            mapper.enable(SerializationFeature.INDENT_OUTPUT);
288            String jsonResult = mapper.writeValueAsString(list);
289
290            logger.info("Hashing JSON file");
291            MessageDigest digest = o.hashType.getDigest();
292            byte[] digestBytes = digest.digest(jsonResult.getBytes(StandardCharsets.UTF_8));
293            String provenanceHash = ProvenanceUtil.bytesToHexString(digestBytes);
294            logger.info("Provenance hash = " + provenanceHash);
295
296            if (o.provenanceFile != null) {
297                logger.info("Writing JSON provenance to " + o.provenanceFile.toString());
298                try (PrintWriter writer = new PrintWriter(new OutputStreamWriter(new FileOutputStream(o.provenanceFile), StandardCharsets.UTF_8))) {
299                    writer.println(jsonResult);
300                }
301            }
302
303            ModelTuple<T> tuple = convertModel(input,provenanceHash,o);
304            logger.info("Writing model to " + o.outputModel);
305            oos.writeObject(tuple.model);
306
307            ModelProvenance newProvenance = tuple.provenance;
308            logger.info("Marshalling provenance and creating JSON.");
309            List<ObjectMarshalledProvenance> newList = ProvenanceUtil.marshalProvenance(newProvenance);
310            String newJsonResult = mapper.writeValueAsString(newList);
311
312            logger.info("Old provenance = \n" + jsonResult);
313            logger.info("New provenance = \n" + newJsonResult);
314        } catch (NoSuchMethodException e) {
315            logger.log(Level.SEVERE, "Model.copy method missing on a class which extends Model.",e);
316        } catch (IllegalAccessException e) {
317            logger.log(Level.SEVERE, "Failed to modify protection on inner copy method on Model.",e);
318        } catch (InvocationTargetException e) {
319            logger.log(Level.SEVERE, "Failed to invoke inner copy method on Model.",e);
320        } catch (UnsupportedEncodingException e) {
321            logger.log(Level.SEVERE, "Unsupported encoding exception.",e);
322        } catch (FileNotFoundException e) {
323            logger.log(Level.SEVERE, "Failed to find the input file.",e);
324        } catch (IOException e) {
325            logger.log(Level.SEVERE, "IO error when reading or writing a file.",e);
326        } catch (ClassNotFoundException e) {
327            logger.log(Level.SEVERE, "The model and/or provenance classes are not on the classpath.",e);
328        }
329
330    }
331
332    /**
333     * It's a record. Or at least it will be.
334     * @param <T> The ouput type.
335     */
336    private static class ModelTuple<T extends Output<T>> {
337        public final Model<T> model;
338        public final ModelProvenance provenance;
339
340        public ModelTuple(Model<T> model, ModelProvenance provenance) {
341            this.model = model;
342            this.provenance = provenance;
343        }
344    }
345}