001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.transform.transformations;
018
019import com.oracle.labs.mlrg.olcut.config.Config;
020import com.oracle.labs.mlrg.olcut.provenance.ObjectProvenance;
021import com.oracle.labs.mlrg.olcut.provenance.Provenance;
022import com.oracle.labs.mlrg.olcut.provenance.primitives.DoubleProvenance;
023import com.oracle.labs.mlrg.olcut.provenance.primitives.EnumProvenance;
024import org.tribuo.transform.TransformStatistics;
025import org.tribuo.transform.Transformation;
026import org.tribuo.transform.TransformationProvenance;
027import org.tribuo.transform.Transformer;
028
029import java.io.Serializable;
030import java.util.Collections;
031import java.util.HashMap;
032import java.util.Map;
033import java.util.Objects;
034import java.util.function.DoubleUnaryOperator;
035
036/**
037 * This is used for stateless functions such as exp, log, addition or multiplication by a constant.
038 * <p>
039 * It's a Transformation, Transformer and TransformStatistics as it has
040 * no data dependent state. This means a single Transformer can be
041 * used for every feature in a dataset.
042 * <p>
043 * Wraps a {@link DoubleUnaryOperator} which actually performs the
044 * transformation.
045 */
046public final class SimpleTransform implements Transformer, Transformation, TransformStatistics {
047    private static final long serialVersionUID = 1L;
048
049    private static final String OP = "op";
050    private static final String OPERAND = "operand";
051    private static final String SECOND_OPERAND = "secondOperand";
052
053    public static final double EPSILON = 1e-12;
054
055    /**
056     * Operations understood by this Transformation.
057     */
058    public enum Operation {
059        /**
060         * Exponentiates the inputs
061         */
062        exp,
063        /**
064         * Logs the inputs (base_e)
065         */
066        log,
067        /**
068         * Adds the specified constant.
069         */
070        add,
071        /**
072         * Subtracts the specified constant.
073         */
074        sub,
075        /**
076         * Multiplies by the specified constant.
077         */
078        mul,
079        /**
080         * Divides by the specified constant.
081         */
082        div,
083        /**
084         * Binarises the output around 1.0.
085         */
086        binarise,
087        /**
088         * Min and max thresholds applied to the input.
089         */
090        threshold
091    }
092
093    @Config(mandatory = true,description="Type of the simple transformation.")
094    private Operation op;
095
096    @Config(description="Operand (if required).")
097    private double operand = Double.NaN;
098
099    @Config(description="Second operand (if required).")
100    private double secondOperand = Double.NaN;
101
102    private SerializableDoubleUnaryOperator operation;
103
104    private transient TransformationProvenance provenance;
105
106    /**
107     * For OLCUT.
108     */
109    private SimpleTransform() {}
110
111    private SimpleTransform(Operation op, double operand, double secondOperand) {
112        this.op = op;
113        this.operand = operand;
114        this.secondOperand = secondOperand;
115        postConfig();
116    }
117
118    private SimpleTransform(Operation op, double operand) {
119        this.op = op;
120        this.operand = operand;
121        postConfig();
122    }
123
124    private SimpleTransform(Operation op) {
125        this.op = op;
126        postConfig();
127    }
128
129    /**
130     * Used by the OLCUT configuration system, and should not be called by external code.
131     */
132    @Override
133    public void postConfig() {
134        switch (op) {
135            case exp:
136                operation = Math::exp;
137                break;
138            case log:
139                operation = Math::log;
140                break;
141            case add:
142                if (Double.isNaN(operand)) {
143                    throw new IllegalArgumentException("operand must not be NaN");
144                }
145                operation = (double input) -> input + operand;
146                break;
147            case sub:
148                if (Double.isNaN(operand)) {
149                    throw new IllegalArgumentException("operand must not be NaN");
150                }
151                operation = (double input) -> input - operand;
152                break;
153            case mul:
154                if (Double.isNaN(operand)) {
155                    throw new IllegalArgumentException("operand must not be NaN");
156                }
157                operation = (double input) -> input * operand;
158                break;
159            case div:
160                if (Double.isNaN(operand)) {
161                    throw new IllegalArgumentException("operand must not be NaN");
162                }
163                operation = (double input) -> input / operand;
164                break;
165            case binarise:
166                operation = (double input) -> input < EPSILON ? 0.0 : 1.0;
167                break;
168            case threshold:
169                if (operand > secondOperand) {
170                    throw new IllegalArgumentException("Min must be greater than max, min = " + operand + ", max = " + secondOperand);
171                } else if (Double.isNaN(operand) || Double.isNaN(secondOperand)) {
172                    throw new IllegalArgumentException("min and/or max must not be NaN");
173                }
174                operation = (double input) -> { if (input < operand) { return operand; } else if (input > secondOperand) { return secondOperand; } else { return input; } };
175                break;
176            default:
177                throw new IllegalArgumentException("Operation " + op + " is unknown");
178        }
179    }
180
181    @Override
182    public TransformationProvenance getProvenance() {
183        if (provenance == null) {
184            provenance = new SimpleTransformProvenance(this);
185        }
186        return provenance;
187    }
188
189    /**
190     * Provenance for {@link SimpleTransform}.
191     */
192    public final static class SimpleTransformProvenance implements TransformationProvenance {
193        private static final long serialVersionUID = 1L;
194
195        private final EnumProvenance<Operation> op;
196        private final DoubleProvenance operand;
197        private final DoubleProvenance secondOperand;
198
199        SimpleTransformProvenance(SimpleTransform host) {
200            this.op = new EnumProvenance<>(OP,host.op);
201            this.operand = new DoubleProvenance(OPERAND,host.operand);
202            this.secondOperand = new DoubleProvenance(SECOND_OPERAND,host.secondOperand);
203        }
204
205        @SuppressWarnings("unchecked") // Enum cast
206        public SimpleTransformProvenance(Map<String,Provenance> map) {
207            op = ObjectProvenance.checkAndExtractProvenance(map,OP,EnumProvenance.class, SimpleTransformProvenance.class.getSimpleName());
208            operand = ObjectProvenance.checkAndExtractProvenance(map,OPERAND,DoubleProvenance.class, SimpleTransformProvenance.class.getSimpleName());
209            secondOperand = ObjectProvenance.checkAndExtractProvenance(map,SECOND_OPERAND,DoubleProvenance.class,SimpleTransformProvenance.class.getSimpleName());
210        }
211
212        @Override
213        public String getClassName() {
214            return SimpleTransform.class.getName();
215        }
216
217        @Override
218        public boolean equals(Object o) {
219            if (this == o) return true;
220            if (!(o instanceof SimpleTransformProvenance)) return false;
221            SimpleTransformProvenance pairs = (SimpleTransformProvenance) o;
222            return op.equals(pairs.op) &&
223                    operand.equals(pairs.operand) &&
224                    secondOperand.equals(pairs.secondOperand);
225        }
226
227        @Override
228        public int hashCode() {
229            return Objects.hash(op, operand, secondOperand);
230        }
231
232        @Override
233        public Map<String, Provenance> getConfiguredParameters() {
234            Map<String,Provenance> map = new HashMap<>();
235            map.put(OP,op);
236            map.put(OPERAND,operand);
237            map.put(SECOND_OPERAND,secondOperand);
238            return Collections.unmodifiableMap(map);
239        }
240    }
241
242    /**
243     * No-op on this TransformStatistics.
244     * @param value The value to observe
245     */
246    @Override
247    public void observeValue(double value) { }
248
249    /**
250     * No-op on this TransformStatistics.
251     */
252    @Override
253    public void observeSparse() { }
254
255    /**
256     * No-op on this TransformStatistics.
257     */
258    @Override
259    public void observeSparse(int count) { }
260
261    /**
262     * Returns itself.
263     * @return this.
264     */
265    @Override
266    public Transformer generateTransformer() {
267        return this;
268    }
269
270    /**
271     * Returns itself.
272     * @return this.
273     */
274    @Override
275    public TransformStatistics createStats() {
276        return this;
277    }
278
279    /**
280     * Apply the operation to the input.
281     * @param input The input value to transform.
282     * @return The transformed value.
283     */
284    @Override
285    public double transform(double input) {
286        return operation.applyAsDouble(input);
287    }
288
289    @Override
290    public String toString() {
291        switch (op) {
292            case exp:
293                return "exp()";
294            case log:
295                return "log()";
296            case add:
297                return "add("+operand+")";
298            case sub:
299                return "sub("+operand+")";
300            case mul:
301                return "mul("+operand+")";
302            case div:
303                return "div("+operand+")";
304            case binarise:
305                return "binarise()";
306            case threshold:
307                return "threshold(min="+operand+",max="+secondOperand+")";
308            default:
309                return op.toString();
310        }
311    }
312
313    /**
314     * Generate a SimpleTransform that applies
315     * {@link Math#exp}.
316     * @return The exponential function.
317     */
318    public static SimpleTransform exp() {
319        return new SimpleTransform(Operation.exp);
320    }
321
322    /**
323     * Generate a SimpleTransform that applies
324     * {@link Math#log}.
325     * @return The logarithm function.
326     */
327    public static SimpleTransform log() {
328        return new SimpleTransform(Operation.log);
329    }
330
331    /**
332     * Generate a SimpleTransform that
333     * adds the operand to each value.
334     * @param operand The operand to add.
335     * @return An addition function.
336     */
337    public static SimpleTransform add(double operand) {
338        return new SimpleTransform(Operation.add,operand);
339    }
340
341    /**
342     * Generate a SimpleTransform that
343     * subtracts the operand from each value.
344     * @param operand The operand to subtract.
345     * @return A subtraction function.
346     */
347    public static SimpleTransform sub(double operand) {
348        return new SimpleTransform(Operation.sub,operand);
349    }
350
351    /**
352     * Generate a SimpleTransform that
353     * multiplies each value by the operand.
354     * @param operand The operand to multiply.
355     * @return A multiplication function.
356     */
357    public static SimpleTransform mul(double operand) {
358        return new SimpleTransform(Operation.mul,operand);
359    }
360
361    /**
362     * Generate a SimpleTransform that
363     * divides each value by the operand.
364     * @param operand The divisor.
365     * @return A division function.
366     */
367    public static SimpleTransform div(double operand) {
368        return new SimpleTransform(Operation.div,operand);
369    }
370
371    /**
372     * Generate a SimpleTransform that sets negative and
373     * zero values to zero and positive values to one.
374     * @return A binarising function.
375     */
376    public static SimpleTransform binarise() {
377        return new SimpleTransform(Operation.binarise);
378    }
379
380    /**
381     * Generate a SimpleTransform that sets values below min to
382     * min, and values above max to max.
383     * @param min The minimum value. To not threshold below, set to {@link Double#NEGATIVE_INFINITY}.
384     * @param max The maximum value. To not threshold above, set to {@link Double#POSITIVE_INFINITY}.
385     * @return A thresholding function.
386     */
387    public static SimpleTransform threshold(double min, double max) {
388        return new SimpleTransform(Operation.threshold,min,max);
389    }
390
391    /**
392     * Tag interface to make the operators serializable.
393     */
394    interface SerializableDoubleUnaryOperator extends DoubleUnaryOperator, Serializable {}
395}