001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.math;
018
019import org.tribuo.math.la.Tensor;
020
021/**
022 * An interface to a {@link Tensor}[] array which accepts updates to the parameters.
023 * <p>
024 * Parameters is essentially an SGD model at training time.
025 * <p>
026 * Subclasses of this should add methods for calculating gradients for
027 * their prediction task.
028 */
029public interface Parameters {
030
031    /**
032     * Generates an empty copy of the underlying {@link Tensor} array.
033     * <p>
034     * It's the same size and shape as the parameters, but all the values are 0.0.
035     * @return A copy of the parameters where all values are 0.0.
036     */
037    public Tensor[] getEmptyCopy();
038
039    /**
040     * Get a reference to the underlying {@link Tensor} array.
041     * @return The parameters.
042     */
043    public Tensor[] get();
044
045    /**
046     * Set the underlying {@link Tensor} array to newWeights.
047     * @param newWeights New parameters to store in this object.
048     */
049    public void set(Tensor[] newWeights);
050
051    /**
052     * Apply gradients to the parameters. Assumes that gradients is the same length as the parameters,
053     * and each {@link Tensor} is the same size as the corresponding one from the parameters.
054     * <p>
055     * The gradients are added to the parameters.
056     * @param gradients A {@link Tensor} array of updates, with the length equal to {@link Parameters#get()}.length.
057     */
058    public void update(Tensor[] gradients);
059
060    /**
061     * Merge together an array of gradient arrays. Assumes the first dimension
062     * is the number of gradient arrays and the second dimension is the
063     * number of parameter {@link Tensor}s.
064     * @param gradients An array of gradient update arrays.
065     * @param size The number of elements of gradients to merge. Allows gradients to have unused elements.
066     * @return A single {@link Tensor} array of the summed gradients.
067     */
068    public Tensor[] merge(Tensor[][] gradients, int size);
069
070}