001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.provenance; 018 019import com.oracle.labs.mlrg.olcut.provenance.ObjectProvenance; 020import com.oracle.labs.mlrg.olcut.provenance.Provenance; 021import com.oracle.labs.mlrg.olcut.provenance.primitives.StringProvenance; 022import com.oracle.labs.mlrg.olcut.util.Pair; 023import org.tribuo.Tribuo; 024 025import java.util.ArrayList; 026import java.util.Iterator; 027import java.util.Map; 028import java.util.Objects; 029 030/** 031 * Provenance for evaluations. 032 */ 033public final class EvaluationProvenance implements ObjectProvenance { 034 private static final long serialVersionUID = 1L; 035 036 private static final String MODEL_PROVENANCE_NAME = "model-provenance"; 037 private static final String DATASET_PROVENANCE_NAME = "dataset-provenance"; 038 private static final String TRIBUO_VERSION_STRING = "tribuo-version"; 039 040 private final StringProvenance className; 041 private final ModelProvenance modelProvenance; 042 private final DataProvenance datasetProvenance; 043 private final StringProvenance versionString; 044 045 public EvaluationProvenance(ModelProvenance modelProvenance, DataProvenance datasetProvenance) { 046 this.className = new StringProvenance(CLASS_NAME, EvaluationProvenance.class.getName()); 047 this.modelProvenance = modelProvenance; 048 this.datasetProvenance = datasetProvenance; 049 this.versionString = new StringProvenance(TRIBUO_VERSION_STRING,Tribuo.VERSION); 050 } 051 052 public EvaluationProvenance(Map<String,Provenance> map) { 053 this.className = ObjectProvenance.checkAndExtractProvenance(map,CLASS_NAME,StringProvenance.class, EvaluationProvenance.class.getSimpleName()); 054 this.modelProvenance = ObjectProvenance.checkAndExtractProvenance(map,MODEL_PROVENANCE_NAME,ModelProvenance.class, EvaluationProvenance.class.getSimpleName()); 055 this.datasetProvenance = ObjectProvenance.checkAndExtractProvenance(map,DATASET_PROVENANCE_NAME,DataProvenance.class, EvaluationProvenance.class.getSimpleName()); 056 this.versionString = ObjectProvenance.checkAndExtractProvenance(map,TRIBUO_VERSION_STRING,StringProvenance.class, EvaluationProvenance.class.getSimpleName()); 057 } 058 059 @Override 060 public String getClassName() { 061 return className.getValue(); 062 } 063 064 /** 065 * The test dataset provenance. 066 * @return The test dataset provenance. 067 */ 068 public DataProvenance getTestDatasetProvenance() { 069 return datasetProvenance; 070 } 071 072 /** 073 * The model provenance. 074 * @return The model provenance. 075 */ 076 public ModelProvenance getModelProvenance() { 077 return modelProvenance; 078 } 079 080 /** 081 * The Tribuo version used to create this dataset. 082 * @return The Tribuo version. 083 */ 084 public String getTribuoVersion() { 085 return versionString.getValue(); 086 } 087 088 @Override 089 public Iterator<Pair<String, Provenance>> iterator() { 090 ArrayList<Pair<String,Provenance>> list = new ArrayList<>(); 091 list.add(new Pair<>(CLASS_NAME, className)); 092 list.add(new Pair<>(MODEL_PROVENANCE_NAME, modelProvenance)); 093 list.add(new Pair<>(DATASET_PROVENANCE_NAME, datasetProvenance)); 094 list.add(new Pair<>(TRIBUO_VERSION_STRING,versionString)); 095 return list.iterator(); 096 } 097 098 @Override 099 public String toString() { 100 return generateString("Evaluation"); 101 } 102 103 @Override 104 public boolean equals(Object o) { 105 if (this == o) return true; 106 if (o == null || getClass() != o.getClass()) return false; 107 EvaluationProvenance pairs = (EvaluationProvenance) o; 108 return Objects.equals(className, pairs.className) && 109 Objects.equals(modelProvenance, pairs.modelProvenance) && 110 Objects.equals(datasetProvenance, pairs.datasetProvenance) && 111 Objects.equals(versionString, pairs.versionString); 112 } 113 114 @Override 115 public int hashCode() { 116 return Objects.hash(className, modelProvenance, datasetProvenance, versionString); 117 } 118}