001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.transform; 018 019import com.oracle.labs.mlrg.olcut.util.Pair; 020import org.tribuo.Dataset; 021import org.tribuo.Example; 022import org.tribuo.Excuse; 023import org.tribuo.Model; 024import org.tribuo.MutableDataset; 025import org.tribuo.Output; 026import org.tribuo.Prediction; 027import org.tribuo.provenance.ModelProvenance; 028 029import java.util.ArrayList; 030import java.util.Collections; 031import java.util.List; 032import java.util.Map; 033import java.util.Optional; 034 035/** 036 * Wraps a {@link Model} with it's {@link TransformerMap} so all {@link Example}s are transformed 037 * appropriately before the model makes predictions. 038 * <p> 039 * If the densify flag is set, densifies all incoming data before transforming it. 040 * <p> 041 * Transformations only operate on observed values. To operate on implicit zeros then 042 * first call {@link MutableDataset#densify} on the datasets. 043 */ 044public class TransformedModel<T extends Output<T>> extends Model<T> { 045 private static final long serialVersionUID = 1L; 046 047 private final Model<T> innerModel; 048 049 private final TransformerMap transformerMap; 050 051 private final boolean densify; 052 053 private ArrayList<String> featureNames; 054 055 TransformedModel(ModelProvenance modelProvenance, Model<T> innerModel, TransformerMap transformerMap, boolean densify) { 056 super(innerModel.getName(), 057 modelProvenance, 058 innerModel.getFeatureIDMap(), 059 innerModel.getOutputIDInfo(), 060 innerModel.generatesProbabilities()); 061 this.innerModel = innerModel; 062 this.transformerMap = transformerMap; 063 this.densify = densify; 064 this.featureNames = new ArrayList<>(featureIDMap.keySet()); 065 Collections.sort(featureNames); 066 } 067 068 @Override 069 public Prediction<T> predict(Example<T> example) { 070 Example<T> transformedExample; 071 if (densify) { 072 transformedExample = transformerMap.transformExample(example,featureNames); 073 } else { 074 transformedExample = transformerMap.transformExample(example); 075 } 076 return innerModel.predict(transformedExample); 077 } 078 079 @Override 080 public List<Prediction<T>> predict(Dataset<T> examples) { 081 Dataset<T> transformedDataset = transformerMap.transformDataset(examples,densify); 082 083 List<Prediction<T>> predictions = new ArrayList<>(); 084 for (Example<T> example : transformedDataset) { 085 predictions.add(innerModel.predict(example)); 086 } 087 088 return predictions; 089 } 090 091 @Override 092 public Map<String, List<Pair<String, Double>>> getTopFeatures(int n) { 093 return innerModel.getTopFeatures(n); 094 } 095 096 @Override 097 public Optional<Excuse<T>> getExcuse(Example<T> example) { 098 Example<T> transformedExample = transformerMap.transformExample(example); 099 return innerModel.getExcuse(transformedExample); 100 } 101 102 @Override 103 protected TransformedModel<T> copy(String name, ModelProvenance newProvenance) { 104 return new TransformedModel<>(newProvenance,innerModel,transformerMap,densify); 105 } 106}