001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.hash; 018 019import com.oracle.labs.mlrg.olcut.config.Option; 020import com.oracle.labs.mlrg.olcut.config.Options; 021import org.tribuo.Output; 022import org.tribuo.Trainer; 023 024import java.util.Optional; 025import java.util.logging.Logger; 026 027/** 028 * An Options implementation which provides CLI arguments for the model hashing functionality. 029 */ 030public class HashingOptions implements Options { 031 private static final Logger logger = Logger.getLogger(HashingOptions.class.getName()); 032 033 /** 034 * Supported types of hashes in CLI programs. 035 */ 036 public enum ModelHashingType { NONE, MOD, HC, SHA1, SHA256 } 037 038 @Option(longName="model-hashing-algorithm",usage="Hash the model during training, options are {NONE,MOD,HC,SHA1,SHA256}") 039 public ModelHashingType modelHashingAlgorithm = ModelHashingType.NONE; 040 @Option(longName="model-hashing-salt",usage="Salt for hashing the model") 041 public String modelHashingSalt = ""; 042 043 /** 044 * Get the specified hasher. 045 * @return The configured hasher. 046 */ 047 public Optional<Hasher> getHasher() { 048 if (modelHashingAlgorithm == ModelHashingType.NONE) { 049 return Optional.empty(); 050 } else if (Hasher.validateSalt(modelHashingSalt)) { 051 switch (modelHashingAlgorithm) { 052 case MOD: 053 return Optional.of(new ModHashCodeHasher(modelHashingSalt)); 054 case HC: 055 return Optional.of(new HashCodeHasher(modelHashingSalt)); 056 case SHA1: 057 return Optional.of(new MessageDigestHasher("SHA1", modelHashingSalt)); 058 case SHA256: 059 return Optional.of(new MessageDigestHasher("SHA-256", modelHashingSalt)); 060 default: 061 logger.info("Unknown hasher " + modelHashingAlgorithm); 062 return Optional.empty(); 063 } 064 } else { 065 logger.info("Invalid salt"); 066 return Optional.empty(); 067 } 068 } 069 070 /** 071 * Gets the trainer wrapped in a hashing trainer. 072 * @param innerTrainer The inner trainer. 073 * @param <T> The output type. 074 * @return The hashing trainer. 075 */ 076 public <T extends Output<T>> Trainer<T> getHashedTrainer(Trainer<T> innerTrainer) { 077 Optional<Hasher> hasherOpt = getHasher(); 078 if (hasherOpt.isPresent()) { 079 return new HashingTrainer<>(innerTrainer,hasherOpt.get()); 080 } else { 081 throw new IllegalArgumentException("Invalid Hasher"); 082 } 083 } 084}