001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.provenance; 018 019import com.oracle.labs.mlrg.olcut.provenance.ObjectProvenance; 020import com.oracle.labs.mlrg.olcut.provenance.PrimitiveProvenance; 021import com.oracle.labs.mlrg.olcut.provenance.Provenance; 022import com.oracle.labs.mlrg.olcut.provenance.ProvenanceException; 023import com.oracle.labs.mlrg.olcut.provenance.impl.SkeletalConfiguredObjectProvenance; 024import com.oracle.labs.mlrg.olcut.provenance.primitives.BooleanProvenance; 025import com.oracle.labs.mlrg.olcut.provenance.primitives.IntProvenance; 026import com.oracle.labs.mlrg.olcut.provenance.primitives.StringProvenance; 027import org.tribuo.Output; 028import org.tribuo.Trainer; 029import org.tribuo.Tribuo; 030import org.tribuo.sequence.SequenceTrainer; 031 032import java.util.HashMap; 033import java.util.Map; 034import java.util.Objects; 035 036/** 037 * The skeleton of a TrainerProvenance that extracts the configured parameters. 038 */ 039public abstract class SkeletalTrainerProvenance extends SkeletalConfiguredObjectProvenance implements TrainerProvenance { 040 private static final long serialVersionUID = 1L; 041 042 private final IntProvenance invocationCount; 043 044 private final BooleanProvenance isSequence; 045 046 private final StringProvenance version; 047 048 protected <T extends Output<T>> SkeletalTrainerProvenance(Trainer<T> host) { 049 super(host,"Trainer"); 050 this.isSequence = new BooleanProvenance(IS_SEQUENCE,false); 051 this.invocationCount = new IntProvenance(TRAIN_INVOCATION_COUNT,host.getInvocationCount()); 052 this.version = new StringProvenance(TRIBUO_VERSION_STRING, Tribuo.VERSION); 053 } 054 055 protected <T extends Output<T>> SkeletalTrainerProvenance(SequenceTrainer<T> host) { 056 super(host,"SequenceTrainer"); 057 this.isSequence = new BooleanProvenance(IS_SEQUENCE,true); 058 this.invocationCount = new IntProvenance(TRAIN_INVOCATION_COUNT,host.getInvocationCount()); 059 this.version = new StringProvenance(TRIBUO_VERSION_STRING, Tribuo.VERSION); 060 } 061 062 protected SkeletalTrainerProvenance(Map<String, Provenance> map) { 063 this(extractProvenanceInfo(map)); 064 } 065 066 protected SkeletalTrainerProvenance(ExtractedInfo info) { 067 super(info); 068 this.invocationCount = SkeletalConfiguredObjectProvenance.checkAndExtractProvenance(info,TRAIN_INVOCATION_COUNT,IntProvenance.class, info.className); 069 this.isSequence = SkeletalConfiguredObjectProvenance.checkAndExtractProvenance(info,IS_SEQUENCE,BooleanProvenance.class, info.className); 070 this.version = SkeletalConfiguredObjectProvenance.checkAndExtractProvenance(info,TRIBUO_VERSION_STRING,StringProvenance.class, info.className); 071 } 072 073 /** 074 * Is this a sequence trainer. 075 * @return True if it's a sequence trainer. 076 */ 077 public boolean isSequence() { 078 return isSequence.getValue(); 079 } 080 081 /** 082 * The Tribuo version. 083 * @return The Tribuo version. 084 */ 085 public String getTribuoVersion() { 086 return version.getValue(); 087 } 088 089 @Override 090 public Map<String, PrimitiveProvenance<?>> getInstanceValues() { 091 Map<String, PrimitiveProvenance<?>> map = super.getInstanceValues(); 092 093 map.put(TRAIN_INVOCATION_COUNT, invocationCount); 094 map.put(IS_SEQUENCE, isSequence); 095 096 return map; 097 } 098 099 @Override 100 public boolean equals(Object o) { 101 if (this == o) return true; 102 if (!(o instanceof SkeletalTrainerProvenance)) return false; 103 if (!super.equals(o)) return false; 104 SkeletalTrainerProvenance pairs = (SkeletalTrainerProvenance) o; 105 return invocationCount.equals(pairs.invocationCount) && 106 isSequence.equals(pairs.isSequence); 107 } 108 109 @Override 110 public int hashCode() { 111 return Objects.hash(super.hashCode(), invocationCount, isSequence); 112 } 113 114 protected static ExtractedInfo extractProvenanceInfo(Map<String,Provenance> map) { 115 String className; 116 String hostTypeStringName; 117 Map<String,Provenance> configuredParameters = new HashMap<>(map); 118 Map<String,PrimitiveProvenance<?>> instanceValues = new HashMap<>(); 119 if (configuredParameters.containsKey(ObjectProvenance.CLASS_NAME)) { 120 className = configuredParameters.remove(ObjectProvenance.CLASS_NAME).toString(); 121 } else { 122 throw new ProvenanceException("Failed to find class name when constructing SkeletalTrainerProvenance"); 123 } 124 if (configuredParameters.containsKey(SkeletalConfiguredObjectProvenance.HOST_SHORT_NAME)) { 125 hostTypeStringName = configuredParameters.remove(SkeletalConfiguredObjectProvenance.HOST_SHORT_NAME).toString(); 126 } else { 127 throw new ProvenanceException("Failed to find host type short name when constructing SkeletalTrainerProvenance"); 128 } 129 if (configuredParameters.containsKey(TrainerProvenance.TRAIN_INVOCATION_COUNT)) { 130 Provenance tmpProv = configuredParameters.remove(TrainerProvenance.TRAIN_INVOCATION_COUNT); 131 if (tmpProv instanceof IntProvenance) { 132 instanceValues.put(TRAIN_INVOCATION_COUNT,(IntProvenance) tmpProv); 133 } else { 134 throw new ProvenanceException(TRAIN_INVOCATION_COUNT + " was not of type IntProvenance in class " + className); 135 } 136 } else { 137 throw new ProvenanceException("Failed to find invocation count when constructing SkeletalTrainerProvenance"); 138 } 139 if (configuredParameters.containsKey(TrainerProvenance.IS_SEQUENCE)) { 140 Provenance tmpProv = configuredParameters.remove(TrainerProvenance.IS_SEQUENCE); 141 if (tmpProv instanceof BooleanProvenance) { 142 instanceValues.put(IS_SEQUENCE,(BooleanProvenance) tmpProv); 143 } else { 144 throw new ProvenanceException(IS_SEQUENCE + " was not of type BooleanProvenance in class " + className); 145 } 146 } else { 147 throw new ProvenanceException("Failed to find is-sequence when constructing SkeletalTrainerProvenance"); 148 } 149 if (configuredParameters.containsKey(TrainerProvenance.TRIBUO_VERSION_STRING)) { 150 Provenance tmpProv = configuredParameters.remove(TrainerProvenance.TRIBUO_VERSION_STRING); 151 if (tmpProv instanceof StringProvenance) { 152 instanceValues.put(TRIBUO_VERSION_STRING,(StringProvenance) tmpProv); 153 } else { 154 throw new ProvenanceException(TRIBUO_VERSION_STRING + " was not of type StringProvenance in class " + className); 155 } 156 } else { 157 throw new ProvenanceException("Failed to find Tribuo version when constructing SkeletalTrainerProvenance"); 158 } 159 160 return new ExtractedInfo(className,hostTypeStringName,configuredParameters,instanceValues); 161 } 162}