001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.interop;
018
019import com.oracle.labs.mlrg.olcut.provenance.ObjectProvenance;
020import com.oracle.labs.mlrg.olcut.provenance.PrimitiveProvenance;
021import com.oracle.labs.mlrg.olcut.provenance.Provenance;
022import com.oracle.labs.mlrg.olcut.provenance.ProvenanceUtil;
023import com.oracle.labs.mlrg.olcut.provenance.primitives.DateTimeProvenance;
024import com.oracle.labs.mlrg.olcut.provenance.primitives.HashProvenance;
025import com.oracle.labs.mlrg.olcut.provenance.primitives.URLProvenance;
026import org.tribuo.Trainer;
027import org.tribuo.provenance.TrainerProvenance;
028
029import java.net.URL;
030import java.time.OffsetDateTime;
031import java.util.Collections;
032import java.util.HashMap;
033import java.util.Map;
034import java.util.Objects;
035import java.util.Optional;
036
037/**
038 * A dummy provenance for a model trained outside Tribuo.
039 * <p>
040 * It records the timestamp, hash and location of the loaded model.
041 */
042public final class ExternalTrainerProvenance implements TrainerProvenance {
043    private static final long serialVersionUID = 1L;
044
045    private final URLProvenance location;
046    private final DateTimeProvenance fileModifiedTime;
047    private final HashProvenance modelHash;
048
049    /**
050     * Creates an external trainer provenance, storing the location
051     * and pulling in the timestamp and file hash.
052     * @param location The location to use.
053     */
054    public ExternalTrainerProvenance(URL location) {
055        this.location = new URLProvenance("location",location);
056        Optional<OffsetDateTime> time = ProvenanceUtil.getModifiedTime(location);
057        this.fileModifiedTime = time.map(offsetDateTime -> new DateTimeProvenance("fileModifiedTime", offsetDateTime)).orElseGet(() -> new DateTimeProvenance("fileModifiedTime", OffsetDateTime.MIN));
058        this.modelHash = new HashProvenance(DEFAULT_HASH_TYPE,"modelHash", ProvenanceUtil.hashResource(DEFAULT_HASH_TYPE,location));
059    }
060
061    /**
062     * Used by the provenance serialization system.
063     * @param provenance The provenance to reconstruct.
064     */
065    public ExternalTrainerProvenance(Map<String,Provenance> provenance) {
066        this.location = ObjectProvenance.checkAndExtractProvenance(provenance,"location",URLProvenance.class,ExternalTrainerProvenance.class.getSimpleName());
067        this.fileModifiedTime = ObjectProvenance.checkAndExtractProvenance(provenance,"fileModifiedTime",DateTimeProvenance.class,ExternalTrainerProvenance.class.getSimpleName());
068        this.modelHash = ObjectProvenance.checkAndExtractProvenance(provenance,"modelHash",HashProvenance.class,ExternalTrainerProvenance.class.getSimpleName());
069    }
070
071    @Override
072    public Map<String, Provenance> getConfiguredParameters() {
073        return Collections.emptyMap();
074    }
075
076    @Override
077    public String getClassName() {
078        return Trainer.class.getName();
079    }
080
081    @Override
082    public String toString() {
083        return generateString("ExternalTrainer");
084    }
085
086    @Override
087    public Map<String, PrimitiveProvenance<?>> getInstanceValues() {
088        Map<String,PrimitiveProvenance<?>> map = new HashMap<>();
089
090        map.put(location.getKey(),location);
091        map.put(fileModifiedTime.getKey(),fileModifiedTime);
092        map.put(modelHash.getKey(),modelHash);
093
094        return map;
095    }
096
097    @Override
098    public boolean equals(Object o) {
099        if (this == o) return true;
100        if (o == null || getClass() != o.getClass()) return false;
101        ExternalTrainerProvenance other = (ExternalTrainerProvenance) o;
102        return location.equals(other.location) &&
103                fileModifiedTime.equals(other.fileModifiedTime) &&
104                modelHash.equals(other.modelHash);
105    }
106
107    @Override
108    public int hashCode() {
109        return Objects.hash(location, fileModifiedTime, modelHash);
110    }
111}