001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.impl;
018
019import com.oracle.labs.mlrg.olcut.util.SortUtil;
020import org.tribuo.Example;
021import org.tribuo.Feature;
022import org.tribuo.ImmutableFeatureMap;
023import org.tribuo.ImmutableOutputInfo;
024import org.tribuo.Output;
025import org.tribuo.util.Merger;
026
027import java.util.ArrayList;
028import java.util.Arrays;
029import java.util.Collection;
030import java.util.HashMap;
031import java.util.Iterator;
032import java.util.List;
033import java.util.Map;
034import java.util.NoSuchElementException;
035import java.util.Objects;
036import java.util.PriorityQueue;
037
038/**
039 * A version of ArrayExample which also has the id numbers.
040 * <p>
041 * Used in feature selection to provide log n lookups. May be used
042 * elsewhere in the future as a performance optimisation.
043 */
044public class IndexedArrayExample<T extends Output<T>> extends ArrayExample<T> {
045    private static final long serialVersionUID = 1L;
046
047    protected int[] featureIDs;
048
049    protected final int outputID;
050
051    private final ImmutableFeatureMap featureMap;
052
053    private final ImmutableOutputInfo<T> outputMap;
054
055    /**
056     * Copy constructor.
057     * @param other The example to copy.
058     */
059    public IndexedArrayExample(IndexedArrayExample<T> other) {
060        super(other.getOutput(),other.getWeight(),other.getMetadata());
061        featureNames = Arrays.copyOf(other.featureNames,other.featureNames.length);
062        featureIDs = Arrays.copyOf(other.featureIDs,other.size());
063        featureValues = Arrays.copyOf(other.featureValues,other.featureValues.length);
064        featureMap = other.featureMap;
065        outputMap = other.outputMap;
066        outputID = outputMap.getID(output);
067        size = other.size;
068    }
069
070    /**
071     * This constructor removes unknown features.
072     * @param other The example to copy from.
073     * @param featureMap The feature map.
074     * @param outputMap The output info.
075     */
076    public IndexedArrayExample(Example<T> other, ImmutableFeatureMap featureMap, ImmutableOutputInfo<T> outputMap) {
077        super(other);
078        this.featureIDs = new int[other.size()];
079        this.featureMap = featureMap;
080        this.outputMap = outputMap;
081        this.outputID = outputMap.getID(output);
082        for (int i = 0; i < featureNames.length; i++) {
083            featureIDs[i] = featureMap.getID(featureNames[i]);
084        }
085        int[] newIDs = new int[other.size()];
086        String[] newNames = new String[other.size()];
087        double[] newValues = new double[other.size()];
088        int counter = 0;
089        for (int i = 0; i < featureIDs.length; i++) {
090            if (featureIDs[i] != -1) {
091                newIDs[counter] = featureIDs[i];
092                newValues[counter] = featureValues[i];
093                newNames[counter] = featureNames[i];
094                counter++;
095            }
096        }
097        size = counter;
098        featureNames = newNames;
099        featureIDs = newIDs;
100        featureValues = newValues;
101    }
102
103    @Override
104    public boolean equals(Object o) {
105        if (this == o) return true;
106        if (!(o instanceof IndexedArrayExample)) return false;
107        if (!super.equals(o)) return false;
108        IndexedArrayExample<?> that = (IndexedArrayExample<?>) o;
109        return outputID == that.outputID &&
110                Arrays.equals(featureIDs, that.featureIDs) &&
111                featureMap.equals(that.featureMap) &&
112                outputMap.equals(that.outputMap);
113    }
114
115    @Override
116    public int hashCode() {
117        int result = Objects.hash(super.hashCode(), outputID, featureMap, outputMap);
118        result = 31 * result + Arrays.hashCode(featureIDs);
119        return result;
120    }
121
122    @Override
123    protected void growArray(int minCapacity) {
124        int newCapacity = newCapacity(minCapacity);
125        featureNames = Arrays.copyOf(featureNames,newCapacity);
126        featureIDs = Arrays.copyOf(featureIDs,newCapacity);
127        featureValues = Arrays.copyOf(featureValues,newCapacity);
128    }
129
130    @Override
131    public void add(Feature feature) {
132        if (size >= featureNames.length) {
133            growArray();
134        }
135        featureNames[size] = feature.getName();
136        featureIDs[size] = featureMap.getID(feature.getName());
137        featureValues[size] = feature.getValue();
138        size++;
139        sort();
140    }
141
142    @Override
143    public void addAll(Collection<? extends Feature> features) {
144        if (size + features.size() >= featureNames.length) {
145            growArray(size+features.size());
146        }
147        for (Feature f : features) {
148            featureNames[size] = f.getName();
149            featureIDs[size] = featureMap.getID(f.getName());
150            featureValues[size] = f.getValue();
151            size++;
152        }
153        sort();
154    }
155
156    @Override
157    protected void sort() {
158        int[] sortedIndices = SortUtil.argsort(featureNames,0,size,true);
159
160        String[] newNames = Arrays.copyOf(featureNames,size);
161        int[] newIDs = Arrays.copyOf(featureIDs,size);
162        double[] newValues = Arrays.copyOf(featureValues,size);
163        for (int i = 0; i < sortedIndices.length; i++) {
164            featureNames[i] = newNames[sortedIndices[i]];
165            featureIDs[i] = newIDs[sortedIndices[i]];
166            featureValues[i] = newValues[sortedIndices[i]];
167        }
168    }
169
170    @Override
171    public void reduceByName(Merger merger) {
172        if (size > 0) {
173            int[] sortedIndices = SortUtil.argsort(featureNames, 0, size, true);
174            String[] newNames = new String[featureNames.length];
175            int[] newIDs = new int[featureIDs.length];
176            double[] newValues = new double[featureNames.length];
177            for (int i = 0; i < sortedIndices.length; i++) {
178                newNames[i] = featureNames[sortedIndices[i]];
179                newIDs[i] = featureIDs[sortedIndices[i]];
180                newValues[i] = featureValues[sortedIndices[i]];
181            }
182            featureNames[0] = newNames[0];
183            featureIDs[0] = newIDs[0];
184            featureValues[0] = newValues[0];
185            int dest = 0;
186            for (int i = 1; i < size; i++) {
187                while ((i < size) && newNames[i].equals(featureNames[dest])) {
188                    featureValues[dest] = merger.merge(featureValues[dest], newValues[i]);
189                    i++;
190                }
191                if (i < size) {
192                    dest++;
193                    featureNames[dest] = newNames[i];
194                    featureIDs[dest] = newIDs[i];
195                    featureValues[dest] = newValues[i];
196                }
197            }
198            size = dest + 1;
199        }
200    }
201
202    @Override
203    public void removeFeatures(List<Feature> featureList) {
204        Map<String,List<Integer>> map = new HashMap<>();
205        for (int i = 0; i < featureNames.length; i++) {
206            List<Integer> list = map.computeIfAbsent(featureNames[i],(k) -> new ArrayList<>());
207            list.add(i);
208        }
209
210        PriorityQueue<Integer> removeQueue = new PriorityQueue<>();
211        for (Feature f : featureList) {
212            List<Integer> i = map.get(f.getName());
213            if (i != null) {
214                // If we've found this feature remove it from the map to prevent double counting
215                map.remove(f.getName());
216                removeQueue.addAll(i);
217            }
218        }
219
220        String[] newNames = new String[size-removeQueue.size()];
221        int[] newIDs = new int[size-removeQueue.size()];
222        double[] newValues = new double[size-removeQueue.size()];
223
224        int source = 0;
225        int dest = 0;
226        while (!removeQueue.isEmpty()) {
227            int curRemoveIdx = removeQueue.poll();
228            while (source < curRemoveIdx) {
229                newNames[dest] = featureNames[source];
230                newIDs[dest] = featureIDs[source];
231                newValues[dest] = featureValues[source];
232                source++;
233                dest++;
234            }
235            source++;
236        }
237        while (source < size) {
238            newNames[dest] = featureNames[source];
239            newIDs[dest] = featureIDs[source];
240            newValues[dest] = featureValues[source];
241            source++;
242            dest++;
243        }
244        featureNames = newNames;
245        featureIDs = newIDs;
246        featureValues = newValues;
247        size = featureNames.length;
248    }
249
250    /**
251     * Does this example contain a feature with id i.
252     * @param i The index to check.
253     * @return True if the example contains the id.
254     */
255    public boolean contains(int i) {
256        return Arrays.binarySearch(featureIDs,i) > -1;
257    }
258
259    @Override
260    public IndexedArrayExample<T> copy() {
261        return new IndexedArrayExample<>(this);
262    }
263
264    @Override
265    public void densify(List<String> featureList) {
266        if (featureList.size() != featureMap.size()) {
267            throw new IllegalArgumentException("Densifying an example with a different feature map");
268        }
269        // Ensure we have enough space.
270        if (featureList.size() > featureNames.length) {
271            growArray(featureList.size());
272        }
273        int insertedCount = 0;
274        int curPos = 0;
275        for (String curName : featureList) {
276            // If we've reached the end of our old feature set, just insert.
277            if (curPos == size) {
278                featureNames[size + insertedCount] = curName;
279                featureIDs[size + insertedCount] = featureMap.getID(curName);
280                insertedCount++;
281            } else {
282                // Check to see if our insertion candidate is the same as the current feature name.
283                int comparison = curName.compareTo(featureNames[curPos]);
284                if (comparison < 0) {
285                    // If it's earlier, insert it.
286                    featureNames[size + insertedCount] = curName;
287                    featureIDs[size + insertedCount] = featureMap.getID(curName);
288                    insertedCount++;
289                } else if (comparison == 0) {
290                    // Otherwise just bump our pointer, we've already got this feature.
291                    curPos++;
292                }
293            }
294        }
295        // Bump the size up by the number of inserted features.
296        size += insertedCount;
297        // Sort the features
298        sort();
299    }
300
301    /**
302     * Gets the feature at internal index i.
303     * @param i The internal index.
304     * @return The feature index.
305     */
306    public int getIdx(int i) {
307        return featureIDs[i];
308    }
309
310    /**
311     * Gets the output id dimension number.
312     * @return The output id.
313     */
314    public int getOutputID() {
315        return outputID;
316    }
317
318    /**
319     * Iterator over the feature ids and values.
320     * @return The feature ids and values.
321     */
322    public Iterable<FeatureTuple> idIterator() {
323        return ArrayIndexedExampleIterator::new;
324    }
325
326    /**
327     * A tuple of the feature name, id and value.
328     */
329    public static class FeatureTuple {
330        public String name;
331        public int id;
332        public double value;
333
334        public FeatureTuple() { }
335
336        public FeatureTuple(String name, int id, double value) {
337            this.name = name;
338            this.id = id;
339            this.value = value;
340        }
341    }
342
343    class ArrayIndexedExampleIterator implements Iterator<FeatureTuple> {
344        int pos = 0;
345        FeatureTuple tuple = new FeatureTuple();
346
347        @Override
348        public boolean hasNext() {
349            return pos < size;
350        }
351
352        @Override
353        public FeatureTuple next() {
354            if (!hasNext()) {
355                throw new NoSuchElementException("Off the end of the iterator.");
356            }
357            tuple.name = featureNames[pos];
358            tuple.id = featureIDs[pos];
359            tuple.value = featureValues[pos];
360            pos++;
361            return tuple;
362        }
363    }
364}