001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.regression.impl; 018 019import org.tribuo.Example; 020import org.tribuo.ImmutableFeatureMap; 021import org.tribuo.ImmutableOutputInfo; 022import org.tribuo.Model; 023import org.tribuo.Prediction; 024import org.tribuo.SparseModel; 025import org.tribuo.math.la.SparseVector; 026import org.tribuo.provenance.ModelProvenance; 027import org.tribuo.regression.Regressor; 028 029import java.util.Arrays; 030import java.util.List; 031import java.util.Map; 032 033/** 034 * A {@link SparseModel} which wraps n independent regression models, where n is the 035 * size of the MultipleRegressor domain. Each model independently predicts 036 * a single regression dimension. 037 */ 038public abstract class SkeletalIndependentRegressionSparseModel extends SparseModel<Regressor> { 039 private static final long serialVersionUID = 1L; 040 041 protected final String[] dimensions; 042 043 /** 044 * models.size() must equal labelInfo.getDomain().size() 045 * @param name Model name. 046 * @param dimensions Dimension names. 047 * @param modelProvenance The model provenance. 048 * @param featureMap The feature domain used in training. 049 * @param outputInfo The output domain used in training. 050 * @param activeFeatures The active features in this model. 051 */ 052 protected SkeletalIndependentRegressionSparseModel(String name, String[] dimensions, ModelProvenance modelProvenance, ImmutableFeatureMap featureMap, ImmutableOutputInfo<Regressor> outputInfo, Map<String,List<String>> activeFeatures) { 053 super(name,modelProvenance,featureMap,outputInfo,false,activeFeatures); 054 this.dimensions = Arrays.copyOf(dimensions,dimensions.length); 055 } 056 057 @Override 058 public Prediction<Regressor> predict(Example<Regressor> example) { 059 SparseVector features = createFeatures(example); 060 061 Regressor.DimensionTuple[] outputs = new Regressor.DimensionTuple[dimensions.length]; 062 063 for (int i = 0; i < dimensions.length; i++) { 064 outputs[i] = scoreDimension(i,features); 065 } 066 067 return new Prediction<>(new Regressor(outputs),features.numActiveElements(),example); 068 } 069 070 /** 071 * Creates the feature vector. Does not include a bias term. 072 * <p> 073 * Designed to be overridden, called by the predict method. 074 * @param example The example to convert. 075 * @return The feature vector. 076 */ 077 protected SparseVector createFeatures(Example<Regressor> example) { 078 return SparseVector.createSparseVector(example,featureIDMap,false); 079 } 080 081 /** 082 * Makes a prediction for a single dimension. 083 * @param dimensionIdx The dimension index to predict. 084 * @param features The features to use. 085 * @return A single dimension prediction. 086 */ 087 protected abstract Regressor.DimensionTuple scoreDimension(int dimensionIdx, SparseVector features); 088}