001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.regression.rtree.impurity;
018
019import com.oracle.labs.mlrg.olcut.config.Configurable;
020import com.oracle.labs.mlrg.olcut.provenance.ConfiguredObjectProvenance;
021import com.oracle.labs.mlrg.olcut.provenance.Provenancable;
022import org.tribuo.common.tree.impl.IntArrayContainer;
023
024import java.util.List;
025
026/**
027 * Calculates a tree impurity score based on the regression targets.
028 */
029public interface RegressorImpurity extends Configurable, Provenancable<ConfiguredObjectProvenance> {
030
031    /**
032     * Calculates the impurity based on the supplied weights and targets.
033     * @param targets The targets.
034     * @param weights The weights.
035     * @return The impurity.
036     */
037    public double impurity(float[] targets, float[] weights);
038
039    /**
040     * Calculates the weighted impurity of the targets specified in the indices array.
041     * @param indices The indices in the targets and weights arrays.
042     * @param indicesLength The number of values to use in indices.
043     * @param targets The regression targets.
044     * @param weights The example weights.
045     * @return A tuple containing the impurity and the used weight sum.
046     */
047    public ImpurityTuple impurityTuple(int[] indices, int indicesLength, float[] targets, float[] weights);
048
049    /**
050     * Calculates the weighted impurity of the targets specified in all the indices arrays.
051     * @param indices The indices in the targets and weights arrays.
052     * @param targets The regression targets.
053     * @param weights The example weights.
054     * @return A tuple containing the impurity and the used weight sum.
055     */
056    public ImpurityTuple impurityTuple(List<int[]> indices, float[] targets, float[] weights);
057
058    /**
059     * Calculates the weighted impurity of the targets specified in the indices array.
060     * @param indices The indices in the targets and weights arrays.
061     * @param indicesLength The number of values to use in indices.
062     * @param targets The regression targets.
063     * @param weights The example weights.
064     * @return The weighted impurity.
065     */
066    default public double impurity(int[] indices, int indicesLength, float[] targets, float[] weights) {
067        return impurityTuple(indices, indicesLength, targets, weights).impurity;
068    }
069
070    /**
071     * Calculates the weighted impurity of the targets specified in all the indices arrays.
072     * @param indices The indices in the targets and weights arrays.
073     * @param targets The regression targets.
074     * @param weights The example weights.
075     * @return The weighted impurity.
076     */
077    default public double impurity(List<int[]> indices, float[] targets, float[] weights) {
078        return impurityTuple(indices,targets,weights).impurity;
079    }
080
081    /**
082     * Calculates the weighted impurity of the targets specified in the indices array.
083     * @param indices The indices in the targets and weights arrays.
084     * @param targets The regression targets.
085     * @param weights The example weights.
086     * @return The weighted impurity.
087     */
088    default public double impurity(int[] indices, float[] targets, float[] weights) {
089        return impurity(indices, indices.length, targets, weights);
090    }
091
092    /**
093     * Calculates the weighted impurity of the targets specified in the indices container.
094     * @param indices The indices in the targets and weights arrays.
095     * @param targets The regression targets.
096     * @param weights The example weights.
097     * @return The weighted impurity.
098     */
099    default public double impurity(IntArrayContainer indices, float[] targets, float[] weights) {
100        return impurity(indices.array, indices.size, targets, weights);
101    }
102
103    /**
104     * Tuple class for the impurity and summed weight.
105     */
106    public static class ImpurityTuple {
107        public final float impurity;
108        public final float weight;
109
110        public ImpurityTuple(float impurity, float weight) {
111            this.impurity = impurity;
112            this.weight = weight;
113        }
114    }
115}