001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.hash; 018 019import com.oracle.labs.mlrg.olcut.config.Config; 020import com.oracle.labs.mlrg.olcut.provenance.Provenance; 021import org.tribuo.Dataset; 022import org.tribuo.ImmutableDataset; 023import org.tribuo.Model; 024import org.tribuo.Output; 025import org.tribuo.Trainer; 026import org.tribuo.provenance.TrainerProvenance; 027import org.tribuo.provenance.impl.TrainerProvenanceImpl; 028 029import java.util.Map; 030import java.util.logging.Level; 031import java.util.logging.Logger; 032 033/** 034 * A {@link Trainer} which hashes the {@link Dataset} before the {@link Model} 035 * is produced. This means the model does not contain any feature names, 036 * only one way hashes of names. 037 * <p> 038 * It wraps another Trainer which actually performs the training. 039 * @param <T> The type of Output this trainer works with. 040 */ 041public final class HashingTrainer<T extends Output<T>> implements Trainer<T> { 042 private static final Logger logger = Logger.getLogger(HashingTrainer.class.getName()); 043 044 @Config(mandatory = true,description="Trainer to use.") 045 private Trainer<T> innerTrainer; 046 047 @Config(mandatory = true,description="Feature hashing function to use.") 048 private Hasher hasher; 049 050 /** 051 * For olcut. 052 */ 053 private HashingTrainer() {} 054 055 public HashingTrainer(Trainer<T> trainer, Hasher hasher) { 056 this.innerTrainer = trainer; 057 this.hasher = hasher; 058 } 059 060 /** 061 * This clones the {@link Dataset}, hashes each of the examples 062 * and rewrites their feature ids before passing it to the inner trainer. 063 * <p> 064 * This ensures the Trainer sees the data after the collisions, and thus 065 * builds the correct size data structures. 066 * @param dataset The input dataset. 067 * @param instanceProvenance Provenance information specific to this execution of train (e.g., cross validation fold number). 068 * @return A trained {@link Model}. 069 */ 070 @Override 071 public Model<T> train(Dataset<T> dataset,Map<String, Provenance> instanceProvenance) { 072 logger.log(Level.INFO,"Before hashing, had " + dataset.getFeatureMap().size() + " features."); 073 ImmutableDataset<T> hashedData = ImmutableDataset.hashFeatureMap(dataset, hasher); 074 logger.log(Level.INFO,"After hashing, had " + hashedData.getFeatureMap().size() + " features."); 075 Model<T> model = innerTrainer.train(hashedData,instanceProvenance); 076 if (!(model.getFeatureIDMap() instanceof HashedFeatureMap)) { 077 // 078 // This exception is thrown when the innerTrainer did not copy the ImmutableFeatureMap from the 079 // ImmutableDataset, but modified it in some way. For example Viterbi will do this. 080 throw new IllegalStateException("Trainer " + innerTrainer.getClass().getName() + " does not support hashing."); 081 } 082 return model; 083 } 084 085 @Override 086 public int getInvocationCount() { 087 return innerTrainer.getInvocationCount(); 088 } 089 090 @Override 091 public TrainerProvenance getProvenance() { 092 return new TrainerProvenanceImpl(this); 093 } 094}