001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.transform.transformations; 018 019import com.oracle.labs.mlrg.olcut.config.Config; 020import com.oracle.labs.mlrg.olcut.provenance.ObjectProvenance; 021import com.oracle.labs.mlrg.olcut.provenance.Provenance; 022import com.oracle.labs.mlrg.olcut.provenance.primitives.DoubleProvenance; 023import org.tribuo.transform.TransformStatistics; 024import org.tribuo.transform.Transformation; 025import org.tribuo.transform.TransformationProvenance; 026import org.tribuo.transform.Transformer; 027 028import java.util.Collections; 029import java.util.HashMap; 030import java.util.Map; 031import java.util.Objects; 032 033/** 034 * A Transformation which takes an observed distribution and rescales 035 * it so it has the desired mean and standard deviation. 036 * <p> 037 * Checks to see that the requested standard deviation is 038 * positive, throws IllegalArgumentException otherwise. 039 */ 040public final class MeanStdDevTransformation implements Transformation { 041 042 private static final String TARGET_MEAN = "targetMean"; 043 private static final String TARGET_STDDEV = "targetStdDev"; 044 045 @Config(mandatory = true,description="Mean value after transformation.") 046 private double targetMean = 0.0; 047 048 @Config(mandatory = true,description="Standard deviation after transformation.") 049 private double targetStdDev = 1.0; 050 051 private MeanStdDevTransformationProvenance provenance; 052 053 /** 054 * Defaults to zero mean, one std dev. 055 */ 056 public MeanStdDevTransformation() { } 057 058 public MeanStdDevTransformation(double targetMean, double targetStdDev) { 059 this.targetMean = targetMean; 060 this.targetStdDev = targetStdDev; 061 postConfig(); 062 } 063 064 /** 065 * Used by the OLCUT configuration system, and should not be called by external code. 066 */ 067 @Override 068 public void postConfig() { 069 if (targetStdDev < SimpleTransform.EPSILON) { 070 throw new IllegalArgumentException("Target standard deviation must be positive, found " + targetStdDev); 071 } 072 } 073 074 @Override 075 public TransformStatistics createStats() { 076 return new MeanStdDevStatistics(targetMean,targetStdDev); 077 } 078 079 @Override 080 public TransformationProvenance getProvenance() { 081 if (provenance == null) { 082 provenance = new MeanStdDevTransformationProvenance(this); 083 } 084 return provenance; 085 } 086 087 /** 088 * Provenance for {@link MeanStdDevTransformation}. 089 */ 090 public final static class MeanStdDevTransformationProvenance implements TransformationProvenance { 091 private static final long serialVersionUID = 1L; 092 093 private final DoubleProvenance targetMean; 094 private final DoubleProvenance targetStdDev; 095 096 MeanStdDevTransformationProvenance(MeanStdDevTransformation host) { 097 this.targetMean = new DoubleProvenance(TARGET_MEAN, host.targetMean); 098 this.targetStdDev = new DoubleProvenance(TARGET_STDDEV, host.targetStdDev); 099 } 100 101 public MeanStdDevTransformationProvenance(Map<String, Provenance> map) { 102 targetMean = ObjectProvenance.checkAndExtractProvenance(map, TARGET_MEAN, DoubleProvenance.class, MeanStdDevTransformationProvenance.class.getSimpleName()); 103 targetStdDev = ObjectProvenance.checkAndExtractProvenance(map, TARGET_STDDEV, DoubleProvenance.class, MeanStdDevTransformationProvenance.class.getSimpleName()); 104 } 105 106 @Override 107 public String getClassName() { 108 return MeanStdDevTransformation.class.getName(); 109 } 110 111 @Override 112 public boolean equals(Object o) { 113 if (this == o) return true; 114 if (!(o instanceof MeanStdDevTransformationProvenance)) return false; 115 MeanStdDevTransformationProvenance pairs = (MeanStdDevTransformationProvenance) o; 116 return targetMean.equals(pairs.targetMean) && 117 targetStdDev.equals(pairs.targetStdDev); 118 } 119 120 @Override 121 public int hashCode() { 122 return Objects.hash(targetMean, targetStdDev); 123 } 124 125 @Override 126 public Map<String, Provenance> getConfiguredParameters() { 127 Map<String,Provenance> map = new HashMap<>(); 128 map.put(TARGET_MEAN,targetMean); 129 map.put(TARGET_STDDEV,targetStdDev); 130 return Collections.unmodifiableMap(map); 131 } 132 } 133 134 @Override 135 public String toString() { 136 return "MeanStdDevTransformation(targetMean="+targetMean+",targetStdDev="+targetStdDev+")"; 137 } 138 139 private static class MeanStdDevStatistics implements TransformStatistics { 140 141 private final double targetMean; 142 private final double targetStdDev; 143 144 private double mean = 0; 145 private double sumSquares = 0; 146 private long count = 0; 147 148 public MeanStdDevStatistics(double targetMean, double targetStdDev) { 149 this.targetMean = targetMean; 150 this.targetStdDev = targetStdDev; 151 } 152 153 @Override 154 public void observeValue(double value) { 155 count++; 156 double delta = value - mean; 157 mean += delta / count; 158 double delta2 = value - mean; 159 sumSquares += delta * delta2; 160 } 161 162 @Override 163 public void observeSparse() { } 164 165 @Override 166 public void observeSparse(int count) { } 167 168 @Override 169 public Transformer generateTransformer() { 170 return new MeanStdDevTransformer(mean,Math.sqrt(sumSquares/(count-1)),targetMean,targetStdDev); 171 } 172 173 @Override 174 public String toString() { 175 return "MeanStdDevStatistics(mean="+mean 176 +",sumSquares="+sumSquares+",count="+count 177 +"targetMean="+targetMean+",targetStdDev="+targetStdDev+")"; 178 } 179 } 180 181 private static class MeanStdDevTransformer implements Transformer { 182 private static final long serialVersionUID = 1L; 183 184 private final double observedMean; 185 private final double observedStdDev; 186 private final double targetMean; 187 private final double targetStdDev; 188 189 public MeanStdDevTransformer(double observedMean, double observedStdDev, double targetMean, double targetStdDev) { 190 this.observedMean = observedMean; 191 this.observedStdDev = observedStdDev; 192 this.targetMean = targetMean; 193 this.targetStdDev = targetStdDev; 194 } 195 196 @Override 197 public double transform(double input) { 198 return (((input - observedMean) / observedStdDev) * targetStdDev) + targetMean; 199 } 200 201 @Override 202 public String toString() { 203 return "MeanStdDevTransformer(observedMean="+observedMean+",observedStdDev="+observedStdDev+",targetMean="+targetMean+",targetStdDev="+targetStdDev+")"; 204 } 205 } 206} 207