001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo; 018 019import com.oracle.labs.mlrg.olcut.util.MutableLong; 020 021import java.util.Map; 022 023/** 024 * Same as a {@link CategoricalInfo}, but with an additional int id field. 025 */ 026public class CategoricalIDInfo extends CategoricalInfo implements VariableIDInfo { 027 private static final long serialVersionUID = 2L; 028 029 private final int id; 030 031 /** 032 * Constructs a categorical id info copying the information from the supplied info, with the specified id. 033 * @param info The info to copy. 034 * @param id The id number to use. 035 */ 036 public CategoricalIDInfo(CategoricalInfo info, int id) { 037 super(info); 038 this.id = id; 039 } 040 041 /** 042 * Constructs a copy of the supplied categorical id info with the new name. 043 * <p> 044 * Used in the feature hashing system. 045 * @param info The info to copy. 046 * @param newName The new feature name. 047 */ 048 private CategoricalIDInfo(CategoricalIDInfo info, String newName) { 049 super(info,newName); 050 this.id = info.id; 051 } 052 053 @Override 054 public int getID() { 055 return id; 056 } 057 058 /** 059 * Generates a {@link RealIDInfo} that matches this CategoricalInfo and 060 * also contains an id number. 061 */ 062 @Override 063 public RealIDInfo generateRealInfo() { 064 double min = Double.POSITIVE_INFINITY; 065 double max = Double.NEGATIVE_INFINITY; 066 double sum = 0.0; 067 double sumSquares = 0.0; 068 double mean; 069 070 if (valueCounts != null) { 071 for (Map.Entry<Double, MutableLong> e : valueCounts.entrySet()) { 072 double value = e.getKey(); 073 double valCount = e.getValue().longValue(); 074 if (value > max) { 075 max = value; 076 } 077 if (value < min) { 078 min = value; 079 } 080 sum += value * valCount; 081 } 082 mean = sum / count; 083 084 for (Map.Entry<Double, MutableLong> e : valueCounts.entrySet()) { 085 double value = e.getKey(); 086 double valCount = e.getValue().longValue(); 087 sumSquares += (value - mean) * (value - mean) * valCount; 088 } 089 } else { 090 min = observedValue; 091 max = observedValue; 092 mean = observedValue; 093 sumSquares = 0.0; 094 } 095 096 return new RealIDInfo(name,count,max,min,mean,sumSquares,id); 097 } 098 099 @Override 100 public CategoricalIDInfo copy() { 101 return new CategoricalIDInfo(this,name); 102 } 103 104 @Override 105 public CategoricalIDInfo makeIDInfo(int id) { 106 return new CategoricalIDInfo(this,id); 107 } 108 109 @Override 110 public CategoricalIDInfo rename(String newName) { 111 return new CategoricalIDInfo(this,newName); 112 } 113 114 @Override 115 public String toString() { 116 if (valueCounts != null) { 117 return "CategoricalFeature(name=" + name + ",id=" + id + ",count=" + count + ",map=" + valueCounts.toString() + ")"; 118 } else { 119 return "CategoricalFeature(name=" + name + ",id=" + id + ",count=" + count + ",map={" +observedValue+","+observedCount+"})"; 120 } 121 } 122}