001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.data.columnar.processors.response; 018 019import com.oracle.labs.mlrg.olcut.config.Config; 020import com.oracle.labs.mlrg.olcut.provenance.ConfiguredObjectProvenance; 021import com.oracle.labs.mlrg.olcut.provenance.impl.ConfiguredObjectProvenanceImpl; 022import org.tribuo.Output; 023import org.tribuo.OutputFactory; 024import org.tribuo.data.columnar.ResponseProcessor; 025 026import java.util.Optional; 027 028/** 029 * Processes the response into quartiles and emits them as classification outputs. 030 * <p> 031 * The emitted outputs are of the form {@code {<name>:first, <name>:second, <name>:third, <name>:fourth} }. 032 */ 033public class QuartileResponseProcessor<T extends Output<T>> implements ResponseProcessor<T> { 034 035 @Config(mandatory = true,description="The string to emit.") 036 private String name; 037 038 @Config(mandatory = true,description="The field name to read.") 039 private String fieldName; 040 041 @Config(mandatory = true,description="The quartile to use.") 042 private Quartile quartile; 043 044 @Config(mandatory = true,description="The output factory to use.") 045 private OutputFactory<T> outputFactory; 046 047 /** 048 * For olcut. 049 */ 050 private QuartileResponseProcessor() {} 051 052 /** 053 * Constructs a repsonse processor which emits 4 distinct bins for the output factory to process. 054 * <p> 055 * This works best with classification outputs as the discrete binning is tricky to do in other output 056 * types. 057 * @param name The output string to emit. 058 * @param fieldName The field to read. 059 * @param quartile The quartile range to use. 060 * @param outputFactory The output factory to use. 061 */ 062 public QuartileResponseProcessor(String name, String fieldName, Quartile quartile, OutputFactory<T> outputFactory) { 063 this.name = name; 064 this.fieldName = fieldName; 065 this.quartile = quartile; 066 this.outputFactory = outputFactory; 067 } 068 069 @Deprecated 070 @Override 071 public void setFieldName(String fieldName) { 072 this.fieldName = fieldName; 073 } 074 075 @Override 076 public OutputFactory<T> getOutputFactory() { 077 return outputFactory; 078 } 079 080 @Override 081 public String getFieldName() { 082 return fieldName; 083 } 084 085 @Override 086 public Optional<T> process(String value) { 087 if(value == null) { 088 return Optional.of(outputFactory.generateOutput(name + ":NONE")); 089 } 090 double dv = Double.parseDouble(value); 091 T output; 092 if (dv <= quartile.getLowerMedian()) { 093 output = outputFactory.generateOutput(name + ":first"); 094 } else if (dv > quartile.getLowerMedian() && dv <= quartile.getMedian()) { 095 output = outputFactory.generateOutput(name + ":second"); 096 } else if (dv > quartile.getMedian() && dv <= quartile.getUpperMedian()) { 097 output = outputFactory.generateOutput(name + ":third"); 098 } else { 099 output = outputFactory.generateOutput(name + ":fourth"); 100 } 101 return Optional.of(output); 102 } 103 104 @Override 105 public String toString() { 106 return "QuartileResponseProcessor(fieldName="+ fieldName +",quartile="+quartile.toString()+")"; 107 } 108 109 @Override 110 public ConfiguredObjectProvenance getProvenance() { 111 return new ConfiguredObjectProvenanceImpl(this,"ResponseProcessor"); 112 } 113}