001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.sequence; 018 019import com.oracle.labs.mlrg.olcut.provenance.Provenancable; 020import org.tribuo.Dataset; 021import org.tribuo.Example; 022import org.tribuo.FeatureMap; 023import org.tribuo.ImmutableDataset; 024import org.tribuo.ImmutableFeatureMap; 025import org.tribuo.ImmutableOutputInfo; 026import org.tribuo.Output; 027import org.tribuo.OutputFactory; 028import org.tribuo.OutputInfo; 029import org.tribuo.provenance.DataProvenance; 030import org.tribuo.provenance.DatasetProvenance; 031 032import java.io.Serializable; 033import java.util.ArrayList; 034import java.util.Collections; 035import java.util.Iterator; 036import java.util.List; 037import java.util.Set; 038import java.util.logging.Logger; 039 040/** 041 * A class for sets of data, which are used to train and evaluate classifiers. 042 * <p> 043 * Subclass either {@link MutableSequenceDataset} or {@link ImmutableSequenceDataset} rather than this class. 044 * 045 * @param <T> the type of the outputs in the data set. 046 */ 047public abstract class SequenceDataset<T extends Output<T>> implements Iterable<SequenceExample<T>>, Provenancable<DatasetProvenance>, Serializable { 048 private static final Logger logger = Logger.getLogger(SequenceDataset.class.getName()); 049 private static final long serialVersionUID = 2L; 050 051 /** 052 * A factory for making {@link OutputInfo} and {@link Output} of the appropriate type. 053 */ 054 protected final OutputFactory<T> outputFactory; 055 056 /** 057 * The data in this data set. 058 */ 059 protected final List<SequenceExample<T>> data = new ArrayList<>(); 060 061 /** 062 * The provenance of the data source, extracted on construction. 063 */ 064 protected final DataProvenance sourceProvenance; 065 066 protected SequenceDataset(DataProvenance sourceProvenance, OutputFactory<T> outputFactory) { 067 this.sourceProvenance = sourceProvenance; 068 this.outputFactory = outputFactory; 069 } 070 071 /** 072 * Returns the description of the source provenance. 073 * @return The source provenance in text form. 074 */ 075 public String getSourceDescription() { 076 return "SequenceDataset(source=" + sourceProvenance.toString() + ")"; 077 } 078 079 /** 080 * Returns an unmodifiable view on the data. 081 * @return The data. 082 */ 083 public List<SequenceExample<T>> getData() { 084 return Collections.unmodifiableList(data); 085 } 086 087 /** 088 * Returns the source provenance. 089 * @return The source provenance. 090 */ 091 public DataProvenance getSourceProvenance() { 092 return sourceProvenance; 093 } 094 095 /** 096 * Gets the set of labels that occur in the examples in this dataset. 097 * 098 * @return the set of labels that occur in the examples in this dataset. 099 */ 100 public abstract Set<T> getOutputs(); 101 102 /** 103 * Gets the example at the specified index, or throws IllegalArgumentException if 104 * the index is out of bounds. 105 * @param index The index. 106 * @return The example at that index. 107 */ 108 public SequenceExample<T> getExample(int index) { 109 if ((index < 0) || (index >= size())) { 110 throw new IllegalArgumentException("Example index " + index + " is out of bounds."); 111 } 112 return data.get(index); 113 } 114 115 /** 116 * Returns a view on this SequenceDataset which aggregates all 117 * the examples and ignores the sequence structure. 118 * 119 * @return A flattened view on this dataset. 120 */ 121 public Dataset<T> getFlatDataset() { 122 return new FlatDataset<>(this); 123 } 124 125 /** 126 * Gets the size of the data set. 127 * 128 * @return the size of the data set. 129 */ 130 public int size() { 131 return data.size(); 132 } 133 134 /** 135 * An immutable view on the output info in this dataset. 136 * @return The output info. 137 */ 138 public abstract ImmutableOutputInfo<T> getOutputIDInfo(); 139 140 /** 141 * The output info in this dataset. 142 * @return The output info. 143 */ 144 public abstract OutputInfo<T> getOutputInfo(); 145 146 /** 147 * An immutable view on the feature map. 148 * @return The feature map. 149 */ 150 public abstract ImmutableFeatureMap getFeatureIDMap(); 151 152 /** 153 * The feature map. 154 * @return The feature map. 155 */ 156 public abstract FeatureMap getFeatureMap(); 157 158 /** 159 * Gets the output factory. 160 * @return The output factory. 161 */ 162 public OutputFactory<T> getOutputFactory() { 163 return outputFactory; 164 } 165 166 @Override 167 public Iterator<SequenceExample<T>> iterator() { 168 return data.iterator(); 169 } 170 171 @Override 172 public String toString() { 173 return "SequenceDataset(source=" + sourceProvenance.toString() + ")"; 174 } 175 176 private static class FlatDataset<T extends Output<T>> extends ImmutableDataset<T> { 177 private static final long serialVersionUID = 1L; 178 179 public FlatDataset(SequenceDataset<T> sequenceDataset) { 180 super(sequenceDataset.sourceProvenance, sequenceDataset.outputFactory, sequenceDataset.getFeatureIDMap(), sequenceDataset.getOutputIDInfo()); 181 for (SequenceExample<T> seq : sequenceDataset) { 182 for (Example<T> e : seq) { 183 data.add(e); 184 } 185 } 186 } 187 } 188 189} 190