001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.hash; 018 019import com.oracle.labs.mlrg.olcut.config.Config; 020import com.oracle.labs.mlrg.olcut.config.PropertyException; 021import com.oracle.labs.mlrg.olcut.provenance.ConfiguredObjectProvenance; 022import com.oracle.labs.mlrg.olcut.provenance.ObjectProvenance; 023import com.oracle.labs.mlrg.olcut.provenance.Provenance; 024import com.oracle.labs.mlrg.olcut.provenance.primitives.StringProvenance; 025 026import java.io.IOException; 027import java.io.ObjectInputStream; 028import java.nio.charset.Charset; 029import java.nio.charset.StandardCharsets; 030import java.security.MessageDigest; 031import java.security.NoSuchAlgorithmException; 032import java.util.Base64; 033import java.util.HashMap; 034import java.util.Map; 035import java.util.Objects; 036import java.util.function.Supplier; 037 038/** 039 * Hashes Strings using the supplied MessageDigest type. 040 */ 041public final class MessageDigestHasher extends Hasher { 042 private static final long serialVersionUID = 3L; 043 044 public static final Charset utf8Charset = StandardCharsets.UTF_8; 045 046 static final String HASH_TYPE = "hashType"; 047 048 @Config(mandatory = true,description="MessageDigest hashing function.") 049 private String hashType; 050 051 private transient ThreadLocal<MessageDigest> md; 052 053 /** 054 * Only used by olcut. 055 */ 056 @Config(mandatory = true,description="Salt used in the hash.",redact=true) 057 private transient String saltStr = null; 058 059 private transient byte[] salt = null; 060 061 private MessageDigestHasherProvenance provenance; 062 063 /** 064 * For olcut. 065 */ 066 private MessageDigestHasher() {} 067 068 public MessageDigestHasher(String hashType, String salt) { 069 this.hashType = hashType; 070 this.salt = salt.getBytes(utf8Charset); 071 this.md = ThreadLocal.withInitial(getDigestSupplier(hashType)); 072 MessageDigest d = this.md.get(); // To trigger the unsupported digest exception. 073 this.provenance = new MessageDigestHasherProvenance(hashType); 074 } 075 076 /** 077 * Used by the OLCUT configuration system, and should not be called by external code. 078 */ 079 @Override 080 public void postConfig() throws PropertyException { 081 if (saltStr != null) { 082 salt = saltStr.getBytes(utf8Charset); 083 } else { 084 throw new PropertyException("","saltStr","Salt not set in MessageDigestHasher."); 085 } 086 md = ThreadLocal.withInitial(getDigestSupplier(hashType)); 087 try { 088 MessageDigest d = md.get();// To trigger the unsupported digest exception. 089 } catch (IllegalArgumentException e) { 090 throw new PropertyException("","hashType","Unsupported hashType = " + hashType); 091 } 092 this.provenance = new MessageDigestHasherProvenance(hashType); 093 } 094 095 @Override 096 public String hash(String input) { 097 if (salt == null) { 098 throw new IllegalStateException("Salt not set."); 099 } 100 MessageDigest localDigest = md.get(); 101 localDigest.reset(); 102 localDigest.update(salt); 103 byte[] hash = localDigest.digest(input.getBytes(utf8Charset)); 104 return Base64.getEncoder().encodeToString(hash); 105 } 106 107 @Override 108 public void setSalt(String salt) { 109 if (Hasher.validateSalt(salt)) { 110 this.salt = salt.getBytes(utf8Charset); 111 } else { 112 throw new IllegalArgumentException("Salt: '" + salt + "', does not meet the requirements for a salt."); 113 } 114 } 115 116 private void readObject(ObjectInputStream in) throws IOException, ClassNotFoundException { 117 in.defaultReadObject(); 118 salt = null; 119 saltStr = null; 120 md = ThreadLocal.withInitial(getDigestSupplier(hashType)); 121 } 122 123 @Override 124 public String toString() { 125 return "MessageDigestHasher(algorithm="+md.get().getAlgorithm()+")"; 126 } 127 128 @Override 129 public ConfiguredObjectProvenance getProvenance() { 130 return provenance; 131 } 132 133 /** 134 * Creates a supplier for the specified hash type. 135 * @param hashType The hash type, used to specify the MessageDigest implementation. 136 * @return A supplier for the MessageDigest. 137 */ 138 public static Supplier<MessageDigest> getDigestSupplier(String hashType) { 139 return () -> { try { return MessageDigest.getInstance(hashType); } catch (NoSuchAlgorithmException e) { throw new IllegalArgumentException("Unsupported hashType = " + hashType,e);}}; 140 } 141 142 /** 143 * Provenance for {@link MessageDigestHasher}. 144 */ 145 public final static class MessageDigestHasherProvenance implements ConfiguredObjectProvenance { 146 private static final long serialVersionUID = 1L; 147 148 private final StringProvenance hashType; 149 150 MessageDigestHasherProvenance(String hashType) { 151 this.hashType = new StringProvenance(HASH_TYPE,hashType); 152 } 153 154 public MessageDigestHasherProvenance(Map<String, Provenance> map) { 155 hashType = ObjectProvenance.checkAndExtractProvenance(map,HASH_TYPE,StringProvenance.class,MessageDigestHasherProvenance.class.getSimpleName()); 156 } 157 158 @Override 159 public Map<String, Provenance> getConfiguredParameters() { 160 Map<String,Provenance> map = new HashMap<>(); 161 map.put("saltStr",new StringProvenance("saltStr","")); 162 map.put(HASH_TYPE,hashType); 163 return map; 164 } 165 166 @Override 167 public boolean equals(Object o) { 168 if (this == o) return true; 169 if (!(o instanceof MessageDigestHasherProvenance)) return false; 170 MessageDigestHasherProvenance pairs = (MessageDigestHasherProvenance) o; 171 return hashType.equals(pairs.hashType); 172 } 173 174 @Override 175 public int hashCode() { 176 return Objects.hash(hashType); 177 } 178 179 @Override 180 public String getClassName() { 181 return MessageDigestHasher.class.getName(); 182 } 183 184 @Override 185 public String toString() { 186 return generateString("Hasher"); 187 } 188 } 189 190}