001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.provenance;
018
019import com.oracle.labs.mlrg.olcut.provenance.ListProvenance;
020import com.oracle.labs.mlrg.olcut.provenance.ObjectProvenance;
021import com.oracle.labs.mlrg.olcut.provenance.Provenance;
022import com.oracle.labs.mlrg.olcut.provenance.primitives.BooleanProvenance;
023import com.oracle.labs.mlrg.olcut.provenance.primitives.IntProvenance;
024import com.oracle.labs.mlrg.olcut.provenance.primitives.StringProvenance;
025import com.oracle.labs.mlrg.olcut.util.Pair;
026import org.tribuo.Dataset;
027import org.tribuo.MutableDataset;
028import org.tribuo.Output;
029import org.tribuo.Tribuo;
030import org.tribuo.sequence.SequenceDataset;
031
032import java.util.ArrayList;
033import java.util.Iterator;
034import java.util.List;
035import java.util.Map;
036import java.util.Objects;
037
038/**
039 * Base class for dataset provenance.
040 * <p>
041 * Dataset provenance can be a chain of other DataProvenances which track operations like selection
042 * and subsampling.
043 * </p>
044 */
045public class DatasetProvenance implements DataProvenance, ObjectProvenance {
046    private static final long serialVersionUID = 1L;
047
048    static final String DATASOURCE = "datasource";
049    static final String TRANSFORMATIONS = "transformations";
050    static final String IS_DENSE = "is-dense";
051    static final String IS_SEQUENCE = "is-sequence";
052    static final String NUM_EXAMPLES = "num-examples";
053    static final String NUM_FEATURES = "num-features";
054    static final String NUM_OUTPUTS = "num-outputs";
055    private static final String TRIBUO_VERSION_STRING = "tribuo-version";
056
057    private final String className;
058
059    private final DataProvenance sourceProvenance;
060
061    private final ListProvenance<ObjectProvenance> transformationProvenance;
062
063    private final boolean isDense;
064
065    private final boolean isSequence;
066
067    private final int numExamples;
068
069    private final int numFeatures;
070
071    private final int numOutputs;
072
073    private final String versionString;
074
075    public <T extends Output<T>> DatasetProvenance(DataProvenance sourceProvenance, ListProvenance<ObjectProvenance> transformationProvenance, Dataset<T> dataset) {
076        this(sourceProvenance,transformationProvenance,dataset.getClass().getName(),dataset instanceof MutableDataset && ((MutableDataset<T>) dataset).isDense(),false,dataset.size(),dataset.getFeatureMap().size(),dataset.getOutputInfo().size());
077    }
078
079    public <T extends Output<T>> DatasetProvenance(DataProvenance sourceProvenance, ListProvenance<ObjectProvenance> transformationProvenance, SequenceDataset<T> dataset) {
080        this(sourceProvenance,transformationProvenance,dataset.getClass().getName(),false,true,dataset.size(),dataset.getFeatureMap().size(),dataset.getOutputInfo().size());
081    }
082
083    protected DatasetProvenance(DataProvenance sourceProvenance, ListProvenance<ObjectProvenance> transformationProvenance, String datasetClassName, boolean isDense, boolean isSequence, int numExamples, int numFeatures, int numOutputs) {
084        this.className = datasetClassName;
085        this.sourceProvenance = sourceProvenance;
086        this.transformationProvenance = transformationProvenance;
087        this.isDense = isDense;
088        this.isSequence = isSequence;
089        this.numExamples = numExamples;
090        this.numFeatures = numFeatures;
091        this.numOutputs = numOutputs;
092        this.versionString = Tribuo.VERSION;
093    }
094
095    @SuppressWarnings("unchecked") //ListProvenance assignment
096    public DatasetProvenance(Map<String,Provenance> map) {
097        this.className = ObjectProvenance.checkAndExtractProvenance(map,CLASS_NAME,StringProvenance.class, DatasetProvenance.class.getSimpleName()).getValue();
098        this.sourceProvenance = ObjectProvenance.checkAndExtractProvenance(map,DATASOURCE,DataProvenance.class, DatasetProvenance.class.getSimpleName());
099        this.transformationProvenance = ObjectProvenance.checkAndExtractProvenance(map,TRANSFORMATIONS,ListProvenance.class, DatasetProvenance.class.getSimpleName());
100        this.isDense = ObjectProvenance.checkAndExtractProvenance(map,IS_DENSE,BooleanProvenance.class, DatasetProvenance.class.getSimpleName()).getValue();
101        this.isSequence = ObjectProvenance.checkAndExtractProvenance(map,IS_SEQUENCE,BooleanProvenance.class, DatasetProvenance.class.getSimpleName()).getValue();
102        this.numExamples = ObjectProvenance.checkAndExtractProvenance(map,NUM_EXAMPLES,IntProvenance.class, DatasetProvenance.class.getSimpleName()).getValue();
103        this.numFeatures = ObjectProvenance.checkAndExtractProvenance(map,NUM_FEATURES,IntProvenance.class, DatasetProvenance.class.getSimpleName()).getValue();
104        this.numOutputs = ObjectProvenance.checkAndExtractProvenance(map,NUM_OUTPUTS,IntProvenance.class, DatasetProvenance.class.getSimpleName()).getValue();
105        this.versionString = ObjectProvenance.checkAndExtractProvenance(map, TRIBUO_VERSION_STRING,StringProvenance.class, DatasetProvenance.class.getSimpleName()).getValue();
106    }
107
108    @Override
109    public String getClassName() {
110        return className;
111    }
112
113    /**
114     * The input data provenance.
115     * @return The data provenance.
116     */
117    public DataProvenance getSourceProvenance() {
118        return sourceProvenance;
119    }
120
121    /**
122     * The transformation provenances, in application order.
123     * @return The transformation provenances.
124     */
125    public ListProvenance<ObjectProvenance> getTransformationProvenance() {
126        return transformationProvenance;
127    }
128
129    /**
130     * Is the Dataset dense?
131     * @return True if dense.
132     */
133    public boolean isDense() {
134        return isDense;
135    }
136
137    /**
138     * Is it a sequence dataset?
139     * @return True if a sequence dataset.
140     */
141    public boolean isSequence() {
142        return isSequence;
143    }
144
145    /**
146     * The number of examples.
147     * @return The number of examples.
148     */
149    public int getNumExamples() {
150        return numExamples;
151    }
152
153    /**
154     * The number of features.
155     * @return The number of features.
156     */
157    public int getNumFeatures() {
158        return numFeatures;
159    }
160
161    /**
162     * The number of output dimensions.
163     * @return The number of output dimensions.
164     */
165    public int getNumOutputs() {
166        return numOutputs;
167    }
168
169    /**
170     * The Tribuo version used to create this dataset.
171     * @return The Tribuo version.
172     */
173    public String getTribuoVersion() {
174        return versionString;
175    }
176
177    @Override
178    public Iterator<Pair<String, Provenance>> iterator() {
179        List<Pair<String,Provenance>> iterable = allProvenances();
180        return iterable.iterator();
181    }
182
183    protected List<Pair<String,Provenance>> allProvenances() {
184        ArrayList<Pair<String,Provenance>> provenances = new ArrayList<>();
185        provenances.add(new Pair<>(CLASS_NAME,new StringProvenance(CLASS_NAME,className)));
186        provenances.add(new Pair<>(DATASOURCE,sourceProvenance));
187        provenances.add(new Pair<>(TRANSFORMATIONS,transformationProvenance));
188        provenances.add(new Pair<>(IS_SEQUENCE,new BooleanProvenance(IS_SEQUENCE,isSequence)));
189        provenances.add(new Pair<>(IS_DENSE,new BooleanProvenance(IS_DENSE,isDense)));
190        provenances.add(new Pair<>(NUM_EXAMPLES,new IntProvenance(NUM_EXAMPLES,numExamples)));
191        provenances.add(new Pair<>(NUM_FEATURES,new IntProvenance(NUM_FEATURES,numFeatures)));
192        provenances.add(new Pair<>(NUM_OUTPUTS,new IntProvenance(NUM_OUTPUTS,numOutputs)));
193        provenances.add(new Pair<>(TRIBUO_VERSION_STRING,new StringProvenance(TRIBUO_VERSION_STRING,versionString)));
194        return provenances;
195    }
196
197    @Override
198    public boolean equals(Object o) {
199        if (this == o) return true;
200        if (!(o instanceof DatasetProvenance)) return false;
201        DatasetProvenance pairs = (DatasetProvenance) o;
202        return isDense == pairs.isDense &&
203                isSequence == pairs.isSequence &&
204                numExamples == pairs.numExamples &&
205                numFeatures == pairs.numFeatures &&
206                numOutputs == pairs.numOutputs &&
207                className.equals(pairs.className) &&
208                sourceProvenance.equals(pairs.sourceProvenance) &&
209                transformationProvenance.equals(pairs.transformationProvenance) &&
210                versionString.equals(pairs.versionString);
211    }
212
213    @Override
214    public int hashCode() {
215        return Objects.hash(className, sourceProvenance, transformationProvenance, isDense, isSequence, numExamples, numFeatures, numOutputs, versionString);
216    }
217
218    @Override
219    public String toString() {
220        if (isSequence) {
221            return generateString("SequenceDataset");
222        } else {
223            return generateString("Dataset");
224        }
225    }
226}