001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.sequence;
018
019import com.oracle.labs.mlrg.olcut.config.Config;
020import com.oracle.labs.mlrg.olcut.provenance.Provenance;
021import org.tribuo.Output;
022import org.tribuo.hash.HashedFeatureMap;
023import org.tribuo.hash.Hasher;
024import org.tribuo.provenance.SkeletalTrainerProvenance;
025import org.tribuo.provenance.TrainerProvenance;
026
027import java.util.Map;
028import java.util.logging.Level;
029import java.util.logging.Logger;
030
031/**
032 * A SequenceTrainer that hashes all the feature names on the way in.
033 * <p>
034 * It wraps another SequenceTrainer which actually builds the {@link SequenceModel}.
035 * @param <T> The type of the output.
036 */
037public final class HashingSequenceTrainer<T extends Output<T>> implements SequenceTrainer<T> {
038    private static final Logger logger = Logger.getLogger(HashingSequenceTrainer.class.getName());
039
040    @Config(mandatory = true,description="Trainer to use.")
041    private SequenceTrainer<T> innerTrainer;
042
043    @Config(mandatory = true,description="Feature hashing function to use.")
044    private Hasher hasher;
045
046    /**
047     * For olcut.
048     */
049    private HashingSequenceTrainer() {}
050
051    public HashingSequenceTrainer(SequenceTrainer<T> trainer, Hasher hasher) {
052        this.innerTrainer = trainer;
053        this.hasher = hasher;
054    }
055
056    /**
057     * This clones the {@link SequenceDataset}, hashes each of the examples
058     * and rewrites their feature ids before passing it to the inner trainer.
059     * <p>
060     * This ensures the Trainer sees the data after the collisions, and thus
061     * builds the correct size data structures.
062     * @param sequenceExamples The input dataset.
063     * @param instanceProvenance Training run specific provenance information.
064     * @return A trained {@link SequenceModel}.
065     */
066    @Override
067    public SequenceModel<T> train(SequenceDataset<T> sequenceExamples, Map<String, Provenance> instanceProvenance) {
068        logger.log(Level.INFO,"Before hashing, had " + sequenceExamples.getFeatureIDMap().size() + " features.");
069        SequenceDataset<T> hashedData = ImmutableSequenceDataset.changeFeatureMap(sequenceExamples, HashedFeatureMap.generateHashedFeatureMap(sequenceExamples.getFeatureIDMap(),hasher));
070        logger.log(Level.INFO,"After hashing, had " + hashedData.getFeatureIDMap().size() + " features.");
071        SequenceModel<T> model = innerTrainer.train(hashedData,instanceProvenance);
072        if (!(model.featureIDMap instanceof HashedFeatureMap)) {
073            //
074            // This exception is thrown when the innerTrainer did not copy the ImmutableFeatureMap from the
075            // ImmutableDataset, but modified it in some way. For example Viterbi will do this.
076            throw new IllegalStateException("Trainer " + innerTrainer.getClass().getName() + " does not support hashing.");
077        }
078        return model;
079    }
080
081    @Override
082    public int getInvocationCount() {
083        return innerTrainer.getInvocationCount();
084    }
085
086    @Override
087    public String toString() {
088        return "HashingSequenceTrainer(trainer="+innerTrainer.toString()+",hasher="+hasher.toString()+")";
089    }
090
091    @Override
092    public TrainerProvenance getProvenance() {
093        return new HashingSequenceTrainerProvenance(this);
094    }
095
096    /**
097     * Provenance for {@link HashingSequenceTrainer}.
098     */
099    public static class HashingSequenceTrainerProvenance extends SkeletalTrainerProvenance {
100        private static final long serialVersionUID = 1L;
101
102        <T extends Output<T>> HashingSequenceTrainerProvenance(HashingSequenceTrainer<T> host) {
103            super(host);
104        }
105
106        public HashingSequenceTrainerProvenance(Map<String, Provenance> map) {
107            super(extractProvenanceInfo(map));
108        }
109    }
110}