001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.provenance; 018 019import com.oracle.labs.mlrg.olcut.provenance.MapProvenance; 020import com.oracle.labs.mlrg.olcut.provenance.ObjectProvenance; 021import com.oracle.labs.mlrg.olcut.provenance.Provenance; 022import com.oracle.labs.mlrg.olcut.provenance.primitives.DateTimeProvenance; 023import com.oracle.labs.mlrg.olcut.provenance.primitives.StringProvenance; 024import com.oracle.labs.mlrg.olcut.util.Pair; 025import org.tribuo.Tribuo; 026 027import java.time.OffsetDateTime; 028import java.util.ArrayList; 029import java.util.Iterator; 030import java.util.Map; 031import java.util.Objects; 032 033/** 034 * Contains provenance information for an instance of a {@link org.tribuo.Model}. 035 * <p> 036 * Made up of the class name of the model object, the date and time it was trained, the provenance of 037 * the training data, and the provenance of the trainer. 038 */ 039public class ModelProvenance implements ObjectProvenance { 040 private static final long serialVersionUID = 1L; 041 042 protected static final String DATASET = "dataset"; 043 protected static final String TRAINER = "trainer"; 044 protected static final String TRAINING_TIME = "trained-at"; 045 protected static final String INSTANCE_VALUES = "instance-values"; 046 private static final String TRIBUO_VERSION_STRING = "tribuo-version"; 047 048 protected final String className; 049 050 protected final OffsetDateTime time; 051 052 protected final DatasetProvenance datasetProvenance; 053 054 protected final TrainerProvenance trainerProvenance; 055 056 protected final MapProvenance<? extends Provenance> instanceProvenance; 057 058 protected final String versionString; 059 060 public ModelProvenance(String className, OffsetDateTime time, DatasetProvenance datasetProvenance, TrainerProvenance trainerProvenance) { 061 this.className = className; 062 this.time = time; 063 this.datasetProvenance = datasetProvenance; 064 this.trainerProvenance = trainerProvenance; 065 this.instanceProvenance = new MapProvenance<>(); 066 this.versionString = Tribuo.VERSION; 067 } 068 069 public ModelProvenance(String className, OffsetDateTime time, DatasetProvenance datasetProvenance, TrainerProvenance trainerProvenance, Map<String,Provenance> instanceProvenance) { 070 this.className = className; 071 this.time = time; 072 this.datasetProvenance = datasetProvenance; 073 this.trainerProvenance = trainerProvenance; 074 this.instanceProvenance = new MapProvenance<>(instanceProvenance); 075 this.versionString = Tribuo.VERSION; 076 } 077 078 public ModelProvenance(Map<String,Provenance> map) { 079 this.className = ObjectProvenance.checkAndExtractProvenance(map,CLASS_NAME,StringProvenance.class, ModelProvenance.class.getSimpleName()).getValue(); 080 this.datasetProvenance = ObjectProvenance.checkAndExtractProvenance(map,DATASET,DatasetProvenance.class, ModelProvenance.class.getSimpleName()); 081 this.trainerProvenance = ObjectProvenance.checkAndExtractProvenance(map,TRAINER,TrainerProvenance.class, ModelProvenance.class.getSimpleName()); 082 this.time = ObjectProvenance.checkAndExtractProvenance(map,TRAINING_TIME,DateTimeProvenance.class, ModelProvenance.class.getSimpleName()).getValue(); 083 this.instanceProvenance = (MapProvenance<?>) ObjectProvenance.checkAndExtractProvenance(map,INSTANCE_VALUES,MapProvenance.class, ModelProvenance.class.getSimpleName()); 084 this.versionString = ObjectProvenance.checkAndExtractProvenance(map, TRIBUO_VERSION_STRING,StringProvenance.class, DatasetProvenance.class.getSimpleName()).getValue(); 085 } 086 087 /** 088 * The training timestamp. 089 * @return The timestamp. 090 */ 091 public OffsetDateTime getTrainingTime() { 092 return time; 093 } 094 095 /** 096 * The training dataset provenance. 097 * @return The training dataset provenance. 098 */ 099 public DatasetProvenance getDatasetProvenance() { 100 return datasetProvenance; 101 } 102 103 /** 104 * The trainer provenance. 105 * @return The trainer provenance. 106 */ 107 public TrainerProvenance getTrainerProvenance() { 108 return trainerProvenance; 109 } 110 111 /** 112 * Provenance for the specific training run which created this model. 113 * @return The instance provenance. 114 */ 115 public MapProvenance<? extends Provenance> getInstanceProvenance() { 116 return instanceProvenance; 117 } 118 119 /** 120 * The Tribuo version used to create this dataset. 121 * @return The Tribuo version. 122 */ 123 public String getTribuoVersion() { 124 return versionString; 125 } 126 127 @Override 128 public String toString() { 129 return generateString("Model"); 130 } 131 132 @Override 133 public String getClassName() { 134 return className; 135 } 136 137 @Override 138 public boolean equals(Object o) { 139 if (this == o) return true; 140 if (!(o instanceof ModelProvenance)) return false; 141 ModelProvenance pairs = (ModelProvenance) o; 142 return className.equals(pairs.className) && 143 time.equals(pairs.time) && 144 datasetProvenance.equals(pairs.datasetProvenance) && 145 trainerProvenance.equals(pairs.trainerProvenance) && 146 instanceProvenance.equals(pairs.instanceProvenance) && 147 versionString.equals(pairs.versionString); 148 } 149 150 @Override 151 public int hashCode() { 152 return Objects.hash(className, time, datasetProvenance, trainerProvenance, instanceProvenance, versionString); 153 } 154 155 @Override 156 public Iterator<Pair<String, Provenance>> iterator() { 157 ArrayList<Pair<String,Provenance>> iterable = new ArrayList<>(); 158 iterable.add(new Pair<>(CLASS_NAME,new StringProvenance(CLASS_NAME,className))); 159 iterable.add(new Pair<>(DATASET,datasetProvenance)); 160 iterable.add(new Pair<>(TRAINER,trainerProvenance)); 161 iterable.add(new Pair<>(TRAINING_TIME,new DateTimeProvenance(TRAINING_TIME,time))); 162 iterable.add(new Pair<>(INSTANCE_VALUES,instanceProvenance)); 163 iterable.add(new Pair<>(TRIBUO_VERSION_STRING,new StringProvenance(TRIBUO_VERSION_STRING,versionString))); 164 return iterable.iterator(); 165 } 166}