001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.transform;
018
019import com.oracle.labs.mlrg.olcut.config.Config;
020import com.oracle.labs.mlrg.olcut.provenance.Provenance;
021import org.tribuo.Dataset;
022import org.tribuo.Model;
023import org.tribuo.MutableDataset;
024import org.tribuo.Output;
025import org.tribuo.Trainer;
026import org.tribuo.provenance.ModelProvenance;
027import org.tribuo.provenance.TrainerProvenance;
028import org.tribuo.provenance.impl.TrainerProvenanceImpl;
029
030import java.time.OffsetDateTime;
031import java.util.Map;
032
033/**
034 * A {@link Trainer} which encapsulates another trainer plus a {@link TransformationMap} object
035 * to apply to each {@link Dataset} before training each {@link Model}.
036 * <p>
037 * Transformations only operate on observed values. To operate on implicit zeros then
038 * first call {@link MutableDataset#densify} on the datasets.
039 */
040public final class TransformTrainer<T extends Output<T>> implements Trainer<T> {
041
042    @Config(mandatory = true,description="Trainer to use.")
043    private Trainer<T> innerTrainer;
044
045    @Config(mandatory = true,description="Transformations to apply.")
046    private TransformationMap transformations;
047
048    @Config(description="Densify all the features before applying transformations.")
049    private boolean densify;
050
051    /**
052     * For OLCUT.
053     */
054    private TransformTrainer() {}
055
056    /**
057     * Creates a trainer which transforms the data before training, and stores
058     * the transformers along with the trained model in a {@link TransformedModel}.
059     * <p>
060     * This constructor makes a trainer which keeps the data sparse.
061     * @param innerTrainer The trainer to use.
062     * @param transformations The transformations to apply to the data first.
063     */
064    public TransformTrainer(Trainer<T> innerTrainer, TransformationMap transformations) {
065        this(innerTrainer,transformations,false);
066    }
067
068    /**
069     * Creates a trainer which transforms the data before training, and stores
070     * the transformers along with the trained model in a {@link TransformedModel}.
071     *
072     * @param innerTrainer The trainer to use.
073     * @param transformations The transformations to apply to the data first.
074     * @param densify Densify the dataset (and any predict time data) before training/prediction.
075     */
076    public TransformTrainer(Trainer<T> innerTrainer, TransformationMap transformations, boolean densify) {
077        this.innerTrainer = innerTrainer;
078        this.transformations = transformations;
079        this.densify = densify;
080    }
081
082    @Override
083    public TransformedModel<T> train(Dataset<T> examples, Map<String, Provenance> instanceProvenance) {
084        TransformerMap transformerMap = examples.createTransformers(transformations);
085
086        Dataset<T> transformedDataset = transformerMap.transformDataset(examples,densify);
087
088        Model<T> innerModel = innerTrainer.train(transformedDataset);
089
090        ModelProvenance provenance = new ModelProvenance(TransformedModel.class.getName(), OffsetDateTime.now(), transformedDataset.getProvenance(), getProvenance(), instanceProvenance);
091
092        return new TransformedModel<>(provenance,innerModel,transformerMap,densify);
093    }
094
095    @Override
096    public int getInvocationCount() {
097        return innerTrainer.getInvocationCount();
098    }
099
100    @Override
101    public TrainerProvenance getProvenance() {
102        return new TrainerProvenanceImpl(this);
103    }
104}