001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.provenance;
018
019import com.oracle.labs.mlrg.olcut.provenance.ObjectProvenance;
020import com.oracle.labs.mlrg.olcut.provenance.PrimitiveProvenance;
021import com.oracle.labs.mlrg.olcut.provenance.Provenance;
022import com.oracle.labs.mlrg.olcut.provenance.ProvenanceException;
023import com.oracle.labs.mlrg.olcut.provenance.impl.SkeletalConfiguredObjectProvenance;
024import com.oracle.labs.mlrg.olcut.provenance.primitives.BooleanProvenance;
025import com.oracle.labs.mlrg.olcut.provenance.primitives.IntProvenance;
026import com.oracle.labs.mlrg.olcut.provenance.primitives.StringProvenance;
027import org.tribuo.Output;
028import org.tribuo.Trainer;
029import org.tribuo.Tribuo;
030import org.tribuo.sequence.SequenceTrainer;
031
032import java.util.HashMap;
033import java.util.Map;
034import java.util.Objects;
035
036/**
037 * The skeleton of a TrainerProvenance that extracts the configured parameters.
038 */
039public abstract class SkeletalTrainerProvenance extends SkeletalConfiguredObjectProvenance implements TrainerProvenance {
040    private static final long serialVersionUID = 1L;
041
042    private final IntProvenance invocationCount;
043
044    private final BooleanProvenance isSequence;
045
046    private final StringProvenance version;
047
048    protected <T extends Output<T>> SkeletalTrainerProvenance(Trainer<T> host) {
049        super(host,"Trainer");
050        this.isSequence = new BooleanProvenance(IS_SEQUENCE,false);
051        this.invocationCount = new IntProvenance(TRAIN_INVOCATION_COUNT,host.getInvocationCount());
052        this.version = new StringProvenance(TRIBUO_VERSION_STRING, Tribuo.VERSION);
053    }
054
055    protected <T extends Output<T>> SkeletalTrainerProvenance(SequenceTrainer<T> host) {
056        super(host,"SequenceTrainer");
057        this.isSequence = new BooleanProvenance(IS_SEQUENCE,true);
058        this.invocationCount = new IntProvenance(TRAIN_INVOCATION_COUNT,host.getInvocationCount());
059        this.version = new StringProvenance(TRIBUO_VERSION_STRING, Tribuo.VERSION);
060    }
061
062    protected SkeletalTrainerProvenance(Map<String, Provenance> map) {
063        this(extractProvenanceInfo(map));
064    }
065
066    protected SkeletalTrainerProvenance(ExtractedInfo info) {
067        super(info);
068        this.invocationCount = SkeletalConfiguredObjectProvenance.checkAndExtractProvenance(info,TRAIN_INVOCATION_COUNT,IntProvenance.class, info.className);
069        this.isSequence = SkeletalConfiguredObjectProvenance.checkAndExtractProvenance(info,IS_SEQUENCE,BooleanProvenance.class, info.className);
070        this.version = SkeletalConfiguredObjectProvenance.checkAndExtractProvenance(info,TRIBUO_VERSION_STRING,StringProvenance.class, info.className);
071    }
072
073    /**
074     * Is this a sequence trainer.
075     * @return True if it's a sequence trainer.
076     */
077    public boolean isSequence() {
078        return isSequence.getValue();
079    }
080
081    /**
082     * The Tribuo version.
083     * @return The Tribuo version.
084     */
085    public String getTribuoVersion() {
086        return version.getValue();
087    }
088
089    @Override
090    public Map<String, PrimitiveProvenance<?>> getInstanceValues() {
091        Map<String, PrimitiveProvenance<?>> map = super.getInstanceValues();
092
093        map.put(TRAIN_INVOCATION_COUNT, invocationCount);
094        map.put(IS_SEQUENCE, isSequence);
095
096        return map;
097    }
098
099    @Override
100    public boolean equals(Object o) {
101        if (this == o) return true;
102        if (!(o instanceof SkeletalTrainerProvenance)) return false;
103        if (!super.equals(o)) return false;
104        SkeletalTrainerProvenance pairs = (SkeletalTrainerProvenance) o;
105        return invocationCount.equals(pairs.invocationCount) &&
106                isSequence.equals(pairs.isSequence);
107    }
108
109    @Override
110    public int hashCode() {
111        return Objects.hash(super.hashCode(), invocationCount, isSequence);
112    }
113
114    protected static ExtractedInfo extractProvenanceInfo(Map<String,Provenance> map) {
115        String className;
116        String hostTypeStringName;
117        Map<String,Provenance> configuredParameters = new HashMap<>(map);
118        Map<String,PrimitiveProvenance<?>> instanceValues = new HashMap<>();
119        if (configuredParameters.containsKey(ObjectProvenance.CLASS_NAME)) {
120            className = configuredParameters.remove(ObjectProvenance.CLASS_NAME).toString();
121        } else {
122            throw new ProvenanceException("Failed to find class name when constructing SkeletalTrainerProvenance");
123        }
124        if (configuredParameters.containsKey(SkeletalConfiguredObjectProvenance.HOST_SHORT_NAME)) {
125            hostTypeStringName = configuredParameters.remove(SkeletalConfiguredObjectProvenance.HOST_SHORT_NAME).toString();
126        } else {
127            throw new ProvenanceException("Failed to find host type short name when constructing SkeletalTrainerProvenance");
128        }
129        if (configuredParameters.containsKey(TrainerProvenance.TRAIN_INVOCATION_COUNT)) {
130            Provenance tmpProv = configuredParameters.remove(TrainerProvenance.TRAIN_INVOCATION_COUNT);
131            if (tmpProv instanceof IntProvenance) {
132                instanceValues.put(TRAIN_INVOCATION_COUNT,(IntProvenance) tmpProv);
133            } else {
134                throw new ProvenanceException(TRAIN_INVOCATION_COUNT + " was not of type IntProvenance in class " + className);
135            }
136        } else {
137            throw new ProvenanceException("Failed to find invocation count when constructing SkeletalTrainerProvenance");
138        }
139        if (configuredParameters.containsKey(TrainerProvenance.IS_SEQUENCE)) {
140            Provenance tmpProv = configuredParameters.remove(TrainerProvenance.IS_SEQUENCE);
141            if (tmpProv instanceof BooleanProvenance) {
142                instanceValues.put(IS_SEQUENCE,(BooleanProvenance) tmpProv);
143            } else {
144                throw new ProvenanceException(IS_SEQUENCE + " was not of type BooleanProvenance in class " + className);
145            }
146        } else {
147            throw new ProvenanceException("Failed to find is-sequence when constructing SkeletalTrainerProvenance");
148        }
149        if (configuredParameters.containsKey(TrainerProvenance.TRIBUO_VERSION_STRING)) {
150            Provenance tmpProv = configuredParameters.remove(TrainerProvenance.TRIBUO_VERSION_STRING);
151            if (tmpProv instanceof StringProvenance) {
152                instanceValues.put(TRIBUO_VERSION_STRING,(StringProvenance) tmpProv);
153            } else {
154                throw new ProvenanceException(TRIBUO_VERSION_STRING + " was not of type StringProvenance in class " + className);
155            }
156        } else {
157            throw new ProvenanceException("Failed to find Tribuo version when constructing SkeletalTrainerProvenance");
158        }
159
160        return new ExtractedInfo(className,hostTypeStringName,configuredParameters,instanceValues);
161    }
162}