001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.regression.sgd.objectives; 018 019import com.oracle.labs.mlrg.olcut.config.Config; 020import com.oracle.labs.mlrg.olcut.provenance.ConfiguredObjectProvenance; 021import com.oracle.labs.mlrg.olcut.provenance.impl.ConfiguredObjectProvenanceImpl; 022import com.oracle.labs.mlrg.olcut.util.Pair; 023import org.tribuo.math.la.DenseVector; 024import org.tribuo.math.la.SGDVector; 025import org.tribuo.regression.sgd.RegressionObjective; 026 027import java.util.function.DoubleUnaryOperator; 028 029/** 030 * Huber loss, i.e., a mixture of l2 and l1 losses. 031 */ 032public class Huber implements RegressionObjective { 033 034 @Config(description="Cost beyond which the loss function is linear.") 035 private double cost = 5; 036 037 private DoubleUnaryOperator lossFunc; 038 039 public Huber() { 040 postConfig(); 041 } 042 043 public Huber(double cost) { 044 this.cost = cost; 045 postConfig(); 046 } 047 048 /** 049 * Used by the OLCUT configuration system, and should not be called by external code. 050 */ 051 @Override 052 public void postConfig() { 053 lossFunc = (a) -> { 054 if (a > cost) { 055 return (cost * a) - (0.5 * cost * cost); 056 } else { 057 return 0.5 * a * a; 058 } 059 }; 060 } 061 062 @Override 063 public Pair<Double, SGDVector> loss(DenseVector truth, SGDVector prediction) { 064 DenseVector difference = truth.subtract(prediction); 065 DenseVector absoluteDifference = difference.copy(); 066 absoluteDifference.foreachInPlace(Math::abs); 067 068 double loss = absoluteDifference.reduce(0.0,lossFunc,Double::sum); 069 difference.foreachInPlace((a) -> {if (Math.abs(a) > cost) { return Double.compare(a,0.0)*cost; } else { return a; }}); 070 return new Pair<>(loss,difference); 071 } 072 073 @Override 074 public String toString() { 075 return "Huber(cost="+cost+")"; 076 } 077 078 @Override 079 public ConfiguredObjectProvenance getProvenance() { 080 return new ConfiguredObjectProvenanceImpl(this,"RegressionObjective"); 081 } 082}