001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.sequence;
018
019import com.oracle.labs.mlrg.olcut.provenance.Provenancable;
020import com.oracle.labs.mlrg.olcut.util.Pair;
021import org.tribuo.ImmutableFeatureMap;
022import org.tribuo.ImmutableOutputInfo;
023import org.tribuo.Output;
024import org.tribuo.Prediction;
025import org.tribuo.provenance.ModelProvenance;
026
027import java.io.Serializable;
028import java.util.ArrayList;
029import java.util.List;
030import java.util.Map;
031import java.util.Set;
032import java.util.stream.Collectors;
033
034/**
035 * A prediction model, which is used to predict outputs for unseen instances.
036 * @param <T> the type of the outputs used to train the model. SequenceModel implementations
037 * must be serializable!
038 */
039public abstract class SequenceModel<T extends Output<T>> implements Provenancable<ModelProvenance>, Serializable {
040    private static final long serialVersionUID = 1L;
041
042    protected String name;
043
044    private final ModelProvenance provenance;
045
046    protected final String provenanceOutput;
047
048    protected final ImmutableFeatureMap featureIDMap;
049
050    protected final ImmutableOutputInfo<T> outputIDMap;
051
052    public SequenceModel(String name, ModelProvenance provenance, ImmutableFeatureMap featureIDMap, ImmutableOutputInfo<T> outputIDMap) {
053        this.name = name;
054        this.provenance = provenance;
055        this.provenanceOutput = provenance.toString();
056        this.featureIDMap = featureIDMap;
057        this.outputIDMap = outputIDMap;
058    }
059
060    /**
061     * Validates that this Model does in fact support the supplied output type.
062     * <p>
063     * As the output type is erased at runtime, deserialising a Model is an unchecked
064     * operation. This method allows the user to check that the deserialised model is
065     * of the appropriate type, rather than seeing if {@link SequenceModel#predict} throws a {@link ClassCastException}
066     * when called.
067     * </p>
068     * @param clazz The class object to verify the output type against.
069     * @return True if the output type is assignable to the class object type, false otherwise.
070     */
071    public boolean validate(Class<? extends Output<?>> clazz) {
072        Set<T> domain = outputIDMap.getDomain();
073        boolean output = true;
074        for (T type : domain) {
075            output &= clazz.isInstance(type);
076        }
077        return output;
078    }
079
080    /**
081     * Gets the model name.
082     * @return The model name.
083     */
084    public String getName() {
085        return name;
086    }
087
088    /**
089     * Sets the model name.
090     * @param name The model name.
091     */
092    public void setName(String name) {
093        this.name = name;
094    }
095
096    @Override
097    public ModelProvenance getProvenance() {
098        return provenance;
099    }
100
101    @Override
102    public String toString() {
103        if (name != null && !name.isEmpty()) {
104            return name + " - " + provenanceOutput;
105        } else {
106            return provenanceOutput;
107        }
108    }
109
110    /**
111     * Gets the feature domain.
112     * @return The feature domain.
113     */
114    public ImmutableFeatureMap getFeatureIDMap() {
115        return featureIDMap;
116    }
117
118    /**
119     * Gets the output domain.
120     * @return The output domain.
121     */
122    public ImmutableOutputInfo<T> getOutputIDInfo() {
123        return outputIDMap;
124    }
125
126    /**
127     * Uses the model to predict the output for a single example.
128     * @param example the example to predict.
129     * @return the result of the prediction.
130     */
131    public abstract List<Prediction<T>> predict(SequenceExample<T> example);
132    
133    /**
134     * Uses the model to predict the output for multiple examples.
135     * @param examples the examples to predict.
136     * @return the results of the prediction, in the same order as the 
137     * examples.
138     */
139    public List<List<Prediction<T>>> predict(Iterable<SequenceExample<T>> examples) {
140        List<List<Prediction<T>>> predictions = new ArrayList<>();
141        for(SequenceExample<T> example : examples) {
142            predictions.add(predict(example));
143        }
144        return predictions;
145    }
146
147    /**
148     * Uses the model to predict the labels for multiple examples contained in
149     * a data set.
150     * @param examples the data set containing the examples to predict.
151     * @return the results of the predictions, in the same order as the 
152     * data set generates the example.
153     */
154    public List<List<Prediction<T>>> predict(SequenceDataset<T> examples) {
155        List<List<Prediction<T>>> predictions = new ArrayList<>();
156        for (SequenceExample<T> example : examples) {
157            predictions.add(predict(example));
158        }
159        return predictions;
160    }
161    
162    /**
163     * Gets the top {@code n} features associated with this model.
164     * <p>
165     * If the model does not produce per output feature lists, it returns
166     * a map with a single element with key Model.ALL_OUTPUTS.
167     * </p>
168     * <p>
169     * If the model cannot describe it's top features then it returns {@link java.util.Collections#emptyMap}.
170     * </p>
171     * @param n the number of features to return. If this value is less than 0,
172     * all features should be returned for each class, unless the model cannot score it's features.
173     * @return a map from string outputs to an ordered list of pairs of
174     * feature names and weights associated with that feature in the model
175     */
176    public abstract Map<String, List<Pair<String, Double>>> getTopFeatures(int n);
177
178    public static <T extends Output<T>> List<T> toMaxLabels(List<Prediction<T>> predictions) {
179        return predictions.stream().map(Prediction::getOutput).collect(Collectors.toList());
180    }
181}