001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.regression.impl;
018
019import com.oracle.labs.mlrg.olcut.config.Config;
020import com.oracle.labs.mlrg.olcut.provenance.Provenance;
021import org.tribuo.Dataset;
022import org.tribuo.Example;
023import org.tribuo.ImmutableFeatureMap;
024import org.tribuo.ImmutableOutputInfo;
025import org.tribuo.SparseTrainer;
026import org.tribuo.math.la.SparseVector;
027import org.tribuo.provenance.ModelProvenance;
028import org.tribuo.provenance.TrainerProvenance;
029import org.tribuo.regression.Regressor;
030
031import java.time.OffsetDateTime;
032import java.util.Collections;
033import java.util.LinkedHashMap;
034import java.util.Map;
035import java.util.Set;
036import java.util.SplittableRandom;
037
038/**
039 * Base class for training n independent sparse models, one per dimension. Generates the SparseVectors
040 * once to reduce allocation.
041 * <p>
042 * Then wraps them in an {@link SkeletalIndependentRegressionSparseModel} to provide a {@link Regressor}
043 * prediction.
044 * <p>
045 * It trains each model sequentially, and could be optimised to train in parallel.
046 */
047public abstract class SkeletalIndependentRegressionSparseTrainer<T> implements SparseTrainer<Regressor> {
048
049    @Config(description="Seed for the RNG, may be unused.")
050    private long seed = 1L;
051
052    private SplittableRandom rng;
053
054    private int trainInvocationCounter = 0;
055
056    /**
057     * for olcut.
058     */
059    protected SkeletalIndependentRegressionSparseTrainer() {}
060
061    @Override
062    public synchronized void postConfig() {
063        this.rng = new SplittableRandom(seed);
064    }
065
066    @Override
067    public SkeletalIndependentRegressionSparseModel train(Dataset<Regressor> examples) {
068        return train(examples, Collections.emptyMap());
069    }
070
071    @Override
072    public SkeletalIndependentRegressionSparseModel train(Dataset<Regressor> examples, Map<String, Provenance> runProvenance) {
073        if (examples.getOutputInfo().getUnknownCount() > 0) {
074            throw new IllegalArgumentException("The supplied Dataset contained unknown Outputs, and this Trainer is supervised.");
075        }
076        SplittableRandom localRNG;
077        TrainerProvenance trainerProvenance;
078        synchronized(this) {
079            localRNG = rng.split();
080            trainerProvenance = getProvenance();
081            trainInvocationCounter++;
082        }
083        ImmutableOutputInfo<Regressor> outputInfo = examples.getOutputIDInfo();
084        ImmutableFeatureMap featureMap = examples.getFeatureIDMap();
085        Set<Regressor> domain = outputInfo.getDomain();
086        LinkedHashMap<String, T> models = new LinkedHashMap<>();
087        int numExamples = examples.size();
088        boolean needBias = useBias();
089        float[] weights = new float[numExamples];
090        double[][] outputs = new double[outputInfo.size()][numExamples];
091        SparseVector[] inputs = new SparseVector[numExamples];
092        int i = 0;
093        for (Example<Regressor> e : examples) {
094            inputs[i] = SparseVector.createSparseVector(e,featureMap,needBias);
095            weights[i] = e.getWeight();
096            for (Regressor.DimensionTuple r : e.getOutput()) {
097                int id = outputInfo.getID(r);
098                outputs[id][i] = r.getValue();
099            }
100            i++;
101        }
102        for (Regressor r : domain) {
103            int id = outputInfo.getID(r);
104            T innerModel = trainDimension(outputs[id],inputs,weights,localRNG);
105            models.put(r.getNames()[0],innerModel);
106        }
107        ModelProvenance provenance = new ModelProvenance(getModelClassName(), OffsetDateTime.now(), examples.getProvenance(), trainerProvenance, runProvenance);
108        return createModel(models,provenance,featureMap,outputInfo);
109    }
110
111    @Override
112    public int getInvocationCount() {
113        return trainInvocationCounter;
114    }
115
116    /**
117     * Constructs the appropriate subclass of {@link SkeletalIndependentRegressionModel} for this trainer.
118     * @param models The models to use.
119     * @param provenance The model provenance
120     * @param featureMap The feature map.
121     * @param outputInfo The regression info.
122     * @return A subclass of IndependentRegressionModel.
123     */
124    protected abstract SkeletalIndependentRegressionSparseModel createModel(Map<String,T> models, ModelProvenance provenance, ImmutableFeatureMap featureMap, ImmutableOutputInfo<Regressor> outputInfo);
125
126    /**
127     * Trains a single dimension of the possibly multiple dimensions.
128     * @param outputs The regression targets for this dimension.
129     * @param features The features.
130     * @param weights The example weights.
131     * @param rng The RNG to use.
132     * @return An object representing the model.
133     */
134    protected abstract T trainDimension(double[] outputs, SparseVector[] features, float[] weights, SplittableRandom rng);
135
136    /**
137     * Returns true if the SparseVector should be constructed with a bias feature.
138     * @return True if the trainer needs a bias.
139     */
140    protected abstract boolean useBias();
141
142    /**
143     * Returns the class name of the model that this class produces.
144     * @return The class name of the model.
145     */
146    protected abstract String getModelClassName();
147}
148