001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.provenance;
018
019import com.oracle.labs.mlrg.olcut.provenance.ListProvenance;
020import com.oracle.labs.mlrg.olcut.provenance.ObjectProvenance;
021import com.oracle.labs.mlrg.olcut.provenance.Provenance;
022import com.oracle.labs.mlrg.olcut.provenance.primitives.DateTimeProvenance;
023import com.oracle.labs.mlrg.olcut.provenance.primitives.StringProvenance;
024import com.oracle.labs.mlrg.olcut.util.Pair;
025
026import java.time.OffsetDateTime;
027import java.util.ArrayList;
028import java.util.Iterator;
029import java.util.Map;
030
031/**
032 * Model provenance for ensemble models.
033 */
034public class EnsembleModelProvenance extends ModelProvenance {
035    private static final long serialVersionUID = 1L;
036
037    protected static final String MEMBERS = "member-provenance";
038
039    private final ListProvenance<? extends ModelProvenance> memberProvenance;
040
041    public EnsembleModelProvenance(String className, OffsetDateTime time, DatasetProvenance datasetProvenance, TrainerProvenance trainerProvenance, ListProvenance<? extends ModelProvenance> memberProvenance) {
042        super(className, time, datasetProvenance, trainerProvenance);
043        this.memberProvenance = memberProvenance;
044    }
045
046    public EnsembleModelProvenance(String className, OffsetDateTime time, DatasetProvenance datasetProvenance, TrainerProvenance trainerProvenance, Map<String, Provenance> instanceProvenance, ListProvenance<? extends ModelProvenance> memberProvenance) {
047        super(className, time, datasetProvenance, trainerProvenance, instanceProvenance);
048        this.memberProvenance = memberProvenance;
049    }
050
051    @SuppressWarnings("unchecked") // member provenance cast.
052    public EnsembleModelProvenance(Map<String, Provenance> map) {
053        super(map);
054        this.memberProvenance = (ListProvenance<? extends ModelProvenance>) ObjectProvenance.checkAndExtractProvenance(map,MEMBERS,ListProvenance.class, EnsembleModelProvenance.class.getSimpleName());
055    }
056
057    public ListProvenance<? extends ModelProvenance> getMemberProvenance() {
058        return memberProvenance;
059    }
060
061    @Override
062    public String toString() {
063        return generateString("EnsembleModel");
064    }
065
066    @Override
067    public Iterator<Pair<String, Provenance>> iterator() {
068        ArrayList<Pair<String,Provenance>> iterable = new ArrayList<>();
069        iterable.add(new Pair<>(CLASS_NAME,new StringProvenance(CLASS_NAME,className)));
070        iterable.add(new Pair<>(DATASET,datasetProvenance));
071        iterable.add(new Pair<>(TRAINER,trainerProvenance));
072        iterable.add(new Pair<>(TRAINING_TIME,new DateTimeProvenance(TRAINING_TIME,time)));
073        iterable.add(new Pair<>(INSTANCE_VALUES,instanceProvenance));
074        iterable.add(new Pair<>(MEMBERS,memberProvenance));
075        return iterable.iterator();
076    }
077}