001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.datasource; 018 019import com.oracle.labs.mlrg.olcut.provenance.ListProvenance; 020import com.oracle.labs.mlrg.olcut.provenance.ObjectProvenance; 021import com.oracle.labs.mlrg.olcut.provenance.Provenance; 022import com.oracle.labs.mlrg.olcut.provenance.primitives.StringProvenance; 023import com.oracle.labs.mlrg.olcut.util.Pair; 024import org.tribuo.DataSource; 025import org.tribuo.Example; 026import org.tribuo.Output; 027import org.tribuo.OutputFactory; 028import org.tribuo.provenance.DataSourceProvenance; 029 030import java.util.ArrayList; 031import java.util.Collections; 032import java.util.Iterator; 033import java.util.List; 034import java.util.Map; 035import java.util.NoSuchElementException; 036import java.util.Objects; 037 038/** 039 * Aggregates multiple {@link DataSource}s, and round-robins the iterators. 040 */ 041public class AggregateDataSource<T extends Output<T>> implements DataSource<T> { 042 043 private final List<DataSource<T>> sources; 044 045 public AggregateDataSource(List<DataSource<T>> sources) { 046 this.sources = Collections.unmodifiableList(new ArrayList<>(sources)); 047 } 048 049 @Override 050 public String toString() { 051 return "AggregateDataSource(sources="+sources.toString()+")"; 052 } 053 054 @Override 055 public OutputFactory<T> getOutputFactory() { 056 return sources.get(0).getOutputFactory(); 057 } 058 059 @Override 060 public Iterator<Example<T>> iterator() { 061 return new ADSIterator(); 062 } 063 064 @Override 065 public DataSourceProvenance getProvenance() { 066 return new AggregateDataSourceProvenance(this); 067 } 068 069 private class ADSIterator implements Iterator<Example<T>> { 070 Iterator<DataSource<T>> si = sources.iterator(); 071 Iterator<Example<T>> curr = null; 072 @Override 073 public boolean hasNext() { 074 if (curr == null) { 075 if(si.hasNext()) { 076 DataSource<T> nds = si.next(); 077 curr = nds.iterator(); 078 return hasNext(); 079 } else { 080 return false; 081 } 082 } else { 083 if(curr.hasNext()) { 084 return true; 085 } else { 086 curr = null; 087 return hasNext(); 088 } 089 } 090 } 091 092 @Override 093 public Example<T> next() { 094 if (hasNext()) { 095 return curr.next(); 096 } else { 097 throw new NoSuchElementException("No more data"); 098 } 099 } 100 } 101 102 /** 103 * Provenance for the {@link AggregateDataSource}. 104 */ 105 public static class AggregateDataSourceProvenance implements DataSourceProvenance { 106 private static final long serialVersionUID = 1L; 107 108 private static final String SOURCES = "sources"; 109 110 private final StringProvenance className; 111 private final ListProvenance<DataSourceProvenance> provenances; 112 113 <T extends Output<T>> AggregateDataSourceProvenance(AggregateDataSource<T> host) { 114 this.className = new StringProvenance(CLASS_NAME,host.getClass().getName()); 115 this.provenances = ListProvenance.createListProvenance(host.sources); 116 } 117 118 @SuppressWarnings("unchecked") //ListProvenance cast 119 public AggregateDataSourceProvenance(Map<String,Provenance> map) { 120 this.className = ObjectProvenance.checkAndExtractProvenance(map,CLASS_NAME, StringProvenance.class,AggregateDataSourceProvenance.class.getSimpleName()); 121 this.provenances = ObjectProvenance.checkAndExtractProvenance(map,SOURCES,ListProvenance.class,AggregateDataSourceProvenance.class.getSimpleName()); 122 } 123 124 @Override 125 public String getClassName() { 126 return className.getValue(); 127 } 128 129 @Override 130 public Iterator<Pair<String, Provenance>> iterator() { 131 ArrayList<Pair<String,Provenance>> list = new ArrayList<>(); 132 133 list.add(new Pair<>(CLASS_NAME,className)); 134 list.add(new Pair<>(SOURCES,provenances)); 135 136 return list.iterator(); 137 } 138 139 @Override 140 public boolean equals(Object o) { 141 if (this == o) return true; 142 if (!(o instanceof AggregateDataSourceProvenance)) return false; 143 AggregateDataSourceProvenance pairs = (AggregateDataSourceProvenance) o; 144 return className.equals(pairs.className) && 145 provenances.equals(pairs.provenances); 146 } 147 148 @Override 149 public int hashCode() { 150 return Objects.hash(className, provenances); 151 } 152 153 @Override 154 public String toString() { 155 return generateString("DataSource"); 156 } 157 } 158}