001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.hash; 018 019import org.tribuo.FeatureMap; 020import org.tribuo.ImmutableFeatureMap; 021import org.tribuo.Model; 022import org.tribuo.VariableIDInfo; 023import org.tribuo.VariableInfo; 024 025import java.util.Map; 026import java.util.TreeMap; 027 028/** 029 * A {@link FeatureMap} used by the {@link HashingTrainer} to 030 * provide feature name hashing and guarantee that the {@link Model} 031 * does not contain feature name information, but still works 032 * with unhashed features names. 033 */ 034public final class HashedFeatureMap extends ImmutableFeatureMap { 035 private static final long serialVersionUID = 1L; 036 037 private final Hasher hasher; 038 039 private HashedFeatureMap(Hasher hasher) { 040 super(); 041 this.hasher = hasher; 042 } 043 044 @Override 045 public VariableIDInfo get(String name) { 046 String hash = hasher.hash(name); 047 return (VariableIDInfo) m.get(hash); 048 } 049 050 /** 051 * Gets the id number for this feature, returns -1 if it's unknown. 052 * @param name The name of the feature. 053 * @return A non-negative integer if the feature is known, -1 otherwise. 054 */ 055 @Override 056 public int getID(String name) { 057 VariableIDInfo info = get(name); 058 if (info != null) { 059 return info.getID(); 060 } else { 061 return -1; 062 } 063 } 064 065 /** 066 * The salt is not serialised with the {@link Model}. 067 * It must be set after deserialisation to the same value from training time. 068 * <p> 069 * If the salt is invalid it will throw {@link IllegalArgumentException}. 070 * @param salt The salt value. Must be the same as the one from training time. 071 */ 072 public void setSalt(String salt) { 073 hasher.setSalt(salt); 074 } 075 076 /** 077 * Converts a standard {@link FeatureMap} by hashing each entry 078 * using the supplied hash function {@link Hasher}. 079 * <p> 080 * This preserves the index ordering of the original feature names, 081 * which is important for making sure test time performance is good. 082 * <p> 083 * It guarantees any collisions will produce an feature id number lower 084 * than the previous feature's number, and so can be easily removed. 085 * 086 * @param map The {@link FeatureMap} to hash. 087 * @param hasher The hashing function. 088 * @return A {@link HashedFeatureMap}. 089 */ 090 public static HashedFeatureMap generateHashedFeatureMap(FeatureMap map, Hasher hasher) { 091 HashedFeatureMap hashedMap = new HashedFeatureMap(hasher); 092 TreeMap<String,VariableInfo> treeHashMap = new TreeMap<>(); 093 for (VariableInfo f : map) { 094 String hash = hasher.hash(f.getName()); 095 if (!treeHashMap.containsKey(f.getName())) { 096 VariableInfo newF = f.rename(hash); 097 treeHashMap.put(f.getName(),newF); 098 } 099 } 100 int counter = 0; 101 for (Map.Entry<String,VariableInfo> e : treeHashMap.entrySet()) { 102 VariableIDInfo newF = e.getValue().makeIDInfo(counter); 103 if (!hashedMap.m.containsKey(newF.getName())) { 104 hashedMap.m.put(newF.getName(), newF); 105 hashedMap.idMap.put(newF.getID(), newF); 106 counter++; 107 } 108 } 109 hashedMap.size = hashedMap.m.size(); 110 return hashedMap; 111 } 112 113}