001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.classification.sgd; 018 019import com.oracle.labs.mlrg.olcut.config.Configurable; 020import com.oracle.labs.mlrg.olcut.provenance.ConfiguredObjectProvenance; 021import com.oracle.labs.mlrg.olcut.provenance.Provenancable; 022import com.oracle.labs.mlrg.olcut.util.Pair; 023import org.tribuo.math.la.SGDVector; 024import org.tribuo.math.util.VectorNormalizer; 025 026/** 027 * An interface for single label prediction objectives. 028 * <p> 029 * An objective knows if it generates a probabilistic model or not, 030 * and what kind of normalization needs to be applied to produce probability values. 031 */ 032public interface LabelObjective extends Configurable, Provenancable<ConfiguredObjectProvenance> { 033 034 /** 035 * Scores a prediction, returning the loss and a vector of per label gradients. 036 * @param truth The true label id. 037 * @param prediction The prediction for each label id. 038 * @return The score and per label gradient. 039 */ 040 public Pair<Double,SGDVector> valueAndGradient(int truth, SGDVector prediction); 041 042 /** 043 * Generates a new {@link VectorNormalizer} which normalizes the predictions into [0,1]. 044 * @return The vector normalizer for this objective. 045 */ 046 public VectorNormalizer getNormalizer(); 047 048 /** 049 * Does the objective function score probabilities or not? 050 * @return boolean. 051 */ 052 public boolean isProbabilistic(); 053 054}