001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.regression.rtree.impl; 018 019import com.oracle.labs.mlrg.olcut.util.Pair; 020import org.tribuo.common.tree.impl.IntArrayContainer; 021 022import java.util.Arrays; 023 024/** 025 * Internal datastructure for implementing a decision tree. 026 * <p> 027 * Represents a single value and feature tuple, with associated arrays for 028 * the indicies where that combination occurs. 029 * <p> 030 * Indices and values must be inserted in sorted ascending order or everything will break. 031 * This code does not check that this invariant is maintained. 032 * <p> 033 * Note: this class has a natural ordering that is inconsistent with equals. 034 */ 035public class InvertedFeature implements Comparable<InvertedFeature> { 036 037 private static final int DEFAULT_SIZE = 8; 038 039 public final double value; 040 041 /** 042 * Indices must be inserted in a sorted order. 043 */ 044 private int[] indices = null; 045 private int curSize = -1; 046 047 /** 048 * This is a short circuit in case there is a single index in this feature. 049 */ 050 private int index; 051 052 public InvertedFeature(double value, int[] indices) { 053 this.value = value; 054 this.indices = indices; 055 this.curSize = indices.length; 056 } 057 058 public InvertedFeature(double value, int index) { 059 this.value = value; 060 this.index = index; 061 } 062 063 private InvertedFeature(InvertedFeature other) { 064 this.value = other.value; 065 this.curSize = other.curSize; 066 this.index = other.index; 067 if (other.indices != null) { 068 this.indices = Arrays.copyOf(other.indices,other.indices.length); 069 } else { 070 this.indices = null; 071 } 072 } 073 074 public void add(int index) { 075 if (indices == null) { 076 initArrays(); 077 } 078 append(index); 079 } 080 081 private void append(int index) { 082 if (curSize == indices.length) { 083 int newSize = indices.length + (indices.length >> 1); 084 indices = Arrays.copyOf(indices,newSize); 085 } 086 indices[curSize] = index; 087 curSize++; 088 } 089 090 public int[] indices() { 091 if (indices != null) { 092 return indices; 093 } else { 094 int[] ret = new int[1]; 095 ret[0] = index; 096 return ret; 097 } 098 } 099 100 public void fixSize() { 101 if (indices != null) { 102 indices = Arrays.copyOf(indices, curSize); 103 } 104 } 105 106 /** 107 * Relies upon allLeftIndices being sorted in ascending order. Undefined when it's not. 108 * @param allLeftIndices The indices of the left branch. 109 * @param buffer The buffer to write out the unused indices to. 110 * @return A pair, with the first element the left branch and the second element the right branch. 111 */ 112 public Pair<InvertedFeature,InvertedFeature> split(IntArrayContainer allLeftIndices, IntArrayContainer buffer) { 113 int[] allLeftArray = allLeftIndices.array; 114 int allLeftSize = allLeftIndices.size; 115 int[] bufferArray = buffer.array; 116 if (indices != null) { 117 // These are init'd to indices.length as allLeftIndices may contain indices not in this InvertedFeature. 118 int[] leftIndices = new int[indices.length]; 119 int leftSize = 0; 120 int[] rightIndices = new int[indices.length]; 121 int rightSize = 0; 122 123 int bufferIdx = 0; 124 int curIndex = 0; 125 int j = 0; 126 for (int i = 0; i < curSize; i++) { 127 //relying on the shortcut evaluation so we don't pop out of allLeftArray 128 while ((j < allLeftSize) && ((curIndex = allLeftArray[j]) < indices[i])) { 129 bufferArray[bufferIdx] = curIndex; 130 bufferIdx++; 131 j++; 132 } 133 if ((j < allLeftSize) && (allLeftArray[j] == indices[i])) { 134 //in the left indices, put in left array 135 leftIndices[leftSize] = indices[i]; 136 leftSize++; 137 j++; // consume the value in allLeftIndices[j] 138 } else { 139 //allLeftIndices[j] now greater than indices[i], so must not include it 140 //put in right array. 141 rightIndices[rightSize] = indices[i]; 142 rightSize++; 143 } 144 } 145 146 if (j < allLeftSize) { 147 System.arraycopy(allLeftArray, j, bufferArray, bufferIdx, allLeftSize - j); 148 } 149 buffer.size = bufferIdx + (allLeftSize - j); 150 allLeftIndices.size = 0; 151 152 InvertedFeature left, right; 153 if (leftSize == 0) { 154 left = null; 155 } else if (leftSize == 1) { 156 left = new InvertedFeature(value,leftIndices[0]); 157 } else { 158 left = new InvertedFeature(value, Arrays.copyOf(leftIndices, leftSize)); 159 } 160 if (rightSize == 0) { 161 right = null; 162 } else if (rightSize == 1) { 163 right = new InvertedFeature(value,rightIndices[0]); 164 } else { 165 right = new InvertedFeature(value, Arrays.copyOf(rightIndices, rightSize)); 166 } 167 return new Pair<>(left,right); 168 } else { 169 //In this case this inverted feature only holds one value, so check for it in left indices 170 boolean found = false; 171 int i = 0; 172 while (!found && i < allLeftSize) { 173 if (allLeftArray[i] == index) { 174 found = true; 175 } else { 176 i++; 177 } 178 } 179 if (found) { 180 System.arraycopy(allLeftArray,0,bufferArray,0,i); 181 i++; 182 while (i < allLeftSize) { 183 bufferArray[i-1] = allLeftArray[i]; 184 i++; 185 } 186 if (i < allLeftSize-1) { 187 System.arraycopy(allLeftArray, i + 1, bufferArray, i, allLeftSize - i); 188 } 189 buffer.size = allLeftSize-1; 190 allLeftIndices.size = 0; 191 return new Pair<>(new InvertedFeature(value,index),null); 192 } else { 193 allLeftIndices.array = bufferArray; 194 allLeftIndices.size = 0; 195 buffer.array = allLeftArray; 196 buffer.size = allLeftSize; 197 return new Pair<>(null,new InvertedFeature(value,index)); 198 } 199 } 200 } 201 202 private void initArrays() { 203 indices = new int[DEFAULT_SIZE]; 204 indices[0] = index; 205 curSize = 1; 206 } 207 208 @Override 209 public int compareTo(InvertedFeature o) { 210 return Double.compare(value, o.value); 211 } 212 213 @Override 214 public String toString() { 215 if (indices != null) { 216 return "InvertedFeature(value=" + value + ",size=" + curSize + ",indices=" + Arrays.toString(indices) + ")"; 217 } else { 218 return "InvertedFeature(value=" + value + ",size=" + curSize + ",index=" + index + ")"; 219 } 220 } 221 222 public InvertedFeature deepCopy() { 223 return new InvertedFeature(this); 224 } 225}