001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.transform;
018
019import com.oracle.labs.mlrg.olcut.config.Config;
020import com.oracle.labs.mlrg.olcut.config.Configurable;
021import com.oracle.labs.mlrg.olcut.provenance.ConfiguredObjectProvenance;
022import com.oracle.labs.mlrg.olcut.provenance.Provenancable;
023import com.oracle.labs.mlrg.olcut.provenance.impl.ConfiguredObjectProvenanceImpl;
024import org.tribuo.Dataset;
025import org.tribuo.FeatureMap;
026import org.tribuo.MutableDataset;
027
028import java.util.ArrayList;
029import java.util.Collections;
030import java.util.HashMap;
031import java.util.HashSet;
032import java.util.List;
033import java.util.Map;
034import java.util.Objects;
035import java.util.regex.Pattern;
036
037/**
038 * A carrier type for a set of transformations to be applied to a {@link Dataset}.
039 * <p>
040 * Feature specific transformations are specified using a regex. If multiple
041 * regexes match a given feature, then an {@link IllegalArgumentException} is thrown
042 * when {@link Dataset#createTransformers(TransformationMap)} is called.
043 * <p>
044 * Global transformations are applied <em>after</em> all feature specific transformations.
045 * <p>
046 * Transformations only operate on observed values. To operate on implicit zeros then
047 * first call {@link MutableDataset#densify} on the datasets.
048 */
049public class TransformationMap implements Configurable, Provenancable<ConfiguredObjectProvenance> {
050
051    @Config(mandatory = true,description="Global transformations to apply after the feature specific transforms.")
052    private List<Transformation> globalTransformations;
053
054    @Config(mandatory = true,description="Feature specific transformations. Accepts regexes for feature names.")
055    private Map<String,TransformationList> featureTransformationList = new HashMap<>();
056
057    private final Map<String,List<Transformation>> featureTransformations = new HashMap<>();
058
059    private ConfiguredObjectProvenanceImpl provenance;
060
061    /**
062     * For OLCUT.
063     */
064    private TransformationMap() {}
065
066    public TransformationMap(List<Transformation> globalTransformations, Map<String,List<Transformation>> featureTransformations) {
067        this.globalTransformations = globalTransformations;
068        this.featureTransformations.putAll(featureTransformations);
069
070        // Copy values out for provenance
071        for (Map.Entry<String,List<Transformation>> e : featureTransformations.entrySet()) {
072            featureTransformationList.put(e.getKey(),new TransformationList(e.getValue()));
073        }
074    }
075
076    public TransformationMap(List<Transformation> globalTransformations) {
077        this(globalTransformations, Collections.emptyMap());
078    }
079
080    public TransformationMap(Map<String,List<Transformation>> featureTransformations) {
081        this(Collections.emptyList(),featureTransformations);
082    }
083
084    /**
085     * Used by the OLCUT configuration system, and should not be called by external code.
086     */
087    @Override
088    public void postConfig() {
089        for (Map.Entry<String,TransformationList> e : featureTransformationList.entrySet()) {
090            featureTransformations.put(e.getKey(),e.getValue().list);
091        }
092    }
093
094    /**
095     * Checks that a given transformation set doesn't have conflicts when applied to the supplied featureMap.
096     * @param featureMap The featureMap to check.
097     * @return True if the transformation set doesn't have conflicts, false otherwise.
098     */
099    public boolean validateTransformations(FeatureMap featureMap) {
100        HashSet<String> featuresWithPatterns = new HashSet<>();
101        ArrayList<String> featureNames = new ArrayList<>(featureMap.keySet());
102        boolean valid = true;
103
104        // Loop over all regexes
105        for (String regex : featureTransformations.keySet()) {
106            Pattern p = Pattern.compile(regex);
107            // Loop over all features
108            for (String s : featureNames) {
109                // Check if the pattern matches the feature
110                if (p.matcher(s).matches()) {
111                    // If it matches, add the feature to the HashSet
112                    valid = featuresWithPatterns.add(s);
113                    // If it already was present, there are two patterns for the same feature
114                    // so the Transformations are invalid.
115                    // Bail out.
116                    if (!valid) {
117                        break;
118                    }
119                }
120            }
121            if (!valid) {
122                break;
123            }
124        }
125
126        return valid;
127    }
128
129    @Override
130    public String toString() {
131        return "TransformationMap(featureTransformations="+featureTransformations.toString()+",globalTransformations="+globalTransformations.toString()+")";
132    }
133
134    /**
135     * Gets the global transformations in this TransformationMap.
136     * @return The global transformations
137     */
138    public List<Transformation> getGlobalTransformations() {
139        return globalTransformations;
140    }
141
142    /**
143     * Gets the map of feature specific transformations.
144     * @return The feature specific transformations.
145     */
146    public Map<String, List<Transformation>> getFeatureTransformations() {
147        return featureTransformations;
148    }
149
150    @Override
151    public synchronized ConfiguredObjectProvenance getProvenance() {
152        if (provenance == null) {
153            provenance = cacheProvenance();
154        }
155        return provenance;
156    }
157
158    private ConfiguredObjectProvenanceImpl cacheProvenance() {
159        return new ConfiguredObjectProvenanceImpl(this,"TransformationMap");
160    }
161
162    /**
163     * A carrier type as OLCUT does not support nested generics.
164     * <p>
165     * Will be deprecated if/when OLCUT supports this.
166     */
167    public final static class TransformationList implements Configurable, Provenancable<ConfiguredObjectProvenance> {
168        @Config(description="A list of transformations to apply.")
169        public List<Transformation> list;
170
171        private TransformationList() {}
172
173        public TransformationList(List<Transformation> list) {
174            this.list = list;
175        }
176
177        @Override
178        public ConfiguredObjectProvenance getProvenance() {
179            return new ConfiguredObjectProvenanceImpl(this,"TransformationList");
180        }
181
182        @Override
183        public boolean equals(Object o) {
184            if (this == o) return true;
185            if (!(o instanceof TransformationList)) return false;
186            TransformationList that = (TransformationList) o;
187            return list.equals(that.list);
188        }
189
190        @Override
191        public int hashCode() {
192            return Objects.hash(list);
193        }
194    }
195
196}