001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.transform;
018
019import com.oracle.labs.mlrg.olcut.provenance.ConfiguredObjectProvenance;
020import com.oracle.labs.mlrg.olcut.provenance.ObjectProvenance;
021import com.oracle.labs.mlrg.olcut.provenance.Provenancable;
022import com.oracle.labs.mlrg.olcut.provenance.Provenance;
023import com.oracle.labs.mlrg.olcut.provenance.primitives.StringProvenance;
024import com.oracle.labs.mlrg.olcut.util.Pair;
025import org.tribuo.Dataset;
026import org.tribuo.Example;
027import org.tribuo.MutableDataset;
028import org.tribuo.Output;
029import org.tribuo.impl.ArrayExample;
030import org.tribuo.provenance.DatasetProvenance;
031import org.tribuo.transform.TransformerMap.TransformerMapProvenance;
032
033import java.io.Serializable;
034import java.util.ArrayList;
035import java.util.Iterator;
036import java.util.List;
037import java.util.Map;
038import java.util.Objects;
039import java.util.Set;
040
041/**
042 * A collection of {@link Transformer}s which can be applied to a {@link Dataset}
043 * or {@link Example}. A TransformerMap is created by applying a {@link TransformationMap}
044 * to a Dataset. It contains Transformers which are specific to the Dataset which created
045 * it, for example the range of a feature used in binning is fixed to the value from
046 * that Dataset.
047 * <p>
048 * Transformations only operate on observed values. To operate on implicit zeros then
049 * first call {@link MutableDataset#densify} on the datasets.
050 */
051public final class TransformerMap implements Provenancable<TransformerMapProvenance>, Serializable {
052    private static final long serialVersionUID = 2L;
053
054    private final Map<String, List<Transformer>> map;
055    private final DatasetProvenance datasetProvenance;
056    private final ConfiguredObjectProvenance transformationMapProvenance;
057
058    /**
059     * Constructs a transformer map which encapsulates a set of transformers that can be applied to features.
060     * @param map The transformers, one per transformed feature.
061     * @param datasetProvenance The provenance of the dataset the transformers were fit against.
062     * @param transformationMapProvenance The provenance of the transformation map that was fit.
063     */
064    public TransformerMap(Map<String,List<Transformer>> map, DatasetProvenance datasetProvenance, ConfiguredObjectProvenance transformationMapProvenance) {
065        this.map = map;
066        this.datasetProvenance = datasetProvenance;
067        this.transformationMapProvenance = transformationMapProvenance;
068    }
069
070    /**
071     * Applies a {@link List} of {@link Transformer}s to the supplied double value,
072     * returning the transformed value.
073     * @param value The value to transform.
074     * @param transformerList The transformers to apply.
075     * @return The transformed value.
076     */
077    public static double applyTransformerList(double value, List<Transformer> transformerList) {
078        if (transformerList != null) {
079            for (Transformer t : transformerList) {
080                value = t.transform(value);
081            }
082        }
083        return value;
084    }
085
086    /**
087     * Copies the supplied example and applies the transformers to it.
088     * @param example The example to transform.
089     * @param <T> The type of Output.
090     * @return A copy of the example with the transformers applied to it's features.
091     */
092    public <T extends Output<T>> Example<T> transformExample(Example<T> example) {
093        ArrayExample<T> newExample = new ArrayExample<>(example);
094        newExample.transform(this);
095        return newExample;
096    }
097
098    /**
099     * Copies the supplied example and applies the transformers to it.
100     * @param example The example to transform.
101     * @param featureNames The feature names to densify.
102     * @param <T> The type of Output.
103     * @return A copy of the example with the transformers applied to it's features.
104     */
105    public <T extends Output<T>> Example<T> transformExample(Example<T> example, List<String> featureNames) {
106        ArrayExample<T> newExample = new ArrayExample<>(example);
107        newExample.densify(featureNames);
108        newExample.transform(this);
109        return newExample;
110    }
111
112    /**
113     * Copies the supplied dataset and applies the transformers to each example in it.
114     * <p>
115     * Does not densify the dataset first.
116     * @param dataset The dataset to transform.
117     * @param <T> The type of Output.
118     * @return A deep copy of the dataset (and it's examples) with the transformers applied to it's features.
119     */
120    public <T extends Output<T>> MutableDataset<T> transformDataset(Dataset<T> dataset) {
121        return transformDataset(dataset,false);
122    }
123
124    /**
125     * Copies the supplied dataset and applies the transformers to each example in it.
126     * @param dataset The dataset to transform.
127     * @param densify Densify the dataset before transforming it.
128     * @param <T> The type of Output.
129     * @return A deep copy of the dataset (and it's examples) with the transformers applied to it's features.
130     */
131    public <T extends Output<T>> MutableDataset<T> transformDataset(Dataset<T> dataset, boolean densify) {
132        MutableDataset<T> newDataset = MutableDataset.createDeepCopy(dataset);
133
134        if (densify) {
135            newDataset.densify();
136        }
137
138        newDataset.transform(this);
139
140        return newDataset;
141    }
142
143    @Override
144    public String toString() {
145        return "TransformerMap(map="+map.toString()+")";
146    }
147
148    /**
149     * Get the feature names and associated list of transformers.
150     * @return The entry set of the transformer map.
151     */
152    public Set<Map.Entry<String,List<Transformer>>> entrySet() {
153        return map.entrySet();
154    }
155
156    @Override
157    public TransformerMapProvenance getProvenance() {
158        return new TransformerMapProvenance(this);
159    }
160
161    /**
162     * Provenance for {@link TransformerMap}.
163     */
164    public final static class TransformerMapProvenance implements ObjectProvenance {
165        private static final long serialVersionUID = 1L;
166
167        private static final String TRANSFORMATION_MAP = "transformation-map";
168        private static final String DATASET = "dataset";
169
170        private final String className;
171        private final ConfiguredObjectProvenance transformationMapProvenance;
172        private final DatasetProvenance datasetProvenance;
173
174        TransformerMapProvenance(TransformerMap host) {
175            this.className = host.getClass().getName();
176            this.transformationMapProvenance = host.transformationMapProvenance;
177            this.datasetProvenance = host.datasetProvenance;
178        }
179
180        public TransformerMapProvenance(Map<String,Provenance> map) {
181            this.className = ObjectProvenance.checkAndExtractProvenance(map,CLASS_NAME,StringProvenance.class,TransformerMapProvenance.class.getSimpleName()).getValue();
182            this.transformationMapProvenance = ObjectProvenance.checkAndExtractProvenance(map,TRANSFORMATION_MAP,ConfiguredObjectProvenance.class,TransformerMapProvenance.class.getSimpleName());
183            this.datasetProvenance = ObjectProvenance.checkAndExtractProvenance(map,DATASET,DatasetProvenance.class,TransformerMapProvenance.class.getSimpleName());
184        }
185
186        @Override
187        public String getClassName() {
188            return className;
189        }
190
191        @Override
192        public Iterator<Pair<String, Provenance>> iterator() {
193            ArrayList<Pair<String,Provenance>> list = new ArrayList<>();
194
195            list.add(new Pair<>(CLASS_NAME,new StringProvenance(CLASS_NAME,className)));
196            list.add(new Pair<>(TRANSFORMATION_MAP,transformationMapProvenance));
197            list.add(new Pair<>(DATASET,datasetProvenance));
198
199            return list.iterator();
200        }
201
202        @Override
203        public boolean equals(Object o) {
204            if (this == o) return true;
205            if (!(o instanceof TransformerMapProvenance)) return false;
206            TransformerMapProvenance pairs = (TransformerMapProvenance) o;
207            return className.equals(pairs.className) &&
208                    transformationMapProvenance.equals(pairs.transformationMapProvenance) &&
209                    datasetProvenance.equals(pairs.datasetProvenance);
210        }
211
212        @Override
213        public int hashCode() {
214            return Objects.hash(className, transformationMapProvenance, datasetProvenance);
215        }
216
217        @Override
218        public String toString() {
219            return generateString("TransformerMap");
220        }
221    }
222}