001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.math.la;
018
019import org.tribuo.math.util.VectorNormalizer;
020import org.tribuo.util.Util;
021
022import java.util.ArrayList;
023import java.util.Arrays;
024import java.util.Iterator;
025import java.util.NoSuchElementException;
026import java.util.function.DoubleBinaryOperator;
027import java.util.function.DoubleUnaryOperator;
028
029/**
030 * A dense vector, backed by a double array.
031 */
032public class DenseVector implements SGDVector {
033    private static final long serialVersionUID = 1L;
034
035    private final int[] shape;
036    protected final double[] elements;
037
038    public DenseVector(int size) {
039        this(size,0.0);
040    }
041
042    public DenseVector(int size, double value) {
043        this.elements = new double[size];
044        Arrays.fill(this.elements,value);
045        this.shape = new int[]{size};
046    }
047
048    /**
049     * Does not defensively copy the input, used internally.
050     * @param values The values of this dense vector.
051     */
052    protected DenseVector(double[] values) {
053        this.elements = values;
054        this.shape = new int[]{elements.length};
055    }
056
057    protected DenseVector(DenseVector other) {
058        this(other.toArray());
059    }
060
061    /**
062     * Defensively copies the values before construction.
063     * @param values The values of this dense vector.
064     * @return A new dense vector.
065     */
066    public static DenseVector createDenseVector(double[] values) {
067        return new DenseVector(Arrays.copyOf(values,values.length));
068    }
069
070    /**
071     * Generates a copy of the values in this DenseVector.
072     * <p>
073     * This implementation uses Arrays.copyOf, and should be overridden if the
074     * get function has been modified.
075     * @return A copy of the values in this DenseVector.
076     */
077    public double[] toArray() {
078        return Arrays.copyOf(elements, elements.length);
079    }
080
081    @Override
082    public int[] getShape() {
083        return shape;
084    }
085
086    @Override
087    public Tensor reshape(int[] newShape) {
088        int sum = Tensor.shapeSum(newShape);
089        if (sum != elements.length) {
090            throw new IllegalArgumentException("Invalid shape " + Arrays.toString(newShape) + ", expected something with " + elements.length + " elements.");
091        }
092
093        if (newShape.length == 2) {
094            DenseMatrix matrix = new DenseMatrix(newShape[0],newShape[1]);
095
096            for (int a = 0; a < size(); a++) {
097                int i = a % newShape[0];
098                int j = a / newShape[0];
099                matrix.set(i,j,get(a));
100            }
101
102            return matrix;
103        } else if (newShape.length == 1) {
104            return new DenseVector(this);
105        } else {
106            throw new IllegalArgumentException("Only supports 1 or 2 dimensional tensors.");
107        }
108    }
109
110    @Override
111    public DenseVector copy() {
112        return new DenseVector(toArray());
113    }
114
115    @Override
116    public int size() {
117        return elements.length;
118    }
119
120    @Override
121    public int numActiveElements() {
122        return elements.length;
123    }
124
125    /**
126     * Performs a reduction from left to right of this vector.
127     * @param initialValue The initial value.
128     * @param op The element wise operation to apply before reducing.
129     * @param reduction The reduction operation (should be commutative).
130     * @return The reduced value.
131     */
132    public double reduce(double initialValue, DoubleUnaryOperator op, DoubleBinaryOperator reduction) {
133        double output = initialValue;
134        for (int i = 0; i < elements.length; i++) {
135            output = reduction.applyAsDouble(output,get(i));
136        }
137        return output;
138    }
139
140    /**
141     * Equals is defined mathematically, that is two SGDVectors are equal iff they have the same indices
142     * and the same values at those indices.
143     * @param other Object to compare against.
144     * @return True if this vector and the other vector contain the same values in the same order.
145     */
146    @Override
147    public boolean equals(Object other) {
148        if (other instanceof SGDVector) {
149            SGDVector otherVector = (SGDVector) other;
150            if (elements.length == otherVector.size()) {
151                Iterator<VectorTuple> ourItr = iterator();
152                Iterator<VectorTuple> otherItr = ((SGDVector) other).iterator();
153                VectorTuple ourTuple;
154                VectorTuple otherTuple;
155
156                while (ourItr.hasNext() && otherItr.hasNext()) {
157                    ourTuple = ourItr.next();
158                    otherTuple = otherItr.next();
159                    if (!ourTuple.equals(otherTuple)) {
160                        return false;
161                    }
162                }
163
164                // If one of the iterators still has elements then they are not the same.
165                return !(ourItr.hasNext() || otherItr.hasNext());
166            } else {
167                return false;
168            }
169        } else {
170            return false;
171        }
172    }
173
174    @Override
175    public int hashCode() {
176        return Arrays.hashCode(elements);
177    }
178
179    /**
180     * Adds {@code other} to this vector, producing a new {@link DenseVector}.
181     * @param other The vector to add.
182     * @return A new {@link DenseVector} where each element value = this.get(i) + other.get(i).
183     */
184    @Override
185    public DenseVector add(SGDVector other) {
186        if (other.size() != elements.length) {
187            throw new IllegalArgumentException("Can't add two vectors of different dimension, this = " + elements.length + ", other = " + other.size());
188        }
189        double[] newValues = toArray();
190        for (VectorTuple tuple : other) {
191            newValues[tuple.index] += tuple.value;
192        }
193        return new DenseVector(newValues);
194    }
195
196    /**
197     * Subtracts {@code other} from this vector, producing a new {@link DenseVector}.
198     * @param other The vector to subtract.
199     * @return A new {@link DenseVector} where each element value = this.get(i) - other.get(i).
200     */
201    @Override
202    public DenseVector subtract(SGDVector other) {
203        if (other.size() != elements.length) {
204            throw new IllegalArgumentException("Can't subtract two vectors of different dimension, this = " + elements.length + ", other = " + other.size());
205        }
206        double[] newValues = toArray();
207        for (VectorTuple tuple : other) {
208            newValues[tuple.index] -= tuple.value;
209        }
210        return new DenseVector(newValues);
211    }
212
213    @Override
214    public void intersectAndAddInPlace(Tensor other, DoubleUnaryOperator f) {
215        if (other instanceof SGDVector) {
216            SGDVector otherVec = (SGDVector) other;
217            if (otherVec.size() != elements.length) {
218                throw new IllegalArgumentException("Can't intersect two vectors of different dimension, this = " + elements.length + ", other = " + otherVec.size());
219            }
220            for (VectorTuple tuple : otherVec) {
221                elements[tuple.index] += f.applyAsDouble(tuple.value);
222            }
223        } else {
224            throw new IllegalArgumentException("Adding a non-Vector to a Vector");
225        }
226    }
227
228    @Override
229    public void hadamardProductInPlace(Tensor other, DoubleUnaryOperator f) {
230        if (other instanceof SGDVector) {
231            SGDVector otherVec = (SGDVector) other;
232            if (otherVec.size() != elements.length) {
233                throw new IllegalArgumentException("Can't hadamard product two vectors of different dimension, this = " + elements.length + ", other = " + otherVec.size());
234            }
235            for (VectorTuple tuple : otherVec) {
236                elements[tuple.index] *= f.applyAsDouble(tuple.value);
237            }
238        } else {
239            throw new IllegalArgumentException("Scaling a Vector by a non-Vector");
240        }
241    }
242
243    @Override
244    public void foreachInPlace(DoubleUnaryOperator f) {
245        for (int i = 0; i < elements.length; i++) {
246            elements[i] = f.applyAsDouble(elements[i]);
247        }
248    }
249
250    @Override
251    public DenseVector scale(double coefficient) {
252        DenseVector output = copy();
253        output.scaleInPlace(coefficient);
254        return output;
255    }
256
257    @Override
258    public void add(int index, double value) {
259        elements[index] += value;
260    }
261
262    @Override
263    public double dot(SGDVector other) {
264        if (other.size() != elements.length) {
265            throw new IllegalArgumentException("Can't dot two vectors of different dimension, this = " + elements.length + ", other = " + other.size());
266        }
267        double score = 0.0;
268
269        for (VectorTuple tuple : other) {
270            score += elements[tuple.index] * tuple.value;
271        }
272
273        return score;
274    }
275
276    @Override
277    public Matrix outer(SGDVector other) {
278        if (other instanceof DenseVector) {
279            //Outer product is a DenseMatrix
280            DenseVector otherVec = (DenseVector) other;
281            double[][] output = new double[elements.length][];
282            for (int i = 0; i < elements.length; i++) {
283                DenseVector tmp = otherVec.scale(elements[i]);
284                output[i] = tmp.elements;
285            }
286            return new DenseMatrix(output);
287        } else if (other instanceof SparseVector) {
288            //Outer product is a DenseSparseMatrix
289            SparseVector otherVec = (SparseVector) other;
290            SparseVector[] output = new SparseVector[elements.length];
291            for (int i = 0; i < elements.length; i++) {
292                output[i] = otherVec.scale(elements[i]);
293            }
294            return new DenseSparseMatrix(output);
295        } else {
296            throw new IllegalArgumentException("Invalid vector subclass " + other.getClass().getCanonicalName() + " for input");
297        }
298    }
299
300    @Override
301    public double sum() {
302        double sum = 0.0;
303        for (int i = 0; i < elements.length; i++) {
304            sum += elements[i];
305        }
306        return sum;
307    }
308
309    public double sum(DoubleUnaryOperator f) {
310        double sum = 0.0;
311        for (int i = 0; i < elements.length; i++) {
312            sum += f.applyAsDouble(elements[i]);
313        }
314        return sum;
315    }
316
317    @Override
318    public double twoNorm() {
319        double sum = 0.0;
320        for (int i = 0; i < elements.length; i++) {
321            sum += elements[i] * elements[i];
322        }
323        return Math.sqrt(sum);
324    }
325
326    @Override
327    public double oneNorm() {
328        double sum = 0.0;
329        for (int i = 0; i < elements.length; i++) {
330            sum += Math.abs(elements[i]);
331        }
332        return sum;
333    }
334
335    @Override
336    public double get(int index) {
337        return elements[index];
338    }
339
340    @Override
341    public void set(int index, double value) {
342        elements[index] = value;
343    }
344
345    /**
346     * Sets all the elements of this vector to be the same as {@code other}.
347     * @param other The {@link DenseVector} to copy.
348     */
349    public void setElements(DenseVector other) {
350        for (int i = 0; i < elements.length; i++) {
351            elements[i] = other.get(i);
352        }
353    }
354
355    /**
356     * Fills this {@link DenseVector} with {@code value}.
357     * @param value The value to store in this vector.
358     */
359    public void fill(double value) {
360        Arrays.fill(elements,value);
361    }
362
363    @Override
364    public int indexOfMax() {
365        int index = 0;
366        double value = Double.NEGATIVE_INFINITY;
367        for (int i = 0; i < elements.length; i++) {
368            double tmp = elements[i];
369            if (tmp > value) {
370                index = i;
371                value = tmp;
372            }
373        }
374        return index;
375    }
376
377    @Override
378    public double maxValue() {
379        double value = Double.NEGATIVE_INFINITY;
380        for (int i = 0; i < elements.length; i++) {
381            double tmp = elements[i];
382            if (tmp > value) {
383                value = tmp;
384            }
385        }
386        return value;
387    }
388
389    @Override
390    public double minValue() {
391        double value = Double.POSITIVE_INFINITY;
392        for (int i = 0; i < elements.length; i++) {
393            double tmp = elements[i];
394            if (tmp < value) {
395                value = tmp;
396            }
397        }
398        return value;
399    }
400
401    @Override
402    public void normalize(VectorNormalizer normalizer) {
403        double[] normed = normalizer.normalize(elements);
404        System.arraycopy(normed, 0, elements, 0, normed.length);
405    }
406
407    /**
408     * An optimisation for the exponential normalizer when
409     * you already know the normalization constant.
410     *
411     * Used in the CRF.
412     * @param total The normalization constant.
413     */
414    public void expNormalize(double total) {
415        for (int i = 0; i < elements.length; i++) {
416            elements[i] = Math.exp(elements[i] - total);
417        }
418    }
419
420    @Override
421    public String toString() {
422        StringBuilder buffer = new StringBuilder();
423
424        buffer.append("DenseVector(size=");
425        buffer.append(elements.length);
426        buffer.append(",values=[");
427
428        for (int i = 0; i < elements.length; i++) {
429            buffer.append(elements[i]);
430            buffer.append(",");
431        }
432        buffer.setCharAt(buffer.length()-1,']');
433        buffer.append(")");
434
435        return buffer.toString();
436    }
437
438    @Override
439    public double variance(double mean) {
440        double variance = 0.0;
441        for (int i = 0; i < elements.length; i++) {
442            variance += (elements[i] - mean) * (elements[i] - mean);
443        }
444        return variance;
445    }
446
447    @Override
448    public VectorIterator iterator() {
449        return new DenseVectorIterator(this);
450    }
451
452    /**
453     * Generates a {@link SparseVector} representation from this dense vector, removing all values
454     * with absolute value below {@link VectorTuple#DELTA}.
455     * @return A {@link SparseVector}.
456     */
457    public SparseVector sparsify() {
458        return sparsify(VectorTuple.DELTA);
459    }
460
461    /**
462     * Generates a {@link SparseVector} representation from this dense vector, removing all values
463     * with absolute value below the supplied tolerance.
464     * @param tolerance The threshold below which to set a value to zero.
465     * @return A {@link SparseVector}.
466     */
467    public SparseVector sparsify(double tolerance) {
468        ArrayList<Integer> indices = new ArrayList<>();
469        ArrayList<Double> values = new ArrayList<>();
470
471        for (int i = 0; i < elements.length; i++) {
472            double value = get(i);
473            if (Math.abs(value) > tolerance) {
474                indices.add(i);
475                values.add(value);
476            }
477        }
478
479        return new SparseVector(elements.length, Util.toPrimitiveInt(indices), Util.toPrimitiveDouble(values));
480    }
481
482    /**
483     * The l2 or euclidean distance between this vector and the other vector.
484     * @param other The other vector.
485     * @return The euclidean distance between them.
486     */
487    @Override
488    public double euclideanDistance(SGDVector other) {
489        if (other.size() != elements.length) {
490            throw new IllegalArgumentException("Can't measure distance of two vectors of different lengths, this = " + elements.length + ", other = " + other.size());
491        } else if (other instanceof DenseVector) {
492            double score = 0.0;
493
494            for (int i = 0; i < elements.length; i++) {
495                double tmp = elements[i] - other.get(i);
496                score += tmp * tmp;
497            }
498
499            return Math.sqrt(score);
500        } else if (other instanceof SparseVector) {
501            double score = 0.0;
502
503            int i = 0;
504            Iterator<VectorTuple> otherItr = other.iterator();
505            VectorTuple otherTuple;
506            while (i < elements.length && otherItr.hasNext()) {
507                otherTuple = otherItr.next();
508                //after this loop, either itr is out or tuple.index >= otherTuple.index
509                while (i < elements.length && (i < otherTuple.index)) {
510                    // as the other vector contains a zero.
511                    score += elements[i]*elements[i];
512                    i++;
513                }
514                if (i == otherTuple.index) {
515                    double tmp = elements[i] - otherTuple.value;
516                    score += tmp * tmp;
517                    i++;
518                }
519            }
520            for (; i < elements.length; i++) {
521                score += elements[i]*elements[i];
522            }
523
524            return Math.sqrt(score);
525        } else {
526            throw new IllegalArgumentException("Unknown vector subclass " + other.getClass().getCanonicalName() + " for input");
527        }
528    }
529
530    /**
531     * The l1 or Manhattan distance between this vector and the other vector.
532     * @param other The other vector.
533     * @return The l1 distance.
534     */
535    @Override
536    public double l1Distance(SGDVector other) {
537        if (other.size() != elements.length) {
538            throw new IllegalArgumentException("Can't measure distance of two vectors of different lengths, this = " + elements.length + ", other = " + other.size());
539        } else if (other instanceof DenseVector) {
540            double score = 0.0;
541
542            for (int i = 0; i < elements.length; i++) {
543                score += Math.abs(elements[i] - other.get(i));
544            }
545
546            return score;
547        } else if (other instanceof SparseVector) {
548            double score = 0.0;
549
550            int i = 0;
551            Iterator<VectorTuple> otherItr = other.iterator();
552            VectorTuple otherTuple;
553            while (i < elements.length && otherItr.hasNext()) {
554                otherTuple = otherItr.next();
555                //after this loop, either itr is out or tuple.index >= otherTuple.index
556                while (i < elements.length && (i < otherTuple.index)) {
557                    // as the other vector contains a zero.
558                    score += Math.abs(elements[i]);
559                    i++;
560                }
561                if (i == otherTuple.index) {
562                    score += Math.abs(elements[i] - otherTuple.value);
563                    i++;
564                }
565            }
566            for (; i < elements.length; i++) {
567                score += Math.abs(elements[i]);
568            }
569
570            return score;
571        } else {
572            throw new IllegalArgumentException("Unknown vector subclass " + other.getClass().getCanonicalName() + " for input");
573        }
574    }
575
576    private static class DenseVectorIterator implements VectorIterator {
577        private final DenseVector vector;
578        private final VectorTuple tuple;
579        private int index;
580
581        public DenseVectorIterator(DenseVector vector) {
582            this.vector = vector;
583            this.tuple = new VectorTuple();
584            this.index = 0;
585        }
586
587        @Override
588        public boolean hasNext() {
589            return index < vector.elements.length;
590        }
591
592        @Override
593        public VectorTuple next() {
594            if (!hasNext()) {
595                throw new NoSuchElementException("Off the end of the iterator.");
596            }
597            tuple.index = index;
598            tuple.value = vector.elements[index];
599            index++;
600            return tuple;
601        }
602
603        @Override
604        public VectorTuple getReference() {
605            return tuple;
606        }
607    }
608
609}