001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.transform.transformations; 018 019import com.oracle.labs.mlrg.olcut.config.Config; 020import com.oracle.labs.mlrg.olcut.provenance.ObjectProvenance; 021import com.oracle.labs.mlrg.olcut.provenance.Provenance; 022import com.oracle.labs.mlrg.olcut.provenance.primitives.DoubleProvenance; 023import org.tribuo.transform.TransformStatistics; 024import org.tribuo.transform.Transformation; 025import org.tribuo.transform.TransformationProvenance; 026import org.tribuo.transform.Transformer; 027 028import java.util.Collections; 029import java.util.HashMap; 030import java.util.Map; 031import java.util.Objects; 032 033/** 034 * A Transformation which takes an observed distribution and rescales 035 * it so all values are between the desired min and max. The scaling 036 * is linear. 037 * <p> 038 * Values outside the observed range are clamped to the desired 039 * min or max. 040 */ 041public final class LinearScalingTransformation implements Transformation { 042 043 private static final String TARGET_MIN = "targetMin"; 044 private static final String TARGET_MAX = "targetMax"; 045 046 @Config(mandatory = true,description="Minimum value after transformation.") 047 private double targetMin = 0.0; 048 049 @Config(mandatory = true,description="Maximum value after transformation.") 050 private double targetMax = 1.0; 051 052 private TransformationProvenance provenance; 053 054 /** 055 * Defaults to zero - one. 056 */ 057 public LinearScalingTransformation() { } 058 059 public LinearScalingTransformation(double targetMin, double targetMax) { 060 this.targetMin = targetMin; 061 this.targetMax = targetMax; 062 postConfig(); 063 } 064 065 /** 066 * Used by the OLCUT configuration system, and should not be called by external code. 067 */ 068 @Override 069 public void postConfig() { 070 if (targetMax < targetMin) { 071 throw new IllegalArgumentException("Range must be positive, min = " + targetMin + ", max = " + targetMax); 072 } 073 } 074 075 @Override 076 public TransformStatistics createStats() { 077 return new LinearScalingStatistics(targetMin, targetMax); 078 } 079 080 @Override 081 public synchronized TransformationProvenance getProvenance() { 082 if (provenance == null) { 083 provenance = new LinearScalingTransformationProvenance(this); 084 } 085 return provenance; 086 } 087 088 /** 089 * Provenance for {@link LinearScalingTransformation}. 090 */ 091 public final static class LinearScalingTransformationProvenance implements TransformationProvenance { 092 private static final long serialVersionUID = 1L; 093 094 private final DoubleProvenance targetMin; 095 private final DoubleProvenance targetMax; 096 097 LinearScalingTransformationProvenance(LinearScalingTransformation host) { 098 this.targetMin = new DoubleProvenance(TARGET_MIN,host.targetMin); 099 this.targetMax = new DoubleProvenance(TARGET_MAX,host.targetMax); 100 } 101 102 public LinearScalingTransformationProvenance(Map<String,Provenance> map) { 103 targetMin = ObjectProvenance.checkAndExtractProvenance(map,TARGET_MIN,DoubleProvenance.class,LinearScalingTransformationProvenance.class.getSimpleName()); 104 targetMax = ObjectProvenance.checkAndExtractProvenance(map,TARGET_MAX,DoubleProvenance.class,LinearScalingTransformationProvenance.class.getSimpleName()); 105 } 106 107 @Override 108 public String getClassName() { 109 return LinearScalingTransformation.class.getName(); 110 } 111 112 @Override 113 public boolean equals(Object o) { 114 if (this == o) return true; 115 if (!(o instanceof LinearScalingTransformationProvenance)) return false; 116 LinearScalingTransformationProvenance pairs = (LinearScalingTransformationProvenance) o; 117 return targetMin.equals(pairs.targetMin) && 118 targetMax.equals(pairs.targetMax); 119 } 120 121 @Override 122 public int hashCode() { 123 return Objects.hash(targetMin, targetMax); 124 } 125 126 @Override 127 public Map<String, Provenance> getConfiguredParameters() { 128 Map<String,Provenance> map = new HashMap<>(); 129 map.put(TARGET_MIN,targetMin); 130 map.put(TARGET_MAX,targetMax); 131 return Collections.unmodifiableMap(map); 132 } 133 } 134 135 @Override 136 public String toString() { 137 return "LinearScalingTransformation(targetMin=" + targetMin + ",targetMax=" + targetMax + ")"; 138 } 139 140 private static class LinearScalingStatistics implements TransformStatistics { 141 142 private final double targetMin; 143 private final double targetMax; 144 145 private double min = Double.POSITIVE_INFINITY; 146 private double max = Double.NEGATIVE_INFINITY; 147 148 public LinearScalingStatistics(double targetMin, double targetMax) { 149 this.targetMin = targetMin; 150 this.targetMax = targetMax; 151 } 152 153 @Override 154 public void observeValue(double value) { 155 if (value < min) { 156 min = value; 157 } 158 if (value > max) { 159 max = value; 160 } 161 } 162 163 @Override 164 public void observeSparse() { } 165 166 @Override 167 public void observeSparse(int count) { } 168 169 @Override 170 public Transformer generateTransformer() { 171 return new LinearScalingTransformer(min, max, targetMin, targetMax); 172 } 173 174 @Override 175 public String toString() { 176 return "LinearScalingStatistics(min="+min+",max="+max 177 +",targetMin="+targetMin+",targetMax="+targetMax+")"; 178 } 179 } 180 181 private static class LinearScalingTransformer implements Transformer { 182 private static final long serialVersionUID = 1L; 183 184 private final double observedMin; 185 private final double observedMax; 186 private final double targetMin; 187 private final double targetMax; 188 private final double scalingFactor; 189 private final boolean constant; 190 191 public LinearScalingTransformer(double observedMin, double observedMax, double targetMin, double targetMax) { 192 this.observedMin = observedMin; 193 this.observedMax = observedMax; 194 this.targetMin = targetMin; 195 this.targetMax = targetMax; 196 double observedRange = observedMax - observedMin; 197 this.constant = (observedRange == 0.0); 198 double targetRange = targetMax - targetMin; 199 this.scalingFactor = targetRange / observedRange; 200 } 201 202 @Override 203 public double transform(double input) { 204 if (constant) { 205 return (targetMax - targetMin) / 2.0; 206 } else if (input < observedMin) { 207 // If outside observed range, clamp to min or max. 208 return targetMin; 209 } else if (input > observedMax) { 210 return targetMax; 211 } else { 212 return ((input - observedMin) * scalingFactor) + targetMin; 213 } 214 } 215 216 @Override 217 public String toString() { 218 return "LinearScalingTransformer(observedMin="+observedMin+",observedMax="+observedMax+",targetMin="+targetMin+",targetMax="+targetMax+")"; 219 } 220 } 221}