001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.math; 018 019import org.tribuo.math.la.Tensor; 020 021/** 022 * An interface to a {@link Tensor}[] array which accepts updates to the parameters. 023 * <p> 024 * Parameters is essentially an SGD model at training time. 025 * <p> 026 * Subclasses of this should add methods for calculating gradients for 027 * their prediction task. 028 */ 029public interface Parameters { 030 031 /** 032 * Generates an empty copy of the underlying {@link Tensor} array. 033 * <p> 034 * It's the same size and shape as the parameters, but all the values are 0.0. 035 * @return A copy of the parameters where all values are 0.0. 036 */ 037 public Tensor[] getEmptyCopy(); 038 039 /** 040 * Get a reference to the underlying {@link Tensor} array. 041 * @return The parameters. 042 */ 043 public Tensor[] get(); 044 045 /** 046 * Set the underlying {@link Tensor} array to newWeights. 047 * @param newWeights New parameters to store in this object. 048 */ 049 public void set(Tensor[] newWeights); 050 051 /** 052 * Apply gradients to the parameters. Assumes that gradients is the same length as the parameters, 053 * and each {@link Tensor} is the same size as the corresponding one from the parameters. 054 * <p> 055 * The gradients are added to the parameters. 056 * @param gradients A {@link Tensor} array of updates, with the length equal to {@link Parameters#get()}.length. 057 */ 058 public void update(Tensor[] gradients); 059 060 /** 061 * Merge together an array of gradient arrays. Assumes the first dimension 062 * is the number of gradient arrays and the second dimension is the 063 * number of parameter {@link Tensor}s. 064 * @param gradients An array of gradient update arrays. 065 * @param size The number of elements of gradients to merge. Allows gradients to have unused elements. 066 * @return A single {@link Tensor} array of the summed gradients. 067 */ 068 public Tensor[] merge(Tensor[][] gradients, int size); 069 070}