001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.transform; 018 019import com.oracle.labs.mlrg.olcut.config.Config; 020import com.oracle.labs.mlrg.olcut.provenance.Provenance; 021import org.tribuo.Dataset; 022import org.tribuo.Model; 023import org.tribuo.MutableDataset; 024import org.tribuo.Output; 025import org.tribuo.Trainer; 026import org.tribuo.provenance.ModelProvenance; 027import org.tribuo.provenance.TrainerProvenance; 028import org.tribuo.provenance.impl.TrainerProvenanceImpl; 029 030import java.time.OffsetDateTime; 031import java.util.Map; 032 033/** 034 * A {@link Trainer} which encapsulates another trainer plus a {@link TransformationMap} object 035 * to apply to each {@link Dataset} before training each {@link Model}. 036 * <p> 037 * Transformations only operate on observed values. To operate on implicit zeros then 038 * first call {@link MutableDataset#densify} on the datasets. 039 */ 040public final class TransformTrainer<T extends Output<T>> implements Trainer<T> { 041 042 @Config(mandatory = true,description="Trainer to use.") 043 private Trainer<T> innerTrainer; 044 045 @Config(mandatory = true,description="Transformations to apply.") 046 private TransformationMap transformations; 047 048 @Config(description="Densify all the features before applying transformations.") 049 private boolean densify; 050 051 /** 052 * For OLCUT. 053 */ 054 private TransformTrainer() {} 055 056 /** 057 * Creates a trainer which transforms the data before training, and stores 058 * the transformers along with the trained model in a {@link TransformedModel}. 059 * <p> 060 * This constructor makes a trainer which keeps the data sparse. 061 * @param innerTrainer The trainer to use. 062 * @param transformations The transformations to apply to the data first. 063 */ 064 public TransformTrainer(Trainer<T> innerTrainer, TransformationMap transformations) { 065 this(innerTrainer,transformations,false); 066 } 067 068 /** 069 * Creates a trainer which transforms the data before training, and stores 070 * the transformers along with the trained model in a {@link TransformedModel}. 071 * 072 * @param innerTrainer The trainer to use. 073 * @param transformations The transformations to apply to the data first. 074 * @param densify Densify the dataset (and any predict time data) before training/prediction. 075 */ 076 public TransformTrainer(Trainer<T> innerTrainer, TransformationMap transformations, boolean densify) { 077 this.innerTrainer = innerTrainer; 078 this.transformations = transformations; 079 this.densify = densify; 080 } 081 082 @Override 083 public TransformedModel<T> train(Dataset<T> examples, Map<String, Provenance> instanceProvenance) { 084 TransformerMap transformerMap = examples.createTransformers(transformations); 085 086 Dataset<T> transformedDataset = transformerMap.transformDataset(examples,densify); 087 088 Model<T> innerModel = innerTrainer.train(transformedDataset); 089 090 ModelProvenance provenance = new ModelProvenance(TransformedModel.class.getName(), OffsetDateTime.now(), transformedDataset.getProvenance(), getProvenance(), instanceProvenance); 091 092 return new TransformedModel<>(provenance,innerModel,transformerMap,densify); 093 } 094 095 @Override 096 public int getInvocationCount() { 097 return innerTrainer.getInvocationCount(); 098 } 099 100 @Override 101 public TrainerProvenance getProvenance() { 102 return new TrainerProvenanceImpl(this); 103 } 104}