001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.regression;
018
019import com.oracle.labs.mlrg.olcut.util.MutableDouble;
020import com.oracle.labs.mlrg.olcut.util.MutableLong;
021import com.oracle.labs.mlrg.olcut.util.Pair;
022import org.tribuo.ImmutableOutputInfo;
023
024import java.util.Arrays;
025import java.util.Collections;
026import java.util.Comparator;
027import java.util.HashMap;
028import java.util.Iterator;
029import java.util.LinkedHashMap;
030import java.util.LinkedHashSet;
031import java.util.Map;
032import java.util.Set;
033import java.util.SortedSet;
034import java.util.TreeSet;
035import java.util.logging.Level;
036import java.util.logging.Logger;
037
038/**
039 * A {@link ImmutableOutputInfo} for {@link Regressor}s.
040 */
041public class ImmutableRegressionInfo extends RegressionInfo implements ImmutableOutputInfo<Regressor> {
042    private static final Logger logger = Logger.getLogger(ImmutableRegressionInfo.class.getName());
043
044    private static final long serialVersionUID = 2L;
045
046    private final Map<Integer,String> idLabelMap;
047
048    private final Map<String,Integer> labelIDMap;
049
050    private final Set<Regressor> domain;
051
052    private ImmutableRegressionInfo(ImmutableRegressionInfo info) {
053        super(info);
054        idLabelMap = new LinkedHashMap<>();
055        idLabelMap.putAll(info.idLabelMap);
056        labelIDMap = new LinkedHashMap<>();
057        labelIDMap.putAll(info.labelIDMap);
058        domain = calculateDomain(minMap);
059    }
060
061    ImmutableRegressionInfo(RegressionInfo info) {
062        super(info);
063        idLabelMap = new LinkedHashMap<>();
064        labelIDMap = new LinkedHashMap<>();
065        int counter = 0;
066        for (Map.Entry<String,MutableLong> e : countMap.entrySet()) {
067            idLabelMap.put(counter,e.getKey());
068            labelIDMap.put(e.getKey(),counter);
069            counter++;
070        }
071        domain = calculateDomain(minMap);
072    }
073
074    ImmutableRegressionInfo(RegressionInfo info, Map<Regressor,Integer> mapping) {
075        super(info);
076        if (mapping.size() != info.size()) {
077            throw new IllegalStateException("Mapping and info come from different sources, mapping.size() = " + mapping.size() + ", info.size() = " + info.size());
078        }
079
080        idLabelMap = new HashMap<>();
081        labelIDMap = new HashMap<>();
082        for (Map.Entry<Regressor, Integer> e : mapping.entrySet()) {
083            Regressor r = e.getKey();
084            String[] names = r.getNames();
085            if (names.length == 1) {
086                idLabelMap.put(e.getValue(), names[0]);
087                labelIDMap.put(names[0], e.getValue());
088            } else {
089                throw new IllegalArgumentException("Mapping must contain a single regression dimension per id, but contains " + Arrays.toString(names) + " -> " + e.getValue());
090            }
091        }
092        domain = calculateDomain(minMap);
093    }
094
095    /**
096     * Generates the domain for this regression info.
097     * @param minMap The set of minimum values per dimension.
098     * @return the domain, sorted lexicographically by name
099     */
100    private static Set<Regressor> calculateDomain(Map<String, MutableDouble> minMap) {
101        TreeSet<Regressor.DimensionTuple> outputs = new TreeSet<>(Comparator.comparing(Regressor.DimensionTuple::getName));
102        for (Map.Entry<String,MutableDouble> e : minMap.entrySet()) {
103            outputs.add(new Regressor.DimensionTuple(e.getKey(),e.getValue().doubleValue()));
104        }
105        //
106        // Now that we're sorted, simplify our output into a LinkedHashSet so we don't hang on to
107        // the comparator we used above in the TreeSet
108        LinkedHashSet<Regressor.DimensionTuple> preSortedOutputs = new LinkedHashSet<>();
109        preSortedOutputs.addAll(outputs);
110        @SuppressWarnings("unchecked") // DimensionTuple is a subtype of Regressor, and this set is immutable.
111        Set<Regressor> immutableOutputs = (Set<Regressor>) (Set) Collections.unmodifiableSet(preSortedOutputs);
112        return immutableOutputs;
113    }
114
115    @Override
116    public Set<Regressor> getDomain() {
117        return domain;
118    }
119
120    @Override
121    public int getID(Regressor output) {
122        return labelIDMap.getOrDefault(output.getDimensionNamesString(),-1);
123    }
124
125    @Override
126    public Regressor getOutput(int id) {
127        String label = idLabelMap.get(id);
128        if (label != null) {
129            return new Regressor(label,1.0);
130        } else {
131            logger.log(Level.INFO,"No entry found for id " + id);
132            return null;
133        }
134    }
135
136    @Override
137    public long getTotalObservations() {
138        return overallCount;
139    }
140
141    @Override
142    public ImmutableRegressionInfo copy() {
143        return new ImmutableRegressionInfo(this);
144    }
145
146    @Override
147    public String toString() {
148        StringBuilder builder = new StringBuilder();
149        builder.append("MultipleRegressionOutput(");
150        for (Map.Entry<String,MutableLong> e : countMap.entrySet()) {
151            String name = e.getKey();
152            long count = e.getValue().longValue();
153            builder.append(String.format("{name=%s,id=%d,count=%d,maxMap=%f,min=%f,mean=%f,variance=%f},",
154                    name,
155                    labelIDMap.get(name),
156                    count,
157                    maxMap.get(name).doubleValue(),
158                    minMap.get(name).doubleValue(),
159                    meanMap.get(name).doubleValue(),
160                    (sumSquaresMap.get(name).doubleValue() / (count - 1))
161            ));
162        }
163        builder.deleteCharAt(builder.length()-1);
164        builder.append(')');
165        return builder.toString();
166    }
167
168    @Override
169    public String toReadableString() {
170        return toString();
171    }
172
173    @Override
174    public Iterator<Pair<Integer, Regressor>> iterator() {
175        return new ImmutableInfoIterator(idLabelMap);
176    }
177
178    private static class ImmutableInfoIterator implements Iterator<Pair<Integer, Regressor>> {
179
180        private final Iterator<Map.Entry<Integer,String>> itr;
181
182        public ImmutableInfoIterator(Map<Integer,String> idLabelMap) {
183            itr = idLabelMap.entrySet().iterator();
184        }
185
186        @Override
187        public boolean hasNext() {
188            return itr.hasNext();
189        }
190
191        @Override
192        public Pair<Integer, Regressor> next() {
193            Map.Entry<Integer,String> e = itr.next();
194            return new Pair<>(e.getKey(),new Regressor(e.getValue(),1.0));
195        }
196    }
197}