001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.util;
018
019import com.oracle.labs.mlrg.olcut.util.Pair;
020
021import java.util.ArrayList;
022import java.util.Arrays;
023import java.util.Collection;
024import java.util.HashSet;
025import java.util.List;
026import java.util.Random;
027import java.util.Set;
028import java.util.SplittableRandom;
029import java.util.function.ToIntFunction;
030import java.util.logging.Level;
031import java.util.logging.Logger;
032
033/**
034 * Ye olde util class.
035 * <p>
036 * Basically full of vector and RNG operations.
037 */
038public final class Util {
039
040    private static final Logger logger = Logger.getLogger(Util.class.getName());
041
042    // private constructor of a final class, this is full of static methods so you can't instantiate it.
043    private Util() {}
044
045    /**
046     * Find the index of the maximum value in a list.
047     * @param values list
048     * @param <T> the type of the values (must implement Comparable)
049     * @return a pair: (index of the max value, max value)
050     */
051    public static <T extends Comparable<T>> Pair<Integer, T> argmax(List<T> values) {
052        if (values.isEmpty()) {
053            throw new IllegalArgumentException("argmax on an empty list");
054        }
055        //
056        // There is no "globally min" value like -Inf for an arbitrary type T so we just pick the first list element
057        T vmax = values.get(0);
058        int imax = 0;
059        for (int i = 1; i < values.size(); i++) {
060            T v = values.get(i);
061            if (v.compareTo(vmax) > 0) {
062                vmax = v;
063                imax = i;
064            }
065        }
066        return new Pair<>(imax, vmax);
067    }
068
069    /**
070     * Find the index of the minimum value in a list.
071     * @param values list
072     * @param <T> the type of the values (must implement Comparable)
073     * @return a pair: (index of the min value, min value)
074     */
075    public static <T extends Comparable<T>> Pair<Integer, T> argmin(List<T> values) {
076        if (values.isEmpty()) {
077            throw new IllegalArgumentException("argmin on an empty list");
078        }
079        //
080        // There is no "globally max" value like Inf for an arbitrary type T so we just pick the first list element
081        T vmin = values.get(0);
082        int imin = 0;
083        for (int i = 1; i < values.size(); i++) {
084            T v = values.get(i);
085            if (v.compareTo(vmin) < 0) {
086                vmin = v;
087                imin = i;
088            }
089        }
090        return new Pair<>(imin, vmin);
091    }
092
093    /**
094     * Convert an array of doubles to an array of floats.
095     *
096     * @param doubles The array of doubles to convert.
097     * @return An array of floats.
098     */
099    public static float[] toFloatArray(double[] doubles) {
100        float[] floats = new float[doubles.length];
101        for (int i = 0; i < doubles.length; i++) {
102            floats[i] = (float) doubles[i];
103        }
104        return floats;
105    }
106
107    /**
108     * Convert an array of floats to an array of doubles.
109     *
110     * @param floats The array of floats to convert.
111     * @return An array of doubles.
112     */
113    public static double[] toDoubleArray(float[] floats) {
114        double[] doubles = new double[floats.length];
115        for (int i = 0; i < floats.length; i++) {
116            doubles[i] = floats[i];
117        }
118        return doubles;
119    }
120
121    /**
122     * Shuffles the indices in the range [0,size).
123     * @param size The number of elements.
124     * @param rng The random number generator to use.
125     * @return A random permutation of the values in the range (0, size-1).
126     */
127    public static int[] randperm(int size, Random rng) {
128        int[] array = new int[size];
129        for (int i = 0; i < array.length; i++) {
130            array[i] = i;
131        }
132        // Shuffle array
133        for (int i = size; i > 1; i--) {
134            int j = rng.nextInt(i);
135            int tmp = array[i-1];
136            array[i-1] = array[j];
137            array[j] = tmp;
138        }
139        return array;
140    }
141
142    /**
143     * Shuffles the indices in the range [0,size).
144     * @param size The number of elements.
145     * @param rng The random number generator to use.
146     * @return A random permutation of the values in the range (0, size-1).
147     */
148    public static int[] randperm(int size, SplittableRandom rng) {
149        int[] array = new int[size];
150        for (int i = 0; i < array.length; i++) {
151            array[i] = i;
152        }
153        // Shuffle array
154        for (int i = size; i > 1; i--) {
155            int j = rng.nextInt(i);
156            int tmp = array[i-1];
157            array[i-1] = array[j];
158            array[j] = tmp;
159        }
160        return array;
161    }
162
163    /**
164     * Shuffles the input.
165     * @param input The array to shuffle.
166     * @param rng The random number generator to use.
167     */
168    public static void randpermInPlace(int[] input, Random rng) {
169        // Shuffle array
170        for (int i = input.length; i > 1; i--) {
171            int j = rng.nextInt(i);
172            int tmp = input[i-1];
173            input[i-1] = input[j];
174            input[j] = tmp;
175        }
176    }
177
178    /**
179     * Shuffles the input.
180     * @param input The array to shuffle.
181     * @param rng The random number generator to use.
182     */
183    public static void randpermInPlace(int[] input, SplittableRandom rng) {
184        // Shuffle array
185        for (int i = input.length; i > 1; i--) {
186            int j = rng.nextInt(i);
187            int tmp = input[i-1];
188            input[i-1] = input[j];
189            input[j] = tmp;
190        }
191    }
192
193    /**
194     * Draws a bootstrap sample of indices.
195     * @param size Size of the sample to generate.
196     * @param rng The RNG to use.
197     * @return A bootstrap sample.
198     */
199    public static int[] generateBootstrapIndices(int size, Random rng) {
200        int[] array = new int[size];
201        for (int i = 0; i < size; i++) {
202            array[i] = rng.nextInt(size);
203        }
204        return array;
205    }
206
207    /**
208     * Draws a bootstrap sample of indices.
209     * @param size Size of the sample to generate.
210     * @param rng The RNG to use.
211     * @return A bootstrap sample.
212     */
213    public static int[] generateBootstrapIndices(int size, SplittableRandom rng) {
214        int[] array = new int[size];
215        for (int i = 0; i < size; i++) {
216            array[i] = rng.nextInt(size);
217        }
218        return array;
219    }
220
221    /**
222     * Generates a sample of indices weighted by the provided weights.
223     * @param size Size of the sample to generate.
224     * @param weights A probability mass function of weights.
225     * @param rng The RNG to use.
226     * @return A sample with replacement from weights.
227     */
228    public static int[] generateWeightedIndicesSample(int size, double[] weights, Random rng) {
229        double[] cdf = generateCDF(weights);
230        if (Math.abs(cdf[cdf.length-1] - 1.0) > 1e-10) {
231            throw new IllegalStateException("Weights do not sum to 1, cdf[cdf.length-1] = " + cdf[cdf.length-1]);
232        }
233        return generateWeightedIndicesSample(cdf, size, rng);
234    }
235
236    /**
237     * Generates a sample of indices weighted by the provided weights.
238     * @param size Size of the sample to generate.
239     * @param weights A probability mass function of weights.
240     * @param rng The RNG to use.
241     * @return A sample with replacement from weights.
242     */
243    public static int[] generateWeightedIndicesSample(int size, float[] weights, Random rng) {
244        double[] cdf = generateCDF(weights);
245        if (Math.abs(cdf[cdf.length - 1] - 1.0) > 1e-6) {
246            throw new IllegalStateException("Weights do not sum to 1, cdf[cdf.length-1] = " + cdf[cdf.length - 1]);
247        }
248        return generateWeightedIndicesSample(cdf, size, rng);
249    }
250
251    private static int[] generateWeightedIndicesSample(double[] cdf, int size, Random rng) {
252        int[] output = new int[size];
253
254        for (int i = 0; i < output.length; i++) {
255            double uniform = rng.nextDouble();
256            int searchVal = Arrays.binarySearch(cdf, uniform);
257            if (searchVal < 0) {
258                output[i] = - 1 - searchVal;
259            } else {
260                output[i] = searchVal;
261            }
262        }
263        return output;
264    }
265
266    /**
267     * Generates a sample of indices weighted by the provided weights.
268     * @param size Size of the sample to generate.
269     * @param weights A probability mass function of weights.
270     * @param rng The RNG to use.
271     * @return A sample with replacement from weights.
272     */
273    public static int[] generateWeightedIndicesSample(int size, double[] weights, SplittableRandom rng) {
274        double[] cdf = generateCDF(weights);
275        if (Math.abs(cdf[cdf.length-1] - 1.0) > 1e-10) {
276            throw new IllegalStateException("Weights do not sum to 1, cdf[cdf.length-1] = " + cdf[cdf.length-1]);
277        }
278        return generateWeightedIndicesSample(cdf, size, rng);
279    }
280
281    /**
282     * Generates a sample of indices weighted by the provided weights.
283     * @param size Size of the sample to generate.
284     * @param weights A probability mass function of weights.
285     * @param rng The RNG to use.
286     * @return A sample with replacement from weights.
287     */
288    public static int[] generateWeightedIndicesSample(int size, float[] weights, SplittableRandom rng) {
289        double[] cdf = generateCDF(weights);
290        if (Math.abs(cdf[cdf.length - 1] - 1.0) > 1e-6) {
291            throw new IllegalStateException("Weights do not sum to 1, cdf[cdf.length-1] = " + cdf[cdf.length - 1]);
292        }
293        return generateWeightedIndicesSample(cdf, size, rng);
294    }
295
296    private static int[] generateWeightedIndicesSample(double[] cdf, int size, SplittableRandom rng) {
297        int[] output = new int[size];
298
299        for (int i = 0; i < output.length; i++) {
300            double uniform = rng.nextDouble();
301            int searchVal = Arrays.binarySearch(cdf, uniform);
302            if (searchVal < 0) {
303                output[i] = - 1 - searchVal;
304            } else {
305                output[i] = searchVal;
306            }
307        }
308        return output;
309    }
310
311    /**
312     * Generates a sample of indices weighted by the provided weights without replacement. Does not recalculate
313     * proportions in-between samples. Use judiciously.
314     * @param size Size of the sample to generate
315     * @param weights A probability mass function of weights
316     * @param rng The RNG to use
317     * @return A sample without replacement from weights
318     */
319    public static int[] generateWeightedIndicesSampleWithoutReplacement(int size, double[] weights, Random rng) {
320        double[] cdf = generateCDF(weights);
321        if (Math.abs(cdf[cdf.length-1] - 1.0) > 1e-6) {
322            throw new IllegalStateException("Weights do not sum to 1, cdf[cdf.length-1] = " + cdf[cdf.length-1]);
323        }
324        int[] output = new int[size];
325        Set<Integer> seenIdxs = new HashSet<>();
326        int i = 0;
327        while(i < output.length) {
328            double uniform = rng.nextDouble();
329            int searchVal = Arrays.binarySearch(cdf, uniform);
330            int candidateSample = searchVal < 0 ? - 1 - searchVal : searchVal;
331            if(!seenIdxs.contains(candidateSample)) {
332                seenIdxs.add(candidateSample);
333                output[i] = candidateSample;
334                i++;
335            }
336        }
337        return output;
338    }
339
340    /**
341     * Generates a sample of indices weighted by the provided weights without replacement. Does not recalculate
342     * proportions in-between samples. Use judiciously.
343     * @param size Size of the sample to generate
344     * @param weights A probability mass function of weights
345     * @param rng The RNG to use
346     * @return A sample without replacement from weights
347     */
348    public static int[] generateWeightedIndicesSampleWithoutReplacement(int size, float[] weights, Random rng) {
349        double[] cdf = generateCDF(weights);
350        if (Math.abs(cdf[cdf.length-1] - 1.0) > 1e-6) {
351            throw new IllegalStateException("Weights do not sum to 1, cdf[cdf.length-1] = " + cdf[cdf.length-1]);
352        }
353        int[] output = new int[size];
354        Set<Integer> seenIdxs = new HashSet<>();
355        int i = 0;
356        while(i < output.length) {
357            double uniform = rng.nextDouble();
358            int searchVal = Arrays.binarySearch(cdf, uniform);
359            int candidateSample = searchVal < 0 ? - 1 - searchVal : searchVal;
360            if(!seenIdxs.contains(candidateSample)) {
361                seenIdxs.add(candidateSample);
362                output[i] = candidateSample;
363                i++;
364            }
365        }
366        return output;
367    }
368
369    /**
370     * Generates a cumulative distribution function from the supplied probability mass function.
371     * @param pmf The probability mass function (i.e., the probability distribution).
372     * @return The CDF.
373     */
374    public static double[] generateCDF(double[] pmf) {
375        return cumulativeSum(pmf);
376    }
377
378    /**
379     * Produces a cumulative sum array.
380     * @param input The input to sum.
381     * @return The cumulative sum.
382     */
383    public static double[] cumulativeSum(double[] input) {
384        double[] cdf = new double[input.length];
385
386        double sum = 0;
387        for (int i = 0; i < input.length; i++) {
388            sum += input[i];
389            cdf[i] = sum;
390        }
391
392        return cdf;
393    }
394
395    /**
396     * Produces a cumulative sum array.
397     * @param input The input to sum.
398     * @return The cumulative sum.
399     */
400    public static int[] cumulativeSum(boolean[] input) {
401        int[] cumulativeSum = new int[input.length];
402
403        int sum = 0;
404        for (int i = 0; i < input.length; i++) {
405            sum += input[i] ? 1 : 0;
406            cumulativeSum[i] = sum;
407        }
408
409        return cumulativeSum;
410    }
411
412    /**
413     * Generates a cumulative distribution function from the supplied probability mass function.
414     * @param pmf The probability mass function (i.e., the probability distribution).
415     * @return The CDF.
416     */
417    public static double[] generateCDF(float[] pmf) {
418        double[] cdf = new double[pmf.length];
419
420        double sum = 0;
421        for (int i = 0; i < pmf.length; i++) {
422            sum += pmf[i];
423            cdf[i] = sum;
424        }
425        
426        return cdf;
427    }
428
429    /**
430     * Generates a cumulative distribution function from the supplied probability mass function.
431     * @param counts The frequency counts.
432     * @param countSum The sum of the counts.
433     * @return The CDF.
434     */
435    public static double[] generateCDF(long[] counts, long countSum) {
436        double[] cdf = new double[counts.length];
437
438        double countSumD = countSum;
439        double probSum = 0.0;
440        for (int i = 0; i < counts.length; i++) {
441            probSum += counts[i] / countSumD;
442            cdf[i] = probSum;
443        }
444
445        return cdf;
446    }
447
448    /**
449     * Samples an index from the supplied cdf.
450     * @param cdf The cdf to sample from.
451     * @param rng The rng to use.
452     * @return A sample.
453     */
454    public static int sampleFromCDF(double[] cdf, Random rng) {
455        if (Math.abs(cdf[cdf.length-1] - 1.0) > 1e-6) {
456            throw new IllegalStateException("Weights do not sum to 1, cdf[cdf.length-1] = " + cdf[cdf.length-1]);
457        }
458        double uniform = rng.nextDouble();
459        int searchVal = Arrays.binarySearch(cdf, uniform);
460        if (searchVal < 0) {
461            return - 1 - searchVal;
462        } else {
463            return searchVal;
464        }
465    }
466
467    /**
468     * Samples an index from the supplied cdf.
469     * @param cdf The cdf to sample from.
470     * @param rng The rng to use.
471     * @return A sample.
472     */
473    public static int sampleFromCDF(double[] cdf, SplittableRandom rng) {
474        if (Math.abs(cdf[cdf.length-1] - 1.0) > 1e-6) {
475            throw new IllegalStateException("Weights do not sum to 1, cdf[cdf.length-1] = " + cdf[cdf.length-1]);
476        }
477        double uniform = rng.nextDouble();
478        int searchVal = Arrays.binarySearch(cdf, uniform);
479        if (searchVal < 0) {
480            return - 1 - searchVal;
481        } else {
482            return searchVal;
483        }
484    }
485
486    public static double[] generateUniformVector(int length, double value) {
487        double[] output = new double[length];
488
489        Arrays.fill(output, value);
490
491        return output;
492    }
493
494    public static float[] generateUniformVector(int length, float value) {
495        float[] output = new float[length];
496
497        Arrays.fill(output, value);
498
499        return output;
500    }
501
502    public static double[] normalizeToDistribution(double[] input) {
503        double[] output = new double[input.length];
504        double sum = 0.0;
505
506        for (int i = 0; i < input.length; i++) {
507            output[i] = input[i];
508            sum += output[i];
509        }
510
511        for (int i = 0; i < input.length; i++) {
512            output[i] /= sum;
513        }
514
515        return output;
516    }
517
518    public static double[] normalizeToDistribution(float[] input) {
519        double[] output = new double[input.length];
520        double sum = 0.0;
521
522        for (int i = 0; i < input.length; i++) {
523            output[i] = input[i];
524            sum += output[i];
525        }
526
527        for (int i = 0; i < input.length; i++) {
528            output[i] /= sum;
529        }
530
531        return output;
532    }
533
534    public static double[] inplaceNormalizeToDistribution(double[] input) {
535        double sum = 0.0;
536
537        for (int i = 0; i < input.length; i++) {
538            sum += input[i];
539        }
540
541        for (int i = 0; i < input.length; i++) {
542            input[i] /= sum;
543        }
544
545        return input;
546    }
547
548    public static void inplaceNormalizeToDistribution(float[] input) {
549        float sum = 0.0f;
550
551        for (int i = 0; i < input.length; i++) {
552            sum += input[i];
553        }
554
555        for (int i = 0; i < input.length; i++) {
556            input[i] /= sum;
557        }
558
559    }
560
561    public static void logVector(Logger otherLogger, Level level, double[] input) {
562        StringBuilder buffer = new StringBuilder();
563
564        for (int i = 0; i < input.length; i++) {
565            buffer.append("(");
566            buffer.append(i);
567            buffer.append(",");
568            buffer.append(input[i]);
569            buffer.append(") ");
570        }
571        buffer.deleteCharAt(buffer.length()-1);
572        otherLogger.log(level, buffer.toString());
573    }
574
575    public static void logVector(Logger otherLogger, Level level, float[] input) {
576        StringBuilder buffer = new StringBuilder();
577
578        for (int i = 0; i < input.length; i++) {
579            buffer.append("(");
580            buffer.append(i);
581            buffer.append(",");
582            buffer.append(input[i]);
583            buffer.append(") ");
584        }
585        buffer.deleteCharAt(buffer.length()-1);
586        otherLogger.log(level, buffer.toString());
587    }
588
589    public static double[] toPrimitiveDoubleFromInteger(List<Integer> input) {
590        double[] output = new double[input.size()];
591
592        for (int i = 0; i < input.size(); i++) {
593            output[i] = input.get(i);
594        }
595
596        return output;
597    }
598
599    public static double[] toPrimitiveDouble(List<Double> input) {
600        double[] output = new double[input.size()];
601
602        for (int i = 0; i < input.size(); i++) {
603            output[i] = input.get(i);
604        }
605        
606        return output;
607    }
608
609    public static float[] toPrimitiveFloat(List<Float> input) {
610        float[] output = new float[input.size()];
611
612        for (int i = 0; i < input.size(); i++) {
613            output[i] = input.get(i);
614        }
615        
616        return output;
617    }
618
619    public static int[] toPrimitiveInt(List<Integer> input) {
620        int[] output = new int[input.size()];
621
622        for (int i = 0; i < input.size(); i++) {
623            output[i] = input.get(i);
624        }
625        
626        return output;
627    }
628
629    public static long[] toPrimitiveLong(List<Long> input) {
630        long[] output = new long[input.size()];
631
632        for (int i = 0; i < input.size(); i++) {
633            output[i] = input.get(i);
634        }
635        
636        return output;
637    }
638
639    public static int[] sampleInts(Random rng, int size, int range) {
640        int[] output = new int[size];
641
642        for (int i = 0; i < output.length; i++) {
643            output[i] = rng.nextInt(range);
644        }
645
646        return output;
647    }
648
649    public static void inPlaceAdd(double[] input, double[] update) {
650        for (int i = 0; i < input.length; i++) {
651            input[i] += update[i];
652        }
653    }
654
655    public static void inPlaceSubtract(double[] input, double[] update) {
656        for (int i = 0; i < input.length; i++) {
657            input[i] -= update[i];
658        }
659    }
660
661    public static void inPlaceAdd(float[] input, float[] update) {
662        for (int i = 0; i < input.length; i++) {
663            input[i] += update[i];
664        }
665    }
666
667    public static void inPlaceSubtract(float[] input, float[] update) {
668        for (int i = 0; i < input.length; i++) {
669            input[i] -= update[i];
670        }
671    }
672
673    public static double vectorNorm(double[] input) {
674        double norm = 0.0;
675        for (double d : input) {
676            norm += d * d;
677        }
678        return norm;
679    }
680
681    public static double sum(double[] input) {
682        double sum = 0.0;
683        for (double d : input) {
684            sum += d;
685        }
686        return sum;
687    }
688
689    public static float sum(float[] input) {
690        float sum = 0.0f;
691        for (float d : input) {
692            sum += d;
693        }
694        return sum;
695    }
696
697    public static double sum(double[] array, int length) {
698        double sum = 0.0;
699        for (int i = 0; i < length; i++) {
700            sum += array[i];
701        }
702        return sum;
703    }
704
705    public static float sum(float[] array, int length) {
706        float sum = 0.0f;
707        for (int i = 0; i < length; i++) {
708            sum += array[i];
709        }
710        return sum;
711    }
712
713    public static float sum(int[] indices, int indicesLength, float[] input) {
714        float sum = 0.0f;
715        for (int i = 0; i < indicesLength; i++) {
716            sum += input[indices[i]];
717        }
718        return sum;
719    }
720
721    public static float sum(int[] indices, float[] input) {
722        return sum(indices,indices.length,input);
723    }
724
725    public static float[] generateUniformFloatVector(int length, float value) {
726        float[] output = new float[length];
727
728        Arrays.fill(output, value);
729
730        return output;
731    }
732
733    /**
734     * A binary search function.
735     * @param list Input list, must be ordered.
736     * @param key Key to search for.
737     * @param <T> Type of the list, must implement Comparable.
738     * @return the index of the search key, if it is contained in the list;
739     *         otherwise, (-(insertion point) - 1). The insertion point is
740     *         defined as the point at which the key would be inserted into
741     *         the list: the index of the first element greater than the key,
742     *         or list.size() if all elements in the list are less than the
743     *         specified key. Note that this guarantees that the return value
744     *         will be &gt;= 0 if and only if the key is found.
745     */
746    public static <T> int binarySearch(List<? extends Comparable<? super T>> list, T key) {
747        return binarySearch(list,key,0,list.size()-1);
748    }
749
750    /**
751     * A binary search function.
752     * @param list Input list, must be ordered.
753     * @param key Key to search for.
754     * @param low Starting index.
755     * @param high End index (will be searched).
756     * @param <T> Type of the list, must implement Comparable.
757     * @return the index of the search key, if it is contained in the list;
758     *         otherwise, (-(insertion point) - 1). The insertion point is
759     *         defined as the point at which the key would be inserted into
760     *         the list: the index of the first element greater than the key,
761     *         or high if all elements in the list are less than the
762     *         specified key. Note that this guarantees that the return value
763     *         will be &gt;= 0 if and only if the key is found.
764     */
765    public static <T> int binarySearch(List<? extends Comparable<? super T>> list, T key, int low, int high) {
766        while (low <= high) {
767            int mid = (low + high) >>> 1;
768            Comparable<? super T> midVal = list.get(mid);
769            int cmp = midVal.compareTo(key);
770            if (cmp < 0) {
771                low = mid + 1;
772            } else if (cmp > 0) {
773                high = mid - 1;
774            } else {
775                return mid; // key found
776            }
777        }
778        return -(low + 1);  // key not found
779    }
780
781    /**
782     * A binary search function.
783     * @param list Input list, must be ordered.
784     * @param key Key to search for.
785     * @param extractionFunc Takes a T and generates an int
786     *                       which can be used for comparison using int's natural ordering.
787     * @param <T> Type of the list, must implement Comparable.
788     * @return the index of the search key, if it is contained in the list;
789     *         otherwise, (-(insertion point) - 1). The insertion point is
790     *         defined as the point at which the key would be inserted into
791     *         the list: the index of the first element greater than the key,
792     *         or high if all elements in the list are less than the
793     *         specified key. Note that this guarantees that the return value
794     *         will be &gt;= 0 if and only if the key is found.
795     */
796    public static <T> int binarySearch(List<? extends T> list, int key, ToIntFunction<T> extractionFunc) {
797        int low = 0;
798        int high = list.size()-1;
799        while (low <= high) {
800            int mid = (low + high) >>> 1;
801            int midVal = extractionFunc.applyAsInt(list.get(mid));
802            int cmp = Integer.compare(midVal, key);
803            if (cmp < 0) {
804                low = mid + 1;
805            } else if (cmp > 0) {
806                high = mid - 1;
807            } else {
808                return mid; // key found
809            }
810        }
811        return -(low + 1);  // key not found
812    }
813
814    /**
815     * Calculates the area under the curve, bounded below by the x axis.
816     * <p>
817     * Uses linear interpolation between the points on the x axis,
818     * i.e., trapezoidal integration.
819     * <p>
820     * The x axis must be increasing.
821     * @param x The x points to evaluate.
822     * @param y The corresponding heights.
823     * @return The AUC.
824     */
825    public static double auc(double[] x, double[] y) {
826        if (x.length != y.length) {
827            throw new IllegalArgumentException("x and y must be the same length, x.length = " + x.length + ", y.length = " + y.length);
828        }
829        double output = 0.0;
830
831        for (int i = 1; i < x.length; i++) {
832            double ySum = y[i] + y[i-1];
833            double xDiff = x[i] - x[i-1];
834            if (xDiff < -1e-12) {
835                throw new IllegalStateException(String.format("X is not increasing, x[%d]=%f, x[%d]=%f",i,x[i],i-1,x[i-1]));
836            }
837            output += (ySum * xDiff) / 2.0;
838        }
839
840        return output;
841    }
842
843    /**
844     * Returns the mean and variance of the input.
845     * @param inputs The input array.
846     * @return The mean and variance of the inputs. The mean is the first element, the variance is the second.
847     */
848    public static Pair<Double,Double> meanAndVariance(double[] inputs) {
849        return meanAndVariance(inputs,inputs.length);
850    }
851
852    /**
853     * Returns the mean and variance of the input's first length elements.
854     * @param inputs The input array.
855     * @param length The number of elements to use.
856     * @return The mean and variance of the inputs. The mean is the first element, the variance is the second.
857     */
858    public static Pair<Double,Double> meanAndVariance(double[] inputs, int length) {
859        double mean = 0.0;
860        double sumSquares = 0.0;
861        for (int i = 0; i < length; i++) {
862            double value = inputs[i];
863            double delta = value - mean;
864            mean += delta / (i+1);
865            double delta2 = value - mean;
866            sumSquares += delta * delta2;
867        }
868        return new Pair<>(mean,sumSquares/(length-1));
869    }
870
871    /**
872     * Returns the weighted mean of the input.
873     * <p>
874     * Throws IllegalArgumentException if the two arrays are not the same length.
875     * @param inputs The input array.
876     * @param weights The weights to use.
877     * @return The weighted mean.
878     */
879    public static double weightedMean(double[] inputs, double[] weights) {
880        if (inputs.length != weights.length) {
881            throw new IllegalArgumentException("inputs and weights must be the same length, inputs.length = " + inputs.length + ", weights.length = " + weights.length);
882        }
883
884        double output = 0.0;
885        double sum = 0.0;
886        for (int i = 0; i < inputs.length; i++) {
887            output += inputs[i] * weights[i];
888            sum += weights[i];
889        }
890
891        return output/sum;
892    }
893
894    /**
895     * Returns the mean of the input array.
896     * @param inputs The input array.
897     * @return The mean of inputs.
898     */
899    public static double mean(double[] inputs) {
900        double output = 0.0;
901        for (int i = 0; i < inputs.length; i++) {
902            output += inputs[i];
903        }
904        return output / inputs.length;
905    }
906
907    public static double mean(double[] array, int length) {
908        double sum = sum(array,length);
909        return sum / length;
910    }
911
912    public static <V extends Number> double mean(Collection<V> values) {
913        double total = 0d;
914        for (V v : values) {
915            total += v.doubleValue();
916        }
917        return total / values.size();
918    }
919
920    public static <V extends Number> double sampleVariance(Collection<V> values) {
921        double mean = mean(values);
922        double total = 0d;
923        for (V v : values) {
924            total += Math.pow(v.doubleValue()-mean, 2);
925        }
926        return total / (values.size() - 1);
927    }
928
929    public static <V extends Number> double sampleStandardDeviation(Collection<V> values) {
930        return Math.sqrt(sampleVariance(values));
931    }
932
933    public static double weightedMean(double[] array, float[] weights, int length) {
934        double sum = weightedSum(array,weights,length);
935        return sum / sum(weights,length);
936    }
937
938    public static double weightedSum(double[] array, float[] weights, int length) {
939        if (array.length != weights.length) {
940            throw new IllegalArgumentException("array and weights must be the same length, array.length = " + array.length + ", weights.length = " + weights.length);
941        }
942
943        double sum = 0.0;
944        for (int i = 0; i < length; i++) {
945            sum += weights[i] * array[i];
946        }
947        return sum;
948    }
949
950    /**
951     * Returns an array containing the indices where values are different.
952     * Basically a combination of np.where and np.diff.
953     * <p>
954     * Stores an index if the value after it is different. Always stores the
955     * final index.
956     * <p>
957     * Uses a default tolerance of 1e-12.
958     * @param input Input array.
959     * @return An array containing the indices where the input changes.
960     */
961    public static int[] differencesIndices(double[] input) {
962        return differencesIndices(input,1e-12);
963    }
964
965    /**
966     * Returns an array containing the indices where values are different.
967     * Basically a combination of np.where and np.diff.
968     * <p>
969     * Stores an index if the value after it is different. Always stores the
970     * final index.
971     * @param input Input array.
972     * @param tolerance Tolerance to determine a difference.
973     * @return An array containing the indices where the input changes.
974     */
975    public static int[] differencesIndices(double[] input, double tolerance) {
976        List<Integer> indices = new ArrayList<>();
977
978        for (int i = 0; i < input.length-1; i++) {
979            double diff = Math.abs(input[i+1] - input[i]);
980            if (diff > tolerance) {
981                indices.add(i);
982            }
983        }
984        indices.add(input.length-1);
985
986        return Util.toPrimitiveInt(indices);
987    }
988
989    /**
990     * Formats a duration given two times in milliseconds.
991     * <p>
992     * Format string is - (%02d:%02d:%02d:%03d) or (%d days, %02d:%02d:%02d:%03d)
993     *
994     * @param startMillis Start time in ms.
995     * @param stopMillis End time in ms.
996     * @return A formatted string measuring time in hours, minutes, second and milliseconds.
997     */
998    public static String formatDuration(long startMillis, long stopMillis) {
999        long millis = stopMillis - startMillis;
1000        long second = (millis / 1000) % 60;
1001        long minute = (millis / (1000 * 60)) % 60;
1002        long hour = (millis / (1000 * 60 * 60)) % 24;
1003        long days = (millis / (1000 * 60 * 60)) / 24;
1004
1005        if (days == 0) {
1006            return String.format("(%02d:%02d:%02d:%03d)", hour, minute, second, millis % 1000);
1007        } else {
1008            return String.format("(%d days, %02d:%02d:%02d:%03d)", days, hour, minute, second, millis % 1000);
1009        }
1010    }
1011
1012    /**
1013     * Expects sorted input arrays. Returns an array containing all the elements in first that are not in second.
1014     * @param first The first sorted array.
1015     * @param second The second sorted array.
1016     * @return An array containing all the elements of first that aren't in second.
1017     */
1018    public static int[] sortedDifference(int[] first, int[] second) {
1019        List<Integer> diffIndicesList = new ArrayList<>();
1020
1021        int i = 0;
1022        int j = 0;
1023        while (i < first.length && j < second.length) {
1024            //after this loop, either itr is out or tuple.index >= otherTuple.index
1025            while (i < first.length && (first[i] < second[j])) {
1026                diffIndicesList.add(first[i]);
1027                i++;
1028            }
1029            //after this loop, either otherItr is out or tuple.index <= otherTuple.index
1030            while (j < second.length && (first[i] > second[j])) {
1031                j++;
1032            }
1033            if (first[i] != second[j]) {
1034                diffIndicesList.add(first[i]);
1035            }
1036        }
1037        for (; i < first.length; i++) {
1038            diffIndicesList.add(first[i]);
1039        }
1040        return diffIndicesList.stream().mapToInt(Integer::intValue).toArray();
1041    }
1042
1043}