001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.provenance;
018
019import com.oracle.labs.mlrg.olcut.provenance.MapProvenance;
020import com.oracle.labs.mlrg.olcut.provenance.ObjectProvenance;
021import com.oracle.labs.mlrg.olcut.provenance.Provenance;
022import com.oracle.labs.mlrg.olcut.provenance.primitives.DateTimeProvenance;
023import com.oracle.labs.mlrg.olcut.provenance.primitives.StringProvenance;
024import com.oracle.labs.mlrg.olcut.util.Pair;
025import org.tribuo.Tribuo;
026
027import java.time.OffsetDateTime;
028import java.util.ArrayList;
029import java.util.Iterator;
030import java.util.Map;
031import java.util.Objects;
032
033/**
034 * Contains provenance information for an instance of a {@link org.tribuo.Model}.
035 * <p>
036 * Made up of the class name of the model object, the date and time it was trained, the provenance of
037 * the training data, and the provenance of the trainer.
038 */
039public class ModelProvenance implements ObjectProvenance {
040    private static final long serialVersionUID = 1L;
041
042    protected static final String DATASET = "dataset";
043    protected static final String TRAINER = "trainer";
044    protected static final String TRAINING_TIME = "trained-at";
045    protected static final String INSTANCE_VALUES = "instance-values";
046    private static final String TRIBUO_VERSION_STRING = "tribuo-version";
047
048    protected final String className;
049
050    protected final OffsetDateTime time;
051
052    protected final DatasetProvenance datasetProvenance;
053
054    protected final TrainerProvenance trainerProvenance;
055
056    protected final MapProvenance<? extends Provenance> instanceProvenance;
057
058    protected final String versionString;
059
060    public ModelProvenance(String className, OffsetDateTime time, DatasetProvenance datasetProvenance, TrainerProvenance trainerProvenance) {
061        this.className = className;
062        this.time = time;
063        this.datasetProvenance = datasetProvenance;
064        this.trainerProvenance = trainerProvenance;
065        this.instanceProvenance = new MapProvenance<>();
066        this.versionString = Tribuo.VERSION;
067    }
068
069    public ModelProvenance(String className, OffsetDateTime time, DatasetProvenance datasetProvenance, TrainerProvenance trainerProvenance, Map<String,Provenance> instanceProvenance) {
070        this.className = className;
071        this.time = time;
072        this.datasetProvenance = datasetProvenance;
073        this.trainerProvenance = trainerProvenance;
074        this.instanceProvenance = new MapProvenance<>(instanceProvenance);
075        this.versionString = Tribuo.VERSION;
076    }
077
078    public ModelProvenance(Map<String,Provenance> map) {
079        this.className = ObjectProvenance.checkAndExtractProvenance(map,CLASS_NAME,StringProvenance.class, ModelProvenance.class.getSimpleName()).getValue();
080        this.datasetProvenance = ObjectProvenance.checkAndExtractProvenance(map,DATASET,DatasetProvenance.class, ModelProvenance.class.getSimpleName());
081        this.trainerProvenance = ObjectProvenance.checkAndExtractProvenance(map,TRAINER,TrainerProvenance.class, ModelProvenance.class.getSimpleName());
082        this.time = ObjectProvenance.checkAndExtractProvenance(map,TRAINING_TIME,DateTimeProvenance.class, ModelProvenance.class.getSimpleName()).getValue();
083        this.instanceProvenance = (MapProvenance<?>) ObjectProvenance.checkAndExtractProvenance(map,INSTANCE_VALUES,MapProvenance.class, ModelProvenance.class.getSimpleName());
084        this.versionString = ObjectProvenance.checkAndExtractProvenance(map, TRIBUO_VERSION_STRING,StringProvenance.class, DatasetProvenance.class.getSimpleName()).getValue();
085    }
086
087    /**
088     * The training timestamp.
089     * @return The timestamp.
090     */
091    public OffsetDateTime getTrainingTime() {
092        return time;
093    }
094
095    /**
096     * The training dataset provenance.
097     * @return The training dataset provenance.
098     */
099    public DatasetProvenance getDatasetProvenance() {
100        return datasetProvenance;
101    }
102
103    /**
104     * The trainer provenance.
105     * @return The trainer provenance.
106     */
107    public TrainerProvenance getTrainerProvenance() {
108        return trainerProvenance;
109    }
110
111    /**
112     * Provenance for the specific training run which created this model.
113     * @return The instance provenance.
114     */
115    public MapProvenance<? extends Provenance> getInstanceProvenance() {
116        return instanceProvenance;
117    }
118
119    /**
120     * The Tribuo version used to create this dataset.
121     * @return The Tribuo version.
122     */
123    public String getTribuoVersion() {
124        return versionString;
125    }
126
127    @Override
128    public String toString() {
129        return generateString("Model");
130    }
131
132    @Override
133    public String getClassName() {
134        return className;
135    }
136
137    @Override
138    public boolean equals(Object o) {
139        if (this == o) return true;
140        if (!(o instanceof ModelProvenance)) return false;
141        ModelProvenance pairs = (ModelProvenance) o;
142        return className.equals(pairs.className) &&
143                time.equals(pairs.time) &&
144                datasetProvenance.equals(pairs.datasetProvenance) &&
145                trainerProvenance.equals(pairs.trainerProvenance) &&
146                instanceProvenance.equals(pairs.instanceProvenance) &&
147                versionString.equals(pairs.versionString);
148    }
149
150    @Override
151    public int hashCode() {
152        return Objects.hash(className, time, datasetProvenance, trainerProvenance, instanceProvenance, versionString);
153    }
154
155    @Override
156    public Iterator<Pair<String, Provenance>> iterator() {
157        ArrayList<Pair<String,Provenance>> iterable = new ArrayList<>();
158        iterable.add(new Pair<>(CLASS_NAME,new StringProvenance(CLASS_NAME,className)));
159        iterable.add(new Pair<>(DATASET,datasetProvenance));
160        iterable.add(new Pair<>(TRAINER,trainerProvenance));
161        iterable.add(new Pair<>(TRAINING_TIME,new DateTimeProvenance(TRAINING_TIME,time)));
162        iterable.add(new Pair<>(INSTANCE_VALUES,instanceProvenance));
163        iterable.add(new Pair<>(TRIBUO_VERSION_STRING,new StringProvenance(TRIBUO_VERSION_STRING,versionString)));
164        return iterable.iterator();
165    }
166}