001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.transform; 018 019import com.oracle.labs.mlrg.olcut.config.Config; 020import com.oracle.labs.mlrg.olcut.config.Configurable; 021import com.oracle.labs.mlrg.olcut.provenance.ConfiguredObjectProvenance; 022import com.oracle.labs.mlrg.olcut.provenance.Provenancable; 023import com.oracle.labs.mlrg.olcut.provenance.impl.ConfiguredObjectProvenanceImpl; 024import org.tribuo.Dataset; 025import org.tribuo.FeatureMap; 026import org.tribuo.MutableDataset; 027 028import java.util.ArrayList; 029import java.util.Collections; 030import java.util.HashMap; 031import java.util.HashSet; 032import java.util.List; 033import java.util.Map; 034import java.util.Objects; 035import java.util.regex.Pattern; 036 037/** 038 * A carrier type for a set of transformations to be applied to a {@link Dataset}. 039 * <p> 040 * Feature specific transformations are specified using a regex. If multiple 041 * regexes match a given feature, then an {@link IllegalArgumentException} is thrown 042 * when {@link Dataset#createTransformers(TransformationMap)} is called. 043 * <p> 044 * Global transformations are applied <em>after</em> all feature specific transformations. 045 * <p> 046 * Transformations only operate on observed values. To operate on implicit zeros then 047 * first call {@link MutableDataset#densify} on the datasets. 048 */ 049public class TransformationMap implements Configurable, Provenancable<ConfiguredObjectProvenance> { 050 051 @Config(mandatory = true,description="Global transformations to apply after the feature specific transforms.") 052 private List<Transformation> globalTransformations; 053 054 @Config(mandatory = true,description="Feature specific transformations. Accepts regexes for feature names.") 055 private Map<String,TransformationList> featureTransformationList = new HashMap<>(); 056 057 private final Map<String,List<Transformation>> featureTransformations = new HashMap<>(); 058 059 private ConfiguredObjectProvenanceImpl provenance; 060 061 /** 062 * For OLCUT. 063 */ 064 private TransformationMap() {} 065 066 public TransformationMap(List<Transformation> globalTransformations, Map<String,List<Transformation>> featureTransformations) { 067 this.globalTransformations = globalTransformations; 068 this.featureTransformations.putAll(featureTransformations); 069 070 // Copy values out for provenance 071 for (Map.Entry<String,List<Transformation>> e : featureTransformations.entrySet()) { 072 featureTransformationList.put(e.getKey(),new TransformationList(e.getValue())); 073 } 074 } 075 076 public TransformationMap(List<Transformation> globalTransformations) { 077 this(globalTransformations, Collections.emptyMap()); 078 } 079 080 public TransformationMap(Map<String,List<Transformation>> featureTransformations) { 081 this(Collections.emptyList(),featureTransformations); 082 } 083 084 /** 085 * Used by the OLCUT configuration system, and should not be called by external code. 086 */ 087 @Override 088 public void postConfig() { 089 for (Map.Entry<String,TransformationList> e : featureTransformationList.entrySet()) { 090 featureTransformations.put(e.getKey(),e.getValue().list); 091 } 092 } 093 094 /** 095 * Checks that a given transformation set doesn't have conflicts when applied to the supplied featureMap. 096 * @param featureMap The featureMap to check. 097 * @return True if the transformation set doesn't have conflicts, false otherwise. 098 */ 099 public boolean validateTransformations(FeatureMap featureMap) { 100 HashSet<String> featuresWithPatterns = new HashSet<>(); 101 ArrayList<String> featureNames = new ArrayList<>(featureMap.keySet()); 102 boolean valid = true; 103 104 // Loop over all regexes 105 for (String regex : featureTransformations.keySet()) { 106 Pattern p = Pattern.compile(regex); 107 // Loop over all features 108 for (String s : featureNames) { 109 // Check if the pattern matches the feature 110 if (p.matcher(s).matches()) { 111 // If it matches, add the feature to the HashSet 112 valid = featuresWithPatterns.add(s); 113 // If it already was present, there are two patterns for the same feature 114 // so the Transformations are invalid. 115 // Bail out. 116 if (!valid) { 117 break; 118 } 119 } 120 } 121 if (!valid) { 122 break; 123 } 124 } 125 126 return valid; 127 } 128 129 @Override 130 public String toString() { 131 return "TransformationMap(featureTransformations="+featureTransformations.toString()+",globalTransformations="+globalTransformations.toString()+")"; 132 } 133 134 /** 135 * Gets the global transformations in this TransformationMap. 136 * @return The global transformations 137 */ 138 public List<Transformation> getGlobalTransformations() { 139 return globalTransformations; 140 } 141 142 /** 143 * Gets the map of feature specific transformations. 144 * @return The feature specific transformations. 145 */ 146 public Map<String, List<Transformation>> getFeatureTransformations() { 147 return featureTransformations; 148 } 149 150 @Override 151 public synchronized ConfiguredObjectProvenance getProvenance() { 152 if (provenance == null) { 153 provenance = cacheProvenance(); 154 } 155 return provenance; 156 } 157 158 private ConfiguredObjectProvenanceImpl cacheProvenance() { 159 return new ConfiguredObjectProvenanceImpl(this,"TransformationMap"); 160 } 161 162 /** 163 * A carrier type as OLCUT does not support nested generics. 164 * <p> 165 * Will be deprecated if/when OLCUT supports this. 166 */ 167 public final static class TransformationList implements Configurable, Provenancable<ConfiguredObjectProvenance> { 168 @Config(description="A list of transformations to apply.") 169 public List<Transformation> list; 170 171 private TransformationList() {} 172 173 public TransformationList(List<Transformation> list) { 174 this.list = list; 175 } 176 177 @Override 178 public ConfiguredObjectProvenance getProvenance() { 179 return new ConfiguredObjectProvenanceImpl(this,"TransformationList"); 180 } 181 182 @Override 183 public boolean equals(Object o) { 184 if (this == o) return true; 185 if (!(o instanceof TransformationList)) return false; 186 TransformationList that = (TransformationList) o; 187 return list.equals(that.list); 188 } 189 190 @Override 191 public int hashCode() { 192 return Objects.hash(list); 193 } 194 } 195 196}