001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.util; 018 019import com.oracle.labs.mlrg.olcut.util.Pair; 020 021import java.util.ArrayList; 022import java.util.Arrays; 023import java.util.Collection; 024import java.util.HashSet; 025import java.util.List; 026import java.util.Random; 027import java.util.Set; 028import java.util.SplittableRandom; 029import java.util.function.ToIntFunction; 030import java.util.logging.Level; 031import java.util.logging.Logger; 032 033/** 034 * Ye olde util class. 035 * <p> 036 * Basically full of vector and RNG operations. 037 */ 038public final class Util { 039 040 private static final Logger logger = Logger.getLogger(Util.class.getName()); 041 042 // private constructor of a final class, this is full of static methods so you can't instantiate it. 043 private Util() {} 044 045 /** 046 * Find the index of the maximum value in a list. 047 * @param values list 048 * @param <T> the type of the values (must implement Comparable) 049 * @return a pair: (index of the max value, max value) 050 */ 051 public static <T extends Comparable<T>> Pair<Integer, T> argmax(List<T> values) { 052 if (values.isEmpty()) { 053 throw new IllegalArgumentException("argmax on an empty list"); 054 } 055 // 056 // There is no "globally min" value like -Inf for an arbitrary type T so we just pick the first list element 057 T vmax = values.get(0); 058 int imax = 0; 059 for (int i = 1; i < values.size(); i++) { 060 T v = values.get(i); 061 if (v.compareTo(vmax) > 0) { 062 vmax = v; 063 imax = i; 064 } 065 } 066 return new Pair<>(imax, vmax); 067 } 068 069 /** 070 * Find the index of the minimum value in a list. 071 * @param values list 072 * @param <T> the type of the values (must implement Comparable) 073 * @return a pair: (index of the min value, min value) 074 */ 075 public static <T extends Comparable<T>> Pair<Integer, T> argmin(List<T> values) { 076 if (values.isEmpty()) { 077 throw new IllegalArgumentException("argmin on an empty list"); 078 } 079 // 080 // There is no "globally max" value like Inf for an arbitrary type T so we just pick the first list element 081 T vmin = values.get(0); 082 int imin = 0; 083 for (int i = 1; i < values.size(); i++) { 084 T v = values.get(i); 085 if (v.compareTo(vmin) < 0) { 086 vmin = v; 087 imin = i; 088 } 089 } 090 return new Pair<>(imin, vmin); 091 } 092 093 /** 094 * Convert an array of doubles to an array of floats. 095 * 096 * @param doubles The array of doubles to convert. 097 * @return An array of floats. 098 */ 099 public static float[] toFloatArray(double[] doubles) { 100 float[] floats = new float[doubles.length]; 101 for (int i = 0; i < doubles.length; i++) { 102 floats[i] = (float) doubles[i]; 103 } 104 return floats; 105 } 106 107 /** 108 * Convert an array of floats to an array of doubles. 109 * 110 * @param floats The array of floats to convert. 111 * @return An array of doubles. 112 */ 113 public static double[] toDoubleArray(float[] floats) { 114 double[] doubles = new double[floats.length]; 115 for (int i = 0; i < floats.length; i++) { 116 doubles[i] = floats[i]; 117 } 118 return doubles; 119 } 120 121 /** 122 * Shuffles the indices in the range [0,size). 123 * @param size The number of elements. 124 * @param rng The random number generator to use. 125 * @return A random permutation of the values in the range (0, size-1). 126 */ 127 public static int[] randperm(int size, Random rng) { 128 int[] array = new int[size]; 129 for (int i = 0; i < array.length; i++) { 130 array[i] = i; 131 } 132 // Shuffle array 133 for (int i = size; i > 1; i--) { 134 int j = rng.nextInt(i); 135 int tmp = array[i-1]; 136 array[i-1] = array[j]; 137 array[j] = tmp; 138 } 139 return array; 140 } 141 142 /** 143 * Shuffles the indices in the range [0,size). 144 * @param size The number of elements. 145 * @param rng The random number generator to use. 146 * @return A random permutation of the values in the range (0, size-1). 147 */ 148 public static int[] randperm(int size, SplittableRandom rng) { 149 int[] array = new int[size]; 150 for (int i = 0; i < array.length; i++) { 151 array[i] = i; 152 } 153 // Shuffle array 154 for (int i = size; i > 1; i--) { 155 int j = rng.nextInt(i); 156 int tmp = array[i-1]; 157 array[i-1] = array[j]; 158 array[j] = tmp; 159 } 160 return array; 161 } 162 163 /** 164 * Shuffles the input. 165 * @param input The array to shuffle. 166 * @param rng The random number generator to use. 167 */ 168 public static void randpermInPlace(int[] input, Random rng) { 169 // Shuffle array 170 for (int i = input.length; i > 1; i--) { 171 int j = rng.nextInt(i); 172 int tmp = input[i-1]; 173 input[i-1] = input[j]; 174 input[j] = tmp; 175 } 176 } 177 178 /** 179 * Shuffles the input. 180 * @param input The array to shuffle. 181 * @param rng The random number generator to use. 182 */ 183 public static void randpermInPlace(int[] input, SplittableRandom rng) { 184 // Shuffle array 185 for (int i = input.length; i > 1; i--) { 186 int j = rng.nextInt(i); 187 int tmp = input[i-1]; 188 input[i-1] = input[j]; 189 input[j] = tmp; 190 } 191 } 192 193 /** 194 * Draws a bootstrap sample of indices. 195 * @param size Size of the sample to generate. 196 * @param rng The RNG to use. 197 * @return A bootstrap sample. 198 */ 199 public static int[] generateBootstrapIndices(int size, Random rng) { 200 int[] array = new int[size]; 201 for (int i = 0; i < size; i++) { 202 array[i] = rng.nextInt(size); 203 } 204 return array; 205 } 206 207 /** 208 * Draws a bootstrap sample of indices. 209 * @param size Size of the sample to generate. 210 * @param rng The RNG to use. 211 * @return A bootstrap sample. 212 */ 213 public static int[] generateBootstrapIndices(int size, SplittableRandom rng) { 214 int[] array = new int[size]; 215 for (int i = 0; i < size; i++) { 216 array[i] = rng.nextInt(size); 217 } 218 return array; 219 } 220 221 /** 222 * Generates a sample of indices weighted by the provided weights. 223 * @param size Size of the sample to generate. 224 * @param weights A probability mass function of weights. 225 * @param rng The RNG to use. 226 * @return A sample with replacement from weights. 227 */ 228 public static int[] generateWeightedIndicesSample(int size, double[] weights, Random rng) { 229 double[] cdf = generateCDF(weights); 230 if (Math.abs(cdf[cdf.length-1] - 1.0) > 1e-10) { 231 throw new IllegalStateException("Weights do not sum to 1, cdf[cdf.length-1] = " + cdf[cdf.length-1]); 232 } 233 return generateWeightedIndicesSample(cdf, size, rng); 234 } 235 236 /** 237 * Generates a sample of indices weighted by the provided weights. 238 * @param size Size of the sample to generate. 239 * @param weights A probability mass function of weights. 240 * @param rng The RNG to use. 241 * @return A sample with replacement from weights. 242 */ 243 public static int[] generateWeightedIndicesSample(int size, float[] weights, Random rng) { 244 double[] cdf = generateCDF(weights); 245 if (Math.abs(cdf[cdf.length - 1] - 1.0) > 1e-6) { 246 throw new IllegalStateException("Weights do not sum to 1, cdf[cdf.length-1] = " + cdf[cdf.length - 1]); 247 } 248 return generateWeightedIndicesSample(cdf, size, rng); 249 } 250 251 private static int[] generateWeightedIndicesSample(double[] cdf, int size, Random rng) { 252 int[] output = new int[size]; 253 254 for (int i = 0; i < output.length; i++) { 255 double uniform = rng.nextDouble(); 256 int searchVal = Arrays.binarySearch(cdf, uniform); 257 if (searchVal < 0) { 258 output[i] = - 1 - searchVal; 259 } else { 260 output[i] = searchVal; 261 } 262 } 263 return output; 264 } 265 266 /** 267 * Generates a sample of indices weighted by the provided weights. 268 * @param size Size of the sample to generate. 269 * @param weights A probability mass function of weights. 270 * @param rng The RNG to use. 271 * @return A sample with replacement from weights. 272 */ 273 public static int[] generateWeightedIndicesSample(int size, double[] weights, SplittableRandom rng) { 274 double[] cdf = generateCDF(weights); 275 if (Math.abs(cdf[cdf.length-1] - 1.0) > 1e-10) { 276 throw new IllegalStateException("Weights do not sum to 1, cdf[cdf.length-1] = " + cdf[cdf.length-1]); 277 } 278 return generateWeightedIndicesSample(cdf, size, rng); 279 } 280 281 /** 282 * Generates a sample of indices weighted by the provided weights. 283 * @param size Size of the sample to generate. 284 * @param weights A probability mass function of weights. 285 * @param rng The RNG to use. 286 * @return A sample with replacement from weights. 287 */ 288 public static int[] generateWeightedIndicesSample(int size, float[] weights, SplittableRandom rng) { 289 double[] cdf = generateCDF(weights); 290 if (Math.abs(cdf[cdf.length - 1] - 1.0) > 1e-6) { 291 throw new IllegalStateException("Weights do not sum to 1, cdf[cdf.length-1] = " + cdf[cdf.length - 1]); 292 } 293 return generateWeightedIndicesSample(cdf, size, rng); 294 } 295 296 private static int[] generateWeightedIndicesSample(double[] cdf, int size, SplittableRandom rng) { 297 int[] output = new int[size]; 298 299 for (int i = 0; i < output.length; i++) { 300 double uniform = rng.nextDouble(); 301 int searchVal = Arrays.binarySearch(cdf, uniform); 302 if (searchVal < 0) { 303 output[i] = - 1 - searchVal; 304 } else { 305 output[i] = searchVal; 306 } 307 } 308 return output; 309 } 310 311 /** 312 * Generates a sample of indices weighted by the provided weights without replacement. Does not recalculate 313 * proportions in-between samples. Use judiciously. 314 * @param size Size of the sample to generate 315 * @param weights A probability mass function of weights 316 * @param rng The RNG to use 317 * @return A sample without replacement from weights 318 */ 319 public static int[] generateWeightedIndicesSampleWithoutReplacement(int size, double[] weights, Random rng) { 320 double[] cdf = generateCDF(weights); 321 if (Math.abs(cdf[cdf.length-1] - 1.0) > 1e-6) { 322 throw new IllegalStateException("Weights do not sum to 1, cdf[cdf.length-1] = " + cdf[cdf.length-1]); 323 } 324 int[] output = new int[size]; 325 Set<Integer> seenIdxs = new HashSet<>(); 326 int i = 0; 327 while(i < output.length) { 328 double uniform = rng.nextDouble(); 329 int searchVal = Arrays.binarySearch(cdf, uniform); 330 int candidateSample = searchVal < 0 ? - 1 - searchVal : searchVal; 331 if(!seenIdxs.contains(candidateSample)) { 332 seenIdxs.add(candidateSample); 333 output[i] = candidateSample; 334 i++; 335 } 336 } 337 return output; 338 } 339 340 /** 341 * Generates a sample of indices weighted by the provided weights without replacement. Does not recalculate 342 * proportions in-between samples. Use judiciously. 343 * @param size Size of the sample to generate 344 * @param weights A probability mass function of weights 345 * @param rng The RNG to use 346 * @return A sample without replacement from weights 347 */ 348 public static int[] generateWeightedIndicesSampleWithoutReplacement(int size, float[] weights, Random rng) { 349 double[] cdf = generateCDF(weights); 350 if (Math.abs(cdf[cdf.length-1] - 1.0) > 1e-6) { 351 throw new IllegalStateException("Weights do not sum to 1, cdf[cdf.length-1] = " + cdf[cdf.length-1]); 352 } 353 int[] output = new int[size]; 354 Set<Integer> seenIdxs = new HashSet<>(); 355 int i = 0; 356 while(i < output.length) { 357 double uniform = rng.nextDouble(); 358 int searchVal = Arrays.binarySearch(cdf, uniform); 359 int candidateSample = searchVal < 0 ? - 1 - searchVal : searchVal; 360 if(!seenIdxs.contains(candidateSample)) { 361 seenIdxs.add(candidateSample); 362 output[i] = candidateSample; 363 i++; 364 } 365 } 366 return output; 367 } 368 369 /** 370 * Generates a cumulative distribution function from the supplied probability mass function. 371 * @param pmf The probability mass function (i.e., the probability distribution). 372 * @return The CDF. 373 */ 374 public static double[] generateCDF(double[] pmf) { 375 return cumulativeSum(pmf); 376 } 377 378 /** 379 * Produces a cumulative sum array. 380 * @param input The input to sum. 381 * @return The cumulative sum. 382 */ 383 public static double[] cumulativeSum(double[] input) { 384 double[] cdf = new double[input.length]; 385 386 double sum = 0; 387 for (int i = 0; i < input.length; i++) { 388 sum += input[i]; 389 cdf[i] = sum; 390 } 391 392 return cdf; 393 } 394 395 /** 396 * Produces a cumulative sum array. 397 * @param input The input to sum. 398 * @return The cumulative sum. 399 */ 400 public static int[] cumulativeSum(boolean[] input) { 401 int[] cumulativeSum = new int[input.length]; 402 403 int sum = 0; 404 for (int i = 0; i < input.length; i++) { 405 sum += input[i] ? 1 : 0; 406 cumulativeSum[i] = sum; 407 } 408 409 return cumulativeSum; 410 } 411 412 /** 413 * Generates a cumulative distribution function from the supplied probability mass function. 414 * @param pmf The probability mass function (i.e., the probability distribution). 415 * @return The CDF. 416 */ 417 public static double[] generateCDF(float[] pmf) { 418 double[] cdf = new double[pmf.length]; 419 420 double sum = 0; 421 for (int i = 0; i < pmf.length; i++) { 422 sum += pmf[i]; 423 cdf[i] = sum; 424 } 425 426 return cdf; 427 } 428 429 /** 430 * Generates a cumulative distribution function from the supplied probability mass function. 431 * @param counts The frequency counts. 432 * @param countSum The sum of the counts. 433 * @return The CDF. 434 */ 435 public static double[] generateCDF(long[] counts, long countSum) { 436 double[] cdf = new double[counts.length]; 437 438 double countSumD = countSum; 439 double probSum = 0.0; 440 for (int i = 0; i < counts.length; i++) { 441 probSum += counts[i] / countSumD; 442 cdf[i] = probSum; 443 } 444 445 return cdf; 446 } 447 448 /** 449 * Samples an index from the supplied cdf. 450 * @param cdf The cdf to sample from. 451 * @param rng The rng to use. 452 * @return A sample. 453 */ 454 public static int sampleFromCDF(double[] cdf, Random rng) { 455 if (Math.abs(cdf[cdf.length-1] - 1.0) > 1e-6) { 456 throw new IllegalStateException("Weights do not sum to 1, cdf[cdf.length-1] = " + cdf[cdf.length-1]); 457 } 458 double uniform = rng.nextDouble(); 459 int searchVal = Arrays.binarySearch(cdf, uniform); 460 if (searchVal < 0) { 461 return - 1 - searchVal; 462 } else { 463 return searchVal; 464 } 465 } 466 467 /** 468 * Samples an index from the supplied cdf. 469 * @param cdf The cdf to sample from. 470 * @param rng The rng to use. 471 * @return A sample. 472 */ 473 public static int sampleFromCDF(double[] cdf, SplittableRandom rng) { 474 if (Math.abs(cdf[cdf.length-1] - 1.0) > 1e-6) { 475 throw new IllegalStateException("Weights do not sum to 1, cdf[cdf.length-1] = " + cdf[cdf.length-1]); 476 } 477 double uniform = rng.nextDouble(); 478 int searchVal = Arrays.binarySearch(cdf, uniform); 479 if (searchVal < 0) { 480 return - 1 - searchVal; 481 } else { 482 return searchVal; 483 } 484 } 485 486 public static double[] generateUniformVector(int length, double value) { 487 double[] output = new double[length]; 488 489 Arrays.fill(output, value); 490 491 return output; 492 } 493 494 public static float[] generateUniformVector(int length, float value) { 495 float[] output = new float[length]; 496 497 Arrays.fill(output, value); 498 499 return output; 500 } 501 502 public static double[] normalizeToDistribution(double[] input) { 503 double[] output = new double[input.length]; 504 double sum = 0.0; 505 506 for (int i = 0; i < input.length; i++) { 507 output[i] = input[i]; 508 sum += output[i]; 509 } 510 511 for (int i = 0; i < input.length; i++) { 512 output[i] /= sum; 513 } 514 515 return output; 516 } 517 518 public static double[] normalizeToDistribution(float[] input) { 519 double[] output = new double[input.length]; 520 double sum = 0.0; 521 522 for (int i = 0; i < input.length; i++) { 523 output[i] = input[i]; 524 sum += output[i]; 525 } 526 527 for (int i = 0; i < input.length; i++) { 528 output[i] /= sum; 529 } 530 531 return output; 532 } 533 534 public static double[] inplaceNormalizeToDistribution(double[] input) { 535 double sum = 0.0; 536 537 for (int i = 0; i < input.length; i++) { 538 sum += input[i]; 539 } 540 541 for (int i = 0; i < input.length; i++) { 542 input[i] /= sum; 543 } 544 545 return input; 546 } 547 548 public static void inplaceNormalizeToDistribution(float[] input) { 549 float sum = 0.0f; 550 551 for (int i = 0; i < input.length; i++) { 552 sum += input[i]; 553 } 554 555 for (int i = 0; i < input.length; i++) { 556 input[i] /= sum; 557 } 558 559 } 560 561 public static void logVector(Logger otherLogger, Level level, double[] input) { 562 StringBuilder buffer = new StringBuilder(); 563 564 for (int i = 0; i < input.length; i++) { 565 buffer.append("("); 566 buffer.append(i); 567 buffer.append(","); 568 buffer.append(input[i]); 569 buffer.append(") "); 570 } 571 buffer.deleteCharAt(buffer.length()-1); 572 otherLogger.log(level, buffer.toString()); 573 } 574 575 public static void logVector(Logger otherLogger, Level level, float[] input) { 576 StringBuilder buffer = new StringBuilder(); 577 578 for (int i = 0; i < input.length; i++) { 579 buffer.append("("); 580 buffer.append(i); 581 buffer.append(","); 582 buffer.append(input[i]); 583 buffer.append(") "); 584 } 585 buffer.deleteCharAt(buffer.length()-1); 586 otherLogger.log(level, buffer.toString()); 587 } 588 589 public static double[] toPrimitiveDoubleFromInteger(List<Integer> input) { 590 double[] output = new double[input.size()]; 591 592 for (int i = 0; i < input.size(); i++) { 593 output[i] = input.get(i); 594 } 595 596 return output; 597 } 598 599 public static double[] toPrimitiveDouble(List<Double> input) { 600 double[] output = new double[input.size()]; 601 602 for (int i = 0; i < input.size(); i++) { 603 output[i] = input.get(i); 604 } 605 606 return output; 607 } 608 609 public static float[] toPrimitiveFloat(List<Float> input) { 610 float[] output = new float[input.size()]; 611 612 for (int i = 0; i < input.size(); i++) { 613 output[i] = input.get(i); 614 } 615 616 return output; 617 } 618 619 public static int[] toPrimitiveInt(List<Integer> input) { 620 int[] output = new int[input.size()]; 621 622 for (int i = 0; i < input.size(); i++) { 623 output[i] = input.get(i); 624 } 625 626 return output; 627 } 628 629 public static long[] toPrimitiveLong(List<Long> input) { 630 long[] output = new long[input.size()]; 631 632 for (int i = 0; i < input.size(); i++) { 633 output[i] = input.get(i); 634 } 635 636 return output; 637 } 638 639 public static int[] sampleInts(Random rng, int size, int range) { 640 int[] output = new int[size]; 641 642 for (int i = 0; i < output.length; i++) { 643 output[i] = rng.nextInt(range); 644 } 645 646 return output; 647 } 648 649 public static void inPlaceAdd(double[] input, double[] update) { 650 for (int i = 0; i < input.length; i++) { 651 input[i] += update[i]; 652 } 653 } 654 655 public static void inPlaceSubtract(double[] input, double[] update) { 656 for (int i = 0; i < input.length; i++) { 657 input[i] -= update[i]; 658 } 659 } 660 661 public static void inPlaceAdd(float[] input, float[] update) { 662 for (int i = 0; i < input.length; i++) { 663 input[i] += update[i]; 664 } 665 } 666 667 public static void inPlaceSubtract(float[] input, float[] update) { 668 for (int i = 0; i < input.length; i++) { 669 input[i] -= update[i]; 670 } 671 } 672 673 public static double vectorNorm(double[] input) { 674 double norm = 0.0; 675 for (double d : input) { 676 norm += d * d; 677 } 678 return norm; 679 } 680 681 public static double sum(double[] input) { 682 double sum = 0.0; 683 for (double d : input) { 684 sum += d; 685 } 686 return sum; 687 } 688 689 public static float sum(float[] input) { 690 float sum = 0.0f; 691 for (float d : input) { 692 sum += d; 693 } 694 return sum; 695 } 696 697 public static double sum(double[] array, int length) { 698 double sum = 0.0; 699 for (int i = 0; i < length; i++) { 700 sum += array[i]; 701 } 702 return sum; 703 } 704 705 public static float sum(float[] array, int length) { 706 float sum = 0.0f; 707 for (int i = 0; i < length; i++) { 708 sum += array[i]; 709 } 710 return sum; 711 } 712 713 public static float sum(int[] indices, int indicesLength, float[] input) { 714 float sum = 0.0f; 715 for (int i = 0; i < indicesLength; i++) { 716 sum += input[indices[i]]; 717 } 718 return sum; 719 } 720 721 public static float sum(int[] indices, float[] input) { 722 return sum(indices,indices.length,input); 723 } 724 725 public static float[] generateUniformFloatVector(int length, float value) { 726 float[] output = new float[length]; 727 728 Arrays.fill(output, value); 729 730 return output; 731 } 732 733 /** 734 * A binary search function. 735 * @param list Input list, must be ordered. 736 * @param key Key to search for. 737 * @param <T> Type of the list, must implement Comparable. 738 * @return the index of the search key, if it is contained in the list; 739 * otherwise, (-(insertion point) - 1). The insertion point is 740 * defined as the point at which the key would be inserted into 741 * the list: the index of the first element greater than the key, 742 * or list.size() if all elements in the list are less than the 743 * specified key. Note that this guarantees that the return value 744 * will be >= 0 if and only if the key is found. 745 */ 746 public static <T> int binarySearch(List<? extends Comparable<? super T>> list, T key) { 747 return binarySearch(list,key,0,list.size()-1); 748 } 749 750 /** 751 * A binary search function. 752 * @param list Input list, must be ordered. 753 * @param key Key to search for. 754 * @param low Starting index. 755 * @param high End index (will be searched). 756 * @param <T> Type of the list, must implement Comparable. 757 * @return the index of the search key, if it is contained in the list; 758 * otherwise, (-(insertion point) - 1). The insertion point is 759 * defined as the point at which the key would be inserted into 760 * the list: the index of the first element greater than the key, 761 * or high if all elements in the list are less than the 762 * specified key. Note that this guarantees that the return value 763 * will be >= 0 if and only if the key is found. 764 */ 765 public static <T> int binarySearch(List<? extends Comparable<? super T>> list, T key, int low, int high) { 766 while (low <= high) { 767 int mid = (low + high) >>> 1; 768 Comparable<? super T> midVal = list.get(mid); 769 int cmp = midVal.compareTo(key); 770 if (cmp < 0) { 771 low = mid + 1; 772 } else if (cmp > 0) { 773 high = mid - 1; 774 } else { 775 return mid; // key found 776 } 777 } 778 return -(low + 1); // key not found 779 } 780 781 /** 782 * A binary search function. 783 * @param list Input list, must be ordered. 784 * @param key Key to search for. 785 * @param extractionFunc Takes a T and generates an int 786 * which can be used for comparison using int's natural ordering. 787 * @param <T> Type of the list, must implement Comparable. 788 * @return the index of the search key, if it is contained in the list; 789 * otherwise, (-(insertion point) - 1). The insertion point is 790 * defined as the point at which the key would be inserted into 791 * the list: the index of the first element greater than the key, 792 * or high if all elements in the list are less than the 793 * specified key. Note that this guarantees that the return value 794 * will be >= 0 if and only if the key is found. 795 */ 796 public static <T> int binarySearch(List<? extends T> list, int key, ToIntFunction<T> extractionFunc) { 797 int low = 0; 798 int high = list.size()-1; 799 while (low <= high) { 800 int mid = (low + high) >>> 1; 801 int midVal = extractionFunc.applyAsInt(list.get(mid)); 802 int cmp = Integer.compare(midVal, key); 803 if (cmp < 0) { 804 low = mid + 1; 805 } else if (cmp > 0) { 806 high = mid - 1; 807 } else { 808 return mid; // key found 809 } 810 } 811 return -(low + 1); // key not found 812 } 813 814 /** 815 * Calculates the area under the curve, bounded below by the x axis. 816 * <p> 817 * Uses linear interpolation between the points on the x axis, 818 * i.e., trapezoidal integration. 819 * <p> 820 * The x axis must be increasing. 821 * @param x The x points to evaluate. 822 * @param y The corresponding heights. 823 * @return The AUC. 824 */ 825 public static double auc(double[] x, double[] y) { 826 if (x.length != y.length) { 827 throw new IllegalArgumentException("x and y must be the same length, x.length = " + x.length + ", y.length = " + y.length); 828 } 829 double output = 0.0; 830 831 for (int i = 1; i < x.length; i++) { 832 double ySum = y[i] + y[i-1]; 833 double xDiff = x[i] - x[i-1]; 834 if (xDiff < -1e-12) { 835 throw new IllegalStateException(String.format("X is not increasing, x[%d]=%f, x[%d]=%f",i,x[i],i-1,x[i-1])); 836 } 837 output += (ySum * xDiff) / 2.0; 838 } 839 840 return output; 841 } 842 843 /** 844 * Returns the mean and variance of the input. 845 * @param inputs The input array. 846 * @return The mean and variance of the inputs. The mean is the first element, the variance is the second. 847 */ 848 public static Pair<Double,Double> meanAndVariance(double[] inputs) { 849 return meanAndVariance(inputs,inputs.length); 850 } 851 852 /** 853 * Returns the mean and variance of the input's first length elements. 854 * @param inputs The input array. 855 * @param length The number of elements to use. 856 * @return The mean and variance of the inputs. The mean is the first element, the variance is the second. 857 */ 858 public static Pair<Double,Double> meanAndVariance(double[] inputs, int length) { 859 double mean = 0.0; 860 double sumSquares = 0.0; 861 for (int i = 0; i < length; i++) { 862 double value = inputs[i]; 863 double delta = value - mean; 864 mean += delta / (i+1); 865 double delta2 = value - mean; 866 sumSquares += delta * delta2; 867 } 868 return new Pair<>(mean,sumSquares/(length-1)); 869 } 870 871 /** 872 * Returns the weighted mean of the input. 873 * <p> 874 * Throws IllegalArgumentException if the two arrays are not the same length. 875 * @param inputs The input array. 876 * @param weights The weights to use. 877 * @return The weighted mean. 878 */ 879 public static double weightedMean(double[] inputs, double[] weights) { 880 if (inputs.length != weights.length) { 881 throw new IllegalArgumentException("inputs and weights must be the same length, inputs.length = " + inputs.length + ", weights.length = " + weights.length); 882 } 883 884 double output = 0.0; 885 double sum = 0.0; 886 for (int i = 0; i < inputs.length; i++) { 887 output += inputs[i] * weights[i]; 888 sum += weights[i]; 889 } 890 891 return output/sum; 892 } 893 894 /** 895 * Returns the mean of the input array. 896 * @param inputs The input array. 897 * @return The mean of inputs. 898 */ 899 public static double mean(double[] inputs) { 900 double output = 0.0; 901 for (int i = 0; i < inputs.length; i++) { 902 output += inputs[i]; 903 } 904 return output / inputs.length; 905 } 906 907 public static double mean(double[] array, int length) { 908 double sum = sum(array,length); 909 return sum / length; 910 } 911 912 public static <V extends Number> double mean(Collection<V> values) { 913 double total = 0d; 914 for (V v : values) { 915 total += v.doubleValue(); 916 } 917 return total / values.size(); 918 } 919 920 public static <V extends Number> double sampleVariance(Collection<V> values) { 921 double mean = mean(values); 922 double total = 0d; 923 for (V v : values) { 924 total += Math.pow(v.doubleValue()-mean, 2); 925 } 926 return total / (values.size() - 1); 927 } 928 929 public static <V extends Number> double sampleStandardDeviation(Collection<V> values) { 930 return Math.sqrt(sampleVariance(values)); 931 } 932 933 public static double weightedMean(double[] array, float[] weights, int length) { 934 double sum = weightedSum(array,weights,length); 935 return sum / sum(weights,length); 936 } 937 938 public static double weightedSum(double[] array, float[] weights, int length) { 939 if (array.length != weights.length) { 940 throw new IllegalArgumentException("array and weights must be the same length, array.length = " + array.length + ", weights.length = " + weights.length); 941 } 942 943 double sum = 0.0; 944 for (int i = 0; i < length; i++) { 945 sum += weights[i] * array[i]; 946 } 947 return sum; 948 } 949 950 /** 951 * Returns an array containing the indices where values are different. 952 * Basically a combination of np.where and np.diff. 953 * <p> 954 * Stores an index if the value after it is different. Always stores the 955 * final index. 956 * <p> 957 * Uses a default tolerance of 1e-12. 958 * @param input Input array. 959 * @return An array containing the indices where the input changes. 960 */ 961 public static int[] differencesIndices(double[] input) { 962 return differencesIndices(input,1e-12); 963 } 964 965 /** 966 * Returns an array containing the indices where values are different. 967 * Basically a combination of np.where and np.diff. 968 * <p> 969 * Stores an index if the value after it is different. Always stores the 970 * final index. 971 * @param input Input array. 972 * @param tolerance Tolerance to determine a difference. 973 * @return An array containing the indices where the input changes. 974 */ 975 public static int[] differencesIndices(double[] input, double tolerance) { 976 List<Integer> indices = new ArrayList<>(); 977 978 for (int i = 0; i < input.length-1; i++) { 979 double diff = Math.abs(input[i+1] - input[i]); 980 if (diff > tolerance) { 981 indices.add(i); 982 } 983 } 984 indices.add(input.length-1); 985 986 return Util.toPrimitiveInt(indices); 987 } 988 989 /** 990 * Formats a duration given two times in milliseconds. 991 * <p> 992 * Format string is - (%02d:%02d:%02d:%03d) or (%d days, %02d:%02d:%02d:%03d) 993 * 994 * @param startMillis Start time in ms. 995 * @param stopMillis End time in ms. 996 * @return A formatted string measuring time in hours, minutes, second and milliseconds. 997 */ 998 public static String formatDuration(long startMillis, long stopMillis) { 999 long millis = stopMillis - startMillis; 1000 long second = (millis / 1000) % 60; 1001 long minute = (millis / (1000 * 60)) % 60; 1002 long hour = (millis / (1000 * 60 * 60)) % 24; 1003 long days = (millis / (1000 * 60 * 60)) / 24; 1004 1005 if (days == 0) { 1006 return String.format("(%02d:%02d:%02d:%03d)", hour, minute, second, millis % 1000); 1007 } else { 1008 return String.format("(%d days, %02d:%02d:%02d:%03d)", days, hour, minute, second, millis % 1000); 1009 } 1010 } 1011 1012 /** 1013 * Expects sorted input arrays. Returns an array containing all the elements in first that are not in second. 1014 * @param first The first sorted array. 1015 * @param second The second sorted array. 1016 * @return An array containing all the elements of first that aren't in second. 1017 */ 1018 public static int[] sortedDifference(int[] first, int[] second) { 1019 List<Integer> diffIndicesList = new ArrayList<>(); 1020 1021 int i = 0; 1022 int j = 0; 1023 while (i < first.length && j < second.length) { 1024 //after this loop, either itr is out or tuple.index >= otherTuple.index 1025 while (i < first.length && (first[i] < second[j])) { 1026 diffIndicesList.add(first[i]); 1027 i++; 1028 } 1029 //after this loop, either otherItr is out or tuple.index <= otherTuple.index 1030 while (j < second.length && (first[i] > second[j])) { 1031 j++; 1032 } 1033 if (first[i] != second[j]) { 1034 diffIndicesList.add(first[i]); 1035 } 1036 } 1037 for (; i < first.length; i++) { 1038 diffIndicesList.add(first[i]); 1039 } 1040 return diffIndicesList.stream().mapToInt(Integer::intValue).toArray(); 1041 } 1042 1043}