001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.regression.rtree.impl; 018 019import com.oracle.labs.mlrg.olcut.util.Pair; 020import org.tribuo.common.tree.impl.IntArrayContainer; 021 022import java.util.ArrayList; 023import java.util.Collections; 024import java.util.HashMap; 025import java.util.Iterator; 026import java.util.List; 027import java.util.Map; 028 029/** 030 * An inverted feature, which stores a reference to all the values of this feature. 031 * <p> 032 * Can be split into two values based on an example index list. 033 */ 034public class TreeFeature implements Iterable<InvertedFeature> { 035 036 private final int id; 037 038 private final List<InvertedFeature> feature; 039 040 private final Map<Double,InvertedFeature> valueMap; 041 042 private boolean sorted = true; 043 044 public TreeFeature(int id) { 045 this.id = id; 046 this.feature = new ArrayList<>(); 047 this.valueMap = new HashMap<>(); 048 } 049 050 /** 051 * This constructor doesn't make a valueMap, and is only used when all data has been observed. 052 * So it will throw NullPointerException if you call observeValue(); 053 * @param id The id number for this feature. 054 * @param data The data. 055 */ 056 private TreeFeature(int id, List<InvertedFeature> data) { 057 this.id = id; 058 this.feature = data; 059 this.valueMap = null; 060 } 061 062 /** 063 * Constructor used by {@link TreeFeature#deepCopy}. 064 * @param id The id number for this feature. 065 * @param data The data. 066 * @param valueMap The value map. 067 * @param sorted Is this data sorted. 068 */ 069 private TreeFeature(int id, List<InvertedFeature> data, Map<Double,InvertedFeature> valueMap, boolean sorted) { 070 this.id = id; 071 this.feature = data; 072 this.valueMap = valueMap; 073 this.sorted = sorted; 074 } 075 076 @Override 077 public Iterator<InvertedFeature> iterator() { 078 return feature.iterator(); 079 } 080 081 public List<InvertedFeature> getFeature() { 082 return feature; 083 } 084 085 /** 086 * Observes a value for this feature. 087 * @param value The value observed. 088 * @param exampleID The example id number. 089 */ 090 public void observeValue(double value, int exampleID) { 091 Double dValue = value; 092 InvertedFeature f = valueMap.get(dValue); 093 if (f == null) { 094 f = new InvertedFeature(value,exampleID); 095 valueMap.put(dValue,f); 096 feature.add(f); 097 // feature list is no longer guaranteed to be sorted 098 sorted = false; 099 } else { 100 // Update currently known feature 101 f.add(exampleID); 102 } 103 } 104 105 /** 106 * Sort the list using InvertedFeature's natural ordering. Must be done after all elements are inserted. 107 */ 108 public void sort() { 109 feature.sort(null); 110 sorted = true; 111 } 112 113 /** 114 * Fixes the size of each {@link InvertedFeature}'s inner arrays. 115 */ 116 public void fixSize() { 117 feature.forEach(InvertedFeature::fixSize); 118 } 119 120 /** 121 * Splits this tree feature into two. 122 * 123 * @param leftIndices The indices to go in the left branch. 124 * @param rightIndices The indices to go in the right branch. 125 * @param firstBuffer A buffer for temporary work. 126 * @param secondBuffer A buffer for temporary work. 127 * @return A pair of TreeFeatures, the first element is the left branch, the second the right. 128 */ 129 public Pair<TreeFeature,TreeFeature> split(int[] leftIndices, int[] rightIndices, IntArrayContainer firstBuffer, IntArrayContainer secondBuffer) { 130 if (!sorted) { 131 throw new IllegalStateException("TreeFeature must be sorted before split is called"); 132 } 133 134 List<InvertedFeature> leftFeatures; 135 List<InvertedFeature> rightFeatures; 136 if (feature.size() == 1) { 137 double value = feature.get(0).value; 138 leftFeatures = Collections.singletonList(new InvertedFeature(value,leftIndices)); 139 rightFeatures = Collections.singletonList(new InvertedFeature(value,rightIndices)); 140 } else { 141 leftFeatures = new ArrayList<>(); 142 rightFeatures = new ArrayList<>(); 143 firstBuffer.fill(leftIndices); 144 for (InvertedFeature f : feature) { 145 // Check if we've exhausted all the left side indices 146 if (firstBuffer.size > 0) { 147 Pair<InvertedFeature, InvertedFeature> split = f.split(firstBuffer, secondBuffer); 148 IntArrayContainer tmp = secondBuffer; 149 secondBuffer = firstBuffer; 150 firstBuffer = tmp; 151 InvertedFeature left = split.getA(); 152 InvertedFeature right = split.getB(); 153 if (left != null) { 154 leftFeatures.add(left); 155 } 156 if (right != null) { 157 rightFeatures.add(right); 158 } 159 } else { 160 rightFeatures.add(f); 161 } 162 } 163 } 164 165 return new Pair<>(new TreeFeature(id,leftFeatures),new TreeFeature(id,rightFeatures)); 166 167 } 168 169 public String toString() { 170 return "TreeFeature(id="+id+",values="+feature.toString()+")"; 171 } 172 173 /** 174 * Returns a deep copy of this tree feature. 175 * @return A deep copy. 176 */ 177 public TreeFeature deepCopy() { 178 Map<Double,InvertedFeature> newValueMap; 179 List<InvertedFeature> newFeature = new ArrayList<>(); 180 if (valueMap != null) { 181 newValueMap = new HashMap<>(); 182 for (Map.Entry<Double,InvertedFeature> e : valueMap.entrySet()) { 183 InvertedFeature featureCopy = e.getValue().deepCopy(); 184 newValueMap.put(e.getKey(),featureCopy); 185 newFeature.add(featureCopy); 186 newFeature.sort(null); 187 } 188 } else { 189 newValueMap = null; 190 for (InvertedFeature f : feature) { 191 newFeature.add(f.deepCopy()); 192 } 193 } 194 return new TreeFeature(id,newFeature,newValueMap,true); 195 } 196}