001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.transform.transformations; 018 019import com.oracle.labs.mlrg.olcut.config.Config; 020import com.oracle.labs.mlrg.olcut.provenance.ObjectProvenance; 021import com.oracle.labs.mlrg.olcut.provenance.Provenance; 022import com.oracle.labs.mlrg.olcut.provenance.primitives.DoubleProvenance; 023import com.oracle.labs.mlrg.olcut.provenance.primitives.EnumProvenance; 024import org.tribuo.transform.TransformStatistics; 025import org.tribuo.transform.Transformation; 026import org.tribuo.transform.TransformationProvenance; 027import org.tribuo.transform.Transformer; 028 029import java.io.Serializable; 030import java.util.Collections; 031import java.util.HashMap; 032import java.util.Map; 033import java.util.Objects; 034import java.util.function.DoubleUnaryOperator; 035 036/** 037 * This is used for stateless functions such as exp, log, addition or multiplication by a constant. 038 * <p> 039 * It's a Transformation, Transformer and TransformStatistics as it has 040 * no data dependent state. This means a single Transformer can be 041 * used for every feature in a dataset. 042 * <p> 043 * Wraps a {@link DoubleUnaryOperator} which actually performs the 044 * transformation. 045 */ 046public final class SimpleTransform implements Transformer, Transformation, TransformStatistics { 047 private static final long serialVersionUID = 1L; 048 049 private static final String OP = "op"; 050 private static final String OPERAND = "operand"; 051 private static final String SECOND_OPERAND = "secondOperand"; 052 053 public static final double EPSILON = 1e-12; 054 055 /** 056 * Operations understood by this Transformation. 057 */ 058 public enum Operation { 059 /** 060 * Exponentiates the inputs 061 */ 062 exp, 063 /** 064 * Logs the inputs (base_e) 065 */ 066 log, 067 /** 068 * Adds the specified constant. 069 */ 070 add, 071 /** 072 * Subtracts the specified constant. 073 */ 074 sub, 075 /** 076 * Multiplies by the specified constant. 077 */ 078 mul, 079 /** 080 * Divides by the specified constant. 081 */ 082 div, 083 /** 084 * Binarises the output around 1.0. 085 */ 086 binarise, 087 /** 088 * Min and max thresholds applied to the input. 089 */ 090 threshold 091 } 092 093 @Config(mandatory = true,description="Type of the simple transformation.") 094 private Operation op; 095 096 @Config(description="Operand (if required).") 097 private double operand = Double.NaN; 098 099 @Config(description="Second operand (if required).") 100 private double secondOperand = Double.NaN; 101 102 private SerializableDoubleUnaryOperator operation; 103 104 private transient TransformationProvenance provenance; 105 106 /** 107 * For OLCUT. 108 */ 109 private SimpleTransform() {} 110 111 private SimpleTransform(Operation op, double operand, double secondOperand) { 112 this.op = op; 113 this.operand = operand; 114 this.secondOperand = secondOperand; 115 postConfig(); 116 } 117 118 private SimpleTransform(Operation op, double operand) { 119 this.op = op; 120 this.operand = operand; 121 postConfig(); 122 } 123 124 private SimpleTransform(Operation op) { 125 this.op = op; 126 postConfig(); 127 } 128 129 /** 130 * Used by the OLCUT configuration system, and should not be called by external code. 131 */ 132 @Override 133 public void postConfig() { 134 switch (op) { 135 case exp: 136 operation = Math::exp; 137 break; 138 case log: 139 operation = Math::log; 140 break; 141 case add: 142 if (Double.isNaN(operand)) { 143 throw new IllegalArgumentException("operand must not be NaN"); 144 } 145 operation = (double input) -> input + operand; 146 break; 147 case sub: 148 if (Double.isNaN(operand)) { 149 throw new IllegalArgumentException("operand must not be NaN"); 150 } 151 operation = (double input) -> input - operand; 152 break; 153 case mul: 154 if (Double.isNaN(operand)) { 155 throw new IllegalArgumentException("operand must not be NaN"); 156 } 157 operation = (double input) -> input * operand; 158 break; 159 case div: 160 if (Double.isNaN(operand)) { 161 throw new IllegalArgumentException("operand must not be NaN"); 162 } 163 operation = (double input) -> input / operand; 164 break; 165 case binarise: 166 operation = (double input) -> input < EPSILON ? 0.0 : 1.0; 167 break; 168 case threshold: 169 if (operand > secondOperand) { 170 throw new IllegalArgumentException("Min must be greater than max, min = " + operand + ", max = " + secondOperand); 171 } else if (Double.isNaN(operand) || Double.isNaN(secondOperand)) { 172 throw new IllegalArgumentException("min and/or max must not be NaN"); 173 } 174 operation = (double input) -> { if (input < operand) { return operand; } else if (input > secondOperand) { return secondOperand; } else { return input; } }; 175 break; 176 default: 177 throw new IllegalArgumentException("Operation " + op + " is unknown"); 178 } 179 } 180 181 @Override 182 public TransformationProvenance getProvenance() { 183 if (provenance == null) { 184 provenance = new SimpleTransformProvenance(this); 185 } 186 return provenance; 187 } 188 189 /** 190 * Provenance for {@link SimpleTransform}. 191 */ 192 public final static class SimpleTransformProvenance implements TransformationProvenance { 193 private static final long serialVersionUID = 1L; 194 195 private final EnumProvenance<Operation> op; 196 private final DoubleProvenance operand; 197 private final DoubleProvenance secondOperand; 198 199 SimpleTransformProvenance(SimpleTransform host) { 200 this.op = new EnumProvenance<>(OP,host.op); 201 this.operand = new DoubleProvenance(OPERAND,host.operand); 202 this.secondOperand = new DoubleProvenance(SECOND_OPERAND,host.secondOperand); 203 } 204 205 @SuppressWarnings("unchecked") // Enum cast 206 public SimpleTransformProvenance(Map<String,Provenance> map) { 207 op = ObjectProvenance.checkAndExtractProvenance(map,OP,EnumProvenance.class, SimpleTransformProvenance.class.getSimpleName()); 208 operand = ObjectProvenance.checkAndExtractProvenance(map,OPERAND,DoubleProvenance.class, SimpleTransformProvenance.class.getSimpleName()); 209 secondOperand = ObjectProvenance.checkAndExtractProvenance(map,SECOND_OPERAND,DoubleProvenance.class,SimpleTransformProvenance.class.getSimpleName()); 210 } 211 212 @Override 213 public String getClassName() { 214 return SimpleTransform.class.getName(); 215 } 216 217 @Override 218 public boolean equals(Object o) { 219 if (this == o) return true; 220 if (!(o instanceof SimpleTransformProvenance)) return false; 221 SimpleTransformProvenance pairs = (SimpleTransformProvenance) o; 222 return op.equals(pairs.op) && 223 operand.equals(pairs.operand) && 224 secondOperand.equals(pairs.secondOperand); 225 } 226 227 @Override 228 public int hashCode() { 229 return Objects.hash(op, operand, secondOperand); 230 } 231 232 @Override 233 public Map<String, Provenance> getConfiguredParameters() { 234 Map<String,Provenance> map = new HashMap<>(); 235 map.put(OP,op); 236 map.put(OPERAND,operand); 237 map.put(SECOND_OPERAND,secondOperand); 238 return Collections.unmodifiableMap(map); 239 } 240 } 241 242 /** 243 * No-op on this TransformStatistics. 244 * @param value The value to observe 245 */ 246 @Override 247 public void observeValue(double value) { } 248 249 /** 250 * No-op on this TransformStatistics. 251 */ 252 @Override 253 public void observeSparse() { } 254 255 /** 256 * No-op on this TransformStatistics. 257 */ 258 @Override 259 public void observeSparse(int count) { } 260 261 /** 262 * Returns itself. 263 * @return this. 264 */ 265 @Override 266 public Transformer generateTransformer() { 267 return this; 268 } 269 270 /** 271 * Returns itself. 272 * @return this. 273 */ 274 @Override 275 public TransformStatistics createStats() { 276 return this; 277 } 278 279 /** 280 * Apply the operation to the input. 281 * @param input The input value to transform. 282 * @return The transformed value. 283 */ 284 @Override 285 public double transform(double input) { 286 return operation.applyAsDouble(input); 287 } 288 289 @Override 290 public String toString() { 291 switch (op) { 292 case exp: 293 return "exp()"; 294 case log: 295 return "log()"; 296 case add: 297 return "add("+operand+")"; 298 case sub: 299 return "sub("+operand+")"; 300 case mul: 301 return "mul("+operand+")"; 302 case div: 303 return "div("+operand+")"; 304 case binarise: 305 return "binarise()"; 306 case threshold: 307 return "threshold(min="+operand+",max="+secondOperand+")"; 308 default: 309 return op.toString(); 310 } 311 } 312 313 /** 314 * Generate a SimpleTransform that applies 315 * {@link Math#exp}. 316 * @return The exponential function. 317 */ 318 public static SimpleTransform exp() { 319 return new SimpleTransform(Operation.exp); 320 } 321 322 /** 323 * Generate a SimpleTransform that applies 324 * {@link Math#log}. 325 * @return The logarithm function. 326 */ 327 public static SimpleTransform log() { 328 return new SimpleTransform(Operation.log); 329 } 330 331 /** 332 * Generate a SimpleTransform that 333 * adds the operand to each value. 334 * @param operand The operand to add. 335 * @return An addition function. 336 */ 337 public static SimpleTransform add(double operand) { 338 return new SimpleTransform(Operation.add,operand); 339 } 340 341 /** 342 * Generate a SimpleTransform that 343 * subtracts the operand from each value. 344 * @param operand The operand to subtract. 345 * @return A subtraction function. 346 */ 347 public static SimpleTransform sub(double operand) { 348 return new SimpleTransform(Operation.sub,operand); 349 } 350 351 /** 352 * Generate a SimpleTransform that 353 * multiplies each value by the operand. 354 * @param operand The operand to multiply. 355 * @return A multiplication function. 356 */ 357 public static SimpleTransform mul(double operand) { 358 return new SimpleTransform(Operation.mul,operand); 359 } 360 361 /** 362 * Generate a SimpleTransform that 363 * divides each value by the operand. 364 * @param operand The divisor. 365 * @return A division function. 366 */ 367 public static SimpleTransform div(double operand) { 368 return new SimpleTransform(Operation.div,operand); 369 } 370 371 /** 372 * Generate a SimpleTransform that sets negative and 373 * zero values to zero and positive values to one. 374 * @return A binarising function. 375 */ 376 public static SimpleTransform binarise() { 377 return new SimpleTransform(Operation.binarise); 378 } 379 380 /** 381 * Generate a SimpleTransform that sets values below min to 382 * min, and values above max to max. 383 * @param min The minimum value. To not threshold below, set to {@link Double#NEGATIVE_INFINITY}. 384 * @param max The maximum value. To not threshold above, set to {@link Double#POSITIVE_INFINITY}. 385 * @return A thresholding function. 386 */ 387 public static SimpleTransform threshold(double min, double max) { 388 return new SimpleTransform(Operation.threshold,min,max); 389 } 390 391 /** 392 * Tag interface to make the operators serializable. 393 */ 394 interface SerializableDoubleUnaryOperator extends DoubleUnaryOperator, Serializable {} 395}