001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.regression; 018 019import com.oracle.labs.mlrg.olcut.config.Config; 020import com.oracle.labs.mlrg.olcut.provenance.Provenance; 021import com.oracle.labs.mlrg.olcut.provenance.primitives.CharProvenance; 022import com.oracle.labs.mlrg.olcut.util.Pair; 023import org.tribuo.ImmutableOutputInfo; 024import org.tribuo.MutableOutputInfo; 025import org.tribuo.OutputFactory; 026import org.tribuo.evaluation.Evaluator; 027import org.tribuo.provenance.OutputFactoryProvenance; 028import org.tribuo.regression.evaluation.RegressionEvaluation; 029import org.tribuo.regression.evaluation.RegressionEvaluator; 030 031import java.util.ArrayList; 032import java.util.Collection; 033import java.util.Collections; 034import java.util.List; 035import java.util.Map; 036import java.util.Objects; 037 038/** 039 * A factory for creating {@link Regressor}s and {@link RegressionInfo}s. 040 * <p> 041 * It parses the regression dimensions by toStringing the input and calling {@link Regressor#parseString}, unless 042 * the input is a collection, in which case it extracts the elements. 043 * <p> 044 * This OutputFactory has mutable state, namely the character which the dimension input is split on. 045 * In most cases the default {@link RegressionFactory#DEFAULT_SPLIT_CHAR} is fine. 046 */ 047public final class RegressionFactory implements OutputFactory<Regressor> { 048 private static final long serialVersionUID = 2L; 049 050 public static final char DEFAULT_SPLIT_CHAR = ','; 051 052 @Config(description="The character to split the dimensions on.") 053 private char splitChar = DEFAULT_SPLIT_CHAR; 054 055 public static final Regressor UNKNOWN_REGRESSOR = new Regressor(new String[]{"UNKNOWN"}, new double[]{Double.NaN}); 056 057 public static final Regressor UNKNOWN_MULTIPLE_REGRESSOR = UNKNOWN_REGRESSOR; 058 059 private RegressionFactoryProvenance provenance; 060 061 private static final RegressionEvaluator evaluator = new RegressionEvaluator(); 062 063 /** 064 * Builds a regression factory using the default split character {@link RegressionFactory#DEFAULT_SPLIT_CHAR}. 065 */ 066 public RegressionFactory() { 067 this.provenance = new RegressionFactoryProvenance(splitChar); 068 } 069 070 /** 071 * Sets the split character used to parse {@link Regressor} instances from Strings. 072 * @param splitChar The split character. 073 */ 074 public RegressionFactory(char splitChar) { 075 this.splitChar = splitChar; 076 postConfig(); 077 } 078 079 /** 080 * Used by the OLCUT configuration system, and should not be called by external code. 081 */ 082 @Override 083 public void postConfig() { 084 this.provenance = new RegressionFactoryProvenance(splitChar); 085 } 086 087 /** 088 * Parses the MultipleRegression value either by toStringing the input and calling {@link Regressor#parseString} 089 * or if it's a collection iterating over the elements calling toString on each element in turn and using 090 * {@link Regressor#parseElement}. 091 * @param label An input value. 092 * @param <V> The type of the input value. 093 * @return A MultipleRegressor with sentinel variances. 094 */ 095 @Override 096 public <V> Regressor generateOutput(V label) { 097 if (label instanceof Collection) { 098 Collection<?> c = (Collection<?>) label; 099 List<Pair<String,Double>> dimensions = new ArrayList<>(); 100 int i = 0; 101 for (Object o : c) { 102 dimensions.add(Regressor.parseElement(i,o.toString())); 103 i++; 104 } 105 return Regressor.createFromPairList(dimensions); 106 } else { 107 return Regressor.parseString(label.toString(), splitChar); 108 } 109 } 110 111 @Override 112 public Regressor getUnknownOutput() { 113 return UNKNOWN_REGRESSOR; 114 } 115 116 @Override 117 public MutableOutputInfo<Regressor> generateInfo() { 118 return new MutableRegressionInfo(); 119 } 120 121 @Override 122 public ImmutableOutputInfo<Regressor> constructInfoForExternalModel(Map<Regressor,Integer> mapping) { 123 // Validate inputs are dense 124 OutputFactory.validateMapping(mapping); 125 126 MutableRegressionInfo info = new MutableRegressionInfo(); 127 128 for (Map.Entry<Regressor,Integer> e : mapping.entrySet()) { 129 info.observe(e.getKey()); 130 } 131 132 return new ImmutableRegressionInfo(info,mapping); 133 } 134 135 @Override 136 public Evaluator<Regressor, RegressionEvaluation> getEvaluator() { 137 return evaluator; 138 } 139 140 @Override 141 public int hashCode() { 142 return "RegressionFactory".hashCode(); 143 } 144 145 @Override 146 public boolean equals(Object obj) { 147 return obj instanceof RegressionFactory; 148 } 149 150 @Override 151 public OutputFactoryProvenance getProvenance() { 152 return provenance; 153 } 154 155 /** 156 * Provenance for {@link RegressionFactory}. 157 */ 158 public final static class RegressionFactoryProvenance implements OutputFactoryProvenance { 159 private static final long serialVersionUID = 1L; 160 161 private final char splitChar; 162 163 /** 164 * Constructs a provenance for the factory, reading it's split character. 165 * @param splitChar The split character used by the factory. 166 */ 167 RegressionFactoryProvenance(char splitChar) { 168 this.splitChar = splitChar; 169 } 170 171 /** 172 * Constructs a provenance from it's marshalled form. 173 * @param map The provenance map, containing a splitChar field. 174 */ 175 public RegressionFactoryProvenance(Map<String,Provenance> map) { 176 this.splitChar = ((CharProvenance)map.get("splitChar")).getValue(); 177 } 178 179 @Override 180 public Map<String, Provenance> getConfiguredParameters() { 181 return Collections.singletonMap("splitChar",new CharProvenance("splitChar",splitChar)); 182 } 183 184 @Override 185 public String getClassName() { 186 return RegressionFactory.class.getName(); 187 } 188 189 @Override 190 public String toString() { 191 return generateString("OutputFactory"); 192 } 193 194 @Override 195 public boolean equals(Object o) { 196 if (this == o) return true; 197 if (!(o instanceof RegressionFactoryProvenance)) return false; 198 RegressionFactoryProvenance pairs = (RegressionFactoryProvenance) o; 199 return splitChar == pairs.splitChar; 200 } 201 202 @Override 203 public int hashCode() { 204 return Objects.hash(splitChar); 205 } 206 } 207}