001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.data.columnar.processors.response;
018
019import com.oracle.labs.mlrg.olcut.config.Config;
020import com.oracle.labs.mlrg.olcut.provenance.ConfiguredObjectProvenance;
021import com.oracle.labs.mlrg.olcut.provenance.impl.ConfiguredObjectProvenanceImpl;
022import org.tribuo.Output;
023import org.tribuo.OutputFactory;
024import org.tribuo.data.columnar.ResponseProcessor;
025
026import java.util.Optional;
027
028/**
029 * Processes the response into quartiles and emits them as classification outputs.
030 * <p>
031 * The emitted outputs are of the form {@code {<name>:first, <name>:second, <name>:third, <name>:fourth} }.
032 */
033public class QuartileResponseProcessor<T extends Output<T>> implements ResponseProcessor<T> {
034
035    @Config(mandatory = true,description="The string to emit.")
036    private String name;
037
038    @Config(mandatory = true,description="The field name to read.")
039    private String fieldName;
040
041    @Config(mandatory = true,description="The quartile to use.")
042    private Quartile quartile;
043
044    @Config(mandatory = true,description="The output factory to use.")
045    private OutputFactory<T> outputFactory;
046
047    /**
048     * For olcut.
049     */
050    private QuartileResponseProcessor() {}
051
052    /**
053     * Constructs a repsonse processor which emits 4 distinct bins for the output factory to process.
054     * <p>
055     * This works best with classification outputs as the discrete binning is tricky to do in other output
056     * types.
057     * @param name The output string to emit.
058     * @param fieldName The field to read.
059     * @param quartile The quartile range to use.
060     * @param outputFactory The output factory to use.
061     */
062    public QuartileResponseProcessor(String name, String fieldName, Quartile quartile, OutputFactory<T> outputFactory) {
063        this.name = name;
064        this.fieldName = fieldName;
065        this.quartile = quartile;
066        this.outputFactory = outputFactory;
067    }
068
069    @Deprecated
070    @Override
071    public void setFieldName(String fieldName) {
072        this.fieldName = fieldName;
073    }
074
075    @Override
076    public OutputFactory<T> getOutputFactory() {
077        return outputFactory;
078    }
079
080    @Override
081    public String getFieldName() {
082        return fieldName;
083    }
084
085    @Override
086    public Optional<T> process(String value) {
087        if(value == null) {
088            return Optional.of(outputFactory.generateOutput(name + ":NONE"));
089        }
090        double dv = Double.parseDouble(value);
091        T output;
092        if (dv <= quartile.getLowerMedian()) {
093            output = outputFactory.generateOutput(name + ":first");
094        } else if (dv > quartile.getLowerMedian() && dv <= quartile.getMedian()) {
095            output = outputFactory.generateOutput(name + ":second");
096        } else if (dv > quartile.getMedian() && dv <= quartile.getUpperMedian()) {
097            output = outputFactory.generateOutput(name + ":third");
098        } else {
099            output = outputFactory.generateOutput(name + ":fourth");
100        }
101        return Optional.of(output);
102    }
103
104    @Override
105    public String toString() {
106        return "QuartileResponseProcessor(fieldName="+ fieldName +",quartile="+quartile.toString()+")";
107    }
108
109    @Override
110    public ConfiguredObjectProvenance getProvenance() {
111        return new ConfiguredObjectProvenanceImpl(this,"ResponseProcessor");
112    }
113}