001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.regression.impl; 018 019import com.oracle.labs.mlrg.olcut.config.Config; 020import com.oracle.labs.mlrg.olcut.provenance.Provenance; 021import org.tribuo.Dataset; 022import org.tribuo.Example; 023import org.tribuo.ImmutableFeatureMap; 024import org.tribuo.ImmutableOutputInfo; 025import org.tribuo.Model; 026import org.tribuo.Trainer; 027import org.tribuo.math.la.SparseVector; 028import org.tribuo.provenance.ModelProvenance; 029import org.tribuo.provenance.TrainerProvenance; 030import org.tribuo.regression.Regressor; 031 032import java.time.OffsetDateTime; 033import java.util.Collections; 034import java.util.LinkedHashMap; 035import java.util.Map; 036import java.util.Set; 037import java.util.SplittableRandom; 038 039/** 040 * Trains n independent binary {@link Model}s, each of which predicts a single {@link Regressor}. 041 * Generates the SparseVectors once to reduce allocation. 042 * <p> 043 * Then wraps it up in an {@link SkeletalIndependentRegressionModel} to provide a {@link Regressor} 044 * prediction. 045 * <p> 046 * It trains each model sequentially, and could be optimised to train in parallel. 047 */ 048public abstract class SkeletalIndependentRegressionTrainer<T> implements Trainer<Regressor> { 049 050 @Config(description="Seed for the RNG, may be unused.") 051 private long seed = 1L; 052 053 private SplittableRandom rng; 054 055 private int trainInvocationCounter = 0; 056 057 /** 058 * for olcut. 059 */ 060 protected SkeletalIndependentRegressionTrainer() {} 061 062 @Override 063 public synchronized void postConfig() { 064 this.rng = new SplittableRandom(seed); 065 } 066 067 @Override 068 public SkeletalIndependentRegressionModel train(Dataset<Regressor> examples) { 069 return train(examples, Collections.emptyMap()); 070 } 071 072 @Override 073 public SkeletalIndependentRegressionModel train(Dataset<Regressor> examples, Map<String, Provenance> runProvenance) { 074 if (examples.getOutputInfo().getUnknownCount() > 0) { 075 throw new IllegalArgumentException("The supplied Dataset contained unknown Outputs, and this Trainer is supervised."); 076 } 077 SplittableRandom localRNG; 078 TrainerProvenance trainerProvenance; 079 synchronized(this) { 080 localRNG = rng.split(); 081 trainerProvenance = getProvenance(); 082 trainInvocationCounter++; 083 } 084 ImmutableOutputInfo<Regressor> outputInfo = examples.getOutputIDInfo(); 085 ImmutableFeatureMap featureMap = examples.getFeatureIDMap(); 086 Set<Regressor> domain = outputInfo.getDomain(); 087 LinkedHashMap<String, T> models = new LinkedHashMap<>(); 088 int numExamples = examples.size(); 089 boolean needBias = useBias(); 090 float[] weights = new float[numExamples]; 091 double[][] outputs = new double[outputInfo.size()][numExamples]; 092 SparseVector[] inputs = new SparseVector[numExamples]; 093 int i = 0; 094 for (Example<Regressor> e : examples) { 095 inputs[i] = SparseVector.createSparseVector(e,featureMap,needBias); 096 weights[i] = e.getWeight(); 097 for (Regressor.DimensionTuple r : e.getOutput()) { 098 int id = outputInfo.getID(r); 099 outputs[id][i] = r.getValue(); 100 } 101 i++; 102 } 103 for (Regressor r : domain) { 104 int id = outputInfo.getID(r); 105 T innerModel = trainDimension(outputs[id],inputs,weights,localRNG); 106 models.put(r.getNames()[0],innerModel); 107 } 108 ModelProvenance provenance = new ModelProvenance(getModelClassName(), OffsetDateTime.now(), examples.getProvenance(), trainerProvenance, runProvenance); 109 return createModel(models,provenance,featureMap,outputInfo); 110 } 111 112 @Override 113 public int getInvocationCount() { 114 return trainInvocationCounter; 115 } 116 117 /** 118 * Constructs the appropriate subclass of {@link SkeletalIndependentRegressionModel} for this trainer. 119 * @param models The models to use. 120 * @param provenance The model provenance 121 * @param featureMap The feature map. 122 * @param outputInfo The regression info. 123 * @return A subclass of IndependentRegressionModel. 124 */ 125 protected abstract SkeletalIndependentRegressionModel createModel(Map<String,T> models, ModelProvenance provenance, ImmutableFeatureMap featureMap, ImmutableOutputInfo<Regressor> outputInfo); 126 127 /** 128 * Trains a single dimension of the possibly multiple dimensions. 129 * @param outputs The regression targets for this dimension. 130 * @param features The features. 131 * @param weights The example weights. 132 * @param rng The RNG to use. 133 * @return An object representing the model. Should be the same type as that expected by {@link #createModel}. 134 */ 135 protected abstract T trainDimension(double[] outputs, SparseVector[] features, float[] weights, SplittableRandom rng); 136 137 /** 138 * Returns true if the SparseVector should be constructed with a bias feature. 139 * @return True if the trainer needs a bias. 140 */ 141 protected abstract boolean useBias(); 142 143 /** 144 * Returns the class name of the model that this class produces. 145 * @return The class name of the model. 146 */ 147 protected abstract String getModelClassName(); 148} 149