001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.classification.sequence.viterbi; 018 019import com.oracle.labs.mlrg.olcut.config.Config; 020import com.oracle.labs.mlrg.olcut.provenance.Provenance; 021import org.tribuo.Dataset; 022import org.tribuo.Example; 023import org.tribuo.Feature; 024import org.tribuo.Model; 025import org.tribuo.Trainer; 026import org.tribuo.classification.Label; 027import org.tribuo.classification.sequence.viterbi.ViterbiModel.ScoreAggregation; 028import org.tribuo.provenance.ModelProvenance; 029import org.tribuo.provenance.TrainerProvenance; 030import org.tribuo.provenance.impl.TrainerProvenanceImpl; 031import org.tribuo.sequence.ImmutableSequenceDataset; 032import org.tribuo.sequence.MutableSequenceDataset; 033import org.tribuo.sequence.SequenceDataset; 034import org.tribuo.sequence.SequenceExample; 035import org.tribuo.sequence.SequenceModel; 036import org.tribuo.sequence.SequenceTrainer; 037 038import java.time.OffsetDateTime; 039import java.util.ArrayList; 040import java.util.List; 041import java.util.Map; 042 043/** 044 * Builds a Viterbi model using the supplied {@link Trainer}. 045 * Has a parameter to control the label features which are added to the features supplied by the data. 046 */ 047public final class ViterbiTrainer implements SequenceTrainer<Label> { 048 049 @Config(mandatory = true, description = "Inner trainer for each sequence element.") 050 private Trainer<Label> trainer; 051 052 @Config(mandatory = true, description = "Feature extractor to pull in surrounding label features.") 053 private LabelFeatureExtractor labelFeatureExtractor; 054 055 @Config(mandatory = true, description = "Number of candidate paths.") 056 private int stackSize; 057 058 @Config(mandatory = true, description = "Score aggregation function.") 059 private ScoreAggregation scoreAggregation; 060 061 private int trainInvocationCounter = 0; 062 063 public ViterbiTrainer(Trainer<Label> trainer, LabelFeatureExtractor labelFeatureExtractor, 064 ScoreAggregation scoreAggregation) { 065 this(trainer, labelFeatureExtractor, -1, scoreAggregation); 066 } 067 068 public ViterbiTrainer(Trainer<Label> trainer, LabelFeatureExtractor labelFeatureExtractor, int stackSize, 069 ScoreAggregation scoreAggregation) { 070 this.trainer = trainer; 071 this.labelFeatureExtractor = labelFeatureExtractor; 072 this.stackSize = stackSize; 073 this.scoreAggregation = scoreAggregation; 074 } 075 076 /** 077 * The viterbi train method is unique because it delegates to a regular 078 * {@link Model} train method, but before it does, it adds features derived 079 * from preceding labels. The pipeline upstream of this call should not care 080 * that these features are being added - that is, we would not want to make 081 * the upstream logic worry about what kind of trainer will be used and have 082 * conditional logic that says to add special label-derived features if 083 * using the ViterbiTrainer. So, these one-of-a-kind unique-in-the-world 084 * label-derived features are generated here and added to the sequence 085 * examples of the passed in dataset. If you pass in a 086 * MutableSequenceDataset, then please be aware that your dataset will be 087 * modified after calling this method and therefore subsequent calls to 088 * other SequenceModel.train methods with your dataset should be avoided. If 089 * you pass in an ImmutableSequenceDataset, then please be aware that your 090 * entire dataset is going to be copied as a MutableSequenceDataset - so 091 * there is a memory penalty. 092 * @param dataset The input dataset. 093 * @param runProvenance Any additional information to record in the provenance. 094 * @return A {@link SequenceModel} using Viterbi wrapped around an inner {@link Model}. 095 */ 096 @Override 097 public SequenceModel<Label> train(SequenceDataset<Label> dataset, Map<String, Provenance> runProvenance) { 098 if (dataset.getOutputInfo().getUnknownCount() > 0) { 099 throw new IllegalArgumentException("The supplied Dataset contained unknown Outputs, and this Trainer is supervised."); 100 } 101 // if stack size isn't specified, then we will calculate it based on the 102 // number of unique output values 103 if (stackSize == -1) { 104 stackSize = dataset.getOutputIDInfo().size(); 105 } 106 107 // create a copy of the dataset to a mutable one. See note above. 108 if (dataset instanceof ImmutableSequenceDataset) { 109 dataset = new MutableSequenceDataset<>((ImmutableSequenceDataset<Label>) dataset); 110 } 111 112 if (!(dataset instanceof MutableSequenceDataset)) { 113 throw new IllegalArgumentException("unable to handle sub-type of dataset: " + dataset.getClass().getName()); 114 } 115 116 for (SequenceExample<Label> sequenceExample : dataset) { 117 List<Label> labels = new ArrayList<>(); 118 119 for (Example<Label> example : sequenceExample) { 120 List<Feature> labelFeatures = extractFeatures(labels, (MutableSequenceDataset<Label>) dataset, 121 1.0); 122 example.addAll(labelFeatures); 123 labels.add(example.getOutput()); 124 } 125 } 126 127 TrainerProvenance trainerProvenance = getProvenance(); 128 ModelProvenance provenance = new ModelProvenance(ViterbiModel.class.getName(), OffsetDateTime.now(), dataset.getProvenance(), trainerProvenance, runProvenance); 129 trainInvocationCounter++; 130 Dataset<Label> flatData = dataset.getFlatDataset(); 131 Model<Label> model = trainer.train(flatData); 132 return new ViterbiModel("viterbi+" + model.getName(), provenance, model, 133 labelFeatureExtractor, stackSize, scoreAggregation); 134 } 135 136 @Override 137 public int getInvocationCount() { 138 return trainInvocationCounter; 139 } 140 141 private List<Feature> extractFeatures(List<Label> labels, 142 MutableSequenceDataset<Label> dataset, double value) { 143 List<Feature> labelFeatures = new ArrayList<>(); 144 for (Feature labelFeature : labelFeatureExtractor.extractFeatures(labels, value)) { 145 dataset.getFeatureMap().add(labelFeature.getName(), labelFeature.getValue()); 146 labelFeatures.add(labelFeature); 147 } 148 return labelFeatures; 149 } 150 151 @Override 152 public String toString() { 153 return "ViterbiTrainer(innerTrainer=" + trainer.toString() + ",labelFeatureExtractor=" + labelFeatureExtractor.toString() + ")"; 154 } 155 156 @Override 157 public TrainerProvenance getProvenance() { 158 return new TrainerProvenanceImpl(this); 159 } 160}