001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.transform.transformations;
018
019import com.oracle.labs.mlrg.olcut.config.Config;
020import com.oracle.labs.mlrg.olcut.provenance.ObjectProvenance;
021import com.oracle.labs.mlrg.olcut.provenance.Provenance;
022import com.oracle.labs.mlrg.olcut.provenance.primitives.DoubleProvenance;
023import org.tribuo.transform.TransformStatistics;
024import org.tribuo.transform.Transformation;
025import org.tribuo.transform.TransformationProvenance;
026import org.tribuo.transform.Transformer;
027
028import java.util.Collections;
029import java.util.HashMap;
030import java.util.Map;
031import java.util.Objects;
032
033/**
034 * A Transformation which takes an observed distribution and rescales
035 * it so all values are between the desired min and max. The scaling
036 * is linear.
037 * <p>
038 * Values outside the observed range are clamped to the desired
039 * min or max.
040 */
041public final class LinearScalingTransformation implements Transformation {
042
043    private static final String TARGET_MIN = "targetMin";
044    private static final String TARGET_MAX = "targetMax";
045
046    @Config(mandatory = true,description="Minimum value after transformation.")
047    private double targetMin = 0.0;
048
049    @Config(mandatory = true,description="Maximum value after transformation.")
050    private double targetMax = 1.0;
051
052    private TransformationProvenance provenance;
053
054    /**
055     * Defaults to zero - one.
056     */
057    public LinearScalingTransformation() { }
058
059    public LinearScalingTransformation(double targetMin, double targetMax) {
060        this.targetMin = targetMin;
061        this.targetMax = targetMax;
062        postConfig();
063    }
064
065    /**
066     * Used by the OLCUT configuration system, and should not be called by external code.
067     */
068    @Override
069    public void postConfig() {
070        if (targetMax < targetMin) {
071            throw new IllegalArgumentException("Range must be positive, min = " + targetMin + ", max = " + targetMax);
072        }
073    }
074
075    @Override
076    public TransformStatistics createStats() {
077        return new LinearScalingStatistics(targetMin, targetMax);
078    }
079
080    @Override
081    public synchronized TransformationProvenance getProvenance() {
082        if (provenance == null) {
083            provenance = new LinearScalingTransformationProvenance(this);
084        }
085        return provenance;
086    }
087
088    /**
089     * Provenance for {@link LinearScalingTransformation}.
090     */
091    public final static class LinearScalingTransformationProvenance implements TransformationProvenance {
092        private static final long serialVersionUID = 1L;
093
094        private final DoubleProvenance targetMin;
095        private final DoubleProvenance targetMax;
096
097        LinearScalingTransformationProvenance(LinearScalingTransformation host) {
098            this.targetMin = new DoubleProvenance(TARGET_MIN,host.targetMin);
099            this.targetMax = new DoubleProvenance(TARGET_MAX,host.targetMax);
100        }
101
102        public LinearScalingTransformationProvenance(Map<String,Provenance> map) {
103            targetMin = ObjectProvenance.checkAndExtractProvenance(map,TARGET_MIN,DoubleProvenance.class,LinearScalingTransformationProvenance.class.getSimpleName());
104            targetMax = ObjectProvenance.checkAndExtractProvenance(map,TARGET_MAX,DoubleProvenance.class,LinearScalingTransformationProvenance.class.getSimpleName());
105        }
106
107        @Override
108        public String getClassName() {
109            return LinearScalingTransformation.class.getName();
110        }
111
112        @Override
113        public boolean equals(Object o) {
114            if (this == o) return true;
115            if (!(o instanceof LinearScalingTransformationProvenance)) return false;
116            LinearScalingTransformationProvenance pairs = (LinearScalingTransformationProvenance) o;
117            return targetMin.equals(pairs.targetMin) &&
118                    targetMax.equals(pairs.targetMax);
119        }
120
121        @Override
122        public int hashCode() {
123            return Objects.hash(targetMin, targetMax);
124        }
125
126        @Override
127        public Map<String, Provenance> getConfiguredParameters() {
128            Map<String,Provenance> map = new HashMap<>();
129            map.put(TARGET_MIN,targetMin);
130            map.put(TARGET_MAX,targetMax);
131            return Collections.unmodifiableMap(map);
132        }
133    }
134
135    @Override
136    public String toString() {
137        return "LinearScalingTransformation(targetMin=" + targetMin + ",targetMax=" + targetMax + ")";
138    }
139
140    private static class LinearScalingStatistics implements TransformStatistics {
141
142        private final double targetMin;
143        private final double targetMax;
144
145        private double min = Double.POSITIVE_INFINITY;
146        private double max = Double.NEGATIVE_INFINITY;
147
148        public LinearScalingStatistics(double targetMin, double targetMax) {
149            this.targetMin = targetMin;
150            this.targetMax = targetMax;
151        }
152
153        @Override
154        public void observeValue(double value) {
155            if (value < min) {
156                min = value;
157            }
158            if (value > max) {
159                max = value;
160            }
161        }
162
163        @Override
164        public void observeSparse() { }
165
166        @Override
167        public void observeSparse(int count) { }
168
169        @Override
170        public Transformer generateTransformer() {
171            return new LinearScalingTransformer(min, max, targetMin, targetMax);
172        }
173
174        @Override
175        public String toString() {
176            return "LinearScalingStatistics(min="+min+",max="+max
177                    +",targetMin="+targetMin+",targetMax="+targetMax+")";
178        }
179    }
180
181    private static class LinearScalingTransformer implements Transformer {
182        private static final long serialVersionUID = 1L;
183
184        private final double observedMin;
185        private final double observedMax;
186        private final double targetMin;
187        private final double targetMax;
188        private final double scalingFactor;
189        private final boolean constant;
190
191        public LinearScalingTransformer(double observedMin, double observedMax, double targetMin, double targetMax) {
192            this.observedMin = observedMin;
193            this.observedMax = observedMax;
194            this.targetMin = targetMin;
195            this.targetMax = targetMax;
196            double observedRange = observedMax - observedMin;
197            this.constant = (observedRange == 0.0);
198            double targetRange = targetMax - targetMin;
199            this.scalingFactor = targetRange / observedRange;
200        }
201
202        @Override
203        public double transform(double input) {
204            if (constant) {
205                return (targetMax - targetMin) / 2.0;
206            } else if (input < observedMin) {
207                // If outside observed range, clamp to min or max.
208                return targetMin;
209            } else if (input > observedMax) {
210                return targetMax;
211            } else {
212                return ((input - observedMin) * scalingFactor) + targetMin;
213            }
214        }
215
216        @Override
217        public String toString() {
218            return "LinearScalingTransformer(observedMin="+observedMin+",observedMax="+observedMax+",targetMin="+targetMin+",targetMax="+targetMax+")";
219        }
220    }
221}