001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.math.la; 018 019import org.tribuo.Dataset; 020import org.tribuo.Example; 021import org.tribuo.Feature; 022import org.tribuo.ImmutableFeatureMap; 023import org.tribuo.Output; 024import org.tribuo.math.util.VectorNormalizer; 025import org.tribuo.util.IntDoublePair; 026import org.tribuo.util.Util; 027 028import java.util.ArrayList; 029import java.util.Arrays; 030import java.util.HashMap; 031import java.util.Iterator; 032import java.util.List; 033import java.util.Map; 034import java.util.NoSuchElementException; 035import java.util.Objects; 036import java.util.function.DoubleUnaryOperator; 037import java.util.stream.Collectors; 038 039/** 040 * A sparse vector. Stored as a sorted array of indices and an array of values. 041 * <p> 042 * Uses binary search to look up a specific index, so it's usually faster to 043 * use the iterator to iterate the values. 044 * <p> 045 * This vector has immutable indices. It cannot get new indices after construction, 046 * and will throw {@link IllegalArgumentException} if such an operation is tried. 047 */ 048public class SparseVector implements SGDVector { 049 private static final long serialVersionUID = 1L; 050 051 private final int[] shape; 052 protected final int[] indices; 053 protected final double[] values; 054 private final int size; 055 056 /** 057 * Used internally for performance. 058 * Does not defensively copy the input, nor check it's sorted. 059 * <p> 060 * @param size The dimension of this vector. 061 * @param indices The indices. 062 * @param values The values. 063 */ 064 SparseVector(int size, int[] indices, double[] values) { 065 this.size = size; 066 this.shape = new int[]{size}; 067 this.indices = indices; 068 this.values = values; 069 } 070 071 /** 072 * Returns a deep copy of the supplied sparse vector. 073 * <p> 074 * Copies the value by iterating it's VectorTuple. 075 * @param other The SparseVector to copy. 076 */ 077 private SparseVector(SparseVector other) { 078 this.size = other.size; 079 int numActiveElements = other.numActiveElements(); 080 this.indices = new int[numActiveElements]; 081 this.values = new double[numActiveElements]; 082 083 int i = 0; 084 for (VectorTuple tuple : other) { 085 indices[i] = tuple.index; 086 values[i] = tuple.value; 087 i++; 088 } 089 this.shape = new int[]{size}; 090 } 091 092 public SparseVector(int size, int[] indices, double value) { 093 this.indices = Arrays.copyOf(indices,indices.length); 094 this.values = new double[indices.length]; 095 Arrays.fill(this.values,value); 096 this.size = size; 097 this.shape = new int[]{size}; 098 } 099 100 /** 101 * Builds a {@link SparseVector} from an {@link Example}. 102 * <p> 103 * Used in training and inference. 104 * <p> 105 * Throws {@link IllegalArgumentException} if the Example contains NaN-valued features. 106 * @param example The example to convert. 107 * @param featureInfo The feature information, used to calculate the dimension of this SparseVector. 108 * @param addBias Add a bias feature. 109 * @param <T> The type parameter of the {@code example}. 110 * @return A SparseVector representing the example's features. 111 */ 112 public static <T extends Output<T>> SparseVector createSparseVector(Example<T> example, ImmutableFeatureMap featureInfo, boolean addBias) { 113 int size; 114 int numFeatures = example.size(); 115 if (addBias) { 116 size = featureInfo.size() + 1; 117 numFeatures++; 118 } else { 119 size = featureInfo.size(); 120 } 121 int[] tmpIndices = new int[numFeatures]; 122 double[] tmpValues = new double[numFeatures]; 123 int i = 0; 124 int prevIdx = -1; 125 for (Feature f : example) { 126 int index = featureInfo.getID(f.getName()); 127 if (index > prevIdx){ 128 prevIdx = index; 129 tmpIndices[i] = index; 130 tmpValues[i] = f.getValue(); 131 if (Double.isNaN(tmpValues[i])) { 132 throw new IllegalArgumentException("Example contained a NaN feature, " + f.toString()); 133 } 134 i++; 135 } else if (index > -1) { 136 // 137 // Collision, deal with it. 138 int collisionIdx = Arrays.binarySearch(tmpIndices,0,i,index); 139 if (collisionIdx < 0) { 140 // 141 // Collision but not present in tmpIndices 142 // move data and bump i 143 collisionIdx = - (collisionIdx + 1); 144 System.arraycopy(tmpIndices,collisionIdx,tmpIndices,collisionIdx+1,i-collisionIdx); 145 System.arraycopy(tmpValues,collisionIdx,tmpValues,collisionIdx+1,i-collisionIdx); 146 tmpIndices[collisionIdx] = index; 147 tmpValues[collisionIdx] = f.getValue(); 148 if (Double.isNaN(tmpValues[collisionIdx])) { 149 throw new IllegalArgumentException("Example contained a NaN feature, " + f.toString()); 150 } 151 i++; 152 } else { 153 // 154 // Collision present in tmpIndices 155 // add the values. 156 tmpValues[collisionIdx] += f.getValue(); 157 if (Double.isNaN(tmpValues[collisionIdx])) { 158 throw new IllegalArgumentException("Example contained a NaN feature, " + f.toString()); 159 } 160 } 161 } 162 } 163 if (addBias) { 164 tmpIndices[i] = size - 1; 165 tmpValues[i] = 1.0; 166 i++; 167 } 168 return new SparseVector(size,Arrays.copyOf(tmpIndices,i),Arrays.copyOf(tmpValues,i)); 169 } 170 171 /** 172 * Defensively copies the input, and checks that the indices are sorted. If not, 173 * it sorts them. 174 * <p> 175 * Throws {@link IllegalArgumentException} if the arrays are not the same length, or if size is less than 176 * the max index. 177 * @param dimension The dimension of this vector. 178 * @param indices The indices of the non-zero elements. 179 * @param values The values of the non-zero elements. 180 * @return A SparseVector encapsulating the indices and values. 181 */ 182 public static SparseVector createSparseVector(int dimension, int[] indices, double[] values) { 183 if (indices.length != values.length) { 184 throw new IllegalArgumentException("Indices and values must be the same length, found indices.length = " + indices.length + " and values.length = " + values.length); 185 } else if (indices.length == 0) { 186 return new SparseVector(dimension,indices,values); 187 } else { 188 IntDoublePair[] pairArray = new IntDoublePair[indices.length]; 189 for (int i = 0; i < pairArray.length; i++) { 190 pairArray[i] = new IntDoublePair(indices[i], values[i]); 191 } 192 Arrays.sort(pairArray, IntDoublePair.pairIndexComparator()); 193 int[] newIndices = new int[indices.length]; 194 double[] newValues = new double[values.length]; 195 for (int i = 0; i < pairArray.length; i++) { 196 newIndices[i] = pairArray[i].index; 197 newValues[i] = pairArray[i].value; 198 } 199 if (dimension < newIndices[newIndices.length - 1]) { 200 throw new IllegalArgumentException("Number of dimensions is less than the maximum index, dimensions = " + dimension + ", max index = " + newIndices[newIndices.length - 1]); 201 } 202 return new SparseVector(dimension, newIndices, newValues); 203 } 204 } 205 206 /** 207 * Builds a SparseVector from a map. 208 * <p> 209 * Throws {@link IllegalArgumentException} if dimension is less than the max index. 210 * @param dimension The dimension of this vector. 211 * @param indexMap The map from indices to values. 212 * @return A SparseVector. 213 */ 214 public static SparseVector createSparseVector(int dimension, Map<Integer, Double> indexMap) { 215 if (indexMap.isEmpty()) { 216 return new SparseVector(dimension,new int[0],new double[0]); 217 } else { 218 List<Map.Entry<Integer, Double>> sortedEntries = indexMap.entrySet() 219 .stream().sorted(Map.Entry.comparingByKey()) 220 .collect(Collectors.toList()); 221 222 int[] indices = new int[sortedEntries.size()]; 223 double[] values = new double[sortedEntries.size()]; 224 for (int i = 0; i < sortedEntries.size(); i++) { 225 indices[i] = sortedEntries.get(i).getKey(); 226 values[i] = sortedEntries.get(i).getValue(); 227 } 228 if (dimension < indices[indices.length - 1]) { 229 throw new IllegalArgumentException("Number of dimensions is less than the maximum index, dimensions = " + dimension + ", max index = " + indices[indices.length - 1]); 230 } 231 return new SparseVector(dimension, indices, values); 232 } 233 } 234 235 @Override 236 public SparseVector copy() { 237 return new SparseVector(this); 238 } 239 240 @Override 241 public int[] getShape() { 242 return shape; 243 } 244 245 @Override 246 public Tensor reshape(int[] newShape) { 247 throw new UnsupportedOperationException("Reshape not supported on sparse Tensors."); 248 } 249 250 @Override 251 public int size() { 252 return size; 253 } 254 255 @Override 256 public int numActiveElements() { 257 return values.length; 258 } 259 260 /** 261 * Equals is defined mathematically, that is two SGDVectors are equal iff they have the same indices 262 * and the same values at those indices. 263 * @param other Object to compare against. 264 * @return True if this vector and the other vector contain the same values in the same order. 265 */ 266 @Override 267 public boolean equals(Object other) { 268 if (other instanceof SGDVector) { 269 Iterator<VectorTuple> ourItr = iterator(); 270 Iterator<VectorTuple> otherItr = ((SGDVector)other).iterator(); 271 VectorTuple ourTuple; 272 VectorTuple otherTuple; 273 274 while (ourItr.hasNext() && otherItr.hasNext()) { 275 ourTuple = ourItr.next(); 276 otherTuple = otherItr.next(); 277 if (!ourTuple.equals(otherTuple)) { 278 return false; 279 } 280 } 281 282 // If one of the iterators still has elements then they are not the same. 283 return !(ourItr.hasNext() || otherItr.hasNext()); 284 } else { 285 return false; 286 } 287 } 288 289 @Override 290 public int hashCode() { 291 int result = Objects.hash(size); 292 result = 31 * result + Arrays.hashCode(indices); 293 result = 31 * result + Arrays.hashCode(values); 294 return result; 295 } 296 297 /** 298 * Adds {@code other} to this vector, producing a new {@link SGDVector}. 299 * If {@code other} is a {@link SparseVector} then the returned vector is also 300 * a {@link SparseVector} otherwise it's a {@link DenseVector}. 301 * @param other The vector to add. 302 * @return A new {@link SGDVector} where each element value = this.get(i) + other.get(i). 303 */ 304 @Override 305 public SGDVector add(SGDVector other) { 306 if (other.size() != size) { 307 throw new IllegalArgumentException("Can't add two vectors of different dimension, this = " + size + ", other = " + other.size()); 308 } 309 if (other instanceof DenseVector) { 310 return other.add(this); 311 } else if (other instanceof SparseVector) { 312 Map<Integer, Double> values = new HashMap<>(); 313 for (VectorTuple tuple : this) { 314 values.put(tuple.index, tuple.value); 315 } 316 for (VectorTuple tuple : other) { 317 values.merge(tuple.index, tuple.value, Double::sum); 318 } 319 return createSparseVector(size, values); 320 } else { 321 throw new IllegalArgumentException("Vector other is not dense or sparse."); 322 } 323 } 324 325 /** 326 * Subtracts {@code other} from this vector, producing a new {@link SGDVector}. 327 * If {@code other} is a {@link SparseVector} then the returned vector is also 328 * a {@link SparseVector} otherwise it's a {@link DenseVector}. 329 * @param other The vector to subtract. 330 * @return A new {@link SGDVector} where each element value = this.get(i) - other.get(i). 331 */ 332 @Override 333 public SGDVector subtract(SGDVector other) { 334 if (other.size() != size) { 335 throw new IllegalArgumentException("Can't subtract two vectors of different dimension, this = " + size + ", other = " + other.size()); 336 } 337 if (other instanceof DenseVector) { 338 DenseVector output = ((DenseVector)other).copy(); 339 for (VectorTuple tuple : this) { 340 output.set(tuple.index,tuple.value-output.get(tuple.index)); 341 } 342 return output; 343 } else if (other instanceof SparseVector) { 344 Map<Integer, Double> values = new HashMap<>(); 345 for (VectorTuple tuple : this) { 346 values.put(tuple.index, tuple.value); 347 } 348 for (VectorTuple tuple : other) { 349 values.merge(tuple.index, -tuple.value, Double::sum); 350 } 351 return createSparseVector(size, values); 352 } else { 353 throw new IllegalArgumentException("Vector other is not dense or sparse."); 354 } 355 } 356 357 @Override 358 public void intersectAndAddInPlace(Tensor other, DoubleUnaryOperator f) { 359 if (other instanceof SparseVector) { 360 SparseVector otherVec = (SparseVector) other; 361 if (otherVec.size() != size) { 362 throw new IllegalArgumentException("Can't intersect two vectors of different dimension, this = " + size + ", other = " + otherVec.size()); 363 } else if (otherVec.numActiveElements() > 0) { 364 int i = 0; 365 Iterator<VectorTuple> otherItr = otherVec.iterator(); 366 VectorTuple tuple = otherItr.next(); 367 while (i < (indices.length-1) && otherItr.hasNext()) { 368 if (indices[i] == tuple.index) { 369 values[i] += f.applyAsDouble(tuple.value); 370 i++; 371 tuple = otherItr.next(); 372 } else if (indices[i] < tuple.index) { 373 i++; 374 } else { 375 tuple = otherItr.next(); 376 } 377 } 378 for (; i < indices.length-1; i++) { 379 if (indices[i] == tuple.index) { 380 values[i] += f.applyAsDouble(tuple.value); 381 } 382 } 383 while (otherItr.hasNext()) { 384 if (indices[i] == tuple.index) { 385 values[i] += f.applyAsDouble(tuple.value); 386 } 387 tuple = otherItr.next(); 388 } 389 if (indices[i] == tuple.index) { 390 values[i] += f.applyAsDouble(tuple.value); 391 } 392 } 393 } else if (other instanceof DenseVector) { 394 DenseVector otherVec = (DenseVector) other; 395 if (otherVec.size() != size) { 396 throw new IllegalArgumentException("Can't intersect two vectors of different dimension, this = " + size + ", other = " + otherVec.size()); 397 } 398 for (int i = 0; i < indices.length; i++) { 399 values[i] += f.applyAsDouble(otherVec.get(indices[i])); 400 } 401 } else { 402 throw new IllegalStateException("Unknown Tensor subclass " + other.getClass().getCanonicalName() + " for input"); 403 } 404 } 405 406 @Override 407 public void hadamardProductInPlace(Tensor other, DoubleUnaryOperator f) { 408 if (other instanceof SparseVector) { 409 SparseVector otherVec = (SparseVector) other; 410 if (otherVec.size() != size) { 411 throw new IllegalArgumentException("Can't hadamard product two vectors of different dimension, this = " + size + ", other = " + otherVec.size()); 412 } else if (otherVec.numActiveElements() > 0) { 413 int i = 0; 414 Iterator<VectorTuple> otherItr = otherVec.iterator(); 415 VectorTuple tuple = otherItr.next(); 416 while (i < (indices.length-1) && otherItr.hasNext()) { 417 if (indices[i] == tuple.index) { 418 values[i] *= f.applyAsDouble(tuple.value); 419 i++; 420 tuple = otherItr.next(); 421 } else if (indices[i] < tuple.index) { 422 i++; 423 } else { 424 tuple = otherItr.next(); 425 } 426 } 427 for (; i < indices.length-1; i++) { 428 if (indices[i] == tuple.index) { 429 values[i] *= f.applyAsDouble(tuple.value); 430 } 431 } 432 while (otherItr.hasNext()) { 433 if (indices[i] == tuple.index) { 434 values[i] *= f.applyAsDouble(tuple.value); 435 } 436 tuple = otherItr.next(); 437 } 438 if (indices[i] == tuple.index) { 439 values[i] *= f.applyAsDouble(tuple.value); 440 } 441 } 442 } else if (other instanceof DenseVector) { 443 DenseVector otherVec = (DenseVector) other; 444 if (otherVec.size() != size) { 445 throw new IllegalArgumentException("Can't hadamard product two vectors of different dimension, this = " + size + ", other = " + otherVec.size()); 446 } 447 for (int i = 0; i < indices.length; i++) { 448 values[i] *= f.applyAsDouble(otherVec.get(indices[i])); 449 } 450 } else { 451 throw new IllegalArgumentException("Invalid Tensor subclass " + other.getClass().getCanonicalName() + " for input"); 452 } 453 } 454 455 @Override 456 public void foreachInPlace(DoubleUnaryOperator f) { 457 for (int i = 0; i < values.length; i++) { 458 values[i] = f.applyAsDouble(values[i]); 459 } 460 } 461 462 @Override 463 public SparseVector scale(double coefficient) { 464 double[] newValues = Arrays.copyOf(values, values.length); 465 for (int i = 0; i < values.length; i++) { 466 newValues[i] *= coefficient; 467 } 468 return new SparseVector(size, Arrays.copyOf(indices, indices.length), newValues); 469 } 470 471 @Override 472 public void add(int index, double value) { 473 int foundIndex = Arrays.binarySearch(indices, index); 474 if (foundIndex < 0) { 475 throw new IllegalArgumentException("SparseVector cannot have new elements added."); 476 } else { 477 values[foundIndex] += value; 478 } 479 } 480 481 @Override 482 public double dot(SGDVector other) { 483 if (other.size() != size) { 484 throw new IllegalArgumentException("Can't dot two vectors of different lengths, this = " + size + ", other = " + other.size()); 485 } else if (other instanceof SparseVector) { 486 double score = 0.0; 487 488 // If there are elements, calculate the dot product. 489 if ((other.numActiveElements() != 0) && (indices.length != 0)) { 490 Iterator<VectorTuple> itr = iterator(); 491 Iterator<VectorTuple> otherItr = other.iterator(); 492 VectorTuple tuple = itr.next(); 493 VectorTuple otherTuple = otherItr.next(); 494 while (itr.hasNext() && otherItr.hasNext()) { 495 if (tuple.index == otherTuple.index) { 496 score += tuple.value * otherTuple.value; 497 tuple = itr.next(); 498 otherTuple = otherItr.next(); 499 } else if (tuple.index < otherTuple.index) { 500 tuple = itr.next(); 501 } else { 502 otherTuple = otherItr.next(); 503 } 504 } 505 while (itr.hasNext()) { 506 if (tuple.index == otherTuple.index) { 507 score += tuple.value * otherTuple.value; 508 } 509 tuple = itr.next(); 510 } 511 while (otherItr.hasNext()) { 512 if (tuple.index == otherTuple.index) { 513 score += tuple.value * otherTuple.value; 514 } 515 otherTuple = otherItr.next(); 516 } 517 if (tuple.index == otherTuple.index) { 518 score += tuple.value * otherTuple.value; 519 } 520 } 521 522 return score; 523 } else if (other instanceof DenseVector) { 524 double score = 0.0; 525 526 for (int i = 0; i < indices.length; i++) { 527 score += other.get(indices[i]) * values[i]; 528 } 529 530 return score; 531 } else { 532 throw new IllegalArgumentException("Unknown vector subclass " + other.getClass().getCanonicalName() + " for input"); 533 } 534 } 535 536 /** 537 * This generates the outer product when dotted with another {@link SparseVector}. 538 * <p> 539 * It throws an {@link IllegalArgumentException} if used with a {@link DenseVector}. 540 * 541 * @param other A vector. 542 * @return A {@link DenseSparseMatrix} representing the outer product. 543 */ 544 @Override 545 public Matrix outer(SGDVector other) { 546 if (other instanceof SparseVector) { 547 //This horrible mess is why there should be a sparse-sparse matrix type. 548 SparseVector otherVec = (SparseVector) other; 549 SparseVector[] output = new SparseVector[size]; 550 int i = 0; 551 for (VectorTuple tuple : this) { 552 while (i < tuple.index) { 553 output[i] = new SparseVector(other.size(), new int[0], new double[0]); 554 i++; 555 } 556 output[tuple.index] = otherVec.scale(tuple.value); 557 i++; 558 } 559 while (i < output.length) { 560 output[i] = new SparseVector(other.size(), new int[0], new double[0]); 561 i++; 562 } 563 //TODO this is suboptimal if there are lots of missing rows. 564 return new DenseSparseMatrix(output); 565 } else if (other instanceof DenseVector) { 566 throw new IllegalArgumentException("sparse.outer(dense) is currently not implemented."); 567 } else { 568 throw new IllegalArgumentException("Unknown vector subclass " + other.getClass().getCanonicalName() + " for input"); 569 } 570 } 571 572 @Override 573 public double sum() { 574 double sum = 0.0; 575 for (int i = 0; i < values.length; i++) { 576 sum += values[i]; 577 } 578 return sum; 579 } 580 581 @Override 582 public double twoNorm() { 583 double sum = 0.0; 584 for (int i = 0; i < values.length; i++) { 585 sum += values[i] * values[i]; 586 } 587 return Math.sqrt(sum); 588 } 589 590 @Override 591 public double oneNorm() { 592 double sum = 0.0; 593 for (int i = 0; i < values.length; i++) { 594 sum += Math.abs(values[i]); 595 } 596 return sum; 597 } 598 599 @Override 600 public double get(int index) { 601 int foundIndex = Arrays.binarySearch(indices, index); 602 if (foundIndex < 0) { 603 return 0; 604 } else { 605 return values[foundIndex]; 606 } 607 } 608 609 @Override 610 public void set(int index, double value) { 611 int foundIndex = Arrays.binarySearch(indices, index); 612 if (foundIndex < 0) { 613 throw new IllegalArgumentException("SparseVector cannot have new elements added."); 614 } else { 615 values[foundIndex] = value; 616 } 617 } 618 619 @Override 620 public int indexOfMax() { 621 int index = 0; 622 double value = Double.NEGATIVE_INFINITY; 623 for (int i = 0; i < values.length; i++) { 624 double tmp = values[i]; 625 if (tmp > value) { 626 index = i; 627 value = tmp; 628 } 629 } 630 return indices[index]; 631 } 632 633 @Override 634 public double maxValue() { 635 double value = Double.NEGATIVE_INFINITY; 636 for (int i = 0; i < values.length; i++) { 637 double tmp = values[i]; 638 if (tmp > value) { 639 value = tmp; 640 } 641 } 642 return value; 643 } 644 645 @Override 646 public double minValue() { 647 double value = Double.POSITIVE_INFINITY; 648 for (int i = 0; i < values.length; i++) { 649 double tmp = values[i]; 650 if (tmp < value) { 651 value = tmp; 652 } 653 } 654 return value; 655 } 656 657 /** 658 * Generates an array of the indices that are active in this vector 659 * but are not present in {@code other}. 660 * 661 * @param other The vector to compare. 662 * @return An array of indices that are active only in this vector. 663 */ 664 public int[] difference(SparseVector other) { 665 List<Integer> diffIndicesList = new ArrayList<>(); 666 667 if (other.numActiveElements() == 0) { 668 return Arrays.copyOf(indices,indices.length); 669 } else if (indices.length == 0) { 670 return new int[0]; 671 } else { 672 Iterator<VectorTuple> itr = iterator(); 673 Iterator<VectorTuple> otherItr = other.iterator(); 674 VectorTuple tuple = itr.next(); 675 VectorTuple otherTuple = otherItr.next(); 676 while (itr.hasNext() && otherItr.hasNext()) { 677 if (tuple.index == otherTuple.index) { 678 tuple = itr.next(); 679 otherTuple = otherItr.next(); 680 } else if (tuple.index < otherTuple.index) { 681 diffIndicesList.add(tuple.index); 682 tuple = itr.next(); 683 } else { 684 otherTuple = otherItr.next(); 685 } 686 } 687 while (itr.hasNext()) { 688 if (tuple.index != otherTuple.index) { 689 diffIndicesList.add(tuple.index); 690 } 691 tuple = itr.next(); 692 } 693 while (otherItr.hasNext()) { 694 if (tuple.index == otherTuple.index) { 695 break; // break out of loop as we've found the last value. 696 } 697 otherTuple = otherItr.next(); 698 } 699 if (tuple.index != otherTuple.index) { 700 diffIndicesList.add(tuple.index); 701 } 702 } 703 704 return Util.toPrimitiveInt(diffIndicesList); 705 } 706 707 /** 708 * Generates an array of the indices that are active in both this 709 * vector and {@code other} 710 * 711 * @param other The vector to intersect. 712 * @return An array of indices that are active in both vectors. 713 */ 714 public int[] intersection(SparseVector other) { 715 List<Integer> diffIndicesList = new ArrayList<>(); 716 717 Iterator<VectorTuple> itr = iterator(); 718 Iterator<VectorTuple> otherItr = other.iterator(); 719 if (itr.hasNext() && otherItr.hasNext()) { 720 VectorTuple tuple = itr.next(); 721 VectorTuple otherTuple = otherItr.next(); 722 while (itr.hasNext() && otherItr.hasNext()) { 723 if (tuple.index == otherTuple.index) { 724 diffIndicesList.add(tuple.index); 725 tuple = itr.next(); 726 otherTuple = otherItr.next(); 727 } else if (tuple.index < otherTuple.index) { 728 tuple = itr.next(); 729 } else { 730 otherTuple = otherItr.next(); 731 } 732 } 733 while (itr.hasNext()) { 734 if (tuple.index == otherTuple.index) { 735 diffIndicesList.add(tuple.index); 736 } 737 tuple = itr.next(); 738 } 739 while (otherItr.hasNext()) { 740 if (tuple.index == otherTuple.index) { 741 diffIndicesList.add(tuple.index); 742 } 743 otherTuple = otherItr.next(); 744 } 745 if (tuple.index == otherTuple.index) { 746 diffIndicesList.add(tuple.index); 747 } 748 } 749 750 return Util.toPrimitiveInt(diffIndicesList); 751 } 752 753 754 @Override 755 public void normalize(VectorNormalizer normalizer) { 756 throw new IllegalStateException("Can't normalize a sparse array"); 757 } 758 759 @Override 760 public double euclideanDistance(SGDVector other) { 761 return distance(other,(double a) -> a*a, Math::sqrt); 762 } 763 764 @Override 765 public double l1Distance(SGDVector other) { 766 return distance(other,Math::abs,DoubleUnaryOperator.identity()); 767 } 768 769 public double distance(SGDVector other, DoubleUnaryOperator transformFunc, DoubleUnaryOperator normalizeFunc) { 770 if (other.size() != size) { 771 throw new IllegalArgumentException("Can't measure the distance between two vectors of different lengths, this = " + size + ", other = " + other.size()); 772 } 773 double score = 0.0; 774 775 if ((other.numActiveElements() != 0) && (indices.length != 0)){ 776 Iterator<VectorTuple> itr = iterator(); 777 Iterator<VectorTuple> otherItr = other.iterator(); 778 VectorTuple tuple = itr.next(); 779 VectorTuple otherTuple = otherItr.next(); 780 while (itr.hasNext() && otherItr.hasNext()) { 781 if (tuple.index == otherTuple.index) { 782 score += transformFunc.applyAsDouble(tuple.value - otherTuple.value); 783 tuple = itr.next(); 784 otherTuple = otherItr.next(); 785 } else if (tuple.index < otherTuple.index) { 786 score += transformFunc.applyAsDouble(tuple.value); 787 tuple = itr.next(); 788 } else { 789 score += transformFunc.applyAsDouble(otherTuple.value); 790 otherTuple = otherItr.next(); 791 } 792 } 793 while (itr.hasNext()) { 794 if (tuple.index == otherTuple.index) { 795 score += transformFunc.applyAsDouble(tuple.value - otherTuple.value); 796 otherTuple = new VectorTuple(); // Consumed this value, replace with sentinel 797 } else { 798 score += transformFunc.applyAsDouble(tuple.value); 799 } 800 tuple = itr.next(); 801 } 802 while (otherItr.hasNext()) { 803 if (tuple.index == otherTuple.index) { 804 score += transformFunc.applyAsDouble(tuple.value - otherTuple.value); 805 tuple = new VectorTuple(); // Consumed this value, replace with sentinel 806 } else { 807 score += transformFunc.applyAsDouble(otherTuple.value); 808 } 809 otherTuple = otherItr.next(); 810 } 811 if (tuple.index == otherTuple.index) { 812 score += transformFunc.applyAsDouble(tuple.value - otherTuple.value); 813 } else { 814 if (tuple.index != -1) { 815 score += transformFunc.applyAsDouble(tuple.value); 816 } 817 if (otherTuple.index != -1) { 818 score += transformFunc.applyAsDouble(otherTuple.value); 819 } 820 } 821 } else if (indices.length != 0) { 822 for (VectorTuple tuple : this) { 823 score += transformFunc.applyAsDouble(tuple.value); 824 } 825 } else { 826 for (VectorTuple tuple : other) { 827 score += transformFunc.applyAsDouble(tuple.value); 828 } 829 } 830 831 return normalizeFunc.applyAsDouble(score); 832 } 833 834 @Override 835 public String toString() { 836 StringBuilder buffer = new StringBuilder(); 837 838 buffer.append("SparseVector(size="); 839 buffer.append(size); 840 buffer.append(",tuples="); 841 842 for (int i = 0; i < indices.length; i++) { 843 buffer.append("["); 844 buffer.append(indices[i]); 845 buffer.append(","); 846 buffer.append(values[i]); 847 buffer.append("],"); 848 } 849 buffer.setCharAt(buffer.length() - 1, ')'); 850 851 return buffer.toString(); 852 } 853 854 public double[] toDenseArray() { 855 double[] output = new double[size]; 856 for (int i = 0; i < values.length; i++) { 857 output[indices[i]] = values[i]; 858 } 859 return output; 860 } 861 862 @Override 863 public double variance(double mean) { 864 double variance = 0.0; 865 for (int i = 0; i < values.length; i++) { 866 variance += (values[i] - mean) * (values[i] - mean); 867 } 868 variance += (size - values.length) * mean * mean; 869 return variance; 870 } 871 872 @Override 873 public VectorIterator iterator() { 874 return new SparseVectorIterator(this); 875 } 876 877 private static class SparseVectorIterator implements VectorIterator { 878 private final SparseVector vector; 879 private final VectorTuple tuple; 880 private int index; 881 882 public SparseVectorIterator(SparseVector vector) { 883 this.vector = vector; 884 this.tuple = new VectorTuple(); 885 this.index = 0; 886 } 887 888 @Override 889 public boolean hasNext() { 890 return index < vector.indices.length; 891 } 892 893 @Override 894 public VectorTuple next() { 895 if (!hasNext()) { 896 throw new NoSuchElementException("Off the end of the iterator."); 897 } 898 tuple.index = vector.indices[index]; 899 tuple.value = vector.values[index]; 900 index++; 901 return tuple; 902 } 903 904 @Override 905 public VectorTuple getReference() { 906 return tuple; 907 } 908 } 909 910 /** 911 * Transposes an array of sparse vectors from row-major to column-major or 912 * vice versa. 913 * @param input Input sparse vectors. 914 * @return A column-major array of SparseVectors. 915 */ 916 public static SparseVector[] transpose(SparseVector[] input) { 917 int firstDimension = input.length; 918 int secondDimension = input[0].size; 919 920 ArrayList<ArrayList<Integer>> indices = new ArrayList<>(); 921 ArrayList<ArrayList<Double>> values = new ArrayList<>(); 922 923 for (int i = 0; i < secondDimension; i++) { 924 indices.add(new ArrayList<>()); 925 values.add(new ArrayList<>()); 926 } 927 928 for (int i = 0; i < firstDimension; i++) { 929 for (VectorTuple f : input[i]) { 930 indices.get(f.index).add(i); 931 values.get(f.index).add(f.value); 932 } 933 } 934 935 SparseVector[] output = new SparseVector[secondDimension]; 936 937 for (int i = 0; i < secondDimension; i++) { 938 output[i] = new SparseVector(firstDimension,Util.toPrimitiveInt(indices.get(i)),Util.toPrimitiveDouble(values.get(i))); 939 } 940 941 return output; 942 } 943 944 /** 945 * Converts a dataset of row-major examples into an array of column-major 946 * sparse vectors. 947 * @param dataset Input dataset. 948 * @param <T> The type of the dataset. 949 * @return A column-major array of SparseVectors. 950 */ 951 public static <T extends Output<T>> SparseVector[] transpose(Dataset<T> dataset) { 952 ImmutableFeatureMap fMap = dataset.getFeatureIDMap(); 953 return transpose(dataset,fMap); 954 } 955 956 /** 957 * Converts a dataset of row-major examples into an array of column-major 958 * sparse vectors. 959 * @param dataset Input dataset. 960 * @param fMap The feature map to use. If it's different to the feature map used by the dataset then behaviour is undefined. 961 * @param <T> The type of the dataset. 962 * @return A column-major array of SparseVectors. 963 */ 964 public static <T extends Output<T>> SparseVector[] transpose(Dataset<T> dataset, ImmutableFeatureMap fMap) { 965 if (dataset.getFeatureMap().size() != fMap.size()) { 966 throw new IllegalArgumentException( 967 "The dataset's internal feature map and the supplied feature map have different sizes. dataset = " 968 + dataset.getFeatureMap().size() + ", fMap = " + fMap.size()); 969 } 970 int numExamples = dataset.size(); 971 int numFeatures = fMap.size(); 972 973 ArrayList<ArrayList<Integer>> indices = new ArrayList<>(); 974 ArrayList<ArrayList<Double>> values = new ArrayList<>(); 975 976 for (int i = 0; i < numFeatures; i++) { 977 indices.add(new ArrayList<>()); 978 values.add(new ArrayList<>()); 979 } 980 981 int j = 0; 982 for (Example<T> e : dataset) { 983 for (Feature f : e) { 984 int index = fMap.getID(f.getName()); 985 indices.get(index).add(j); 986 values.get(index).add(f.getValue()); 987 } 988 j++; 989 } 990 991 SparseVector[] output = new SparseVector[numFeatures]; 992 993 for (int i = 0; i < fMap.size(); i++) { 994 output[i] = new SparseVector(numExamples,Util.toPrimitiveInt(indices.get(i)),Util.toPrimitiveDouble(values.get(i))); 995 } 996 997 return output; 998 } 999} 1000