001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.math;
018
019import com.oracle.labs.mlrg.olcut.util.Pair;
020import org.tribuo.math.la.DenseMatrix;
021import org.tribuo.math.la.DenseSparseMatrix;
022import org.tribuo.math.la.SGDVector;
023import org.tribuo.math.la.SparseVector;
024import org.tribuo.math.la.Tensor;
025import org.tribuo.math.util.HeapMerger;
026import org.tribuo.math.util.Merger;
027
028/**
029 * A Parameters for producing single label linear models.
030 */
031public class LinearParameters implements Parameters {
032
033    private static final Merger merger = new HeapMerger();
034
035    // Last row in this DenseMatrix is the bias, added by calling new SparseVector(example,featureInfo,true);
036    private Tensor[] weights;
037    private DenseMatrix weightMatrix;
038
039    /**
040     * Constructor. The number of features and the number of outputs must be fixed and known in advance.
041     * @param numFeatures The number of features in the training dataset (excluding the bias).
042     * @param numLabels The number of outputs in the training dataset.
043     */
044    public LinearParameters(int numFeatures, int numLabels) {
045        weights = new Tensor[1];
046        weightMatrix = new DenseMatrix(numLabels,numFeatures);
047        weights[0] = weightMatrix;
048    }
049
050    /**
051     * Generates an unnormalised prediction by leftMultiply'ing the weights with the incoming features.
052     * @param example A feature vector
053     * @return A {@link org.tribuo.math.la.DenseVector} containing a score for each label.
054     */
055    public SGDVector predict(SparseVector example) {
056        return weightMatrix.leftMultiply(example);
057    }
058
059    /**
060     * Generate the gradients for a particular feature vector given
061     * the loss and the per output gradients.
062     *
063     * This parameters returns a single element {@link Tensor} array.
064     * @param score The Pair returned by the objective.
065     * @param features The feature vector.
066     * @return A {@link Tensor} array with a single {@link DenseSparseMatrix} containing all gradients.
067     */
068    public Tensor[] gradients(Pair<Double, SGDVector> score, SparseVector features) {
069        Tensor[] output = new Tensor[1];
070        output[0] = score.getB().outer(features);
071        return output;
072    }
073
074    /**
075     * This returns a {@link DenseMatrix} the same size as the Parameters.
076     * @return A {@link Tensor} array containing a single {@link DenseMatrix}.
077     */
078    @Override
079    public Tensor[] getEmptyCopy() {
080        DenseMatrix matrix = new DenseMatrix(weightMatrix.getDimension1Size(),weightMatrix.getDimension2Size());
081        Tensor[] output = new Tensor[1];
082        output[0] = matrix;
083        return output;
084    }
085
086    @Override
087    public Tensor[] get() {
088        return weights;
089    }
090
091    /**
092     * Returns the weight matrix.
093     * @return The weight matrix.
094     */
095    public DenseMatrix getWeightMatrix() {
096        return weightMatrix;
097    }
098
099    @Override
100    public void set(Tensor[] newWeights) {
101        if (newWeights.length == weights.length) {
102            weights = newWeights;
103            weightMatrix = (DenseMatrix) weights[0];
104        }
105    }
106
107    @Override
108    public void update(Tensor[] gradients) {
109        for (int i = 0; i < gradients.length; i++) {
110            weights[i].intersectAndAddInPlace(gradients[i]);
111        }
112    }
113
114    @Override
115    public Tensor[] merge(Tensor[][] gradients, int size) {
116        DenseSparseMatrix[] updates = new DenseSparseMatrix[size];
117        for (int j = 0; j < updates.length; j++) {
118            updates[j] = (DenseSparseMatrix) gradients[j][0];
119        }
120
121        DenseSparseMatrix update = merger.merge(updates);
122
123        return new Tensor[]{update};
124    }
125}