001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.math; 018 019import com.oracle.labs.mlrg.olcut.util.Pair; 020import org.tribuo.math.la.DenseMatrix; 021import org.tribuo.math.la.DenseSparseMatrix; 022import org.tribuo.math.la.SGDVector; 023import org.tribuo.math.la.SparseVector; 024import org.tribuo.math.la.Tensor; 025import org.tribuo.math.util.HeapMerger; 026import org.tribuo.math.util.Merger; 027 028/** 029 * A Parameters for producing single label linear models. 030 */ 031public class LinearParameters implements Parameters { 032 033 private static final Merger merger = new HeapMerger(); 034 035 // Last row in this DenseMatrix is the bias, added by calling new SparseVector(example,featureInfo,true); 036 private Tensor[] weights; 037 private DenseMatrix weightMatrix; 038 039 /** 040 * Constructor. The number of features and the number of outputs must be fixed and known in advance. 041 * @param numFeatures The number of features in the training dataset (excluding the bias). 042 * @param numLabels The number of outputs in the training dataset. 043 */ 044 public LinearParameters(int numFeatures, int numLabels) { 045 weights = new Tensor[1]; 046 weightMatrix = new DenseMatrix(numLabels,numFeatures); 047 weights[0] = weightMatrix; 048 } 049 050 /** 051 * Generates an unnormalised prediction by leftMultiply'ing the weights with the incoming features. 052 * @param example A feature vector 053 * @return A {@link org.tribuo.math.la.DenseVector} containing a score for each label. 054 */ 055 public SGDVector predict(SparseVector example) { 056 return weightMatrix.leftMultiply(example); 057 } 058 059 /** 060 * Generate the gradients for a particular feature vector given 061 * the loss and the per output gradients. 062 * 063 * This parameters returns a single element {@link Tensor} array. 064 * @param score The Pair returned by the objective. 065 * @param features The feature vector. 066 * @return A {@link Tensor} array with a single {@link DenseSparseMatrix} containing all gradients. 067 */ 068 public Tensor[] gradients(Pair<Double, SGDVector> score, SparseVector features) { 069 Tensor[] output = new Tensor[1]; 070 output[0] = score.getB().outer(features); 071 return output; 072 } 073 074 /** 075 * This returns a {@link DenseMatrix} the same size as the Parameters. 076 * @return A {@link Tensor} array containing a single {@link DenseMatrix}. 077 */ 078 @Override 079 public Tensor[] getEmptyCopy() { 080 DenseMatrix matrix = new DenseMatrix(weightMatrix.getDimension1Size(),weightMatrix.getDimension2Size()); 081 Tensor[] output = new Tensor[1]; 082 output[0] = matrix; 083 return output; 084 } 085 086 @Override 087 public Tensor[] get() { 088 return weights; 089 } 090 091 /** 092 * Returns the weight matrix. 093 * @return The weight matrix. 094 */ 095 public DenseMatrix getWeightMatrix() { 096 return weightMatrix; 097 } 098 099 @Override 100 public void set(Tensor[] newWeights) { 101 if (newWeights.length == weights.length) { 102 weights = newWeights; 103 weightMatrix = (DenseMatrix) weights[0]; 104 } 105 } 106 107 @Override 108 public void update(Tensor[] gradients) { 109 for (int i = 0; i < gradients.length; i++) { 110 weights[i].intersectAndAddInPlace(gradients[i]); 111 } 112 } 113 114 @Override 115 public Tensor[] merge(Tensor[][] gradients, int size) { 116 DenseSparseMatrix[] updates = new DenseSparseMatrix[size]; 117 for (int j = 0; j < updates.length; j++) { 118 updates[j] = (DenseSparseMatrix) gradients[j][0]; 119 } 120 121 DenseSparseMatrix update = merger.merge(updates); 122 123 return new Tensor[]{update}; 124 } 125}