001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.provenance; 018 019import com.oracle.labs.mlrg.olcut.provenance.ListProvenance; 020import com.oracle.labs.mlrg.olcut.provenance.ObjectProvenance; 021import com.oracle.labs.mlrg.olcut.provenance.Provenance; 022import com.oracle.labs.mlrg.olcut.provenance.primitives.DateTimeProvenance; 023import com.oracle.labs.mlrg.olcut.provenance.primitives.StringProvenance; 024import com.oracle.labs.mlrg.olcut.util.Pair; 025 026import java.time.OffsetDateTime; 027import java.util.ArrayList; 028import java.util.Iterator; 029import java.util.Map; 030 031/** 032 * Model provenance for ensemble models. 033 */ 034public class EnsembleModelProvenance extends ModelProvenance { 035 private static final long serialVersionUID = 1L; 036 037 protected static final String MEMBERS = "member-provenance"; 038 039 private final ListProvenance<? extends ModelProvenance> memberProvenance; 040 041 public EnsembleModelProvenance(String className, OffsetDateTime time, DatasetProvenance datasetProvenance, TrainerProvenance trainerProvenance, ListProvenance<? extends ModelProvenance> memberProvenance) { 042 super(className, time, datasetProvenance, trainerProvenance); 043 this.memberProvenance = memberProvenance; 044 } 045 046 public EnsembleModelProvenance(String className, OffsetDateTime time, DatasetProvenance datasetProvenance, TrainerProvenance trainerProvenance, Map<String, Provenance> instanceProvenance, ListProvenance<? extends ModelProvenance> memberProvenance) { 047 super(className, time, datasetProvenance, trainerProvenance, instanceProvenance); 048 this.memberProvenance = memberProvenance; 049 } 050 051 @SuppressWarnings("unchecked") // member provenance cast. 052 public EnsembleModelProvenance(Map<String, Provenance> map) { 053 super(map); 054 this.memberProvenance = (ListProvenance<? extends ModelProvenance>) ObjectProvenance.checkAndExtractProvenance(map,MEMBERS,ListProvenance.class, EnsembleModelProvenance.class.getSimpleName()); 055 } 056 057 public ListProvenance<? extends ModelProvenance> getMemberProvenance() { 058 return memberProvenance; 059 } 060 061 @Override 062 public String toString() { 063 return generateString("EnsembleModel"); 064 } 065 066 @Override 067 public Iterator<Pair<String, Provenance>> iterator() { 068 ArrayList<Pair<String,Provenance>> iterable = new ArrayList<>(); 069 iterable.add(new Pair<>(CLASS_NAME,new StringProvenance(CLASS_NAME,className))); 070 iterable.add(new Pair<>(DATASET,datasetProvenance)); 071 iterable.add(new Pair<>(TRAINER,trainerProvenance)); 072 iterable.add(new Pair<>(TRAINING_TIME,new DateTimeProvenance(TRAINING_TIME,time))); 073 iterable.add(new Pair<>(INSTANCE_VALUES,instanceProvenance)); 074 iterable.add(new Pair<>(MEMBERS,memberProvenance)); 075 return iterable.iterator(); 076 } 077}