001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.transform;
018
019import com.oracle.labs.mlrg.olcut.util.Pair;
020import org.tribuo.Dataset;
021import org.tribuo.Example;
022import org.tribuo.Excuse;
023import org.tribuo.Model;
024import org.tribuo.MutableDataset;
025import org.tribuo.Output;
026import org.tribuo.Prediction;
027import org.tribuo.provenance.ModelProvenance;
028
029import java.util.ArrayList;
030import java.util.Collections;
031import java.util.List;
032import java.util.Map;
033import java.util.Optional;
034
035/**
036 * Wraps a {@link Model} with it's {@link TransformerMap} so all {@link Example}s are transformed
037 * appropriately before the model makes predictions.
038 * <p>
039 * If the densify flag is set, densifies all incoming data before transforming it.
040 * <p>
041 * Transformations only operate on observed values. To operate on implicit zeros then
042 * first call {@link MutableDataset#densify} on the datasets.
043 */
044public class TransformedModel<T extends Output<T>> extends Model<T> {
045    private static final long serialVersionUID = 1L;
046
047    private final Model<T> innerModel;
048
049    private final TransformerMap transformerMap;
050
051    private final boolean densify;
052
053    private ArrayList<String> featureNames;
054
055    TransformedModel(ModelProvenance modelProvenance, Model<T> innerModel, TransformerMap transformerMap, boolean densify) {
056        super(innerModel.getName(),
057              modelProvenance,
058              innerModel.getFeatureIDMap(),
059              innerModel.getOutputIDInfo(),
060              innerModel.generatesProbabilities());
061        this.innerModel = innerModel;
062        this.transformerMap = transformerMap;
063        this.densify = densify;
064        this.featureNames = new ArrayList<>(featureIDMap.keySet());
065        Collections.sort(featureNames);
066    }
067
068    @Override
069    public Prediction<T> predict(Example<T> example) {
070        Example<T> transformedExample;
071        if (densify) {
072            transformedExample = transformerMap.transformExample(example,featureNames);
073        } else {
074            transformedExample = transformerMap.transformExample(example);
075        }
076        return innerModel.predict(transformedExample);
077    }
078
079    @Override
080    public List<Prediction<T>> predict(Dataset<T> examples) {
081        Dataset<T> transformedDataset = transformerMap.transformDataset(examples,densify);
082
083        List<Prediction<T>> predictions = new ArrayList<>();
084        for (Example<T> example : transformedDataset) {
085            predictions.add(innerModel.predict(example));
086        }
087
088        return predictions;
089    }
090
091    @Override
092    public Map<String, List<Pair<String, Double>>> getTopFeatures(int n) {
093        return innerModel.getTopFeatures(n);
094    }
095
096    @Override
097    public Optional<Excuse<T>> getExcuse(Example<T> example) {
098        Example<T> transformedExample = transformerMap.transformExample(example);
099        return innerModel.getExcuse(transformedExample);
100    }
101
102    @Override
103    protected TransformedModel<T> copy(String name, ModelProvenance newProvenance) {
104        return new TransformedModel<>(newProvenance,innerModel,transformerMap,densify);
105    }
106}