001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.hash;
018
019import com.oracle.labs.mlrg.olcut.config.Config;
020import com.oracle.labs.mlrg.olcut.provenance.Provenance;
021import org.tribuo.Dataset;
022import org.tribuo.ImmutableDataset;
023import org.tribuo.Model;
024import org.tribuo.Output;
025import org.tribuo.Trainer;
026import org.tribuo.provenance.TrainerProvenance;
027import org.tribuo.provenance.impl.TrainerProvenanceImpl;
028
029import java.util.Map;
030import java.util.logging.Level;
031import java.util.logging.Logger;
032
033/**
034 * A {@link Trainer} which hashes the {@link Dataset} before the {@link Model}
035 * is produced. This means the model does not contain any feature names,
036 * only one way hashes of names.
037 * <p>
038 * It wraps another Trainer which actually performs the training.
039 * @param <T> The type of Output this trainer works with.
040 */
041public final class HashingTrainer<T extends Output<T>> implements Trainer<T> {
042    private static final Logger logger = Logger.getLogger(HashingTrainer.class.getName());
043
044    @Config(mandatory = true,description="Trainer to use.")
045    private Trainer<T> innerTrainer;
046
047    @Config(mandatory = true,description="Feature hashing function to use.")
048    private Hasher hasher;
049
050    /**
051     * For olcut.
052     */
053    private HashingTrainer() {}
054
055    public HashingTrainer(Trainer<T> trainer, Hasher hasher) {
056        this.innerTrainer = trainer;
057        this.hasher = hasher;
058    }
059
060    /**
061     * This clones the {@link Dataset}, hashes each of the examples
062     * and rewrites their feature ids before passing it to the inner trainer.
063     * <p>
064     * This ensures the Trainer sees the data after the collisions, and thus
065     * builds the correct size data structures.
066     * @param dataset The input dataset.
067     * @param instanceProvenance Provenance information specific to this execution of train (e.g., cross validation fold number).
068     * @return A trained {@link Model}.
069     */
070    @Override
071    public Model<T> train(Dataset<T> dataset,Map<String, Provenance> instanceProvenance) {
072        logger.log(Level.INFO,"Before hashing, had " + dataset.getFeatureMap().size() + " features.");
073        ImmutableDataset<T> hashedData = ImmutableDataset.hashFeatureMap(dataset, hasher);
074        logger.log(Level.INFO,"After hashing, had " + hashedData.getFeatureMap().size() + " features.");
075        Model<T> model = innerTrainer.train(hashedData,instanceProvenance);
076        if (!(model.getFeatureIDMap() instanceof HashedFeatureMap)) {
077            //
078            // This exception is thrown when the innerTrainer did not copy the ImmutableFeatureMap from the
079            // ImmutableDataset, but modified it in some way. For example Viterbi will do this.
080            throw new IllegalStateException("Trainer " + innerTrainer.getClass().getName() + " does not support hashing.");
081        }
082        return model;
083    }
084
085    @Override
086    public int getInvocationCount() {
087        return innerTrainer.getInvocationCount();
088    }
089
090    @Override
091    public TrainerProvenance getProvenance() {
092        return new TrainerProvenanceImpl(this);
093    }
094}