001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.hash;
018
019import org.tribuo.FeatureMap;
020import org.tribuo.ImmutableFeatureMap;
021import org.tribuo.Model;
022import org.tribuo.VariableIDInfo;
023import org.tribuo.VariableInfo;
024
025import java.util.Map;
026import java.util.TreeMap;
027
028/**
029 * A {@link FeatureMap} used by the {@link HashingTrainer} to
030 * provide feature name hashing and guarantee that the {@link Model}
031 * does not contain feature name information, but still works
032 * with unhashed features names.
033 */
034public final class HashedFeatureMap extends ImmutableFeatureMap {
035    private static final long serialVersionUID = 1L;
036
037    private final Hasher hasher;
038
039    private HashedFeatureMap(Hasher hasher) {
040        super();
041        this.hasher = hasher;
042    }
043
044    @Override
045    public VariableIDInfo get(String name) {
046        String hash = hasher.hash(name);
047        return (VariableIDInfo) m.get(hash);
048    }
049
050    /**
051     * Gets the id number for this feature, returns -1 if it's unknown.
052     * @param name The name of the feature.
053     * @return A non-negative integer if the feature is known, -1 otherwise.
054     */
055    @Override
056    public int getID(String name) {
057        VariableIDInfo info = get(name);
058        if (info != null) {
059            return info.getID();
060        } else {
061            return -1;
062        }
063    }
064
065    /**
066     * The salt is not serialised with the {@link Model}.
067     * It must be set after deserialisation to the same value from training time.
068     * <p>
069     * If the salt is invalid it will throw {@link IllegalArgumentException}.
070     * @param salt The salt value. Must be the same as the one from training time.
071     */
072    public void setSalt(String salt) {
073        hasher.setSalt(salt);
074    }
075
076    /**
077     * Converts a standard {@link FeatureMap} by hashing each entry
078     * using the supplied hash function {@link Hasher}.
079     * <p>
080     * This preserves the index ordering of the original feature names,
081     * which is important for making sure test time performance is good.
082     * <p>
083     * It guarantees any collisions will produce an feature id number lower
084     * than the previous feature's number, and so can be easily removed.
085     *
086     * @param map The {@link FeatureMap} to hash.
087     * @param hasher The hashing function.
088     * @return A {@link HashedFeatureMap}.
089     */
090    public static HashedFeatureMap generateHashedFeatureMap(FeatureMap map, Hasher hasher) {
091        HashedFeatureMap hashedMap = new HashedFeatureMap(hasher);
092        TreeMap<String,VariableInfo> treeHashMap = new TreeMap<>();
093        for (VariableInfo f : map) {
094            String hash = hasher.hash(f.getName());
095            if (!treeHashMap.containsKey(f.getName())) {
096                VariableInfo newF = f.rename(hash);
097                treeHashMap.put(f.getName(),newF);
098            }
099        }
100        int counter = 0;
101        for (Map.Entry<String,VariableInfo> e : treeHashMap.entrySet()) {
102            VariableIDInfo newF = e.getValue().makeIDInfo(counter);
103            if (!hashedMap.m.containsKey(newF.getName())) {
104                hashedMap.m.put(newF.getName(), newF);
105                hashedMap.idMap.put(newF.getID(), newF);
106                counter++;
107            }
108        }
109        hashedMap.size = hashedMap.m.size();
110        return hashedMap;
111    }
112
113}