001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.json; 018 019import com.fasterxml.jackson.databind.ObjectMapper; 020import com.fasterxml.jackson.databind.SerializationFeature; 021import com.oracle.labs.mlrg.olcut.config.ConfigurationManager; 022import com.oracle.labs.mlrg.olcut.config.Option; 023import com.oracle.labs.mlrg.olcut.config.Options; 024import com.oracle.labs.mlrg.olcut.config.UsageException; 025import com.oracle.labs.mlrg.olcut.config.json.JsonProvenanceModule; 026import com.oracle.labs.mlrg.olcut.provenance.ListProvenance; 027import com.oracle.labs.mlrg.olcut.provenance.ObjectProvenance; 028import com.oracle.labs.mlrg.olcut.provenance.Provenance; 029import com.oracle.labs.mlrg.olcut.provenance.ProvenanceUtil; 030import com.oracle.labs.mlrg.olcut.provenance.io.ObjectMarshalledProvenance; 031import com.oracle.labs.mlrg.olcut.provenance.primitives.HashProvenance; 032import com.oracle.labs.mlrg.olcut.util.IOUtil; 033import com.oracle.labs.mlrg.olcut.util.LabsLogFormatter; 034import org.tribuo.Model; 035import org.tribuo.Output; 036import org.tribuo.ensemble.EnsembleModel; 037import org.tribuo.provenance.DatasetProvenance; 038import org.tribuo.provenance.EnsembleModelProvenance; 039import org.tribuo.provenance.ModelProvenance; 040import org.tribuo.provenance.TrainerProvenance; 041import org.tribuo.provenance.impl.EmptyDatasetProvenance; 042import org.tribuo.provenance.impl.EmptyTrainerProvenance; 043 044import java.io.File; 045import java.io.FileNotFoundException; 046import java.io.FileOutputStream; 047import java.io.IOException; 048import java.io.ObjectInputStream; 049import java.io.ObjectOutputStream; 050import java.io.OutputStreamWriter; 051import java.io.PrintWriter; 052import java.io.UnsupportedEncodingException; 053import java.lang.reflect.InvocationTargetException; 054import java.lang.reflect.Method; 055import java.nio.charset.StandardCharsets; 056import java.security.MessageDigest; 057import java.time.OffsetDateTime; 058import java.util.ArrayList; 059import java.util.EnumSet; 060import java.util.HashMap; 061import java.util.List; 062import java.util.Map; 063import java.util.logging.Level; 064import java.util.logging.Logger; 065 066import static org.tribuo.json.StripProvenance.ProvenanceTypes.ALL; 067import static org.tribuo.json.StripProvenance.ProvenanceTypes.DATASET; 068import static org.tribuo.json.StripProvenance.ProvenanceTypes.INSTANCE; 069import static org.tribuo.json.StripProvenance.ProvenanceTypes.TRAINER; 070 071/** 072 * A main class for stripping out and storing provenance from a model. 073 * <p> 074 * Provenance stripping is useful for deploying models where others may 075 * be able to inspect the model metadata and discover things about the model's 076 * training procedure. 077 */ 078public final class StripProvenance { 079 private static final Logger logger = Logger.getLogger(StripProvenance.class.getName()); 080 081 private StripProvenance() { } 082 083 /** 084 * Types of provenance that can be removed. 085 */ 086 public enum ProvenanceTypes { 087 /** 088 * Select the dataset provenance. 089 */ 090 DATASET, 091 /** 092 * Select the trainer provenance. 093 */ 094 TRAINER, 095 /** 096 * Select any instance provenance from the specific training run that created this model. 097 */ 098 INSTANCE, 099 /** 100 * Selects all provenance stored in the model. 101 */ 102 ALL 103 } 104 105 /** 106 * Creates a new model provenance with the requested provenances stripped out. 107 * @param old The old model provenance. 108 * @param provenanceHash The hash of the provenance (if requested it can be written into the new provenance for tracking). 109 * @param opt The program options. 110 * @return A new model provenance. 111 */ 112 private static ModelProvenance cleanProvenance(ModelProvenance old, String provenanceHash, StripProvenanceOptions opt) { 113 // Dataset provenance 114 DatasetProvenance datasetProvenance; 115 if (opt.removeProvenances.contains(ALL) || opt.removeProvenances.contains(DATASET)) { 116 datasetProvenance = new EmptyDatasetProvenance(); 117 } else { 118 datasetProvenance = old.getDatasetProvenance(); 119 } 120 // Trainer provenance 121 TrainerProvenance trainerProvenance; 122 if (opt.removeProvenances.contains(ALL) || opt.removeProvenances.contains(TRAINER)) { 123 trainerProvenance = new EmptyTrainerProvenance(); 124 } else { 125 trainerProvenance = old.getTrainerProvenance(); 126 } 127 // Instance provenance 128 OffsetDateTime time; 129 Map<String, Provenance> instanceProvenance; 130 if (opt.removeProvenances.contains(ALL) || opt.removeProvenances.contains(INSTANCE)) { 131 instanceProvenance = new HashMap<>(); 132 time = OffsetDateTime.MIN; 133 } else { 134 instanceProvenance = new HashMap<>(old.getInstanceProvenance().getMap()); 135 time = old.getTrainingTime(); 136 } 137 if (opt.storeHash) { 138 logger.info("Writing provenance hash into instance map."); 139 instanceProvenance.put("original-provenance-hash",new HashProvenance(opt.hashType,"original-provenance-hash",provenanceHash)); 140 } 141 142 return new ModelProvenance(old.getClassName(),time,datasetProvenance,trainerProvenance,instanceProvenance); 143 } 144 145 /** 146 * Creates a new ensemble provenance with the requested information removed. 147 * @param old The old ensemble provenance. 148 * @param memberProvenance The new member provenances. 149 * @param provenanceHash The old ensemble provenance hash. 150 * @param opt The program options. 151 * @return The new ensemble provenance with the requested fields removed. 152 */ 153 private static EnsembleModelProvenance cleanEnsembleProvenance(EnsembleModelProvenance old, ListProvenance<ModelProvenance> memberProvenance, String provenanceHash, StripProvenanceOptions opt) { 154 // Dataset provenance 155 DatasetProvenance datasetProvenance; 156 if (opt.removeProvenances.contains(ALL) || opt.removeProvenances.contains(DATASET)) { 157 datasetProvenance = new EmptyDatasetProvenance(); 158 } else { 159 datasetProvenance = old.getDatasetProvenance(); 160 } 161 // Trainer provenance 162 TrainerProvenance trainerProvenance; 163 if (opt.removeProvenances.contains(ALL) || opt.removeProvenances.contains(TRAINER)) { 164 trainerProvenance = new EmptyTrainerProvenance(); 165 } else { 166 trainerProvenance = old.getTrainerProvenance(); 167 } 168 // Instance provenance 169 OffsetDateTime time; 170 Map<String, Provenance> instanceProvenance; 171 if (opt.removeProvenances.contains(ALL) || opt.removeProvenances.contains(INSTANCE)) { 172 instanceProvenance = new HashMap<>(); 173 time = OffsetDateTime.MIN; 174 } else { 175 instanceProvenance = new HashMap<>(old.getInstanceProvenance().getMap()); 176 time = old.getTrainingTime(); 177 } 178 if (opt.storeHash) { 179 logger.info("Writing provenance hash into instance map."); 180 instanceProvenance.put("original-provenance-hash",new HashProvenance(opt.hashType,"original-provenance-hash",provenanceHash)); 181 } 182 return new EnsembleModelProvenance(old.getClassName(),time,datasetProvenance,trainerProvenance,instanceProvenance,memberProvenance); 183 } 184 185 /** 186 * Creates a copy of the old model with the requested provenance removed. 187 * @param oldModel The model to remove provenance from. 188 * @param provenanceHash A hash of the old provenance. 189 * @param opt The program options. 190 * @param <T> The output type. 191 * @return A copy of the model with redacted provenance. 192 * @throws InvocationTargetException If the model doesn't expose a copy method (all models should do). 193 * @throws IllegalAccessException If the model's copy method is not accessible. 194 * @throws NoSuchMethodException If the model's copy method isn't present. 195 */ 196 @SuppressWarnings("unchecked") // cast of model after call to copy which returns model. 197 private static <T extends Output<T>> ModelTuple<T> convertModel(Model<T> oldModel, String provenanceHash, StripProvenanceOptions opt) throws InvocationTargetException, IllegalAccessException, NoSuchMethodException { 198 if (oldModel instanceof EnsembleModel) { 199 EnsembleModelProvenance oldProvenance = ((EnsembleModel<T>) oldModel).getProvenance(); 200 List<ModelProvenance> newProvenances = new ArrayList<>(); 201 List<Model<T>> newModels = new ArrayList<>(); 202 for (Model<T> e : ((EnsembleModel<T>) oldModel).getModels()) { 203 ModelTuple<T> tuple = convertModel(e,provenanceHash,opt); 204 newProvenances.add(tuple.provenance); 205 newModels.add(tuple.model); 206 } 207 ListProvenance<ModelProvenance> listProv = new ListProvenance<>(newProvenances); 208 EnsembleModelProvenance cleanedProvenance = cleanEnsembleProvenance(oldProvenance,listProv,provenanceHash,opt); 209 Class<? extends Model> clazz = oldModel.getClass(); 210 Method copyMethod = clazz.getDeclaredMethod("copy", String.class, ModelProvenance.class, List.class); 211 boolean accessible = copyMethod.isAccessible(); 212 copyMethod.setAccessible(true); 213 String newName = oldModel.getName().isEmpty() ? "deprovenanced" : oldModel.getName() + "-deprovenanced"; 214 EnsembleModel<T> output = (EnsembleModel<T>) copyMethod.invoke(oldModel, newName, cleanedProvenance, newModels); 215 copyMethod.setAccessible(accessible); 216 return new ModelTuple<>(output, cleanedProvenance); 217 } else { 218 ModelProvenance oldProvenance = oldModel.getProvenance(); 219 ModelProvenance cleanedProvenance = cleanProvenance(oldProvenance, provenanceHash, opt); 220 Class<? extends Model> clazz = oldModel.getClass(); 221 Method copyMethod = clazz.getDeclaredMethod("copy", String.class, ModelProvenance.class); 222 boolean accessible = copyMethod.isAccessible(); 223 copyMethod.setAccessible(true); 224 String newName = oldModel.getName().isEmpty() ? "deprovenanced" : oldModel.getName() + "-deprovenanced"; 225 Model<T> output = (Model<T>) copyMethod.invoke(oldModel, newName, cleanedProvenance); 226 copyMethod.setAccessible(accessible); 227 return new ModelTuple<>(output, cleanedProvenance); 228 } 229 } 230 231 public static class StripProvenanceOptions implements Options { 232 @Override 233 public String getOptionsDescription() { 234 return "A program for removing Provenance information from a Tribuo Model or SequenceModel."; 235 } 236 237 @Option(charName = 'h', longName = "store-provenance-hash", usage = "Stores a hash of the model provenance in the stripped model.") 238 public boolean storeHash; 239 @Option(charName = 'i', longName = "input-model-path", usage = "The model to load.") 240 public File inputModel; 241 @Option(charName = 'o', longName = "output-model-path", usage = "The location to write out the stripped model.") 242 public File outputModel; 243 @Option(charName = 'p', longName = "provenance-path", usage = "Write out the stripped provenance as json.") 244 public File provenanceFile; 245 @Option(charName = 'r', longName = "remove-provenances", usage = "The provenances to remove") 246 public EnumSet<ProvenanceTypes> removeProvenances = EnumSet.noneOf(ProvenanceTypes.class); 247 @Option(charName = 't', longName = "hash-type", usage = "The hash type to use.") 248 public ProvenanceUtil.HashType hashType = ObjectProvenance.DEFAULT_HASH_TYPE; 249 } 250 251 /** 252 * @param args the command line arguments 253 * @param <T> The {@link Output} subclass. 254 */ 255 @SuppressWarnings("unchecked") 256 public static <T extends Output<T>> void main(String[] args) { 257 258 // 259 // Use the labs format logging. 260 LabsLogFormatter.setAllLogFormatters(); 261 262 StripProvenanceOptions o = new StripProvenanceOptions(); 263 ConfigurationManager cm; 264 try { 265 cm = new ConfigurationManager(args, o); 266 } catch (UsageException e) { 267 logger.info(e.getMessage()); 268 return; 269 } 270 271 if (o.inputModel == null || o.outputModel == null) { 272 logger.info(cm.usage()); 273 System.exit(1); 274 } 275 276 try (ObjectInputStream ois = IOUtil.getObjectInputStream(o.inputModel); 277 ObjectOutputStream oos = new ObjectOutputStream(new FileOutputStream(o.outputModel))) { 278 logger.info("Loading model from " + o.inputModel); 279 Model<T> input = (Model<T>) ois.readObject(); 280 281 ModelProvenance oldProvenance = input.getProvenance(); 282 283 logger.info("Marshalling provenance and creating JSON."); 284 List<ObjectMarshalledProvenance> list = ProvenanceUtil.marshalProvenance(oldProvenance); 285 ObjectMapper mapper = new ObjectMapper(); 286 mapper.registerModule(new JsonProvenanceModule()); 287 mapper.enable(SerializationFeature.INDENT_OUTPUT); 288 String jsonResult = mapper.writeValueAsString(list); 289 290 logger.info("Hashing JSON file"); 291 MessageDigest digest = o.hashType.getDigest(); 292 byte[] digestBytes = digest.digest(jsonResult.getBytes(StandardCharsets.UTF_8)); 293 String provenanceHash = ProvenanceUtil.bytesToHexString(digestBytes); 294 logger.info("Provenance hash = " + provenanceHash); 295 296 if (o.provenanceFile != null) { 297 logger.info("Writing JSON provenance to " + o.provenanceFile.toString()); 298 try (PrintWriter writer = new PrintWriter(new OutputStreamWriter(new FileOutputStream(o.provenanceFile), StandardCharsets.UTF_8))) { 299 writer.println(jsonResult); 300 } 301 } 302 303 ModelTuple<T> tuple = convertModel(input,provenanceHash,o); 304 logger.info("Writing model to " + o.outputModel); 305 oos.writeObject(tuple.model); 306 307 ModelProvenance newProvenance = tuple.provenance; 308 logger.info("Marshalling provenance and creating JSON."); 309 List<ObjectMarshalledProvenance> newList = ProvenanceUtil.marshalProvenance(newProvenance); 310 String newJsonResult = mapper.writeValueAsString(newList); 311 312 logger.info("Old provenance = \n" + jsonResult); 313 logger.info("New provenance = \n" + newJsonResult); 314 } catch (NoSuchMethodException e) { 315 logger.log(Level.SEVERE, "Model.copy method missing on a class which extends Model.",e); 316 } catch (IllegalAccessException e) { 317 logger.log(Level.SEVERE, "Failed to modify protection on inner copy method on Model.",e); 318 } catch (InvocationTargetException e) { 319 logger.log(Level.SEVERE, "Failed to invoke inner copy method on Model.",e); 320 } catch (UnsupportedEncodingException e) { 321 logger.log(Level.SEVERE, "Unsupported encoding exception.",e); 322 } catch (FileNotFoundException e) { 323 logger.log(Level.SEVERE, "Failed to find the input file.",e); 324 } catch (IOException e) { 325 logger.log(Level.SEVERE, "IO error when reading or writing a file.",e); 326 } catch (ClassNotFoundException e) { 327 logger.log(Level.SEVERE, "The model and/or provenance classes are not on the classpath.",e); 328 } 329 330 } 331 332 /** 333 * It's a record. Or at least it will be. 334 * @param <T> The ouput type. 335 */ 336 private static class ModelTuple<T extends Output<T>> { 337 public final Model<T> model; 338 public final ModelProvenance provenance; 339 340 public ModelTuple(Model<T> model, ModelProvenance provenance) { 341 this.model = model; 342 this.provenance = provenance; 343 } 344 } 345}