T
- the type of the outputs used to train the model.public abstract class SequenceModel<T extends Output<T>> extends Object implements com.oracle.labs.mlrg.olcut.provenance.Provenancable<ModelProvenance>, Serializable
Modifier and Type | Field and Description |
---|---|
protected ImmutableFeatureMap |
featureIDMap |
protected String |
name |
protected ImmutableOutputInfo<T> |
outputIDMap |
protected String |
provenanceOutput |
Constructor and Description |
---|
SequenceModel(String name,
ModelProvenance provenance,
ImmutableFeatureMap featureIDMap,
ImmutableOutputInfo<T> outputIDMap) |
Modifier and Type | Method and Description |
---|---|
ImmutableFeatureMap |
getFeatureIDMap()
Gets the feature domain.
|
String |
getName()
Gets the model name.
|
ImmutableOutputInfo<T> |
getOutputIDInfo()
Gets the output domain.
|
ModelProvenance |
getProvenance() |
abstract Map<String,List<com.oracle.labs.mlrg.olcut.util.Pair<String,Double>>> |
getTopFeatures(int n)
Gets the top
n features associated with this model. |
List<List<Prediction<T>>> |
predict(Iterable<SequenceExample<T>> examples)
Uses the model to predict the output for multiple examples.
|
List<List<Prediction<T>>> |
predict(SequenceDataset<T> examples)
Uses the model to predict the labels for multiple examples contained in
a data set.
|
abstract List<Prediction<T>> |
predict(SequenceExample<T> example)
Uses the model to predict the output for a single example.
|
void |
setName(String name)
Sets the model name.
|
static <T extends Output<T>> |
toMaxLabels(List<Prediction<T>> predictions) |
String |
toString() |
boolean |
validate(Class<? extends Output<?>> clazz)
Validates that this Model does in fact support the supplied output type.
|
protected String name
protected final String provenanceOutput
protected final ImmutableFeatureMap featureIDMap
protected final ImmutableOutputInfo<T extends Output<T>> outputIDMap
public SequenceModel(String name, ModelProvenance provenance, ImmutableFeatureMap featureIDMap, ImmutableOutputInfo<T> outputIDMap)
public boolean validate(Class<? extends Output<?>> clazz)
As the output type is erased at runtime, deserialising a Model is an unchecked
operation. This method allows the user to check that the deserialised model is
of the appropriate type, rather than seeing if predict(org.tribuo.sequence.SequenceExample<T>)
throws a ClassCastException
when called.
clazz
- The class object to verify the output type against.public String getName()
public void setName(String name)
name
- The model name.public ModelProvenance getProvenance()
getProvenance
in interface com.oracle.labs.mlrg.olcut.provenance.Provenancable<ModelProvenance>
public ImmutableFeatureMap getFeatureIDMap()
public ImmutableOutputInfo<T> getOutputIDInfo()
public abstract List<Prediction<T>> predict(SequenceExample<T> example)
example
- the example to predict.public List<List<Prediction<T>>> predict(Iterable<SequenceExample<T>> examples)
examples
- the examples to predict.public List<List<Prediction<T>>> predict(SequenceDataset<T> examples)
examples
- the data set containing the examples to predict.public abstract Map<String,List<com.oracle.labs.mlrg.olcut.util.Pair<String,Double>>> getTopFeatures(int n)
n
features associated with this model.
If the model does not produce per output feature lists, it returns a map with a single element with key Model.ALL_OUTPUTS.
If the model cannot describe it's top features then it returns Collections.emptyMap()
.
n
- the number of features to return. If this value is less than 0,
all features should be returned for each class, unless the model cannot score it's features.public static <T extends Output<T>> List<T> toMaxLabels(List<Prediction<T>> predictions)
Copyright © 2015–2021 Oracle and/or its affiliates. All rights reserved.