001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.classification.sgd.objectives;
018
019import com.oracle.labs.mlrg.olcut.provenance.ConfiguredObjectProvenance;
020import com.oracle.labs.mlrg.olcut.provenance.impl.ConfiguredObjectProvenanceImpl;
021import com.oracle.labs.mlrg.olcut.util.Pair;
022import org.tribuo.classification.sgd.LabelObjective;
023import org.tribuo.math.la.SGDVector;
024import org.tribuo.math.util.ExpNormalizer;
025import org.tribuo.math.util.VectorNormalizer;
026
027/**
028 * A multiclass version of the log loss.
029 * <p>
030 * Generates a probabilistic model, and uses an {@link ExpNormalizer}.
031 */
032public class LogMulticlass implements LabelObjective {
033
034    private final VectorNormalizer normalizer = new ExpNormalizer();
035
036    /**
037     * Returns a {@link Pair} of {@link Double} and the supplied prediction vector.
038     * <p>
039     * The prediction vector is transformed to produce the per label gradient.
040     * @param truth The true label id
041     * @param prediction The prediction for each label id
042     * @return A Pair of the score and per label gradient.
043     */
044    @Override
045    public Pair<Double,SGDVector> valueAndGradient(int truth, SGDVector prediction) {
046        prediction.normalize(normalizer);
047        double loss = Math.log(prediction.get(truth));
048        prediction.scaleInPlace(-1.0);
049        prediction.add(truth,1.0);
050        return new Pair<>(loss,prediction);
051    }
052
053    @Override
054    public VectorNormalizer getNormalizer() {
055        return new ExpNormalizer();
056    }
057
058    /**
059     * Returns true.
060     * @return True.
061     */
062    @Override
063    public boolean isProbabilistic() {
064        return true;
065    }
066
067    @Override
068    public String toString() {
069        return "LogMulticlass";
070    }
071
072    @Override
073    public ConfiguredObjectProvenance getProvenance() {
074        return new ConfiguredObjectProvenanceImpl(this,"LabelObjective");
075    }
076}