001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.impl; 018 019import com.oracle.labs.mlrg.olcut.util.SortUtil; 020import org.tribuo.Example; 021import org.tribuo.Feature; 022import org.tribuo.ImmutableFeatureMap; 023import org.tribuo.ImmutableOutputInfo; 024import org.tribuo.Output; 025import org.tribuo.util.Merger; 026 027import java.util.ArrayList; 028import java.util.Arrays; 029import java.util.Collection; 030import java.util.HashMap; 031import java.util.Iterator; 032import java.util.List; 033import java.util.Map; 034import java.util.NoSuchElementException; 035import java.util.Objects; 036import java.util.PriorityQueue; 037 038/** 039 * A version of ArrayExample which also has the id numbers. 040 * <p> 041 * Used in feature selection to provide log n lookups. May be used 042 * elsewhere in the future as a performance optimisation. 043 */ 044public class IndexedArrayExample<T extends Output<T>> extends ArrayExample<T> { 045 private static final long serialVersionUID = 1L; 046 047 protected int[] featureIDs; 048 049 protected final int outputID; 050 051 private final ImmutableFeatureMap featureMap; 052 053 private final ImmutableOutputInfo<T> outputMap; 054 055 /** 056 * Copy constructor. 057 * @param other The example to copy. 058 */ 059 public IndexedArrayExample(IndexedArrayExample<T> other) { 060 super(other.getOutput(),other.getWeight(),other.getMetadata()); 061 featureNames = Arrays.copyOf(other.featureNames,other.featureNames.length); 062 featureIDs = Arrays.copyOf(other.featureIDs,other.size()); 063 featureValues = Arrays.copyOf(other.featureValues,other.featureValues.length); 064 featureMap = other.featureMap; 065 outputMap = other.outputMap; 066 outputID = outputMap.getID(output); 067 size = other.size; 068 } 069 070 /** 071 * This constructor removes unknown features. 072 * @param other The example to copy from. 073 * @param featureMap The feature map. 074 * @param outputMap The output info. 075 */ 076 public IndexedArrayExample(Example<T> other, ImmutableFeatureMap featureMap, ImmutableOutputInfo<T> outputMap) { 077 super(other); 078 this.featureIDs = new int[other.size()]; 079 this.featureMap = featureMap; 080 this.outputMap = outputMap; 081 this.outputID = outputMap.getID(output); 082 for (int i = 0; i < featureNames.length; i++) { 083 featureIDs[i] = featureMap.getID(featureNames[i]); 084 } 085 int[] newIDs = new int[other.size()]; 086 String[] newNames = new String[other.size()]; 087 double[] newValues = new double[other.size()]; 088 int counter = 0; 089 for (int i = 0; i < featureIDs.length; i++) { 090 if (featureIDs[i] != -1) { 091 newIDs[counter] = featureIDs[i]; 092 newValues[counter] = featureValues[i]; 093 newNames[counter] = featureNames[i]; 094 counter++; 095 } 096 } 097 size = counter; 098 featureNames = newNames; 099 featureIDs = newIDs; 100 featureValues = newValues; 101 } 102 103 @Override 104 public boolean equals(Object o) { 105 if (this == o) return true; 106 if (!(o instanceof IndexedArrayExample)) return false; 107 if (!super.equals(o)) return false; 108 IndexedArrayExample<?> that = (IndexedArrayExample<?>) o; 109 return outputID == that.outputID && 110 Arrays.equals(featureIDs, that.featureIDs) && 111 featureMap.equals(that.featureMap) && 112 outputMap.equals(that.outputMap); 113 } 114 115 @Override 116 public int hashCode() { 117 int result = Objects.hash(super.hashCode(), outputID, featureMap, outputMap); 118 result = 31 * result + Arrays.hashCode(featureIDs); 119 return result; 120 } 121 122 @Override 123 protected void growArray(int minCapacity) { 124 int newCapacity = newCapacity(minCapacity); 125 featureNames = Arrays.copyOf(featureNames,newCapacity); 126 featureIDs = Arrays.copyOf(featureIDs,newCapacity); 127 featureValues = Arrays.copyOf(featureValues,newCapacity); 128 } 129 130 @Override 131 public void add(Feature feature) { 132 if (size >= featureNames.length) { 133 growArray(); 134 } 135 featureNames[size] = feature.getName(); 136 featureIDs[size] = featureMap.getID(feature.getName()); 137 featureValues[size] = feature.getValue(); 138 size++; 139 sort(); 140 } 141 142 @Override 143 public void addAll(Collection<? extends Feature> features) { 144 if (size + features.size() >= featureNames.length) { 145 growArray(size+features.size()); 146 } 147 for (Feature f : features) { 148 featureNames[size] = f.getName(); 149 featureIDs[size] = featureMap.getID(f.getName()); 150 featureValues[size] = f.getValue(); 151 size++; 152 } 153 sort(); 154 } 155 156 @Override 157 protected void sort() { 158 int[] sortedIndices = SortUtil.argsort(featureNames,0,size,true); 159 160 String[] newNames = Arrays.copyOf(featureNames,size); 161 int[] newIDs = Arrays.copyOf(featureIDs,size); 162 double[] newValues = Arrays.copyOf(featureValues,size); 163 for (int i = 0; i < sortedIndices.length; i++) { 164 featureNames[i] = newNames[sortedIndices[i]]; 165 featureIDs[i] = newIDs[sortedIndices[i]]; 166 featureValues[i] = newValues[sortedIndices[i]]; 167 } 168 } 169 170 @Override 171 public void reduceByName(Merger merger) { 172 if (size > 0) { 173 int[] sortedIndices = SortUtil.argsort(featureNames, 0, size, true); 174 String[] newNames = new String[featureNames.length]; 175 int[] newIDs = new int[featureIDs.length]; 176 double[] newValues = new double[featureNames.length]; 177 for (int i = 0; i < sortedIndices.length; i++) { 178 newNames[i] = featureNames[sortedIndices[i]]; 179 newIDs[i] = featureIDs[sortedIndices[i]]; 180 newValues[i] = featureValues[sortedIndices[i]]; 181 } 182 featureNames[0] = newNames[0]; 183 featureIDs[0] = newIDs[0]; 184 featureValues[0] = newValues[0]; 185 int dest = 0; 186 for (int i = 1; i < size; i++) { 187 while ((i < size) && newNames[i].equals(featureNames[dest])) { 188 featureValues[dest] = merger.merge(featureValues[dest], newValues[i]); 189 i++; 190 } 191 if (i < size) { 192 dest++; 193 featureNames[dest] = newNames[i]; 194 featureIDs[dest] = newIDs[i]; 195 featureValues[dest] = newValues[i]; 196 } 197 } 198 size = dest + 1; 199 } 200 } 201 202 @Override 203 public void removeFeatures(List<Feature> featureList) { 204 Map<String,List<Integer>> map = new HashMap<>(); 205 for (int i = 0; i < featureNames.length; i++) { 206 List<Integer> list = map.computeIfAbsent(featureNames[i],(k) -> new ArrayList<>()); 207 list.add(i); 208 } 209 210 PriorityQueue<Integer> removeQueue = new PriorityQueue<>(); 211 for (Feature f : featureList) { 212 List<Integer> i = map.get(f.getName()); 213 if (i != null) { 214 // If we've found this feature remove it from the map to prevent double counting 215 map.remove(f.getName()); 216 removeQueue.addAll(i); 217 } 218 } 219 220 String[] newNames = new String[size-removeQueue.size()]; 221 int[] newIDs = new int[size-removeQueue.size()]; 222 double[] newValues = new double[size-removeQueue.size()]; 223 224 int source = 0; 225 int dest = 0; 226 while (!removeQueue.isEmpty()) { 227 int curRemoveIdx = removeQueue.poll(); 228 while (source < curRemoveIdx) { 229 newNames[dest] = featureNames[source]; 230 newIDs[dest] = featureIDs[source]; 231 newValues[dest] = featureValues[source]; 232 source++; 233 dest++; 234 } 235 source++; 236 } 237 while (source < size) { 238 newNames[dest] = featureNames[source]; 239 newIDs[dest] = featureIDs[source]; 240 newValues[dest] = featureValues[source]; 241 source++; 242 dest++; 243 } 244 featureNames = newNames; 245 featureIDs = newIDs; 246 featureValues = newValues; 247 size = featureNames.length; 248 } 249 250 /** 251 * Does this example contain a feature with id i. 252 * @param i The index to check. 253 * @return True if the example contains the id. 254 */ 255 public boolean contains(int i) { 256 return Arrays.binarySearch(featureIDs,i) > -1; 257 } 258 259 @Override 260 public IndexedArrayExample<T> copy() { 261 return new IndexedArrayExample<>(this); 262 } 263 264 @Override 265 public void densify(List<String> featureList) { 266 if (featureList.size() != featureMap.size()) { 267 throw new IllegalArgumentException("Densifying an example with a different feature map"); 268 } 269 // Ensure we have enough space. 270 if (featureList.size() > featureNames.length) { 271 growArray(featureList.size()); 272 } 273 int insertedCount = 0; 274 int curPos = 0; 275 for (String curName : featureList) { 276 // If we've reached the end of our old feature set, just insert. 277 if (curPos == size) { 278 featureNames[size + insertedCount] = curName; 279 featureIDs[size + insertedCount] = featureMap.getID(curName); 280 insertedCount++; 281 } else { 282 // Check to see if our insertion candidate is the same as the current feature name. 283 int comparison = curName.compareTo(featureNames[curPos]); 284 if (comparison < 0) { 285 // If it's earlier, insert it. 286 featureNames[size + insertedCount] = curName; 287 featureIDs[size + insertedCount] = featureMap.getID(curName); 288 insertedCount++; 289 } else if (comparison == 0) { 290 // Otherwise just bump our pointer, we've already got this feature. 291 curPos++; 292 } 293 } 294 } 295 // Bump the size up by the number of inserted features. 296 size += insertedCount; 297 // Sort the features 298 sort(); 299 } 300 301 /** 302 * Gets the feature at internal index i. 303 * @param i The internal index. 304 * @return The feature index. 305 */ 306 public int getIdx(int i) { 307 return featureIDs[i]; 308 } 309 310 /** 311 * Gets the output id dimension number. 312 * @return The output id. 313 */ 314 public int getOutputID() { 315 return outputID; 316 } 317 318 /** 319 * Iterator over the feature ids and values. 320 * @return The feature ids and values. 321 */ 322 public Iterable<FeatureTuple> idIterator() { 323 return ArrayIndexedExampleIterator::new; 324 } 325 326 /** 327 * A tuple of the feature name, id and value. 328 */ 329 public static class FeatureTuple { 330 public String name; 331 public int id; 332 public double value; 333 334 public FeatureTuple() { } 335 336 public FeatureTuple(String name, int id, double value) { 337 this.name = name; 338 this.id = id; 339 this.value = value; 340 } 341 } 342 343 class ArrayIndexedExampleIterator implements Iterator<FeatureTuple> { 344 int pos = 0; 345 FeatureTuple tuple = new FeatureTuple(); 346 347 @Override 348 public boolean hasNext() { 349 return pos < size; 350 } 351 352 @Override 353 public FeatureTuple next() { 354 if (!hasNext()) { 355 throw new NoSuchElementException("Off the end of the iterator."); 356 } 357 tuple.name = featureNames[pos]; 358 tuple.id = featureIDs[pos]; 359 tuple.value = featureValues[pos]; 360 pos++; 361 return tuple; 362 } 363 } 364}