001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.util.infotheory.impl;
018
019import java.lang.reflect.Array;
020import java.util.ArrayList;
021import java.util.Collection;
022import java.util.Collections;
023import java.util.Iterator;
024import java.util.LinkedHashSet;
025import java.util.List;
026import java.util.ListIterator;
027import java.util.Set;
028
029/**
030 * An implementation of a List which wraps a set of lists.
031 * <p>
032 * Each access returns a {@link Row} drawn by taking an element from each list.
033 * <p>
034 * The rows only expose equals and hashcode, as the information theoretic calculations
035 * only care about equality.
036 * @param <T> The type stored in the lists.
037 */
038public final class RowList<T> implements List<Row<T>> {
039    private final Set<List<T>> set;
040    private final int size;    
041
042    public RowList(Set<List<T>> set) {
043        this.set = Collections.unmodifiableSet(new LinkedHashSet<>(set));
044        size = set.iterator().next().size();
045        for (Collection<T> element : this.set) {
046            if (size != element.size()) {
047                throw new IllegalArgumentException("Not all the collections in the set are the same length");
048            }
049        }
050    }
051    
052    @Override
053    public int size() {
054        return size;
055    }
056
057    @Override
058    public boolean isEmpty() {
059        return size == 0;
060    }
061
062    @Override
063    public boolean contains(Object o) {
064        if (o instanceof Row) {
065            Row<?> otherRow = (Row<?>) o;
066            boolean found = false;
067            for (Row<T> row : this) {
068                if (otherRow.equals(row)) {
069                    found = true;
070                    break;
071                }
072            }
073            return found;
074        } else {
075            return false;
076        }
077    }
078
079    @Override
080    public Iterator<Row<T>> iterator() {
081        return new RowListIterator<>(set);
082    }
083
084    @Override
085    public Object[] toArray() {
086        Object[] output = new Object[size];
087        int counter = 0;
088        for (Row<T> row : this) {
089            output[counter] = row;
090            counter++;
091        }
092        return output;
093    }
094
095    @Override
096    @SuppressWarnings("unchecked")
097    public <U> U[] toArray(U[] a) {
098        U[] output = a;
099        if (output.length < size) {
100            output = (U[]) Array.newInstance(a[0].getClass(), size);
101        }
102        int counter = 0;
103        for (Row<T> row : this) {
104            output[counter] = (U) row;
105            counter++;
106        }
107        if (output.length > size) {
108            //fill with nulls if bigger.
109            for (; counter < output.length; counter++) {
110                output[counter] = null;
111            }
112        }
113        return output;
114    }
115
116    @Override
117    public Row<T> get(int index) {
118        ArrayList<T> list = new ArrayList<>(set.size());
119        int counter = 0;
120        for (List<T> element : set) {
121            list.add(counter, element.get(index));
122            counter++;
123        }
124        return new Row<>(list);
125    }
126
127    @Override
128    public boolean containsAll(Collection<?> c) {
129        boolean found = true;
130        Iterator<?> itr = c.iterator();
131        while (itr.hasNext() && found) {
132            found = this.contains(itr.next());
133        }
134        return found;
135    }
136
137    @Override
138    public int indexOf(Object o) {
139        if (o instanceof Row) {
140            Row<?> otherRow = (Row<?>) o;
141            int counter = 0;
142            int found = -1;
143            Iterator<Row<T>> itr = this.iterator();
144            while (itr.hasNext() && found == -1) {
145                if (itr.next().equals(otherRow)) {
146                    found = counter;
147                }
148                counter++;
149            }
150            return found;
151        } else {
152            return -1;
153        }
154    }
155
156    @Override
157    public int lastIndexOf(Object o) {
158        if (o instanceof Row) {
159            Row<?> otherRow = (Row<?>) o;
160            int counter = 0;
161            int found = -1;
162            for (Row<T> tRow : this) {
163                if (tRow.equals(otherRow)) {
164                    found = counter;
165                }
166                counter++;
167            }
168            return found;
169        } else {
170            return -1;
171        }
172    }
173
174    @Override
175    public ListIterator<Row<T>> listIterator() {
176        return new RowListIterator<>(set);
177    }
178
179    @Override
180    public ListIterator<Row<T>> listIterator(int index) {
181        return new RowListIterator<>(set,index);
182    }
183
184    /**
185     * Unsupported. Throws UnsupportedOperationException.
186     * @param fromIndex n/a
187     * @param toIndex n/a
188     * @return n/a
189     */
190    @Override
191    public List<Row<T>> subList(int fromIndex, int toIndex) {
192        throw new UnsupportedOperationException("Views are not supported on a RowList.");
193    }
194
195    //*************************************************************************
196    // The remaining operations are unsupported as this list is immutable.
197    //*************************************************************************
198    /**
199     * Unsupported. Throws UnsupportedOperationException.
200     * @param e n/a
201     * @return n/a
202     */
203    @Override
204    public boolean add(Row<T> e) {
205        throw new UnsupportedOperationException("This list is immutable.");
206    }
207
208    /**
209     * Unsupported. Throws UnsupportedOperationException.
210     * @param o n/a
211     * @return n/a
212     */
213    @Override
214    public boolean remove(Object o) {
215        throw new UnsupportedOperationException("This list is immutable.");
216    }
217
218    /**
219     * Unsupported. Throws UnsupportedOperationException.
220     * @param c n/a
221     * @return n/a
222     */
223    @Override
224    public boolean addAll(Collection<? extends Row<T>> c) {
225        throw new UnsupportedOperationException("This list is immutable.");
226    }
227
228    /**
229     * Unsupported. Throws UnsupportedOperationException.
230     * @param index n/a
231     * @param c n/a
232     * @return n/a
233     */
234    @Override
235    public boolean addAll(int index, Collection<? extends Row<T>> c) {
236        throw new UnsupportedOperationException("This list is immutable.");
237    }
238
239    /**
240     * Unsupported. Throws UnsupportedOperationException.
241     * @param c n/a
242     * @return n/a
243     */
244    @Override
245    public boolean removeAll(Collection<?> c) {
246        throw new UnsupportedOperationException("This list is immutable.");
247    }
248
249    /**
250     * Unsupported. Throws UnsupportedOperationException.
251     * @param c n/a
252     * @return n/a
253     */
254    @Override
255    public boolean retainAll(Collection<?> c) {
256        throw new UnsupportedOperationException("This list is immutable.");
257    }
258
259    /**
260     * Unsupported. Throws UnsupportedOperationException.
261     */
262    @Override
263    public void clear() {
264        throw new UnsupportedOperationException("This list is immutable.");
265    }
266
267    /**
268     * Unsupported. Throws UnsupportedOperationException.
269     * @param index n/a
270     * @param element n/a
271     * @return n/a
272     */
273    @Override
274    public Row<T> set(int index, Row<T> element) {
275        throw new UnsupportedOperationException("This list is immutable.");
276    }
277
278    /**
279     * Unsupported. Throws UnsupportedOperationException.
280     * @param index n/a
281     * @param element n/a
282     */
283    @Override
284    public void add(int index, Row<T> element) {
285        throw new UnsupportedOperationException("This list is immutable.");
286    }
287
288    /**
289     * Unsupported. Throws UnsupportedOperationException.
290     * @param index n/a
291     * @return n/a
292     */
293    @Override
294    public Row<T> remove(int index) {
295        throw new UnsupportedOperationException("This list is immutable.");
296    }
297
298    /**
299     * The iterator over the rows.
300     * @param <T> The type of the row.
301     */
302    private static class RowListIterator<T> implements ListIterator<Row<T>> {
303        private int curIndex;
304        private final int size;
305        private final Set<List<T>> set;
306
307        public RowListIterator(Set<List<T>> set) {
308            this(set,0);
309        }
310
311        public RowListIterator(Set<List<T>> set, int curIndex) {
312            this.curIndex = curIndex;
313            this.set = set;
314            this.size = set.iterator().next().size();
315        }
316
317        @Override
318        public boolean hasNext() {
319            return curIndex < size;
320        }
321
322        @Override
323        public Row<T> next() {
324            ArrayList<T> list = new ArrayList<>(set.size());
325            int counter = 0;
326            for (List<T> element : set) {
327                list.add(counter, element.get(curIndex));
328                counter++;
329            }
330            curIndex++;
331            return new Row<>(list);
332        }
333
334        @Override
335        public boolean hasPrevious() {
336            return curIndex > 0;
337        }
338
339        @Override
340        public Row<T> previous() {
341            ArrayList<T> list = new ArrayList<>(set.size());
342            curIndex--;
343            int counter = 0;
344            for (List<T> element : set) {
345                list.add(counter, element.get(curIndex));
346                counter++;
347            }
348            return new Row<>(list);
349        }
350
351        @Override
352        public int nextIndex() {
353            return curIndex;
354        }
355
356        @Override
357        public int previousIndex() {
358            return curIndex - 1;
359        }
360
361        @Override
362        public void remove() {
363            throw new UnsupportedOperationException("The list backing this iterator is immutable.");
364        }
365
366        @Override
367        public void set(Row<T> e) {
368            throw new UnsupportedOperationException("The list backing this iterator is immutable.");
369        }
370
371        @Override
372        public void add(Row<T> e) {
373            throw new UnsupportedOperationException("The list backing this iterator is immutable.");
374        }
375    }
376}
377