001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo;
018
019import org.tribuo.provenance.ModelProvenance;
020
021import java.util.ArrayList;
022import java.util.Collections;
023import java.util.HashMap;
024import java.util.List;
025import java.util.Map;
026
027/**
028 * A model which uses a subset of the features it knows about to make predictions.
029 */
030public abstract class SparseModel<T extends Output<T>> extends Model<T> {
031    private static final long serialVersionUID = 1L;
032
033    private final Map<String,List<String>> activeFeatures;
034
035    /**
036     * Constructs a sparse model from the supplied arguments.
037     * @param name The model name.
038     * @param provenance The model provenance.
039     * @param featureIDMap The features the model knows.
040     * @param outputIDInfo The outputs the model can produce.
041     * @param generatesProbabilities Does this model generate probabilistic outputs.
042     * @param activeFeatures The active features in this model.
043     */
044    public SparseModel(String name, ModelProvenance provenance, ImmutableFeatureMap featureIDMap, ImmutableOutputInfo<T> outputIDInfo, boolean generatesProbabilities, Map<String,List<String>> activeFeatures) {
045        super(name, provenance, featureIDMap, outputIDInfo, generatesProbabilities);
046        Map<String,List<String>> tmpActiveFeatures = new HashMap<>();
047        for (Map.Entry<String,List<String>> e : activeFeatures.entrySet()) {
048            List<String> features = new ArrayList<>(e.getValue());
049            Collections.sort(features);
050            tmpActiveFeatures.put(e.getKey(),Collections.unmodifiableList(features));
051        }
052        this.activeFeatures = Collections.unmodifiableMap(tmpActiveFeatures);
053    }
054
055    /**
056     * Return an immutable view on the active features for each dimension.
057     * <p>
058     * Sorted lexicographically.
059     * @return The active features.
060     */
061    public Map<String,List<String>> getActiveFeatures() {
062        return activeFeatures;
063    }
064
065    @Override
066    public SparseModel<T> copy() {
067        return (SparseModel<T>) super.copy();
068    }
069
070}