001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.sequence; 018 019import com.oracle.labs.mlrg.olcut.provenance.Provenancable; 020import com.oracle.labs.mlrg.olcut.util.Pair; 021import org.tribuo.ImmutableFeatureMap; 022import org.tribuo.ImmutableOutputInfo; 023import org.tribuo.Output; 024import org.tribuo.Prediction; 025import org.tribuo.provenance.ModelProvenance; 026 027import java.io.Serializable; 028import java.util.ArrayList; 029import java.util.List; 030import java.util.Map; 031import java.util.Set; 032import java.util.stream.Collectors; 033 034/** 035 * A prediction model, which is used to predict outputs for unseen instances. 036 * @param <T> the type of the outputs used to train the model. SequenceModel implementations 037 * must be serializable! 038 */ 039public abstract class SequenceModel<T extends Output<T>> implements Provenancable<ModelProvenance>, Serializable { 040 private static final long serialVersionUID = 1L; 041 042 protected String name; 043 044 private final ModelProvenance provenance; 045 046 protected final String provenanceOutput; 047 048 protected final ImmutableFeatureMap featureIDMap; 049 050 protected final ImmutableOutputInfo<T> outputIDMap; 051 052 public SequenceModel(String name, ModelProvenance provenance, ImmutableFeatureMap featureIDMap, ImmutableOutputInfo<T> outputIDMap) { 053 this.name = name; 054 this.provenance = provenance; 055 this.provenanceOutput = provenance.toString(); 056 this.featureIDMap = featureIDMap; 057 this.outputIDMap = outputIDMap; 058 } 059 060 /** 061 * Validates that this Model does in fact support the supplied output type. 062 * <p> 063 * As the output type is erased at runtime, deserialising a Model is an unchecked 064 * operation. This method allows the user to check that the deserialised model is 065 * of the appropriate type, rather than seeing if {@link SequenceModel#predict} throws a {@link ClassCastException} 066 * when called. 067 * </p> 068 * @param clazz The class object to verify the output type against. 069 * @return True if the output type is assignable to the class object type, false otherwise. 070 */ 071 public boolean validate(Class<? extends Output<?>> clazz) { 072 Set<T> domain = outputIDMap.getDomain(); 073 boolean output = true; 074 for (T type : domain) { 075 output &= clazz.isInstance(type); 076 } 077 return output; 078 } 079 080 /** 081 * Gets the model name. 082 * @return The model name. 083 */ 084 public String getName() { 085 return name; 086 } 087 088 /** 089 * Sets the model name. 090 * @param name The model name. 091 */ 092 public void setName(String name) { 093 this.name = name; 094 } 095 096 @Override 097 public ModelProvenance getProvenance() { 098 return provenance; 099 } 100 101 @Override 102 public String toString() { 103 if (name != null && !name.isEmpty()) { 104 return name + " - " + provenanceOutput; 105 } else { 106 return provenanceOutput; 107 } 108 } 109 110 /** 111 * Gets the feature domain. 112 * @return The feature domain. 113 */ 114 public ImmutableFeatureMap getFeatureIDMap() { 115 return featureIDMap; 116 } 117 118 /** 119 * Gets the output domain. 120 * @return The output domain. 121 */ 122 public ImmutableOutputInfo<T> getOutputIDInfo() { 123 return outputIDMap; 124 } 125 126 /** 127 * Uses the model to predict the output for a single example. 128 * @param example the example to predict. 129 * @return the result of the prediction. 130 */ 131 public abstract List<Prediction<T>> predict(SequenceExample<T> example); 132 133 /** 134 * Uses the model to predict the output for multiple examples. 135 * @param examples the examples to predict. 136 * @return the results of the prediction, in the same order as the 137 * examples. 138 */ 139 public List<List<Prediction<T>>> predict(Iterable<SequenceExample<T>> examples) { 140 List<List<Prediction<T>>> predictions = new ArrayList<>(); 141 for(SequenceExample<T> example : examples) { 142 predictions.add(predict(example)); 143 } 144 return predictions; 145 } 146 147 /** 148 * Uses the model to predict the labels for multiple examples contained in 149 * a data set. 150 * @param examples the data set containing the examples to predict. 151 * @return the results of the predictions, in the same order as the 152 * data set generates the example. 153 */ 154 public List<List<Prediction<T>>> predict(SequenceDataset<T> examples) { 155 List<List<Prediction<T>>> predictions = new ArrayList<>(); 156 for (SequenceExample<T> example : examples) { 157 predictions.add(predict(example)); 158 } 159 return predictions; 160 } 161 162 /** 163 * Gets the top {@code n} features associated with this model. 164 * <p> 165 * If the model does not produce per output feature lists, it returns 166 * a map with a single element with key Model.ALL_OUTPUTS. 167 * </p> 168 * <p> 169 * If the model cannot describe it's top features then it returns {@link java.util.Collections#emptyMap}. 170 * </p> 171 * @param n the number of features to return. If this value is less than 0, 172 * all features should be returned for each class, unless the model cannot score it's features. 173 * @return a map from string outputs to an ordered list of pairs of 174 * feature names and weights associated with that feature in the model 175 */ 176 public abstract Map<String, List<Pair<String, Double>>> getTopFeatures(int n); 177 178 public static <T extends Output<T>> List<T> toMaxLabels(List<Prediction<T>> predictions) { 179 return predictions.stream().map(Prediction::getOutput).collect(Collectors.toList()); 180 } 181}