001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.evaluation; 018 019import org.tribuo.util.Util; 020 021import java.util.ArrayList; 022import java.util.Arrays; 023import java.util.Comparator; 024import java.util.List; 025import java.util.Objects; 026 027/** 028 * Descriptive statistics calculated across a list of doubles. 029 */ 030public final class DescriptiveStats { 031 032 private final List<Double> samples = new ArrayList<>(); 033 034 public DescriptiveStats() {} 035 036 public DescriptiveStats(List<Double> values) { 037 this.samples.addAll(values); 038 } 039 040 /** 041 * Package private method for appending a value to a DescriptiveStats. 042 * @param value A value to append. 043 */ 044 void addValue(double value) { 045 samples.add(value); 046 } 047 048 /** 049 * Calculates the mean of the values. 050 * @return The mean. 051 */ 052 public double getMean() { 053 return Util.mean(samples); 054 } 055 056 /** 057 * Calculates the sample variance of the values. 058 * @return The sample variance. 059 */ 060 public double getVariance() { 061 return Util.sampleVariance(samples); 062 } 063 064 /** 065 * Calculates the standard deviation of the values. 066 * @return The standard deviation. 067 */ 068 public double getStandardDeviation() { 069 return Util.sampleStandardDeviation(samples); 070 } 071 072 /** 073 * Calculates the max of the values. 074 * @return The maximum value. 075 */ 076 public double getMax() { 077 return Util.argmax(samples).getB(); 078 } 079 080 /** 081 * Calculates the min of the values. 082 * @return The minimum value. 083 */ 084 public double getMin() { 085 return Util.argmin(samples).getB(); 086 } 087 088 /** 089 * Returns the number of values. 090 * @return The number of values. 091 */ 092 public long getN() { 093 return samples.size(); 094 } 095 096 /** 097 * Returns a copy of the values. 098 * @return A copy of the values. 099 */ 100 public List<Double> values() { 101 return new ArrayList<>(samples); 102 } 103 104 @Override 105 public boolean equals(Object o) { 106 if (this == o) return true; 107 if (o == null || getClass() != o.getClass()) return false; 108 DescriptiveStats that = (DescriptiveStats) o; 109 return samples.equals(that.samples); 110 } 111 112 @Override 113 public int hashCode() { 114 return Objects.hash(samples); 115 } 116 117 @Override 118 public String toString() { 119 StringBuilder sb = new StringBuilder(); 120 121 List<String> rows = Arrays.asList("count", "mean", "std", "min", "max"); 122 int maxRowLen = rows.stream().max(Comparator.comparingInt(String::length)).get().length(); 123 String fmtStr = String.format("%%-%ds", maxRowLen+2); 124 125 sb.append(String.format(fmtStr, "count")); 126 sb.append(String.format("%d%n", getN())); 127 128 sb.append(String.format(fmtStr, "mean")); 129 sb.append(String.format("%.6f%n", getMean())); 130 131 sb.append(String.format(fmtStr, "std")); 132 sb.append(String.format("%.6f%n", getStandardDeviation())); 133 134 sb.append(String.format(fmtStr, "min")); 135 sb.append(String.format("%.6f%n", getMin())); 136 137 sb.append(String.format(fmtStr, "max")); 138 sb.append(String.format("%.6f%n", getMax())); 139 140 return sb.toString(); 141 } 142}