Package org.tribuo

Class Dataset<T extends Output<T>>

java.lang.Object
org.tribuo.Dataset<T>
Type Parameters:
T - the type of the features in the data set.
All Implemented Interfaces:
com.oracle.labs.mlrg.olcut.provenance.Provenancable<DatasetProvenance>, Serializable, Iterable<Example<T>>, ProtoSerializable<org.tribuo.protos.core.DatasetProto>
Direct Known Subclasses:
ImmutableDataset, MutableDataset

public abstract class Dataset<T extends Output<T>> extends Object implements Iterable<Example<T>>, ProtoSerializable<org.tribuo.protos.core.DatasetProto>, com.oracle.labs.mlrg.olcut.provenance.Provenancable<DatasetProvenance>, Serializable
A class for sets of data, which are used to train and evaluate classifiers.

Subclass MutableDataset rather than this class.

See Also:
  • Field Details

    • data

      protected final List<Example<T extends Output<T>>> data
      The data in this data set.
    • sourceProvenance

      protected final DataProvenance sourceProvenance
      The provenance of the data source, extracted on construction.
    • outputFactory

      protected final OutputFactory<T extends Output<T>> outputFactory
      A factory for making OutputInfo and Output of the appropriate type.
    • tribuoVersion

      protected final String tribuoVersion
      The Tribuo version which originally created this dataset
    • indices

      protected int[] indices
      The indices of the shuffled order.
  • Constructor Details

    • Dataset

      protected Dataset(DataProvenance provenance, OutputFactory<T> outputFactory)
      Creates a dataset.
      Parameters:
      provenance - A description of the data, including preprocessing steps.
      outputFactory - The output factory.
    • Dataset

      protected Dataset(DataProvenance provenance, OutputFactory<T> outputFactory, String tribuoVersion)
      Creates a dataset.
      Parameters:
      provenance - A description of the data, including preprocessing steps.
      outputFactory - The output factory.
      tribuoVersion - The Tribuo version.
    • Dataset

      protected Dataset(DataSource<T> dataSource)
      Creates a dataset.
      Parameters:
      dataSource - the DataSource to use.
  • Method Details

    • deserialize

      public static Dataset<?> deserialize(org.tribuo.protos.core.DatasetProto datasetProto)
      Deserializes a dataset proto into a dataset.
      Parameters:
      datasetProto - The proto to deserialize.
      Returns:
      The dataset.
    • deserializeFromFile

      public static Dataset<?> deserializeFromFile(Path path) throws IOException
      Reads an instance of DatasetProto from the supplied path and deserializes it.
      Parameters:
      path - The path to read.
      Returns:
      The deserialized dataset.
      Throws:
      IOException - If the path could not be read from, or the parsing failed.
    • deserializeFromStream

      public static Dataset<?> deserializeFromStream(InputStream is) throws IOException
      Reads an instance of DatasetProto from the supplied input stream and deserializes it.
      Parameters:
      is - The input stream to read.
      Returns:
      The deserialized dataset.
      Throws:
      IOException - If the stream could not be read from, or the parsing failed.
    • serializeToFile

      public void serializeToFile(Path path) throws IOException
      Serializes this dataset to a DatasetProto and writes it to the supplied path.
      Parameters:
      path - The path to write to.
      Throws:
      IOException - If the path could not be written to.
    • serializeToStream

      public void serializeToStream(OutputStream stream) throws IOException
      Serializes this dataset to a DatasetProto and writes it to the supplied output stream.

      Does not close the stream.

      Parameters:
      stream - The output stream to write to.
      Throws:
      IOException - If the stream could not be written to.
    • getSourceDescription

      public String getSourceDescription()
      A String description of this dataset.
      Returns:
      The description
    • getSourceProvenance

      public DataProvenance getSourceProvenance()
      The provenance of the data this Dataset contains.
      Returns:
      The data provenance.
    • getData

      public List<Example<T>> getData()
      Gets the examples as an unmodifiable list. This list will throw an UnsupportedOperationException if any elements are added to it.

      In other words, using the following to add additional examples to this dataset with throw an exception: dataset.getData().add(example) Instead, use MutableDataset.add(Example).

      Returns:
      The unmodifiable example list.
    • getOutputFactory

      public OutputFactory<T> getOutputFactory()
      Gets the output factory this dataset contains.
      Returns:
      The output factory.
    • getOutputs

      public abstract Set<T> getOutputs()
      Gets the set of outputs that occur in the examples in this dataset.
      Returns:
      the set of outputs that occur in the examples in this dataset.
    • getExample

      public Example<T> getExample(int index)
      Gets the example at the supplied index.

      Throws IllegalArgumentException if the index is invalid or outside the bounds.

      Parameters:
      index - The index of the example.
      Returns:
      The example.
    • size

      public int size()
      Gets the size of the data set.
      Returns:
      the size of the data set.
    • shuffle

      public void shuffle(boolean shuffle)
      Shuffles the indices, or stops shuffling them.

      The shuffle only affects the iterator, it does not affect getExample(int).

      Multiple calls with the argument true will shuffle the dataset multiple times. The RNG is shared across all Dataset instances, so methods which access it are synchronized.

      Using this method will prevent the provenance system from tracking the exact state of the dataset, which may be important for trainers which depend on the example order, like those using stochastic gradient descent.

      Parameters:
      shuffle - If true shuffle the data.
    • getOutputIDInfo

      public abstract ImmutableOutputInfo<T> getOutputIDInfo()
      Returns or generates an ImmutableOutputInfo.
      Returns:
      An immutable output info.
    • getOutputInfo

      public abstract OutputInfo<T> getOutputInfo()
      Returns this dataset's OutputInfo.
      Returns:
      The output info.
    • getFeatureIDMap

      public abstract ImmutableFeatureMap getFeatureIDMap()
      Returns or generates an ImmutableFeatureMap.
      Returns:
      An immutable feature map with id numbers.
    • getFeatureMap

      public abstract FeatureMap getFeatureMap()
      Returns this dataset's FeatureMap.
      Returns:
      The feature map from this dataset.
    • iterator

      public Iterator<Example<T>> iterator()
      Specified by:
      iterator in interface Iterable<T extends Output<T>>
    • toString

      public String toString()
      Overrides:
      toString in class Object
    • createTransformers

      public TransformerMap createTransformers(TransformationMap transformations)
      Takes a TransformationMap and converts it into a TransformerMap by observing all the values in this dataset.

      Does not mutate the dataset, if you wish to apply the TransformerMap, use MutableDataset.transform(org.tribuo.transform.TransformerMap) or TransformerMap.transformDataset(org.tribuo.Dataset<T>).

      TransformerMaps operate on feature values which are present, sparse values are ignored and not transformed. If the zeros should be transformed, call MutableDataset.densify() on the datasets before applying a transformer.

      This method calls createTransformers(TransformationMap, boolean) with includeImplicitZeroFeatures set to false, thus ignoring implicitly zero features when fitting the transformations. This is the default behaviour in Tribuo 4.0, but causes erroneous behaviour in IDFTransformation so should be avoided with that transformation. See org.tribuo.transform for a more detailed discussion of densify and includeImplicitZeroFeatures.

      Throws IllegalArgumentException if the TransformationMap object has regexes which apply to multiple features.

      Parameters:
      transformations - The transformations to fit.
      Returns:
      A TransformerMap which can apply the transformations to a dataset.
    • createTransformers

      public TransformerMap createTransformers(TransformationMap transformations, boolean includeImplicitZeroFeatures)
      Takes a TransformationMap and converts it into a TransformerMap by observing all the values in this dataset.

      Does not mutate the dataset, if you wish to apply the TransformerMap, use MutableDataset.transform(org.tribuo.transform.TransformerMap) or TransformerMap.transformDataset(org.tribuo.Dataset<T>).

      TransformerMaps operate on feature values which are present, sparse values are ignored and not transformed. If the zeros should be transformed, call MutableDataset.densify() on the datasets before applying a transformer. See org.tribuo.transform for a more detailed discussion of densify and includeImplicitZeroFeatures.

      Throws IllegalArgumentException if the TransformationMap object has regexes which apply to multiple features.

      Parameters:
      transformations - The transformations to fit.
      includeImplicitZeroFeatures - Use the implicit zero feature values to construct the transformations.
      Returns:
      A TransformerMap which can apply the transformations to a dataset.
    • createDataCarrier

      protected DatasetDataCarrier<T> createDataCarrier(FeatureMap featureMap, OutputInfo<T> outputInfo)
      Constructs the data carrier for serialization.
      Parameters:
      featureMap - The feature domain.
      outputInfo - The output domain.
      Returns:
      The serialization data carrier.
    • createDataCarrier

      protected DatasetDataCarrier<T> createDataCarrier(FeatureMap featureMap, OutputInfo<T> outputInfo, List<com.oracle.labs.mlrg.olcut.provenance.ObjectProvenance> transformationProvenances)
      Constructs the data carrier for serialization.
      Parameters:
      featureMap - The feature domain.
      outputInfo - The output domain.
      transformationProvenances - The transformation provenances, must be non-null, but can be empty.
      Returns:
      The serialization data carrier.
    • validate

      public boolean validate(Class<? extends Output<?>> clazz)
      Validates that this Dataset does in fact contain the supplied output type.

      As the output type is erased at runtime, deserialising a Dataset is an unchecked operation. This method allows the user to check that the deserialised dataset is of the appropriate type, rather than seeing if the Dataset throws a ClassCastException when used.

      Parameters:
      clazz - The class object to verify the output type against.
      Returns:
      True if the output type is assignable to the class object type, false otherwise.
    • castDataset

      public static <T extends Output<T>> Dataset<T> castDataset(Dataset<?> inputDataset, Class<T> outputType)
      Casts the dataset to the specified output type, assuming it is valid.

      If it's not valid, throws ClassCastException.

      Type Parameters:
      T - The output type.
      Parameters:
      inputDataset - The model to cast.
      outputType - The output type to cast to.
      Returns:
      The model cast to the correct value.
    • deserializeExamples

      protected static List<Example<?>> deserializeExamples(List<org.tribuo.protos.core.ExampleProto> examplesList, Class<?> outputClass, FeatureMap fmap)
      Deserializes a list of example protos into a list of examples.
      Parameters:
      examplesList - The protos.
      outputClass - The output class.
      fmap - The feature domain.
      Returns:
      The list of deserialized examples.