001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.regression.rtree.impurity; 018 019import com.oracle.labs.mlrg.olcut.config.Configurable; 020import com.oracle.labs.mlrg.olcut.provenance.ConfiguredObjectProvenance; 021import com.oracle.labs.mlrg.olcut.provenance.Provenancable; 022import org.tribuo.common.tree.impl.IntArrayContainer; 023 024import java.util.List; 025 026/** 027 * Calculates a tree impurity score based on the regression targets. 028 */ 029public interface RegressorImpurity extends Configurable, Provenancable<ConfiguredObjectProvenance> { 030 031 /** 032 * Calculates the impurity based on the supplied weights and targets. 033 * @param targets The targets. 034 * @param weights The weights. 035 * @return The impurity. 036 */ 037 public double impurity(float[] targets, float[] weights); 038 039 /** 040 * Calculates the weighted impurity of the targets specified in the indices array. 041 * @param indices The indices in the targets and weights arrays. 042 * @param indicesLength The number of values to use in indices. 043 * @param targets The regression targets. 044 * @param weights The example weights. 045 * @return A tuple containing the impurity and the used weight sum. 046 */ 047 public ImpurityTuple impurityTuple(int[] indices, int indicesLength, float[] targets, float[] weights); 048 049 /** 050 * Calculates the weighted impurity of the targets specified in all the indices arrays. 051 * @param indices The indices in the targets and weights arrays. 052 * @param targets The regression targets. 053 * @param weights The example weights. 054 * @return A tuple containing the impurity and the used weight sum. 055 */ 056 public ImpurityTuple impurityTuple(List<int[]> indices, float[] targets, float[] weights); 057 058 /** 059 * Calculates the weighted impurity of the targets specified in the indices array. 060 * @param indices The indices in the targets and weights arrays. 061 * @param indicesLength The number of values to use in indices. 062 * @param targets The regression targets. 063 * @param weights The example weights. 064 * @return The weighted impurity. 065 */ 066 default public double impurity(int[] indices, int indicesLength, float[] targets, float[] weights) { 067 return impurityTuple(indices, indicesLength, targets, weights).impurity; 068 } 069 070 /** 071 * Calculates the weighted impurity of the targets specified in all the indices arrays. 072 * @param indices The indices in the targets and weights arrays. 073 * @param targets The regression targets. 074 * @param weights The example weights. 075 * @return The weighted impurity. 076 */ 077 default public double impurity(List<int[]> indices, float[] targets, float[] weights) { 078 return impurityTuple(indices,targets,weights).impurity; 079 } 080 081 /** 082 * Calculates the weighted impurity of the targets specified in the indices array. 083 * @param indices The indices in the targets and weights arrays. 084 * @param targets The regression targets. 085 * @param weights The example weights. 086 * @return The weighted impurity. 087 */ 088 default public double impurity(int[] indices, float[] targets, float[] weights) { 089 return impurity(indices, indices.length, targets, weights); 090 } 091 092 /** 093 * Calculates the weighted impurity of the targets specified in the indices container. 094 * @param indices The indices in the targets and weights arrays. 095 * @param targets The regression targets. 096 * @param weights The example weights. 097 * @return The weighted impurity. 098 */ 099 default public double impurity(IntArrayContainer indices, float[] targets, float[] weights) { 100 return impurity(indices.array, indices.size, targets, weights); 101 } 102 103 /** 104 * Tuple class for the impurity and summed weight. 105 */ 106 public static class ImpurityTuple { 107 public final float impurity; 108 public final float weight; 109 110 public ImpurityTuple(float impurity, float weight) { 111 this.impurity = impurity; 112 this.weight = weight; 113 } 114 } 115}