001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.regression.impl; 018 019import com.oracle.labs.mlrg.olcut.config.Config; 020import com.oracle.labs.mlrg.olcut.provenance.Provenance; 021import org.tribuo.Dataset; 022import org.tribuo.Example; 023import org.tribuo.ImmutableFeatureMap; 024import org.tribuo.ImmutableOutputInfo; 025import org.tribuo.SparseTrainer; 026import org.tribuo.math.la.SparseVector; 027import org.tribuo.provenance.ModelProvenance; 028import org.tribuo.provenance.TrainerProvenance; 029import org.tribuo.regression.Regressor; 030 031import java.time.OffsetDateTime; 032import java.util.Collections; 033import java.util.LinkedHashMap; 034import java.util.Map; 035import java.util.Set; 036import java.util.SplittableRandom; 037 038/** 039 * Base class for training n independent sparse models, one per dimension. Generates the SparseVectors 040 * once to reduce allocation. 041 * <p> 042 * Then wraps them in an {@link SkeletalIndependentRegressionSparseModel} to provide a {@link Regressor} 043 * prediction. 044 * <p> 045 * It trains each model sequentially, and could be optimised to train in parallel. 046 */ 047public abstract class SkeletalIndependentRegressionSparseTrainer<T> implements SparseTrainer<Regressor> { 048 049 @Config(description="Seed for the RNG, may be unused.") 050 private long seed = 1L; 051 052 private SplittableRandom rng; 053 054 private int trainInvocationCounter = 0; 055 056 /** 057 * for olcut. 058 */ 059 protected SkeletalIndependentRegressionSparseTrainer() {} 060 061 @Override 062 public synchronized void postConfig() { 063 this.rng = new SplittableRandom(seed); 064 } 065 066 @Override 067 public SkeletalIndependentRegressionSparseModel train(Dataset<Regressor> examples) { 068 return train(examples, Collections.emptyMap()); 069 } 070 071 @Override 072 public SkeletalIndependentRegressionSparseModel train(Dataset<Regressor> examples, Map<String, Provenance> runProvenance) { 073 if (examples.getOutputInfo().getUnknownCount() > 0) { 074 throw new IllegalArgumentException("The supplied Dataset contained unknown Outputs, and this Trainer is supervised."); 075 } 076 SplittableRandom localRNG; 077 TrainerProvenance trainerProvenance; 078 synchronized(this) { 079 localRNG = rng.split(); 080 trainerProvenance = getProvenance(); 081 trainInvocationCounter++; 082 } 083 ImmutableOutputInfo<Regressor> outputInfo = examples.getOutputIDInfo(); 084 ImmutableFeatureMap featureMap = examples.getFeatureIDMap(); 085 Set<Regressor> domain = outputInfo.getDomain(); 086 LinkedHashMap<String, T> models = new LinkedHashMap<>(); 087 int numExamples = examples.size(); 088 boolean needBias = useBias(); 089 float[] weights = new float[numExamples]; 090 double[][] outputs = new double[outputInfo.size()][numExamples]; 091 SparseVector[] inputs = new SparseVector[numExamples]; 092 int i = 0; 093 for (Example<Regressor> e : examples) { 094 inputs[i] = SparseVector.createSparseVector(e,featureMap,needBias); 095 weights[i] = e.getWeight(); 096 for (Regressor.DimensionTuple r : e.getOutput()) { 097 int id = outputInfo.getID(r); 098 outputs[id][i] = r.getValue(); 099 } 100 i++; 101 } 102 for (Regressor r : domain) { 103 int id = outputInfo.getID(r); 104 T innerModel = trainDimension(outputs[id],inputs,weights,localRNG); 105 models.put(r.getNames()[0],innerModel); 106 } 107 ModelProvenance provenance = new ModelProvenance(getModelClassName(), OffsetDateTime.now(), examples.getProvenance(), trainerProvenance, runProvenance); 108 return createModel(models,provenance,featureMap,outputInfo); 109 } 110 111 @Override 112 public int getInvocationCount() { 113 return trainInvocationCounter; 114 } 115 116 /** 117 * Constructs the appropriate subclass of {@link SkeletalIndependentRegressionModel} for this trainer. 118 * @param models The models to use. 119 * @param provenance The model provenance 120 * @param featureMap The feature map. 121 * @param outputInfo The regression info. 122 * @return A subclass of IndependentRegressionModel. 123 */ 124 protected abstract SkeletalIndependentRegressionSparseModel createModel(Map<String,T> models, ModelProvenance provenance, ImmutableFeatureMap featureMap, ImmutableOutputInfo<Regressor> outputInfo); 125 126 /** 127 * Trains a single dimension of the possibly multiple dimensions. 128 * @param outputs The regression targets for this dimension. 129 * @param features The features. 130 * @param weights The example weights. 131 * @param rng The RNG to use. 132 * @return An object representing the model. 133 */ 134 protected abstract T trainDimension(double[] outputs, SparseVector[] features, float[] weights, SplittableRandom rng); 135 136 /** 137 * Returns true if the SparseVector should be constructed with a bias feature. 138 * @return True if the trainer needs a bias. 139 */ 140 protected abstract boolean useBias(); 141 142 /** 143 * Returns the class name of the model that this class produces. 144 * @return The class name of the model. 145 */ 146 protected abstract String getModelClassName(); 147} 148