001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.math.util; 018 019import org.tribuo.math.la.DenseSparseMatrix; 020import org.tribuo.math.la.MatrixIterator; 021import org.tribuo.math.la.MatrixTuple; 022import org.tribuo.math.la.SparseVector; 023 024import java.util.Arrays; 025import java.util.PriorityQueue; 026import java.util.logging.Logger; 027 028/** 029 * Merges each {@link DenseSparseMatrix} using a {@link PriorityQueue} as a heap on the {@link MatrixIterator}. 030 * <p> 031 * Relies upon {@link MatrixIterator#compareTo(MatrixIterator)}. 032 */ 033public class MatrixHeapMerger implements Merger { 034 private static final long serialVersionUID = 1L; 035 036 private static final Logger logger = Logger.getLogger(MatrixHeapMerger.class.getName()); 037 038 @Override 039 public DenseSparseMatrix merge(DenseSparseMatrix[] inputs) { 040 int sparseLength = inputs[0].getDimension2Size(); 041 PriorityQueue<MatrixIterator> queue = new PriorityQueue<>(); 042 int[] totalLengths = new int[inputs[0].getDimension1Size()]; 043 044 for (int i = 0; i < inputs.length; i++) { 045 for (int j = 0; j < totalLengths.length; j++) { 046 totalLengths[j] += inputs[i].numActiveElements(j); 047 } 048 // Setup matrix iterators, call next to load the first value. 049 MatrixIterator cur = inputs[i].iterator(); 050 cur.next(); 051 queue.add(cur); 052 } 053 054 int maxLength = 0; 055 for (int i = 0; i < totalLengths.length; i++) { 056 if (totalLengths[i] > maxLength) { 057 maxLength = totalLengths[i]; 058 } 059 } 060 061 SparseVector[] output = new SparseVector[totalLengths.length]; 062 063 int denseCounter = 0; 064 int sparseCounter = 0; 065 int sparseIndex = -1; 066 067 int[] curIndices = new int[maxLength]; 068 double[] curValues = new double[maxLength]; 069 070 while (!queue.isEmpty()) { 071 MatrixIterator cur = queue.peek(); 072 MatrixTuple ref = cur.getReference(); 073 //logger.log(Level.INFO,"Tuple=" + ref.toString() + ", itrName="+((DenseSparseMatrix.DenseSparseMatrixIterator)cur).getName()+", sparseIndex="+sparseIndex+", sparseCounter="+sparseCounter+", denseCounter="+denseCounter); 074 //logger.log(Level.INFO,"Queue=" + queue.toString()); 075 if (ref.i > denseCounter) { 076 //Reached the end of the current SparseVector, so generate it and reset 077 int[] indices = Arrays.copyOf(curIndices,sparseCounter+1); 078 double[] values = Arrays.copyOf(curValues,sparseCounter+1); 079 output[denseCounter] = SparseVector.createSparseVector(sparseLength,indices,values); 080 Arrays.fill(curIndices,0); 081 Arrays.fill(curValues,0); 082 sparseIndex = -1; 083 sparseCounter = 0; 084 denseCounter++; 085 } 086 087 if (sparseIndex == -1) { 088 //if we're at the start, store the first value 089 sparseIndex = ref.j; 090 curIndices[sparseCounter] = sparseIndex; 091 curValues[sparseCounter] = ref.value; 092 } else if (ref.j == sparseIndex) { 093 //if we're already in the right place, aggregate value 094 curValues[sparseCounter] += ref.value; 095 } else { 096 //else increment the sparseCounter and store the new value 097 sparseIndex = ref.j; 098 sparseCounter++; 099 curIndices[sparseCounter] = sparseIndex; 100 curValues[sparseCounter] = ref.value; 101 } 102 103 if (!cur.hasNext()) { 104 //Discard exhausted iterator 105 queue.poll(); 106 } else { 107 //consume the value and reheap 108 cur.next(); 109 MatrixIterator tmp = queue.poll(); 110 queue.offer(tmp); 111 } 112 } 113 //Generate the final SparseVector 114 int[] indices = Arrays.copyOf(curIndices,sparseCounter+1); 115 double[] values = Arrays.copyOf(curValues,sparseCounter+1); 116 output[denseCounter] = SparseVector.createSparseVector(sparseLength,indices,values); 117 118 return DenseSparseMatrix.createFromSparseVectors(output); 119 } 120 121 @Override 122 public SparseVector merge(SparseVector[] inputs) { 123 int maxLength = 0; 124 125 for (int i = 0; i < inputs.length; i++) { 126 maxLength += inputs[i].numActiveElements(); 127 } 128 129 return HeapMerger.merge(Arrays.asList(inputs),inputs[0].size(),new int[maxLength],new double[maxLength]); 130 } 131}