001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.evaluation;
018
019import org.tribuo.util.Util;
020
021import java.util.ArrayList;
022import java.util.Arrays;
023import java.util.Comparator;
024import java.util.List;
025import java.util.Objects;
026
027/**
028 * Descriptive statistics calculated across a list of doubles.
029 */
030public final class DescriptiveStats {
031
032    private final List<Double> samples = new ArrayList<>();
033
034    public DescriptiveStats() {}
035
036    public DescriptiveStats(List<Double> values) {
037        this.samples.addAll(values);
038    }
039
040    /**
041     * Package private method for appending a value to a DescriptiveStats.
042     * @param value A value to append.
043     */
044    void addValue(double value) {
045        samples.add(value);
046    }
047
048    /**
049     * Calculates the mean of the values.
050     * @return The mean.
051     */
052    public double getMean() {
053        return Util.mean(samples);
054    }
055
056    /**
057     * Calculates the sample variance of the values.
058     * @return The sample variance.
059     */
060    public double getVariance() {
061        return Util.sampleVariance(samples);
062    }
063
064    /**
065     * Calculates the standard deviation of the values.
066     * @return The standard deviation.
067     */
068    public double getStandardDeviation() {
069        return Util.sampleStandardDeviation(samples);
070    }
071
072    /**
073     * Calculates the max of the values.
074     * @return The maximum value.
075     */
076    public double getMax() {
077        return Util.argmax(samples).getB();
078    }
079
080    /**
081     * Calculates the min of the values.
082     * @return The minimum value.
083     */
084    public double getMin() {
085        return Util.argmin(samples).getB();
086    }
087
088    /**
089     * Returns the number of values.
090     * @return The number of values.
091     */
092    public long getN() {
093        return samples.size();
094    }
095
096    /**
097     * Returns a copy of the values.
098     * @return A copy of the values.
099     */
100    public List<Double> values() {
101        return new ArrayList<>(samples);
102    }
103
104    @Override
105    public boolean equals(Object o) {
106        if (this == o) return true;
107        if (o == null || getClass() != o.getClass()) return false;
108        DescriptiveStats that = (DescriptiveStats) o;
109        return samples.equals(that.samples);
110    }
111
112    @Override
113    public int hashCode() {
114        return Objects.hash(samples);
115    }
116
117    @Override
118    public String toString() {
119        StringBuilder sb = new StringBuilder();
120
121        List<String> rows = Arrays.asList("count", "mean", "std", "min", "max");
122        int maxRowLen = rows.stream().max(Comparator.comparingInt(String::length)).get().length();
123        String fmtStr = String.format("%%-%ds", maxRowLen+2);
124
125        sb.append(String.format(fmtStr, "count"));
126        sb.append(String.format("%d%n", getN()));
127
128        sb.append(String.format(fmtStr, "mean"));
129        sb.append(String.format("%.6f%n", getMean()));
130
131        sb.append(String.format(fmtStr, "std"));
132        sb.append(String.format("%.6f%n", getStandardDeviation()));
133
134        sb.append(String.format(fmtStr, "min"));
135        sb.append(String.format("%.6f%n", getMin()));
136
137        sb.append(String.format(fmtStr, "max"));
138        sb.append(String.format("%.6f%n", getMax()));
139
140        return sb.toString();
141    }
142}