001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.regression.impl;
018
019import org.tribuo.Example;
020import org.tribuo.ImmutableFeatureMap;
021import org.tribuo.ImmutableOutputInfo;
022import org.tribuo.Model;
023import org.tribuo.Prediction;
024import org.tribuo.SparseModel;
025import org.tribuo.math.la.SparseVector;
026import org.tribuo.provenance.ModelProvenance;
027import org.tribuo.regression.Regressor;
028
029import java.util.Arrays;
030import java.util.List;
031import java.util.Map;
032
033/**
034 * A {@link SparseModel} which wraps n independent regression models, where n is the
035 * size of the MultipleRegressor domain. Each model independently predicts
036 * a single regression dimension.
037 */
038public abstract class SkeletalIndependentRegressionSparseModel extends SparseModel<Regressor> {
039    private static final long serialVersionUID = 1L;
040
041    protected final String[] dimensions;
042
043    /**
044     * models.size() must equal labelInfo.getDomain().size()
045     * @param name Model name.
046     * @param dimensions Dimension names.
047     * @param modelProvenance The model provenance.
048     * @param featureMap The feature domain used in training.
049     * @param outputInfo The output domain used in training.
050     * @param activeFeatures The active features in this model.
051     */
052    protected SkeletalIndependentRegressionSparseModel(String name, String[] dimensions, ModelProvenance modelProvenance, ImmutableFeatureMap featureMap, ImmutableOutputInfo<Regressor> outputInfo, Map<String,List<String>> activeFeatures) {
053        super(name,modelProvenance,featureMap,outputInfo,false,activeFeatures);
054        this.dimensions = Arrays.copyOf(dimensions,dimensions.length);
055    }
056
057    @Override
058    public Prediction<Regressor> predict(Example<Regressor> example) {
059        SparseVector features = createFeatures(example);
060
061        Regressor.DimensionTuple[] outputs = new Regressor.DimensionTuple[dimensions.length];
062
063        for (int i = 0; i < dimensions.length; i++) {
064            outputs[i] = scoreDimension(i,features);
065        }
066
067        return new Prediction<>(new Regressor(outputs),features.numActiveElements(),example);
068    }
069
070    /**
071     * Creates the feature vector. Does not include a bias term.
072     * <p>
073     * Designed to be overridden, called by the predict method.
074     * @param example The example to convert.
075     * @return The feature vector.
076     */
077    protected SparseVector createFeatures(Example<Regressor> example) {
078        return SparseVector.createSparseVector(example,featureIDMap,false);
079    }
080
081    /**
082     * Makes a prediction for a single dimension.
083     * @param dimensionIdx The dimension index to predict.
084     * @param features The features to use.
085     * @return A single dimension prediction.
086     */
087    protected abstract Regressor.DimensionTuple scoreDimension(int dimensionIdx, SparseVector features);
088}