001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.regression;
018
019import com.oracle.labs.mlrg.olcut.config.Config;
020import com.oracle.labs.mlrg.olcut.provenance.Provenance;
021import com.oracle.labs.mlrg.olcut.provenance.primitives.CharProvenance;
022import com.oracle.labs.mlrg.olcut.util.Pair;
023import org.tribuo.ImmutableOutputInfo;
024import org.tribuo.MutableOutputInfo;
025import org.tribuo.OutputFactory;
026import org.tribuo.evaluation.Evaluator;
027import org.tribuo.provenance.OutputFactoryProvenance;
028import org.tribuo.regression.evaluation.RegressionEvaluation;
029import org.tribuo.regression.evaluation.RegressionEvaluator;
030
031import java.util.ArrayList;
032import java.util.Collection;
033import java.util.Collections;
034import java.util.List;
035import java.util.Map;
036import java.util.Objects;
037
038/**
039 * A factory for creating {@link Regressor}s and {@link RegressionInfo}s.
040 * <p>
041 * It parses the regression dimensions by toStringing the input and calling {@link Regressor#parseString}, unless
042 * the input is a collection, in which case it extracts the elements.
043 * <p>
044 * This OutputFactory has mutable state, namely the character which the dimension input is split on.
045 * In most cases the default {@link RegressionFactory#DEFAULT_SPLIT_CHAR} is fine.
046 */
047public final class RegressionFactory implements OutputFactory<Regressor> {
048    private static final long serialVersionUID = 2L;
049
050    public static final char DEFAULT_SPLIT_CHAR = ',';
051
052    @Config(description="The character to split the dimensions on.")
053    private char splitChar = DEFAULT_SPLIT_CHAR;
054
055    public static final Regressor UNKNOWN_REGRESSOR = new Regressor(new String[]{"UNKNOWN"}, new double[]{Double.NaN});
056
057    public static final Regressor UNKNOWN_MULTIPLE_REGRESSOR = UNKNOWN_REGRESSOR;
058
059    private RegressionFactoryProvenance provenance;
060
061    private static final RegressionEvaluator evaluator = new RegressionEvaluator();
062
063    /**
064     * Builds a regression factory using the default split character {@link RegressionFactory#DEFAULT_SPLIT_CHAR}.
065     */
066    public RegressionFactory() {
067        this.provenance = new RegressionFactoryProvenance(splitChar);
068    }
069
070    /**
071     * Sets the split character used to parse {@link Regressor} instances from Strings.
072     * @param splitChar The split character.
073     */
074    public RegressionFactory(char splitChar) {
075        this.splitChar = splitChar;
076        postConfig();
077    }
078
079    /**
080     * Used by the OLCUT configuration system, and should not be called by external code.
081     */
082    @Override
083    public void postConfig() {
084        this.provenance = new RegressionFactoryProvenance(splitChar);
085    }
086
087    /**
088     * Parses the MultipleRegression value either by toStringing the input and calling {@link Regressor#parseString}
089     * or if it's a collection iterating over the elements calling toString on each element in turn and using
090     * {@link Regressor#parseElement}.
091     * @param label An input value.
092     * @param <V> The type of the input value.
093     * @return A MultipleRegressor with sentinel variances.
094     */
095    @Override
096    public <V> Regressor generateOutput(V label) {
097        if (label instanceof Collection) {
098            Collection<?> c = (Collection<?>) label;
099            List<Pair<String,Double>> dimensions = new ArrayList<>();
100            int i = 0;
101            for (Object o : c) {
102                dimensions.add(Regressor.parseElement(i,o.toString()));
103                i++;
104            }
105            return Regressor.createFromPairList(dimensions);
106        } else {
107            return Regressor.parseString(label.toString(), splitChar);
108        }
109    }
110
111    @Override
112    public Regressor getUnknownOutput() {
113        return UNKNOWN_REGRESSOR;
114    }
115
116    @Override
117    public MutableOutputInfo<Regressor> generateInfo() {
118        return new MutableRegressionInfo();
119    }
120
121    @Override
122    public ImmutableOutputInfo<Regressor> constructInfoForExternalModel(Map<Regressor,Integer> mapping) {
123        // Validate inputs are dense
124        OutputFactory.validateMapping(mapping);
125
126        MutableRegressionInfo info = new MutableRegressionInfo();
127
128        for (Map.Entry<Regressor,Integer> e : mapping.entrySet()) {
129            info.observe(e.getKey());
130        }
131
132        return new ImmutableRegressionInfo(info,mapping);
133    }
134
135    @Override
136    public Evaluator<Regressor, RegressionEvaluation> getEvaluator() {
137        return evaluator;
138    }
139
140    @Override
141    public int hashCode() {
142        return "RegressionFactory".hashCode();
143    }
144
145    @Override
146    public boolean equals(Object obj) {
147        return obj instanceof RegressionFactory;
148    }
149
150    @Override
151    public OutputFactoryProvenance getProvenance() {
152        return provenance;
153    }
154
155    /**
156     * Provenance for {@link RegressionFactory}.
157     */
158    public final static class RegressionFactoryProvenance implements OutputFactoryProvenance {
159        private static final long serialVersionUID = 1L;
160
161        private final char splitChar;
162
163        /**
164         * Constructs a provenance for the factory, reading it's split character.
165         * @param splitChar The split character used by the factory.
166         */
167        RegressionFactoryProvenance(char splitChar) {
168            this.splitChar = splitChar;
169        }
170
171        /**
172         * Constructs a provenance from it's marshalled form.
173         * @param map The provenance map, containing a splitChar field.
174         */
175        public RegressionFactoryProvenance(Map<String,Provenance> map) {
176            this.splitChar = ((CharProvenance)map.get("splitChar")).getValue();
177        }
178
179        @Override
180        public Map<String, Provenance> getConfiguredParameters() {
181            return Collections.singletonMap("splitChar",new CharProvenance("splitChar",splitChar));
182        }
183
184        @Override
185        public String getClassName() {
186            return RegressionFactory.class.getName();
187        }
188
189        @Override
190        public String toString() {
191            return generateString("OutputFactory");
192        }
193
194        @Override
195        public boolean equals(Object o) {
196            if (this == o) return true;
197            if (!(o instanceof RegressionFactoryProvenance)) return false;
198            RegressionFactoryProvenance pairs = (RegressionFactoryProvenance) o;
199            return splitChar == pairs.splitChar;
200        }
201
202        @Override
203        public int hashCode() {
204            return Objects.hash(splitChar);
205        }
206    }
207}