001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.regression.impl;
018
019import com.oracle.labs.mlrg.olcut.config.Config;
020import com.oracle.labs.mlrg.olcut.provenance.Provenance;
021import org.tribuo.Dataset;
022import org.tribuo.Example;
023import org.tribuo.ImmutableFeatureMap;
024import org.tribuo.ImmutableOutputInfo;
025import org.tribuo.Model;
026import org.tribuo.Trainer;
027import org.tribuo.math.la.SparseVector;
028import org.tribuo.provenance.ModelProvenance;
029import org.tribuo.provenance.TrainerProvenance;
030import org.tribuo.regression.Regressor;
031
032import java.time.OffsetDateTime;
033import java.util.Collections;
034import java.util.LinkedHashMap;
035import java.util.Map;
036import java.util.Set;
037import java.util.SplittableRandom;
038
039/**
040 * Trains n independent binary {@link Model}s, each of which predicts a single {@link Regressor}.
041 * Generates the SparseVectors once to reduce allocation.
042 * <p>
043 * Then wraps it up in an {@link SkeletalIndependentRegressionModel} to provide a {@link Regressor}
044 * prediction.
045 * <p>
046 * It trains each model sequentially, and could be optimised to train in parallel.
047 */
048public abstract class SkeletalIndependentRegressionTrainer<T> implements Trainer<Regressor> {
049
050    @Config(description="Seed for the RNG, may be unused.")
051    private long seed = 1L;
052
053    private SplittableRandom rng;
054
055    private int trainInvocationCounter = 0;
056
057    /**
058     * for olcut.
059     */
060    protected SkeletalIndependentRegressionTrainer() {}
061
062    @Override
063    public synchronized void postConfig() {
064        this.rng = new SplittableRandom(seed);
065    }
066
067    @Override
068    public SkeletalIndependentRegressionModel train(Dataset<Regressor> examples) {
069        return train(examples, Collections.emptyMap());
070    }
071
072    @Override
073    public SkeletalIndependentRegressionModel train(Dataset<Regressor> examples, Map<String, Provenance> runProvenance) {
074        if (examples.getOutputInfo().getUnknownCount() > 0) {
075            throw new IllegalArgumentException("The supplied Dataset contained unknown Outputs, and this Trainer is supervised.");
076        }
077        SplittableRandom localRNG;
078        TrainerProvenance trainerProvenance;
079        synchronized(this) {
080            localRNG = rng.split();
081            trainerProvenance = getProvenance();
082            trainInvocationCounter++;
083        }
084        ImmutableOutputInfo<Regressor> outputInfo = examples.getOutputIDInfo();
085        ImmutableFeatureMap featureMap = examples.getFeatureIDMap();
086        Set<Regressor> domain = outputInfo.getDomain();
087        LinkedHashMap<String, T> models = new LinkedHashMap<>();
088        int numExamples = examples.size();
089        boolean needBias = useBias();
090        float[] weights = new float[numExamples];
091        double[][] outputs = new double[outputInfo.size()][numExamples];
092        SparseVector[] inputs = new SparseVector[numExamples];
093        int i = 0;
094        for (Example<Regressor> e : examples) {
095            inputs[i] = SparseVector.createSparseVector(e,featureMap,needBias);
096            weights[i] = e.getWeight();
097            for (Regressor.DimensionTuple r : e.getOutput()) {
098                int id = outputInfo.getID(r);
099                outputs[id][i] = r.getValue();
100            }
101            i++;
102        }
103        for (Regressor r : domain) {
104            int id = outputInfo.getID(r);
105            T innerModel = trainDimension(outputs[id],inputs,weights,localRNG);
106            models.put(r.getNames()[0],innerModel);
107        }
108        ModelProvenance provenance = new ModelProvenance(getModelClassName(), OffsetDateTime.now(), examples.getProvenance(), trainerProvenance, runProvenance);
109        return createModel(models,provenance,featureMap,outputInfo);
110    }
111
112    @Override
113    public int getInvocationCount() {
114        return trainInvocationCounter;
115    }
116
117    /**
118     * Constructs the appropriate subclass of {@link SkeletalIndependentRegressionModel} for this trainer.
119     * @param models The models to use.
120     * @param provenance The model provenance
121     * @param featureMap The feature map.
122     * @param outputInfo The regression info.
123     * @return A subclass of IndependentRegressionModel.
124     */
125    protected abstract SkeletalIndependentRegressionModel createModel(Map<String,T> models, ModelProvenance provenance, ImmutableFeatureMap featureMap, ImmutableOutputInfo<Regressor> outputInfo);
126
127    /**
128     * Trains a single dimension of the possibly multiple dimensions.
129     * @param outputs The regression targets for this dimension.
130     * @param features The features.
131     * @param weights The example weights.
132     * @param rng The RNG to use.
133     * @return An object representing the model. Should be the same type as that expected by {@link #createModel}.
134     */
135    protected abstract T trainDimension(double[] outputs, SparseVector[] features, float[] weights, SplittableRandom rng);
136
137    /**
138     * Returns true if the SparseVector should be constructed with a bias feature.
139     * @return True if the trainer needs a bias.
140     */
141    protected abstract boolean useBias();
142
143    /**
144     * Returns the class name of the model that this class produces.
145     * @return The class name of the model.
146     */
147    protected abstract String getModelClassName();
148}
149