001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.classification.sgd.objectives; 018 019import com.oracle.labs.mlrg.olcut.config.Config; 020import com.oracle.labs.mlrg.olcut.provenance.ConfiguredObjectProvenance; 021import com.oracle.labs.mlrg.olcut.provenance.impl.ConfiguredObjectProvenanceImpl; 022import com.oracle.labs.mlrg.olcut.util.Pair; 023import org.tribuo.classification.sgd.LabelObjective; 024import org.tribuo.math.la.SGDVector; 025import org.tribuo.math.la.SparseVector; 026import org.tribuo.math.util.NoopNormalizer; 027import org.tribuo.math.util.VectorNormalizer; 028 029/** 030 * Hinge loss, scores the correct value margin and any incorrect predictions -margin. 031 * By default the margin is 1.0. 032 * <p> 033 * The Hinge loss does not generate a probabilistic model, and uses a {@link NoopNormalizer}. 034 */ 035public class Hinge implements LabelObjective { 036 037 @Config(description="The classification margin.") 038 private double margin = 1.0; 039 040 /** 041 * Construct a hinge objective with the supplied margin. 042 * @param margin The margin to use. 043 */ 044 public Hinge(double margin) { 045 this.margin = margin; 046 } 047 048 /** 049 * Construct a hinge objective with a margin of 1.0. 050 */ 051 public Hinge() { 052 this(1.0); 053 } 054 055 /** 056 * Returns a {@link Pair} of {@link Double} and {@link SparseVector}. 057 * @param truth The true label id. 058 * @param prediction The prediction for each label id. 059 * @return The loss and per label gradient. 060 */ 061 @Override 062 public Pair<Double,SGDVector> valueAndGradient(int truth, SGDVector prediction) { 063 prediction.add(truth,-margin); 064 int predIndex = prediction.indexOfMax(); 065 066 if (truth == predIndex) { 067 return new Pair<>(0.0, SparseVector.createSparseVector(prediction.size(),new int[0], new double[0])); 068 } else { 069 int[] indices = new int[2]; 070 double[] values = new double[2]; 071 if (truth < predIndex) { 072 indices[0] = truth; 073 values[0] = margin; 074 indices[1] = predIndex; 075 values[1] = -margin; 076 } else { 077 indices[0] = predIndex; 078 values[0] = -margin; 079 indices[1] = truth; 080 values[1] = margin; 081 } 082 SparseVector output = SparseVector.createSparseVector(prediction.size(),indices,values); 083 double loss = prediction.get(truth) - prediction.get(predIndex); 084 return new Pair<>(loss,output); 085 } 086 } 087 088 /** 089 * Returns a new {@link NoopNormalizer}. 090 * @return The vector normalizer. 091 */ 092 @Override 093 public VectorNormalizer getNormalizer() { 094 return new NoopNormalizer(); 095 } 096 097 /** 098 * Returns false. 099 * @return False. 100 */ 101 @Override 102 public boolean isProbabilistic() { 103 return false; 104 } 105 106 @Override 107 public String toString() { 108 return "Hinge(margin="+margin+")"; 109 } 110 111 @Override 112 public ConfiguredObjectProvenance getProvenance() { 113 return new ConfiguredObjectProvenanceImpl(this,"LabelObjective"); 114 } 115}