001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.math.util; 018 019import org.tribuo.math.la.DenseSparseMatrix; 020import org.tribuo.math.la.SparseVector; 021import org.tribuo.math.la.VectorIterator; 022import org.tribuo.math.la.VectorTuple; 023 024import java.util.ArrayList; 025import java.util.Arrays; 026import java.util.List; 027import java.util.PriorityQueue; 028import java.util.logging.Logger; 029 030/** 031 * Merges each {@link SparseVector} separately using a {@link PriorityQueue} as a heap. 032 * <p> 033 * Relies upon {@link VectorIterator#compareTo(VectorIterator)}. 034 */ 035public class HeapMerger implements Merger { 036 private static final long serialVersionUID = 1L; 037 private static final Logger logger = Logger.getLogger(HeapMerger.class.getName()); 038 039 @Override 040 public DenseSparseMatrix merge(DenseSparseMatrix[] inputs) { 041 int denseLength = inputs[0].getDimension1Size(); 042 int sparseLength = inputs[0].getDimension2Size(); 043 int[] totalLengths = new int[inputs[0].getDimension1Size()]; 044 045 for (int i = 0; i < inputs.length; i++) { 046 for (int j = 0; j < totalLengths.length; j++) { 047 totalLengths[j] += inputs[i].numActiveElements(j); 048 } 049 } 050 051 int maxLength = 0; 052 for (int i = 0; i < totalLengths.length; i++) { 053 if (totalLengths[i] > maxLength) { 054 maxLength = totalLengths[i]; 055 } 056 } 057 058 SparseVector[] output = new SparseVector[denseLength]; 059 060 int[] indicesBuffer = new int[maxLength]; 061 double[] valuesBuffer = new double[maxLength]; 062 063 List<SparseVector> vectors = new ArrayList<>(); 064 for (int i = 0; i < denseLength; i++) { 065 vectors.clear(); 066 for (DenseSparseMatrix m : inputs) { 067 SparseVector vec = m.getRow(i); 068 if (vec.numActiveElements() > 0) { 069 vectors.add(vec); 070 } 071 } 072 output[i] = merge(vectors,sparseLength,indicesBuffer,valuesBuffer); 073 } 074 075 return DenseSparseMatrix.createFromSparseVectors(output); 076 } 077 078 @Override 079 public SparseVector merge(SparseVector[] inputs) { 080 int maxLength = 0; 081 082 for (int i = 0; i < inputs.length; i++) { 083 maxLength += inputs[i].numActiveElements(); 084 } 085 086 return merge(Arrays.asList(inputs),inputs[0].size(),new int[maxLength],new double[maxLength]); 087 } 088 089 /** 090 * Merges a list of sparse vectors into a single sparse vector, summing the values. 091 * @param vectors The vectors to merge. 092 * @param dimension The dimension of the sparse vector. 093 * @param indicesBuffer A buffer for the indices. 094 * @param valuesBuffer A buffer for the values. 095 * @return The merged SparseVector. 096 */ 097 public static SparseVector merge(List<SparseVector> vectors, int dimension, int[] indicesBuffer, double[] valuesBuffer) { 098 PriorityQueue<VectorIterator> queue = new PriorityQueue<>(); 099 Arrays.fill(valuesBuffer,0.0); 100 101 for (SparseVector vector : vectors) { 102 // Setup matrix iterators, call next to load the first value. 103 VectorIterator cur = vector.iterator(); 104 cur.next(); 105 queue.add(cur); 106 } 107 108 int sparseCounter = 0; 109 int sparseIndex = -1; 110 111 while (!queue.isEmpty()) { 112 VectorIterator cur = queue.peek(); 113 VectorTuple ref = cur.getReference(); 114 //logger.log(Level.INFO,"Tuple=" + ref.toString() + ", itrName="+((DenseSparseMatrix.DenseSparseMatrixIterator)cur).getName()+", sparseIndex="+sparseIndex+", sparseCounter="+sparseCounter+", denseCounter="+denseCounter); 115 //logger.log(Level.INFO,"Queue=" + queue.toString()); 116 117 if (sparseIndex == -1) { 118 //if we're at the start, store the first value 119 sparseIndex = ref.index; 120 indicesBuffer[sparseCounter] = sparseIndex; 121 valuesBuffer[sparseCounter] = ref.value; 122 } else if (ref.index == sparseIndex) { 123 //if we're already in the right place, aggregate value 124 valuesBuffer[sparseCounter] += ref.value; 125 } else { 126 //else increment the sparseCounter and store the new value 127 sparseIndex = ref.index; 128 sparseCounter++; 129 indicesBuffer[sparseCounter] = sparseIndex; 130 valuesBuffer[sparseCounter] = ref.value; 131 } 132 133 if (!cur.hasNext()) { 134 //Discard exhausted iterator 135 queue.poll(); 136 } else { 137 //consume the value and reheap 138 cur.next(); 139 VectorIterator tmp = queue.poll(); 140 queue.offer(tmp); 141 } 142 } 143 //Generate the final SparseVector 144 int[] indices = Arrays.copyOf(indicesBuffer,sparseCounter+1); 145 double[] values = Arrays.copyOf(valuesBuffer,sparseCounter+1); 146 return SparseVector.createSparseVector(dimension,indices,values); 147 } 148 149}