001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.hash;
018
019import com.oracle.labs.mlrg.olcut.config.Option;
020import com.oracle.labs.mlrg.olcut.config.Options;
021import org.tribuo.Output;
022import org.tribuo.Trainer;
023
024import java.util.Optional;
025import java.util.logging.Logger;
026
027/**
028 * An Options implementation which provides CLI arguments for the model hashing functionality.
029 */
030public class HashingOptions implements Options {
031    private static final Logger logger = Logger.getLogger(HashingOptions.class.getName());
032
033    /**
034     * Supported types of hashes in CLI programs.
035     */
036    public enum ModelHashingType { NONE, MOD, HC, SHA1, SHA256 }
037
038    @Option(longName="model-hashing-algorithm",usage="Hash the model during training, options are {NONE,MOD,HC,SHA1,SHA256}")
039    public ModelHashingType modelHashingAlgorithm = ModelHashingType.NONE;
040    @Option(longName="model-hashing-salt",usage="Salt for hashing the model")
041    public String modelHashingSalt = "";
042
043    /**
044     * Get the specified hasher.
045     * @return The configured hasher.
046     */
047    public Optional<Hasher> getHasher() {
048        if (modelHashingAlgorithm == ModelHashingType.NONE) {
049            return Optional.empty();
050        } else if (Hasher.validateSalt(modelHashingSalt)) {
051            switch (modelHashingAlgorithm) {
052                case MOD:
053                    return Optional.of(new ModHashCodeHasher(modelHashingSalt));
054                case HC:
055                    return Optional.of(new HashCodeHasher(modelHashingSalt));
056                case SHA1:
057                    return Optional.of(new MessageDigestHasher("SHA1", modelHashingSalt));
058                case SHA256:
059                    return Optional.of(new MessageDigestHasher("SHA-256", modelHashingSalt));
060                default:
061                    logger.info("Unknown hasher " + modelHashingAlgorithm);
062                    return Optional.empty();
063            }
064        } else {
065            logger.info("Invalid salt");
066            return Optional.empty();
067        }
068    }
069
070    /**
071     * Gets the trainer wrapped in a hashing trainer.
072     * @param innerTrainer The inner trainer.
073     * @param <T> The output type.
074     * @return The hashing trainer.
075     */
076    public <T extends Output<T>> Trainer<T> getHashedTrainer(Trainer<T> innerTrainer) {
077        Optional<Hasher> hasherOpt = getHasher();
078        if (hasherOpt.isPresent()) {
079            return new HashingTrainer<>(innerTrainer,hasherOpt.get());
080        } else {
081            throw new IllegalArgumentException("Invalid Hasher");
082        }
083    }
084}