001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.datasource;
018
019import com.oracle.labs.mlrg.olcut.provenance.ListProvenance;
020import com.oracle.labs.mlrg.olcut.provenance.ObjectProvenance;
021import com.oracle.labs.mlrg.olcut.provenance.Provenance;
022import com.oracle.labs.mlrg.olcut.provenance.primitives.StringProvenance;
023import com.oracle.labs.mlrg.olcut.util.Pair;
024import org.tribuo.DataSource;
025import org.tribuo.Example;
026import org.tribuo.Output;
027import org.tribuo.OutputFactory;
028import org.tribuo.provenance.DataSourceProvenance;
029
030import java.util.ArrayList;
031import java.util.Collections;
032import java.util.Iterator;
033import java.util.List;
034import java.util.Map;
035import java.util.NoSuchElementException;
036import java.util.Objects;
037
038/**
039 * Aggregates multiple {@link DataSource}s, and round-robins the iterators.
040 */
041public class AggregateDataSource<T extends Output<T>> implements DataSource<T> {
042    
043    private final List<DataSource<T>> sources;
044
045    public AggregateDataSource(List<DataSource<T>> sources) {
046        this.sources = Collections.unmodifiableList(new ArrayList<>(sources));
047    }
048    
049    @Override
050    public String toString() {
051        return "AggregateDataSource(sources="+sources.toString()+")";
052    }
053
054    @Override
055    public OutputFactory<T> getOutputFactory() {
056        return sources.get(0).getOutputFactory();
057    }
058
059    @Override
060    public Iterator<Example<T>> iterator() {
061        return new ADSIterator();
062    }
063
064    @Override
065    public DataSourceProvenance getProvenance() {
066        return new AggregateDataSourceProvenance(this);
067    }
068
069    private class ADSIterator implements Iterator<Example<T>> {
070        Iterator<DataSource<T>> si = sources.iterator();
071        Iterator<Example<T>> curr = null;
072        @Override
073        public boolean hasNext() {
074            if (curr == null) {
075                if(si.hasNext()) {
076                    DataSource<T> nds = si.next();
077                    curr = nds.iterator();
078                    return hasNext();
079                } else {
080                    return false;
081                }
082            } else {
083                if(curr.hasNext()) {
084                    return true;
085                } else {
086                    curr = null;
087                    return hasNext();
088                }
089            }
090        }
091
092        @Override
093        public Example<T> next() {
094            if (hasNext()) {
095                return curr.next();
096            } else {
097                throw new NoSuchElementException("No more data");
098            }
099        }
100    }
101
102    /**
103     * Provenance for the {@link AggregateDataSource}.
104     */
105    public static class AggregateDataSourceProvenance implements DataSourceProvenance {
106        private static final long serialVersionUID = 1L;
107
108        private static final String SOURCES = "sources";
109
110        private final StringProvenance className;
111        private final ListProvenance<DataSourceProvenance> provenances;
112
113        <T extends Output<T>> AggregateDataSourceProvenance(AggregateDataSource<T> host) {
114            this.className = new StringProvenance(CLASS_NAME,host.getClass().getName());
115            this.provenances = ListProvenance.createListProvenance(host.sources);
116        }
117
118        @SuppressWarnings("unchecked") //ListProvenance cast
119        public AggregateDataSourceProvenance(Map<String,Provenance> map) {
120            this.className = ObjectProvenance.checkAndExtractProvenance(map,CLASS_NAME, StringProvenance.class,AggregateDataSourceProvenance.class.getSimpleName());
121            this.provenances = ObjectProvenance.checkAndExtractProvenance(map,SOURCES,ListProvenance.class,AggregateDataSourceProvenance.class.getSimpleName());
122        }
123
124        @Override
125        public String getClassName() {
126            return className.getValue();
127        }
128
129        @Override
130        public Iterator<Pair<String, Provenance>> iterator() {
131            ArrayList<Pair<String,Provenance>> list = new ArrayList<>();
132
133            list.add(new Pair<>(CLASS_NAME,className));
134            list.add(new Pair<>(SOURCES,provenances));
135
136            return list.iterator();
137        }
138
139        @Override
140        public boolean equals(Object o) {
141            if (this == o) return true;
142            if (!(o instanceof AggregateDataSourceProvenance)) return false;
143            AggregateDataSourceProvenance pairs = (AggregateDataSourceProvenance) o;
144            return className.equals(pairs.className) &&
145                    provenances.equals(pairs.provenances);
146        }
147
148        @Override
149        public int hashCode() {
150            return Objects.hash(className, provenances);
151        }
152
153        @Override
154        public String toString() {
155            return generateString("DataSource");
156        }
157    }
158}