001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.hash;
018
019import com.oracle.labs.mlrg.olcut.config.Config;
020import com.oracle.labs.mlrg.olcut.config.PropertyException;
021import com.oracle.labs.mlrg.olcut.provenance.ConfiguredObjectProvenance;
022import com.oracle.labs.mlrg.olcut.provenance.ObjectProvenance;
023import com.oracle.labs.mlrg.olcut.provenance.Provenance;
024import com.oracle.labs.mlrg.olcut.provenance.primitives.StringProvenance;
025
026import java.io.IOException;
027import java.io.ObjectInputStream;
028import java.nio.charset.Charset;
029import java.nio.charset.StandardCharsets;
030import java.security.MessageDigest;
031import java.security.NoSuchAlgorithmException;
032import java.util.Base64;
033import java.util.HashMap;
034import java.util.Map;
035import java.util.Objects;
036import java.util.function.Supplier;
037
038/**
039 * Hashes Strings using the supplied MessageDigest type.
040 */
041public final class MessageDigestHasher extends Hasher {
042    private static final long serialVersionUID = 3L;
043
044    public static final Charset utf8Charset = StandardCharsets.UTF_8;
045
046    static final String HASH_TYPE = "hashType";
047
048    @Config(mandatory = true,description="MessageDigest hashing function.")
049    private String hashType;
050
051    private transient ThreadLocal<MessageDigest> md;
052
053    /**
054     * Only used by olcut.
055     */
056    @Config(mandatory = true,description="Salt used in the hash.",redact=true)
057    private transient String saltStr = null;
058
059    private transient byte[] salt = null;
060
061    private MessageDigestHasherProvenance provenance;
062
063    /**
064     * For olcut.
065     */
066    private MessageDigestHasher() {}
067
068    public MessageDigestHasher(String hashType, String salt) {
069        this.hashType = hashType;
070        this.salt = salt.getBytes(utf8Charset);
071        this.md = ThreadLocal.withInitial(getDigestSupplier(hashType));
072        MessageDigest d = this.md.get(); // To trigger the unsupported digest exception.
073        this.provenance = new MessageDigestHasherProvenance(hashType);
074    }
075
076    /**
077     * Used by the OLCUT configuration system, and should not be called by external code.
078     */
079    @Override
080    public void postConfig() throws PropertyException {
081        if (saltStr != null) {
082            salt = saltStr.getBytes(utf8Charset);
083        } else {
084            throw new PropertyException("","saltStr","Salt not set in MessageDigestHasher.");
085        }
086        md = ThreadLocal.withInitial(getDigestSupplier(hashType));
087        try {
088            MessageDigest d = md.get();// To trigger the unsupported digest exception.
089        } catch (IllegalArgumentException e) {
090            throw new PropertyException("","hashType","Unsupported hashType = " + hashType);
091        }
092        this.provenance = new MessageDigestHasherProvenance(hashType);
093    }
094
095    @Override
096    public String hash(String input) {
097        if (salt == null) {
098            throw new IllegalStateException("Salt not set.");
099        }
100        MessageDigest localDigest = md.get();
101        localDigest.reset();
102        localDigest.update(salt);
103        byte[] hash = localDigest.digest(input.getBytes(utf8Charset));
104        return Base64.getEncoder().encodeToString(hash);
105    }
106
107    @Override
108    public void setSalt(String salt) {
109        if (Hasher.validateSalt(salt)) {
110            this.salt = salt.getBytes(utf8Charset);
111        } else {
112            throw new IllegalArgumentException("Salt: '" + salt + "', does not meet the requirements for a salt.");
113        }
114    }
115
116    private void readObject(ObjectInputStream in) throws IOException, ClassNotFoundException {
117        in.defaultReadObject();
118        salt = null;
119        saltStr = null;
120        md = ThreadLocal.withInitial(getDigestSupplier(hashType));
121    }
122
123    @Override
124    public String toString() {
125        return "MessageDigestHasher(algorithm="+md.get().getAlgorithm()+")";
126    }
127
128    @Override
129    public ConfiguredObjectProvenance getProvenance() {
130        return provenance;
131    }
132
133    /**
134     * Creates a supplier for the specified hash type.
135     * @param hashType The hash type, used to specify the MessageDigest implementation.
136     * @return A supplier for the MessageDigest.
137     */
138    public static Supplier<MessageDigest> getDigestSupplier(String hashType) {
139        return () -> { try { return MessageDigest.getInstance(hashType); } catch (NoSuchAlgorithmException e) { throw new IllegalArgumentException("Unsupported hashType = " + hashType,e);}};
140    }
141
142    /**
143     * Provenance for {@link MessageDigestHasher}.
144     */
145    public final static class MessageDigestHasherProvenance implements ConfiguredObjectProvenance {
146        private static final long serialVersionUID = 1L;
147
148        private final StringProvenance hashType;
149
150        MessageDigestHasherProvenance(String hashType) {
151            this.hashType = new StringProvenance(HASH_TYPE,hashType);
152        }
153
154        public MessageDigestHasherProvenance(Map<String, Provenance> map) {
155            hashType = ObjectProvenance.checkAndExtractProvenance(map,HASH_TYPE,StringProvenance.class,MessageDigestHasherProvenance.class.getSimpleName());
156        }
157
158        @Override
159        public Map<String, Provenance> getConfiguredParameters() {
160            Map<String,Provenance> map = new HashMap<>();
161            map.put("saltStr",new StringProvenance("saltStr",""));
162            map.put(HASH_TYPE,hashType);
163            return map;
164        }
165
166        @Override
167        public boolean equals(Object o) {
168            if (this == o) return true;
169            if (!(o instanceof MessageDigestHasherProvenance)) return false;
170            MessageDigestHasherProvenance pairs = (MessageDigestHasherProvenance) o;
171            return hashType.equals(pairs.hashType);
172        }
173
174        @Override
175        public int hashCode() {
176            return Objects.hash(hashType);
177        }
178
179        @Override
180        public String getClassName() {
181            return MessageDigestHasher.class.getName();
182        }
183
184        @Override
185        public String toString() {
186            return generateString("Hasher");
187        }
188    }
189
190}