001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.provenance;
018
019import com.oracle.labs.mlrg.olcut.provenance.ObjectProvenance;
020import com.oracle.labs.mlrg.olcut.provenance.Provenance;
021import com.oracle.labs.mlrg.olcut.provenance.primitives.StringProvenance;
022import com.oracle.labs.mlrg.olcut.util.Pair;
023import org.tribuo.Tribuo;
024
025import java.util.ArrayList;
026import java.util.Iterator;
027import java.util.Map;
028import java.util.Objects;
029
030/**
031 * Provenance for evaluations.
032 */
033public final class EvaluationProvenance implements ObjectProvenance {
034    private static final long serialVersionUID = 1L;
035
036    private static final String MODEL_PROVENANCE_NAME = "model-provenance";
037    private static final String DATASET_PROVENANCE_NAME = "dataset-provenance";
038    private static final String TRIBUO_VERSION_STRING = "tribuo-version";
039
040    private final StringProvenance className;
041    private final ModelProvenance modelProvenance;
042    private final DataProvenance datasetProvenance;
043    private final StringProvenance versionString;
044
045    public EvaluationProvenance(ModelProvenance modelProvenance, DataProvenance datasetProvenance) {
046        this.className = new StringProvenance(CLASS_NAME, EvaluationProvenance.class.getName());
047        this.modelProvenance = modelProvenance;
048        this.datasetProvenance = datasetProvenance;
049        this.versionString = new StringProvenance(TRIBUO_VERSION_STRING,Tribuo.VERSION);
050    }
051
052    public EvaluationProvenance(Map<String,Provenance> map) {
053        this.className = ObjectProvenance.checkAndExtractProvenance(map,CLASS_NAME,StringProvenance.class, EvaluationProvenance.class.getSimpleName());
054        this.modelProvenance = ObjectProvenance.checkAndExtractProvenance(map,MODEL_PROVENANCE_NAME,ModelProvenance.class, EvaluationProvenance.class.getSimpleName());
055        this.datasetProvenance = ObjectProvenance.checkAndExtractProvenance(map,DATASET_PROVENANCE_NAME,DataProvenance.class, EvaluationProvenance.class.getSimpleName());
056        this.versionString = ObjectProvenance.checkAndExtractProvenance(map,TRIBUO_VERSION_STRING,StringProvenance.class, EvaluationProvenance.class.getSimpleName());
057    }
058
059    @Override
060    public String getClassName() {
061        return className.getValue();
062    }
063
064    /**
065     * The test dataset provenance.
066     * @return The test dataset provenance.
067     */
068    public DataProvenance getTestDatasetProvenance() {
069        return datasetProvenance;
070    }
071
072    /**
073     * The model provenance.
074     * @return The model provenance.
075     */
076    public ModelProvenance getModelProvenance() {
077        return modelProvenance;
078    }
079
080    /**
081     * The Tribuo version used to create this dataset.
082     * @return The Tribuo version.
083     */
084    public String getTribuoVersion() {
085        return versionString.getValue();
086    }
087
088    @Override
089    public Iterator<Pair<String, Provenance>> iterator() {
090        ArrayList<Pair<String,Provenance>> list = new ArrayList<>();
091        list.add(new Pair<>(CLASS_NAME, className));
092        list.add(new Pair<>(MODEL_PROVENANCE_NAME, modelProvenance));
093        list.add(new Pair<>(DATASET_PROVENANCE_NAME, datasetProvenance));
094        list.add(new Pair<>(TRIBUO_VERSION_STRING,versionString));
095        return list.iterator();
096    }
097
098    @Override
099    public String toString() {
100        return generateString("Evaluation");
101    }
102
103    @Override
104    public boolean equals(Object o) {
105        if (this == o) return true;
106        if (o == null || getClass() != o.getClass()) return false;
107        EvaluationProvenance pairs = (EvaluationProvenance) o;
108        return Objects.equals(className, pairs.className) &&
109                Objects.equals(modelProvenance, pairs.modelProvenance) &&
110                Objects.equals(datasetProvenance, pairs.datasetProvenance) &&
111                Objects.equals(versionString, pairs.versionString);
112    }
113
114    @Override
115    public int hashCode() {
116        return Objects.hash(className, modelProvenance, datasetProvenance, versionString);
117    }
118}