001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.transform.transformations;
018
019import com.oracle.labs.mlrg.olcut.config.Config;
020import com.oracle.labs.mlrg.olcut.provenance.ObjectProvenance;
021import com.oracle.labs.mlrg.olcut.provenance.Provenance;
022import com.oracle.labs.mlrg.olcut.provenance.primitives.DoubleProvenance;
023import org.tribuo.transform.TransformStatistics;
024import org.tribuo.transform.Transformation;
025import org.tribuo.transform.TransformationProvenance;
026import org.tribuo.transform.Transformer;
027
028import java.util.Collections;
029import java.util.HashMap;
030import java.util.Map;
031import java.util.Objects;
032
033/**
034 * A Transformation which takes an observed distribution and rescales
035 * it so it has the desired mean and standard deviation.
036 * <p>
037 * Checks to see that the requested standard deviation is
038 * positive, throws IllegalArgumentException otherwise.
039 */
040public final class MeanStdDevTransformation implements Transformation {
041
042    private static final String TARGET_MEAN = "targetMean";
043    private static final String TARGET_STDDEV = "targetStdDev";
044
045    @Config(mandatory = true,description="Mean value after transformation.")
046    private double targetMean = 0.0;
047
048    @Config(mandatory = true,description="Standard deviation after transformation.")
049    private double targetStdDev = 1.0;
050
051    private MeanStdDevTransformationProvenance provenance;
052
053    /**
054     * Defaults to zero mean, one std dev.
055     */
056    public MeanStdDevTransformation() { }
057
058    public MeanStdDevTransformation(double targetMean, double targetStdDev) {
059        this.targetMean = targetMean;
060        this.targetStdDev = targetStdDev;
061        postConfig();
062    }
063
064    /**
065     * Used by the OLCUT configuration system, and should not be called by external code.
066     */
067    @Override
068    public void postConfig() {
069        if (targetStdDev < SimpleTransform.EPSILON) {
070            throw new IllegalArgumentException("Target standard deviation must be positive, found " + targetStdDev);
071        }
072    }
073
074    @Override
075    public TransformStatistics createStats() {
076        return new MeanStdDevStatistics(targetMean,targetStdDev);
077    }
078
079    @Override
080    public TransformationProvenance getProvenance() {
081        if (provenance == null) {
082            provenance = new MeanStdDevTransformationProvenance(this);
083        }
084        return provenance;
085    }
086
087    /**
088     * Provenance for {@link MeanStdDevTransformation}.
089     */
090    public final static class MeanStdDevTransformationProvenance implements TransformationProvenance {
091        private static final long serialVersionUID = 1L;
092
093        private final DoubleProvenance targetMean;
094        private final DoubleProvenance targetStdDev;
095
096        MeanStdDevTransformationProvenance(MeanStdDevTransformation host) {
097            this.targetMean = new DoubleProvenance(TARGET_MEAN, host.targetMean);
098            this.targetStdDev = new DoubleProvenance(TARGET_STDDEV, host.targetStdDev);
099        }
100
101        public MeanStdDevTransformationProvenance(Map<String, Provenance> map) {
102            targetMean = ObjectProvenance.checkAndExtractProvenance(map, TARGET_MEAN, DoubleProvenance.class, MeanStdDevTransformationProvenance.class.getSimpleName());
103            targetStdDev = ObjectProvenance.checkAndExtractProvenance(map, TARGET_STDDEV, DoubleProvenance.class, MeanStdDevTransformationProvenance.class.getSimpleName());
104        }
105
106        @Override
107        public String getClassName() {
108            return MeanStdDevTransformation.class.getName();
109        }
110
111        @Override
112        public boolean equals(Object o) {
113            if (this == o) return true;
114            if (!(o instanceof MeanStdDevTransformationProvenance)) return false;
115            MeanStdDevTransformationProvenance pairs = (MeanStdDevTransformationProvenance) o;
116            return targetMean.equals(pairs.targetMean) &&
117                    targetStdDev.equals(pairs.targetStdDev);
118        }
119
120        @Override
121        public int hashCode() {
122            return Objects.hash(targetMean, targetStdDev);
123        }
124
125        @Override
126        public Map<String, Provenance> getConfiguredParameters() {
127            Map<String,Provenance> map = new HashMap<>();
128            map.put(TARGET_MEAN,targetMean);
129            map.put(TARGET_STDDEV,targetStdDev);
130            return Collections.unmodifiableMap(map);
131        }
132    }
133
134    @Override
135    public String toString() {
136        return "MeanStdDevTransformation(targetMean="+targetMean+",targetStdDev="+targetStdDev+")";
137    }
138
139    private static class MeanStdDevStatistics implements TransformStatistics {
140
141        private final double targetMean;
142        private final double targetStdDev;
143
144        private double mean = 0;
145        private double sumSquares = 0;
146        private long count = 0;
147
148        public MeanStdDevStatistics(double targetMean, double targetStdDev) {
149            this.targetMean = targetMean;
150            this.targetStdDev = targetStdDev;
151        }
152
153        @Override
154        public void observeValue(double value) {
155            count++;
156            double delta = value - mean;
157            mean += delta / count;
158            double delta2 = value - mean;
159            sumSquares += delta * delta2;
160        }
161
162        @Override
163        public void observeSparse() { }
164
165        @Override
166        public void observeSparse(int count) { }
167
168        @Override
169        public Transformer generateTransformer() {
170            return new MeanStdDevTransformer(mean,Math.sqrt(sumSquares/(count-1)),targetMean,targetStdDev);
171        }
172
173        @Override
174        public String toString() {
175            return "MeanStdDevStatistics(mean="+mean
176                    +",sumSquares="+sumSquares+",count="+count
177                    +"targetMean="+targetMean+",targetStdDev="+targetStdDev+")";
178        }
179    }
180
181    private static class MeanStdDevTransformer implements Transformer {
182        private static final long serialVersionUID = 1L;
183
184        private final double observedMean;
185        private final double observedStdDev;
186        private final double targetMean;
187        private final double targetStdDev;
188
189        public MeanStdDevTransformer(double observedMean, double observedStdDev, double targetMean, double targetStdDev) {
190            this.observedMean = observedMean;
191            this.observedStdDev = observedStdDev;
192            this.targetMean = targetMean;
193            this.targetStdDev = targetStdDev;
194        }
195
196        @Override
197        public double transform(double input) {
198            return (((input - observedMean) / observedStdDev) * targetStdDev) + targetMean;
199        }
200
201        @Override
202        public String toString() {
203            return "MeanStdDevTransformer(observedMean="+observedMean+",observedStdDev="+observedStdDev+",targetMean="+targetMean+",targetStdDev="+targetStdDev+")";
204        }
205    }
206}
207