001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.regression; 018 019import com.oracle.labs.mlrg.olcut.util.MutableDouble; 020import com.oracle.labs.mlrg.olcut.util.MutableLong; 021import com.oracle.labs.mlrg.olcut.util.Pair; 022import org.tribuo.ImmutableOutputInfo; 023 024import java.util.Arrays; 025import java.util.Collections; 026import java.util.Comparator; 027import java.util.HashMap; 028import java.util.Iterator; 029import java.util.LinkedHashMap; 030import java.util.LinkedHashSet; 031import java.util.Map; 032import java.util.Set; 033import java.util.SortedSet; 034import java.util.TreeSet; 035import java.util.logging.Level; 036import java.util.logging.Logger; 037 038/** 039 * A {@link ImmutableOutputInfo} for {@link Regressor}s. 040 */ 041public class ImmutableRegressionInfo extends RegressionInfo implements ImmutableOutputInfo<Regressor> { 042 private static final Logger logger = Logger.getLogger(ImmutableRegressionInfo.class.getName()); 043 044 private static final long serialVersionUID = 2L; 045 046 private final Map<Integer,String> idLabelMap; 047 048 private final Map<String,Integer> labelIDMap; 049 050 private final Set<Regressor> domain; 051 052 private ImmutableRegressionInfo(ImmutableRegressionInfo info) { 053 super(info); 054 idLabelMap = new LinkedHashMap<>(); 055 idLabelMap.putAll(info.idLabelMap); 056 labelIDMap = new LinkedHashMap<>(); 057 labelIDMap.putAll(info.labelIDMap); 058 domain = calculateDomain(minMap); 059 } 060 061 ImmutableRegressionInfo(RegressionInfo info) { 062 super(info); 063 idLabelMap = new LinkedHashMap<>(); 064 labelIDMap = new LinkedHashMap<>(); 065 int counter = 0; 066 for (Map.Entry<String,MutableLong> e : countMap.entrySet()) { 067 idLabelMap.put(counter,e.getKey()); 068 labelIDMap.put(e.getKey(),counter); 069 counter++; 070 } 071 domain = calculateDomain(minMap); 072 } 073 074 ImmutableRegressionInfo(RegressionInfo info, Map<Regressor,Integer> mapping) { 075 super(info); 076 if (mapping.size() != info.size()) { 077 throw new IllegalStateException("Mapping and info come from different sources, mapping.size() = " + mapping.size() + ", info.size() = " + info.size()); 078 } 079 080 idLabelMap = new HashMap<>(); 081 labelIDMap = new HashMap<>(); 082 for (Map.Entry<Regressor, Integer> e : mapping.entrySet()) { 083 Regressor r = e.getKey(); 084 String[] names = r.getNames(); 085 if (names.length == 1) { 086 idLabelMap.put(e.getValue(), names[0]); 087 labelIDMap.put(names[0], e.getValue()); 088 } else { 089 throw new IllegalArgumentException("Mapping must contain a single regression dimension per id, but contains " + Arrays.toString(names) + " -> " + e.getValue()); 090 } 091 } 092 domain = calculateDomain(minMap); 093 } 094 095 /** 096 * Generates the domain for this regression info. 097 * @param minMap The set of minimum values per dimension. 098 * @return the domain, sorted lexicographically by name 099 */ 100 private static Set<Regressor> calculateDomain(Map<String, MutableDouble> minMap) { 101 TreeSet<Regressor.DimensionTuple> outputs = new TreeSet<>(Comparator.comparing(Regressor.DimensionTuple::getName)); 102 for (Map.Entry<String,MutableDouble> e : minMap.entrySet()) { 103 outputs.add(new Regressor.DimensionTuple(e.getKey(),e.getValue().doubleValue())); 104 } 105 // 106 // Now that we're sorted, simplify our output into a LinkedHashSet so we don't hang on to 107 // the comparator we used above in the TreeSet 108 LinkedHashSet<Regressor.DimensionTuple> preSortedOutputs = new LinkedHashSet<>(); 109 preSortedOutputs.addAll(outputs); 110 @SuppressWarnings("unchecked") // DimensionTuple is a subtype of Regressor, and this set is immutable. 111 Set<Regressor> immutableOutputs = (Set<Regressor>) (Set) Collections.unmodifiableSet(preSortedOutputs); 112 return immutableOutputs; 113 } 114 115 @Override 116 public Set<Regressor> getDomain() { 117 return domain; 118 } 119 120 @Override 121 public int getID(Regressor output) { 122 return labelIDMap.getOrDefault(output.getDimensionNamesString(),-1); 123 } 124 125 @Override 126 public Regressor getOutput(int id) { 127 String label = idLabelMap.get(id); 128 if (label != null) { 129 return new Regressor(label,1.0); 130 } else { 131 logger.log(Level.INFO,"No entry found for id " + id); 132 return null; 133 } 134 } 135 136 @Override 137 public long getTotalObservations() { 138 return overallCount; 139 } 140 141 @Override 142 public ImmutableRegressionInfo copy() { 143 return new ImmutableRegressionInfo(this); 144 } 145 146 @Override 147 public String toString() { 148 StringBuilder builder = new StringBuilder(); 149 builder.append("MultipleRegressionOutput("); 150 for (Map.Entry<String,MutableLong> e : countMap.entrySet()) { 151 String name = e.getKey(); 152 long count = e.getValue().longValue(); 153 builder.append(String.format("{name=%s,id=%d,count=%d,maxMap=%f,min=%f,mean=%f,variance=%f},", 154 name, 155 labelIDMap.get(name), 156 count, 157 maxMap.get(name).doubleValue(), 158 minMap.get(name).doubleValue(), 159 meanMap.get(name).doubleValue(), 160 (sumSquaresMap.get(name).doubleValue() / (count - 1)) 161 )); 162 } 163 builder.deleteCharAt(builder.length()-1); 164 builder.append(')'); 165 return builder.toString(); 166 } 167 168 @Override 169 public String toReadableString() { 170 return toString(); 171 } 172 173 @Override 174 public Iterator<Pair<Integer, Regressor>> iterator() { 175 return new ImmutableInfoIterator(idLabelMap); 176 } 177 178 private static class ImmutableInfoIterator implements Iterator<Pair<Integer, Regressor>> { 179 180 private final Iterator<Map.Entry<Integer,String>> itr; 181 182 public ImmutableInfoIterator(Map<Integer,String> idLabelMap) { 183 itr = idLabelMap.entrySet().iterator(); 184 } 185 186 @Override 187 public boolean hasNext() { 188 return itr.hasNext(); 189 } 190 191 @Override 192 public Pair<Integer, Regressor> next() { 193 Map.Entry<Integer,String> e = itr.next(); 194 return new Pair<>(e.getKey(),new Regressor(e.getValue(),1.0)); 195 } 196 } 197}