001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.classification.sgd.objectives;
018
019import com.oracle.labs.mlrg.olcut.config.Config;
020import com.oracle.labs.mlrg.olcut.provenance.ConfiguredObjectProvenance;
021import com.oracle.labs.mlrg.olcut.provenance.impl.ConfiguredObjectProvenanceImpl;
022import com.oracle.labs.mlrg.olcut.util.Pair;
023import org.tribuo.classification.sgd.LabelObjective;
024import org.tribuo.math.la.SGDVector;
025import org.tribuo.math.la.SparseVector;
026import org.tribuo.math.util.NoopNormalizer;
027import org.tribuo.math.util.VectorNormalizer;
028
029/**
030 * Hinge loss, scores the correct value margin and any incorrect predictions -margin.
031 * By default the margin is 1.0.
032 * <p>
033 * The Hinge loss does not generate a probabilistic model, and uses a {@link NoopNormalizer}.
034 */
035public class Hinge implements LabelObjective {
036
037    @Config(description="The classification margin.")
038    private double margin = 1.0;
039
040    /**
041     * Construct a hinge objective with the supplied margin.
042     * @param margin The margin to use.
043     */
044    public Hinge(double margin) {
045        this.margin = margin;
046    }
047
048    /**
049     * Construct a hinge objective with a margin of 1.0.
050     */
051    public Hinge() {
052        this(1.0);
053    }
054
055    /**
056     * Returns a {@link Pair} of {@link Double} and {@link SparseVector}.
057     * @param truth The true label id.
058     * @param prediction The prediction for each label id.
059     * @return The loss and per label gradient.
060     */
061    @Override
062    public Pair<Double,SGDVector> valueAndGradient(int truth, SGDVector prediction) {
063        prediction.add(truth,-margin);
064        int predIndex = prediction.indexOfMax();
065
066        if (truth == predIndex) {
067            return new Pair<>(0.0, SparseVector.createSparseVector(prediction.size(),new int[0], new double[0]));
068        } else {
069            int[] indices = new int[2];
070            double[] values = new double[2];
071            if (truth < predIndex) {
072                indices[0] = truth;
073                values[0] = margin;
074                indices[1] = predIndex;
075                values[1] = -margin;
076            } else {
077                indices[0] = predIndex;
078                values[0] = -margin;
079                indices[1] = truth;
080                values[1] = margin;
081            }
082            SparseVector output = SparseVector.createSparseVector(prediction.size(),indices,values);
083            double loss = prediction.get(truth) - prediction.get(predIndex);
084            return new Pair<>(loss,output);
085        }
086    }
087
088    /**
089     * Returns a new {@link NoopNormalizer}.
090     * @return The vector normalizer.
091     */
092    @Override
093    public VectorNormalizer getNormalizer() {
094        return new NoopNormalizer();
095    }
096
097    /**
098     * Returns false.
099     * @return False.
100     */
101    @Override
102    public boolean isProbabilistic() {
103        return false;
104    }
105
106    @Override
107    public String toString() {
108        return "Hinge(margin="+margin+")";
109    }
110
111    @Override
112    public ConfiguredObjectProvenance getProvenance() {
113        return new ConfiguredObjectProvenanceImpl(this,"LabelObjective");
114    }
115}