001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.math.la;
018
019import org.tribuo.Dataset;
020import org.tribuo.Example;
021import org.tribuo.Feature;
022import org.tribuo.ImmutableFeatureMap;
023import org.tribuo.Output;
024import org.tribuo.math.util.VectorNormalizer;
025import org.tribuo.util.IntDoublePair;
026import org.tribuo.util.Util;
027
028import java.util.ArrayList;
029import java.util.Arrays;
030import java.util.HashMap;
031import java.util.Iterator;
032import java.util.List;
033import java.util.Map;
034import java.util.NoSuchElementException;
035import java.util.Objects;
036import java.util.function.DoubleUnaryOperator;
037import java.util.stream.Collectors;
038
039/**
040 * A sparse vector. Stored as a sorted array of indices and an array of values.
041 * <p>
042 * Uses binary search to look up a specific index, so it's usually faster to
043 * use the iterator to iterate the values.
044 * <p>
045 * This vector has immutable indices. It cannot get new indices after construction,
046 * and will throw {@link IllegalArgumentException} if such an operation is tried.
047 */
048public class SparseVector implements SGDVector {
049    private static final long serialVersionUID = 1L;
050
051    private final int[] shape;
052    protected final int[] indices;
053    protected final double[] values;
054    private final int size;
055
056    /**
057     * Used internally for performance.
058     * Does not defensively copy the input, nor check it's sorted.
059     * <p>
060     * @param size The dimension of this vector.
061     * @param indices The indices.
062     * @param values The values.
063     */
064    SparseVector(int size, int[] indices, double[] values) {
065        this.size = size;
066        this.shape = new int[]{size};
067        this.indices = indices;
068        this.values = values;
069    }
070
071    /**
072     * Returns a deep copy of the supplied sparse vector.
073     * <p>
074     * Copies the value by iterating it's VectorTuple.
075     * @param other The SparseVector to copy.
076     */
077    private SparseVector(SparseVector other) {
078        this.size = other.size;
079        int numActiveElements = other.numActiveElements();
080        this.indices = new int[numActiveElements];
081        this.values = new double[numActiveElements];
082
083        int i = 0;
084        for (VectorTuple tuple : other) {
085            indices[i] = tuple.index;
086            values[i] = tuple.value;
087            i++;
088        }
089        this.shape = new int[]{size};
090    }
091
092    public SparseVector(int size, int[] indices, double value) {
093        this.indices = Arrays.copyOf(indices,indices.length);
094        this.values = new double[indices.length];
095        Arrays.fill(this.values,value);
096        this.size = size;
097        this.shape = new int[]{size};
098    }
099
100    /**
101     * Builds a {@link SparseVector} from an {@link Example}.
102     * <p>
103     * Used in training and inference.
104     * <p>
105     * Throws {@link IllegalArgumentException} if the Example contains NaN-valued features.
106     * @param example     The example to convert.
107     * @param featureInfo The feature information, used to calculate the dimension of this SparseVector.
108     * @param addBias     Add a bias feature.
109     * @param <T>         The type parameter of the {@code example}.
110     * @return A SparseVector representing the example's features.
111     */
112    public static <T extends Output<T>> SparseVector createSparseVector(Example<T> example, ImmutableFeatureMap featureInfo, boolean addBias) {
113        int size;
114        int numFeatures = example.size();
115        if (addBias) {
116            size = featureInfo.size() + 1;
117            numFeatures++;
118        } else {
119            size = featureInfo.size();
120        }
121        int[] tmpIndices = new int[numFeatures];
122        double[] tmpValues = new double[numFeatures];
123        int i = 0;
124        int prevIdx = -1;
125        for (Feature f : example) {
126            int index = featureInfo.getID(f.getName());
127            if (index > prevIdx){
128                prevIdx = index;
129                tmpIndices[i] = index;
130                tmpValues[i] = f.getValue();
131                if (Double.isNaN(tmpValues[i])) {
132                    throw new IllegalArgumentException("Example contained a NaN feature, " + f.toString());
133                }
134                i++;
135            } else if (index > -1) {
136                //
137                // Collision, deal with it.
138                int collisionIdx = Arrays.binarySearch(tmpIndices,0,i,index);
139                if (collisionIdx < 0) {
140                    //
141                    // Collision but not present in tmpIndices
142                    // move data and bump i
143                    collisionIdx = - (collisionIdx + 1);
144                    System.arraycopy(tmpIndices,collisionIdx,tmpIndices,collisionIdx+1,i-collisionIdx);
145                    System.arraycopy(tmpValues,collisionIdx,tmpValues,collisionIdx+1,i-collisionIdx);
146                    tmpIndices[collisionIdx] = index;
147                    tmpValues[collisionIdx] = f.getValue();
148                    if (Double.isNaN(tmpValues[collisionIdx])) {
149                        throw new IllegalArgumentException("Example contained a NaN feature, " + f.toString());
150                    }
151                    i++;
152                } else {
153                    //
154                    // Collision present in tmpIndices
155                    // add the values.
156                    tmpValues[collisionIdx] += f.getValue();
157                    if (Double.isNaN(tmpValues[collisionIdx])) {
158                        throw new IllegalArgumentException("Example contained a NaN feature, " + f.toString());
159                    }
160                }
161            }
162        }
163        if (addBias) {
164            tmpIndices[i] = size - 1;
165            tmpValues[i] = 1.0;
166            i++;
167        }
168        return new SparseVector(size,Arrays.copyOf(tmpIndices,i),Arrays.copyOf(tmpValues,i));
169    }
170
171    /**
172     * Defensively copies the input, and checks that the indices are sorted. If not,
173     * it sorts them.
174     *  <p>
175     * Throws {@link IllegalArgumentException} if the arrays are not the same length, or if size is less than
176     * the max index.
177     * @param dimension The dimension of this vector.
178     * @param indices The indices of the non-zero elements.
179     * @param values The values of the non-zero elements.
180     * @return A SparseVector encapsulating the indices and values.
181     */
182    public static SparseVector createSparseVector(int dimension, int[] indices, double[] values) {
183        if (indices.length != values.length) {
184            throw new IllegalArgumentException("Indices and values must be the same length, found indices.length = " + indices.length + " and values.length = " + values.length);
185        } else if (indices.length == 0) {
186            return new SparseVector(dimension,indices,values);
187        } else {
188            IntDoublePair[] pairArray = new IntDoublePair[indices.length];
189            for (int i = 0; i < pairArray.length; i++) {
190                pairArray[i] = new IntDoublePair(indices[i], values[i]);
191            }
192            Arrays.sort(pairArray, IntDoublePair.pairIndexComparator());
193            int[] newIndices = new int[indices.length];
194            double[] newValues = new double[values.length];
195            for (int i = 0; i < pairArray.length; i++) {
196                newIndices[i] = pairArray[i].index;
197                newValues[i] = pairArray[i].value;
198            }
199            if (dimension < newIndices[newIndices.length - 1]) {
200                throw new IllegalArgumentException("Number of dimensions is less than the maximum index, dimensions = " + dimension + ", max index = " + newIndices[newIndices.length - 1]);
201            }
202            return new SparseVector(dimension, newIndices, newValues);
203        }
204    }
205
206    /**
207     * Builds a SparseVector from a map.
208     * <p>
209     * Throws {@link IllegalArgumentException} if dimension is less than the max index.
210     * @param dimension The dimension of this vector.
211     * @param indexMap The map from indices to values.
212     * @return A SparseVector.
213     */
214    public static SparseVector createSparseVector(int dimension, Map<Integer, Double> indexMap) {
215        if (indexMap.isEmpty()) {
216            return new SparseVector(dimension,new int[0],new double[0]);
217        } else {
218            List<Map.Entry<Integer, Double>> sortedEntries = indexMap.entrySet()
219                    .stream().sorted(Map.Entry.comparingByKey())
220                    .collect(Collectors.toList());
221
222            int[] indices = new int[sortedEntries.size()];
223            double[] values = new double[sortedEntries.size()];
224            for (int i = 0; i < sortedEntries.size(); i++) {
225                indices[i] = sortedEntries.get(i).getKey();
226                values[i] = sortedEntries.get(i).getValue();
227            }
228            if (dimension < indices[indices.length - 1]) {
229                throw new IllegalArgumentException("Number of dimensions is less than the maximum index, dimensions = " + dimension + ", max index = " + indices[indices.length - 1]);
230            }
231            return new SparseVector(dimension, indices, values);
232        }
233    }
234
235    @Override
236    public SparseVector copy() {
237        return new SparseVector(this);
238    }
239
240    @Override
241    public int[] getShape() {
242        return shape;
243    }
244
245    @Override
246    public Tensor reshape(int[] newShape) {
247        throw new UnsupportedOperationException("Reshape not supported on sparse Tensors.");
248    }
249
250    @Override
251    public int size() {
252        return size;
253    }
254
255    @Override
256    public int numActiveElements() {
257        return values.length;
258    }
259
260    /**
261     * Equals is defined mathematically, that is two SGDVectors are equal iff they have the same indices
262     * and the same values at those indices.
263     * @param other Object to compare against.
264     * @return True if this vector and the other vector contain the same values in the same order.
265     */
266    @Override
267    public boolean equals(Object other) {
268        if (other instanceof SGDVector) {
269            Iterator<VectorTuple> ourItr = iterator();
270            Iterator<VectorTuple> otherItr = ((SGDVector)other).iterator();
271            VectorTuple ourTuple;
272            VectorTuple otherTuple;
273
274            while (ourItr.hasNext() && otherItr.hasNext()) {
275                ourTuple = ourItr.next();
276                otherTuple = otherItr.next();
277                if (!ourTuple.equals(otherTuple)) {
278                    return false;
279                }
280            }
281
282            // If one of the iterators still has elements then they are not the same.
283            return !(ourItr.hasNext() || otherItr.hasNext());
284        } else {
285            return false;
286        }
287    }
288
289    @Override
290    public int hashCode() {
291        int result = Objects.hash(size);
292        result = 31 * result + Arrays.hashCode(indices);
293        result = 31 * result + Arrays.hashCode(values);
294        return result;
295    }
296
297    /**
298     * Adds {@code other} to this vector, producing a new {@link SGDVector}.
299     * If {@code other} is a {@link SparseVector} then the returned vector is also
300     * a {@link SparseVector} otherwise it's a {@link DenseVector}.
301     * @param other The vector to add.
302     * @return A new {@link SGDVector} where each element value = this.get(i) + other.get(i).
303     */
304    @Override
305    public SGDVector add(SGDVector other) {
306        if (other.size() != size) {
307            throw new IllegalArgumentException("Can't add two vectors of different dimension, this = " + size + ", other = " + other.size());
308        }
309        if (other instanceof DenseVector) {
310            return other.add(this);
311        } else if (other instanceof SparseVector) {
312            Map<Integer, Double> values = new HashMap<>();
313            for (VectorTuple tuple : this) {
314                values.put(tuple.index, tuple.value);
315            }
316            for (VectorTuple tuple : other) {
317                values.merge(tuple.index, tuple.value, Double::sum);
318            }
319            return createSparseVector(size, values);
320        } else {
321            throw new IllegalArgumentException("Vector other is not dense or sparse.");
322        }
323    }
324
325    /**
326     * Subtracts {@code other} from this vector, producing a new {@link SGDVector}.
327     * If {@code other} is a {@link SparseVector} then the returned vector is also
328     * a {@link SparseVector} otherwise it's a {@link DenseVector}.
329     * @param other The vector to subtract.
330     * @return A new {@link SGDVector} where each element value = this.get(i) - other.get(i).
331     */
332    @Override
333    public SGDVector subtract(SGDVector other) {
334        if (other.size() != size) {
335            throw new IllegalArgumentException("Can't subtract two vectors of different dimension, this = " + size + ", other = " + other.size());
336        }
337        if (other instanceof DenseVector) {
338            DenseVector output = ((DenseVector)other).copy();
339            for (VectorTuple tuple : this) {
340                output.set(tuple.index,tuple.value-output.get(tuple.index));
341            }
342            return output;
343        } else if (other instanceof SparseVector) {
344            Map<Integer, Double> values = new HashMap<>();
345            for (VectorTuple tuple : this) {
346                values.put(tuple.index, tuple.value);
347            }
348            for (VectorTuple tuple : other) {
349                values.merge(tuple.index, -tuple.value, Double::sum);
350            }
351            return createSparseVector(size, values);
352        } else {
353            throw new IllegalArgumentException("Vector other is not dense or sparse.");
354        }
355    }
356
357    @Override
358    public void intersectAndAddInPlace(Tensor other, DoubleUnaryOperator f) {
359        if (other instanceof SparseVector) {
360            SparseVector otherVec = (SparseVector) other;
361            if (otherVec.size() != size) {
362                throw new IllegalArgumentException("Can't intersect two vectors of different dimension, this = " + size + ", other = " + otherVec.size());
363            } else if (otherVec.numActiveElements() > 0) {
364                int i = 0;
365                Iterator<VectorTuple> otherItr = otherVec.iterator();
366                VectorTuple tuple = otherItr.next();
367                while (i < (indices.length-1) && otherItr.hasNext()) {
368                    if (indices[i] == tuple.index) {
369                        values[i] += f.applyAsDouble(tuple.value);
370                        i++;
371                        tuple = otherItr.next();
372                    } else if (indices[i] < tuple.index) {
373                        i++;
374                    } else {
375                        tuple = otherItr.next();
376                    }
377                }
378                for (; i < indices.length-1; i++) {
379                    if (indices[i] == tuple.index) {
380                        values[i] += f.applyAsDouble(tuple.value);
381                    }
382                }
383                while (otherItr.hasNext()) {
384                    if (indices[i] == tuple.index) {
385                        values[i] += f.applyAsDouble(tuple.value);
386                    }
387                    tuple = otherItr.next();
388                }
389                if (indices[i] == tuple.index) {
390                    values[i] += f.applyAsDouble(tuple.value);
391                }
392            }
393        } else if (other instanceof DenseVector) {
394            DenseVector otherVec = (DenseVector) other;
395            if (otherVec.size() != size) {
396                throw new IllegalArgumentException("Can't intersect two vectors of different dimension, this = " + size + ", other = " + otherVec.size());
397            }
398            for (int i = 0; i < indices.length; i++) {
399                values[i] += f.applyAsDouble(otherVec.get(indices[i]));
400            }
401        } else {
402            throw new IllegalStateException("Unknown Tensor subclass " + other.getClass().getCanonicalName() + " for input");
403        }
404    }
405
406    @Override
407    public void hadamardProductInPlace(Tensor other, DoubleUnaryOperator f) {
408        if (other instanceof SparseVector) {
409            SparseVector otherVec = (SparseVector) other;
410            if (otherVec.size() != size) {
411                throw new IllegalArgumentException("Can't hadamard product two vectors of different dimension, this = " + size + ", other = " + otherVec.size());
412            } else if (otherVec.numActiveElements() > 0) {
413                int i = 0;
414                Iterator<VectorTuple> otherItr = otherVec.iterator();
415                VectorTuple tuple = otherItr.next();
416                while (i < (indices.length-1) && otherItr.hasNext()) {
417                    if (indices[i] == tuple.index) {
418                        values[i] *= f.applyAsDouble(tuple.value);
419                        i++;
420                        tuple = otherItr.next();
421                    } else if (indices[i] < tuple.index) {
422                        i++;
423                    } else {
424                        tuple = otherItr.next();
425                    }
426                }
427                for (; i < indices.length-1; i++) {
428                    if (indices[i] == tuple.index) {
429                        values[i] *= f.applyAsDouble(tuple.value);
430                    }
431                }
432                while (otherItr.hasNext()) {
433                    if (indices[i] == tuple.index) {
434                        values[i] *= f.applyAsDouble(tuple.value);
435                    }
436                    tuple = otherItr.next();
437                }
438                if (indices[i] == tuple.index) {
439                    values[i] *= f.applyAsDouble(tuple.value);
440                }
441            }
442        } else if (other instanceof DenseVector) {
443            DenseVector otherVec = (DenseVector) other;
444            if (otherVec.size() != size) {
445                throw new IllegalArgumentException("Can't hadamard product two vectors of different dimension, this = " + size + ", other = " + otherVec.size());
446            }
447            for (int i = 0; i < indices.length; i++) {
448                values[i] *= f.applyAsDouble(otherVec.get(indices[i]));
449            }
450        } else {
451            throw new IllegalArgumentException("Invalid Tensor subclass " + other.getClass().getCanonicalName() + " for input");
452        }
453    }
454
455    @Override
456    public void foreachInPlace(DoubleUnaryOperator f) {
457        for (int i = 0; i < values.length; i++) {
458            values[i] = f.applyAsDouble(values[i]);
459        }
460    }
461
462    @Override
463    public SparseVector scale(double coefficient) {
464        double[] newValues = Arrays.copyOf(values, values.length);
465        for (int i = 0; i < values.length; i++) {
466            newValues[i] *= coefficient;
467        }
468        return new SparseVector(size, Arrays.copyOf(indices, indices.length), newValues);
469    }
470
471    @Override
472    public void add(int index, double value) {
473        int foundIndex = Arrays.binarySearch(indices, index);
474        if (foundIndex < 0) {
475            throw new IllegalArgumentException("SparseVector cannot have new elements added.");
476        } else {
477            values[foundIndex] += value;
478        }
479    }
480
481    @Override
482    public double dot(SGDVector other) {
483        if (other.size() != size) {
484            throw new IllegalArgumentException("Can't dot two vectors of different lengths, this = " + size + ", other = " + other.size());
485        } else if (other instanceof SparseVector) {
486            double score = 0.0;
487
488            // If there are elements, calculate the dot product.
489            if ((other.numActiveElements() != 0) && (indices.length != 0)) {
490                Iterator<VectorTuple> itr = iterator();
491                Iterator<VectorTuple> otherItr = other.iterator();
492                VectorTuple tuple = itr.next();
493                VectorTuple otherTuple = otherItr.next();
494                while (itr.hasNext() && otherItr.hasNext()) {
495                    if (tuple.index == otherTuple.index) {
496                        score += tuple.value * otherTuple.value;
497                        tuple = itr.next();
498                        otherTuple = otherItr.next();
499                    } else if (tuple.index < otherTuple.index) {
500                        tuple = itr.next();
501                    } else {
502                        otherTuple = otherItr.next();
503                    }
504                }
505                while (itr.hasNext()) {
506                    if (tuple.index == otherTuple.index) {
507                        score += tuple.value * otherTuple.value;
508                    }
509                    tuple = itr.next();
510                }
511                while (otherItr.hasNext()) {
512                    if (tuple.index == otherTuple.index) {
513                        score += tuple.value * otherTuple.value;
514                    }
515                    otherTuple = otherItr.next();
516                }
517                if (tuple.index == otherTuple.index) {
518                    score += tuple.value * otherTuple.value;
519                }
520            }
521
522            return score;
523        } else if (other instanceof DenseVector) {
524            double score = 0.0;
525
526            for (int i = 0; i < indices.length; i++) {
527                score += other.get(indices[i]) * values[i];
528            }
529
530            return score;
531        } else {
532            throw new IllegalArgumentException("Unknown vector subclass " + other.getClass().getCanonicalName() + " for input");
533        }
534    }
535
536    /**
537     * This generates the outer product when dotted with another {@link SparseVector}.
538     * <p>
539     * It throws an {@link IllegalArgumentException} if used with a {@link DenseVector}.
540     *
541     * @param other A vector.
542     * @return A {@link DenseSparseMatrix} representing the outer product.
543     */
544    @Override
545    public Matrix outer(SGDVector other) {
546        if (other instanceof SparseVector) {
547            //This horrible mess is why there should be a sparse-sparse matrix type.
548            SparseVector otherVec = (SparseVector) other;
549            SparseVector[] output = new SparseVector[size];
550            int i = 0;
551            for (VectorTuple tuple : this) {
552                while (i < tuple.index) {
553                    output[i] = new SparseVector(other.size(), new int[0], new double[0]);
554                    i++;
555                }
556                output[tuple.index] = otherVec.scale(tuple.value);
557                i++;
558            }
559            while (i < output.length) {
560                output[i] = new SparseVector(other.size(), new int[0], new double[0]);
561                i++;
562            }
563            //TODO this is suboptimal if there are lots of missing rows.
564            return new DenseSparseMatrix(output);
565        } else if (other instanceof DenseVector) {
566            throw new IllegalArgumentException("sparse.outer(dense) is currently not implemented.");
567        } else {
568            throw new IllegalArgumentException("Unknown vector subclass " + other.getClass().getCanonicalName() + " for input");
569        }
570    }
571
572    @Override
573    public double sum() {
574        double sum = 0.0;
575        for (int i = 0; i < values.length; i++) {
576            sum += values[i];
577        }
578        return sum;
579    }
580
581    @Override
582    public double twoNorm() {
583        double sum = 0.0;
584        for (int i = 0; i < values.length; i++) {
585            sum += values[i] * values[i];
586        }
587        return Math.sqrt(sum);
588    }
589
590    @Override
591    public double oneNorm() {
592        double sum = 0.0;
593        for (int i = 0; i < values.length; i++) {
594            sum += Math.abs(values[i]);
595        }
596        return sum;
597    }
598
599    @Override
600    public double get(int index) {
601        int foundIndex = Arrays.binarySearch(indices, index);
602        if (foundIndex < 0) {
603            return 0;
604        } else {
605            return values[foundIndex];
606        }
607    }
608
609    @Override
610    public void set(int index, double value) {
611        int foundIndex = Arrays.binarySearch(indices, index);
612        if (foundIndex < 0) {
613            throw new IllegalArgumentException("SparseVector cannot have new elements added.");
614        } else {
615            values[foundIndex] = value;
616        }
617    }
618
619    @Override
620    public int indexOfMax() {
621        int index = 0;
622        double value = Double.NEGATIVE_INFINITY;
623        for (int i = 0; i < values.length; i++) {
624            double tmp = values[i];
625            if (tmp > value) {
626                index = i;
627                value = tmp;
628            }
629        }
630        return indices[index];
631    }
632
633    @Override
634    public double maxValue() {
635        double value = Double.NEGATIVE_INFINITY;
636        for (int i = 0; i < values.length; i++) {
637            double tmp = values[i];
638            if (tmp > value) {
639                value = tmp;
640            }
641        }
642        return value;
643    }
644
645    @Override
646    public double minValue() {
647        double value = Double.POSITIVE_INFINITY;
648        for (int i = 0; i < values.length; i++) {
649            double tmp = values[i];
650            if (tmp < value) {
651                value = tmp;
652            }
653        }
654        return value;
655    }
656
657    /**
658     * Generates an array of the indices that are active in this vector
659     * but are not present in {@code other}.
660     *
661     * @param other The vector to compare.
662     * @return An array of indices that are active only in this vector.
663     */
664    public int[] difference(SparseVector other) {
665        List<Integer> diffIndicesList = new ArrayList<>();
666
667        if (other.numActiveElements() == 0) {
668            return Arrays.copyOf(indices,indices.length);
669        } else if (indices.length == 0) {
670            return new int[0];
671        } else {
672            Iterator<VectorTuple> itr = iterator();
673            Iterator<VectorTuple> otherItr = other.iterator();
674            VectorTuple tuple = itr.next();
675            VectorTuple otherTuple = otherItr.next();
676            while (itr.hasNext() && otherItr.hasNext()) {
677                if (tuple.index == otherTuple.index) {
678                    tuple = itr.next();
679                    otherTuple = otherItr.next();
680                } else if (tuple.index < otherTuple.index) {
681                    diffIndicesList.add(tuple.index);
682                    tuple = itr.next();
683                } else {
684                    otherTuple = otherItr.next();
685                }
686            }
687            while (itr.hasNext()) {
688                if (tuple.index != otherTuple.index) {
689                    diffIndicesList.add(tuple.index);
690                }
691                tuple = itr.next();
692            }
693            while (otherItr.hasNext()) {
694                if (tuple.index == otherTuple.index) {
695                    break; // break out of loop as we've found the last value.
696                }
697                otherTuple = otherItr.next();
698            }
699            if (tuple.index != otherTuple.index) {
700                diffIndicesList.add(tuple.index);
701            }
702        }
703
704        return Util.toPrimitiveInt(diffIndicesList);
705    }
706
707    /**
708     * Generates an array of the indices that are active in both this
709     * vector and {@code other}
710     *
711     * @param other The vector to intersect.
712     * @return An array of indices that are active in both vectors.
713     */
714    public int[] intersection(SparseVector other) {
715        List<Integer> diffIndicesList = new ArrayList<>();
716
717        Iterator<VectorTuple> itr = iterator();
718        Iterator<VectorTuple> otherItr = other.iterator();
719        if (itr.hasNext() && otherItr.hasNext()) {
720            VectorTuple tuple = itr.next();
721            VectorTuple otherTuple = otherItr.next();
722            while (itr.hasNext() && otherItr.hasNext()) {
723                if (tuple.index == otherTuple.index) {
724                    diffIndicesList.add(tuple.index);
725                    tuple = itr.next();
726                    otherTuple = otherItr.next();
727                } else if (tuple.index < otherTuple.index) {
728                    tuple = itr.next();
729                } else {
730                    otherTuple = otherItr.next();
731                }
732            }
733            while (itr.hasNext()) {
734                if (tuple.index == otherTuple.index) {
735                    diffIndicesList.add(tuple.index);
736                }
737                tuple = itr.next();
738            }
739            while (otherItr.hasNext()) {
740                if (tuple.index == otherTuple.index) {
741                    diffIndicesList.add(tuple.index);
742                }
743                otherTuple = otherItr.next();
744            }
745            if (tuple.index == otherTuple.index) {
746                diffIndicesList.add(tuple.index);
747            }
748        }
749
750        return Util.toPrimitiveInt(diffIndicesList);
751    }
752
753
754    @Override
755    public void normalize(VectorNormalizer normalizer) {
756        throw new IllegalStateException("Can't normalize a sparse array");
757    }
758
759    @Override
760    public double euclideanDistance(SGDVector other) {
761        return distance(other,(double a) -> a*a, Math::sqrt);
762    }
763
764    @Override
765    public double l1Distance(SGDVector other) {
766        return distance(other,Math::abs,DoubleUnaryOperator.identity());
767    }
768
769    public double distance(SGDVector other, DoubleUnaryOperator transformFunc, DoubleUnaryOperator normalizeFunc) {
770        if (other.size() != size) {
771            throw new IllegalArgumentException("Can't measure the distance between two vectors of different lengths, this = " + size + ", other = " + other.size());
772        }
773        double score = 0.0;
774
775        if ((other.numActiveElements() != 0) && (indices.length != 0)){
776            Iterator<VectorTuple> itr = iterator();
777            Iterator<VectorTuple> otherItr = other.iterator();
778            VectorTuple tuple = itr.next();
779            VectorTuple otherTuple = otherItr.next();
780            while (itr.hasNext() && otherItr.hasNext()) {
781                if (tuple.index == otherTuple.index) {
782                    score += transformFunc.applyAsDouble(tuple.value - otherTuple.value);
783                    tuple = itr.next();
784                    otherTuple = otherItr.next();
785                } else if (tuple.index < otherTuple.index) {
786                    score += transformFunc.applyAsDouble(tuple.value);
787                    tuple = itr.next();
788                } else {
789                    score += transformFunc.applyAsDouble(otherTuple.value);
790                    otherTuple = otherItr.next();
791                }
792            }
793            while (itr.hasNext()) {
794                if (tuple.index == otherTuple.index) {
795                    score += transformFunc.applyAsDouble(tuple.value - otherTuple.value);
796                    otherTuple = new VectorTuple(); // Consumed this value, replace with sentinel
797                } else {
798                    score += transformFunc.applyAsDouble(tuple.value);
799                }
800                tuple = itr.next();
801            }
802            while (otherItr.hasNext()) {
803                if (tuple.index == otherTuple.index) {
804                    score += transformFunc.applyAsDouble(tuple.value - otherTuple.value);
805                    tuple = new VectorTuple(); // Consumed this value, replace with sentinel
806                } else {
807                    score += transformFunc.applyAsDouble(otherTuple.value);
808                }
809                otherTuple = otherItr.next();
810            }
811            if (tuple.index == otherTuple.index) {
812                score += transformFunc.applyAsDouble(tuple.value - otherTuple.value);
813            } else {
814                if (tuple.index != -1) {
815                    score += transformFunc.applyAsDouble(tuple.value);
816                }
817                if (otherTuple.index != -1) {
818                    score += transformFunc.applyAsDouble(otherTuple.value);
819                }
820            }
821        } else if (indices.length != 0) {
822            for (VectorTuple tuple : this) {
823                score += transformFunc.applyAsDouble(tuple.value);
824            }
825        } else {
826            for (VectorTuple tuple : other) {
827                score += transformFunc.applyAsDouble(tuple.value);
828            }
829        }
830
831        return normalizeFunc.applyAsDouble(score);
832    }
833
834    @Override
835    public String toString() {
836        StringBuilder buffer = new StringBuilder();
837
838        buffer.append("SparseVector(size=");
839        buffer.append(size);
840        buffer.append(",tuples=");
841
842        for (int i = 0; i < indices.length; i++) {
843            buffer.append("[");
844            buffer.append(indices[i]);
845            buffer.append(",");
846            buffer.append(values[i]);
847            buffer.append("],");
848        }
849        buffer.setCharAt(buffer.length() - 1, ')');
850
851        return buffer.toString();
852    }
853
854    public double[] toDenseArray() {
855        double[] output = new double[size];
856        for (int i = 0; i < values.length; i++) {
857            output[indices[i]] = values[i];
858        }
859        return output;
860    }
861
862    @Override
863    public double variance(double mean) {
864        double variance = 0.0;
865        for (int i = 0; i < values.length; i++) {
866            variance += (values[i] - mean) * (values[i] - mean);
867        }
868        variance += (size - values.length) * mean * mean;
869        return variance;
870    }
871
872    @Override
873    public VectorIterator iterator() {
874        return new SparseVectorIterator(this);
875    }
876
877    private static class SparseVectorIterator implements VectorIterator {
878        private final SparseVector vector;
879        private final VectorTuple tuple;
880        private int index;
881
882        public SparseVectorIterator(SparseVector vector) {
883            this.vector = vector;
884            this.tuple = new VectorTuple();
885            this.index = 0;
886        }
887
888        @Override
889        public boolean hasNext() {
890            return index < vector.indices.length;
891        }
892
893        @Override
894        public VectorTuple next() {
895            if (!hasNext()) {
896                throw new NoSuchElementException("Off the end of the iterator.");
897            }
898            tuple.index = vector.indices[index];
899            tuple.value = vector.values[index];
900            index++;
901            return tuple;
902        }
903
904        @Override
905        public VectorTuple getReference() {
906            return tuple;
907        }
908    }
909
910    /**
911     * Transposes an array of sparse vectors from row-major to column-major or
912     * vice versa.
913     * @param input Input sparse vectors.
914     * @return A column-major array of SparseVectors.
915     */
916    public static SparseVector[] transpose(SparseVector[] input) {
917        int firstDimension = input.length;
918        int secondDimension = input[0].size;
919
920        ArrayList<ArrayList<Integer>> indices = new ArrayList<>();
921        ArrayList<ArrayList<Double>> values = new ArrayList<>();
922
923        for (int i = 0; i < secondDimension; i++) {
924            indices.add(new ArrayList<>());
925            values.add(new ArrayList<>());
926        }
927
928        for (int i = 0; i < firstDimension; i++) {
929            for (VectorTuple f : input[i]) {
930                indices.get(f.index).add(i);
931                values.get(f.index).add(f.value);
932            }
933        }
934
935        SparseVector[] output = new SparseVector[secondDimension];
936
937        for (int i = 0; i < secondDimension; i++) {
938            output[i] = new SparseVector(firstDimension,Util.toPrimitiveInt(indices.get(i)),Util.toPrimitiveDouble(values.get(i)));
939        }
940
941        return output;
942    }
943
944    /**
945     * Converts a dataset of row-major examples into an array of column-major
946     * sparse vectors.
947     * @param dataset Input dataset.
948     * @param <T> The type of the dataset.
949     * @return A column-major array of SparseVectors.
950     */
951    public static <T extends Output<T>> SparseVector[] transpose(Dataset<T> dataset) {
952        ImmutableFeatureMap fMap = dataset.getFeatureIDMap();
953        return transpose(dataset,fMap);
954    }
955
956    /**
957     * Converts a dataset of row-major examples into an array of column-major
958     * sparse vectors.
959     * @param dataset Input dataset.
960     * @param fMap The feature map to use. If it's different to the feature map used by the dataset then behaviour is undefined.
961     * @param <T> The type of the dataset.
962     * @return A column-major array of SparseVectors.
963     */
964    public static <T extends Output<T>> SparseVector[] transpose(Dataset<T> dataset, ImmutableFeatureMap fMap) {
965        if (dataset.getFeatureMap().size() != fMap.size()) {
966            throw new IllegalArgumentException(
967                    "The dataset's internal feature map and the supplied feature map have different sizes. dataset = "
968                    + dataset.getFeatureMap().size() + ", fMap = " + fMap.size());
969        }
970        int numExamples = dataset.size();
971        int numFeatures = fMap.size();
972
973        ArrayList<ArrayList<Integer>> indices = new ArrayList<>();
974        ArrayList<ArrayList<Double>> values = new ArrayList<>();
975
976        for (int i = 0; i < numFeatures; i++) {
977            indices.add(new ArrayList<>());
978            values.add(new ArrayList<>());
979        }
980
981        int j = 0;
982        for (Example<T> e : dataset) {
983            for (Feature f : e) {
984                int index = fMap.getID(f.getName());
985                indices.get(index).add(j);
986                values.get(index).add(f.getValue());
987            }
988            j++;
989        }
990
991        SparseVector[] output = new SparseVector[numFeatures];
992
993        for (int i = 0; i < fMap.size(); i++) {
994            output[i] = new SparseVector(numExamples,Util.toPrimitiveInt(indices.get(i)),Util.toPrimitiveDouble(values.get(i)));
995        }
996
997        return output;
998    }
999}
1000