001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.math.util;
018
019import org.tribuo.math.la.DenseSparseMatrix;
020import org.tribuo.math.la.MatrixIterator;
021import org.tribuo.math.la.MatrixTuple;
022import org.tribuo.math.la.SparseVector;
023
024import java.util.Arrays;
025import java.util.PriorityQueue;
026import java.util.logging.Logger;
027
028/**
029 * Merges each {@link DenseSparseMatrix} using a {@link PriorityQueue} as a heap on the {@link MatrixIterator}.
030 * <p>
031 * Relies upon {@link MatrixIterator#compareTo(MatrixIterator)}.
032 */
033public class MatrixHeapMerger implements Merger {
034    private static final long serialVersionUID = 1L;
035
036    private static final Logger logger = Logger.getLogger(MatrixHeapMerger.class.getName());
037
038    @Override
039    public DenseSparseMatrix merge(DenseSparseMatrix[] inputs) {
040        int sparseLength = inputs[0].getDimension2Size();
041        PriorityQueue<MatrixIterator> queue = new PriorityQueue<>();
042        int[] totalLengths = new int[inputs[0].getDimension1Size()];
043
044        for (int i = 0; i < inputs.length; i++) {
045            for (int j = 0; j < totalLengths.length; j++) {
046                totalLengths[j] += inputs[i].numActiveElements(j);
047            }
048            // Setup matrix iterators, call next to load the first value.
049            MatrixIterator cur = inputs[i].iterator();
050            cur.next();
051            queue.add(cur);
052        }
053
054        int maxLength = 0;
055        for (int i = 0; i < totalLengths.length; i++) {
056            if (totalLengths[i] > maxLength) {
057                maxLength = totalLengths[i];
058            }
059        }
060
061        SparseVector[] output = new SparseVector[totalLengths.length];
062
063        int denseCounter = 0;
064        int sparseCounter = 0;
065        int sparseIndex = -1;
066
067        int[] curIndices = new int[maxLength];
068        double[] curValues = new double[maxLength];
069
070        while (!queue.isEmpty()) {
071            MatrixIterator cur = queue.peek();
072            MatrixTuple ref = cur.getReference();
073            //logger.log(Level.INFO,"Tuple=" + ref.toString() + ", itrName="+((DenseSparseMatrix.DenseSparseMatrixIterator)cur).getName()+", sparseIndex="+sparseIndex+", sparseCounter="+sparseCounter+", denseCounter="+denseCounter);
074            //logger.log(Level.INFO,"Queue=" + queue.toString());
075            if (ref.i > denseCounter) {
076                //Reached the end of the current SparseVector, so generate it and reset
077                int[] indices = Arrays.copyOf(curIndices,sparseCounter+1);
078                double[] values = Arrays.copyOf(curValues,sparseCounter+1);
079                output[denseCounter] = SparseVector.createSparseVector(sparseLength,indices,values);
080                Arrays.fill(curIndices,0);
081                Arrays.fill(curValues,0);
082                sparseIndex = -1;
083                sparseCounter = 0;
084                denseCounter++;
085            }
086
087            if (sparseIndex == -1) {
088                //if we're at the start, store the first value
089                sparseIndex = ref.j;
090                curIndices[sparseCounter] = sparseIndex;
091                curValues[sparseCounter] = ref.value;
092            } else if (ref.j == sparseIndex) {
093                //if we're already in the right place, aggregate value
094                curValues[sparseCounter] += ref.value;
095            } else {
096                //else increment the sparseCounter and store the new value
097                sparseIndex = ref.j;
098                sparseCounter++;
099                curIndices[sparseCounter] = sparseIndex;
100                curValues[sparseCounter] = ref.value;
101            }
102
103            if (!cur.hasNext()) {
104                //Discard exhausted iterator
105                queue.poll();
106            } else {
107                //consume the value and reheap
108                cur.next();
109                MatrixIterator tmp = queue.poll();
110                queue.offer(tmp);
111            }
112        }
113        //Generate the final SparseVector
114        int[] indices = Arrays.copyOf(curIndices,sparseCounter+1);
115        double[] values = Arrays.copyOf(curValues,sparseCounter+1);
116        output[denseCounter] = SparseVector.createSparseVector(sparseLength,indices,values);
117
118        return DenseSparseMatrix.createFromSparseVectors(output);
119    }
120
121    @Override
122    public SparseVector merge(SparseVector[] inputs) {
123        int maxLength = 0;
124
125        for (int i = 0; i < inputs.length; i++) {
126            maxLength += inputs[i].numActiveElements();
127        }
128
129        return HeapMerger.merge(Arrays.asList(inputs),inputs[0].size(),new int[maxLength],new double[maxLength]);
130    }
131}