001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.math.util;
018
019import org.tribuo.math.la.DenseSparseMatrix;
020import org.tribuo.math.la.SparseVector;
021import org.tribuo.math.la.VectorIterator;
022import org.tribuo.math.la.VectorTuple;
023
024import java.util.ArrayList;
025import java.util.Arrays;
026import java.util.List;
027import java.util.PriorityQueue;
028import java.util.logging.Logger;
029
030/**
031 * Merges each {@link SparseVector} separately using a {@link PriorityQueue} as a heap.
032 * <p>
033 * Relies upon {@link VectorIterator#compareTo(VectorIterator)}.
034 */
035public class HeapMerger implements Merger {
036    private static final long serialVersionUID = 1L;
037    private static final Logger logger = Logger.getLogger(HeapMerger.class.getName());
038
039    @Override
040    public DenseSparseMatrix merge(DenseSparseMatrix[] inputs) {
041        int denseLength = inputs[0].getDimension1Size();
042        int sparseLength = inputs[0].getDimension2Size();
043        int[] totalLengths = new int[inputs[0].getDimension1Size()];
044
045        for (int i = 0; i < inputs.length; i++) {
046            for (int j = 0; j < totalLengths.length; j++) {
047                totalLengths[j] += inputs[i].numActiveElements(j);
048            }
049        }
050
051        int maxLength = 0;
052        for (int i = 0; i < totalLengths.length; i++) {
053            if (totalLengths[i] > maxLength) {
054                maxLength = totalLengths[i];
055            }
056        }
057
058        SparseVector[] output = new SparseVector[denseLength];
059
060        int[] indicesBuffer = new int[maxLength];
061        double[] valuesBuffer = new double[maxLength];
062        
063        List<SparseVector> vectors = new ArrayList<>();
064        for (int i = 0; i < denseLength; i++) {
065            vectors.clear();
066            for (DenseSparseMatrix m : inputs) {
067                SparseVector vec = m.getRow(i);
068                if (vec.numActiveElements() > 0) {
069                    vectors.add(vec);
070                }
071            }
072            output[i] = merge(vectors,sparseLength,indicesBuffer,valuesBuffer);
073        }
074
075        return DenseSparseMatrix.createFromSparseVectors(output);
076    }
077
078    @Override
079    public SparseVector merge(SparseVector[] inputs) {
080        int maxLength = 0;
081
082        for (int i = 0; i < inputs.length; i++) {
083            maxLength += inputs[i].numActiveElements();
084        }
085
086        return merge(Arrays.asList(inputs),inputs[0].size(),new int[maxLength],new double[maxLength]);
087    }
088
089    /**
090     * Merges a list of sparse vectors into a single sparse vector, summing the values.
091     * @param vectors The vectors to merge.
092     * @param dimension The dimension of the sparse vector.
093     * @param indicesBuffer A buffer for the indices.
094     * @param valuesBuffer A buffer for the values.
095     * @return The merged SparseVector.
096     */
097    public static SparseVector merge(List<SparseVector> vectors, int dimension, int[] indicesBuffer, double[] valuesBuffer) {
098        PriorityQueue<VectorIterator> queue = new PriorityQueue<>();
099        Arrays.fill(valuesBuffer,0.0);
100
101        for (SparseVector vector : vectors) {
102            // Setup matrix iterators, call next to load the first value.
103            VectorIterator cur = vector.iterator();
104            cur.next();
105            queue.add(cur);
106        }
107
108        int sparseCounter = 0;
109        int sparseIndex = -1;
110
111        while (!queue.isEmpty()) {
112            VectorIterator cur = queue.peek();
113            VectorTuple ref = cur.getReference();
114            //logger.log(Level.INFO,"Tuple=" + ref.toString() + ", itrName="+((DenseSparseMatrix.DenseSparseMatrixIterator)cur).getName()+", sparseIndex="+sparseIndex+", sparseCounter="+sparseCounter+", denseCounter="+denseCounter);
115            //logger.log(Level.INFO,"Queue=" + queue.toString());
116
117            if (sparseIndex == -1) {
118                //if we're at the start, store the first value
119                sparseIndex = ref.index;
120                indicesBuffer[sparseCounter] = sparseIndex;
121                valuesBuffer[sparseCounter] = ref.value;
122            } else if (ref.index == sparseIndex) {
123                //if we're already in the right place, aggregate value
124                valuesBuffer[sparseCounter] += ref.value;
125            } else {
126                //else increment the sparseCounter and store the new value
127                sparseIndex = ref.index;
128                sparseCounter++;
129                indicesBuffer[sparseCounter] = sparseIndex;
130                valuesBuffer[sparseCounter] = ref.value;
131            }
132
133            if (!cur.hasNext()) {
134                //Discard exhausted iterator
135                queue.poll();
136            } else {
137                //consume the value and reheap
138                cur.next();
139                VectorIterator tmp = queue.poll();
140                queue.offer(tmp);
141            }
142        }
143        //Generate the final SparseVector
144        int[] indices = Arrays.copyOf(indicesBuffer,sparseCounter+1);
145        double[] values = Arrays.copyOf(valuesBuffer,sparseCounter+1);
146        return SparseVector.createSparseVector(dimension,indices,values);
147    }
148
149}