001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.common.liblinear; 018 019import de.bwaldvogel.liblinear.Linear; 020import org.tribuo.Example; 021import org.tribuo.Excuse; 022import org.tribuo.ImmutableFeatureMap; 023import org.tribuo.ImmutableOutputInfo; 024import org.tribuo.Model; 025import org.tribuo.Output; 026import org.tribuo.provenance.ModelProvenance; 027 028import java.io.IOException; 029import java.io.StringReader; 030import java.io.StringWriter; 031import java.util.ArrayList; 032import java.util.Collections; 033import java.util.List; 034import java.util.Optional; 035import java.util.logging.Logger; 036 037/** 038 * A {@link Model} which wraps a LibLinear-java model. 039 * <p> 040 * It disables the LibLinear debug output as it's very chatty. 041 * <p> 042 * See: 043 * <pre> 044 * Fan RE, Chang KW, Hsieh CJ, Wang XR, Lin CJ. 045 * "LIBLINEAR: A library for Large Linear Classification" 046 * Journal of Machine Learning Research, 2008. 047 * </pre> 048 * and for the original algorithm: 049 * <pre> 050 * Cortes C, Vapnik V. 051 * "Support-Vector Networks" 052 * Machine Learning, 1995. 053 * </pre> 054 */ 055public abstract class LibLinearModel<T extends Output<T>> extends Model<T> { 056 private static final long serialVersionUID = 3L; 057 058 private static final Logger logger = Logger.getLogger(LibLinearModel.class.getName()); 059 060 /** 061 * The list of LibLinear models. Multiple models are used by multi-label and multidimensional regression outputs. 062 */ 063 protected final List<de.bwaldvogel.liblinear.Model> models; 064 065 /** 066 * Constructs a LibLinear model from the supplied arguments. 067 * @param name The model name. 068 * @param description The model provenance. 069 * @param featureIDMap The features this model knows about. 070 * @param labelIDMap The outputs this model produces. 071 * @param generatesProbabilities Does this model generate probabilities? 072 * @param models The liblinear models themselves. 073 */ 074 protected LibLinearModel(String name, ModelProvenance description, ImmutableFeatureMap featureIDMap, ImmutableOutputInfo<T> labelIDMap, boolean generatesProbabilities, List<de.bwaldvogel.liblinear.Model> models) { 075 super(name, description, featureIDMap, labelIDMap, generatesProbabilities); 076 this.models = models; 077 Linear.disableDebugOutput(); 078 } 079 080 /** 081 * Returns an unmodifiable list containing a copy of each model. 082 * <p> 083 * As liblinear-java models don't expose a copy constructor this requires 084 * serializing each model to a String and rebuilding it, and is thus quite expensive. 085 * 086 * @return A copy of all of the models. 087 */ 088 public List<de.bwaldvogel.liblinear.Model> getInnerModels() { 089 List<de.bwaldvogel.liblinear.Model> copy = new ArrayList<>(); 090 091 for (de.bwaldvogel.liblinear.Model m : models) { 092 copy.add(copyModel(m)); 093 } 094 095 return Collections.unmodifiableList(copy); 096 } 097 098 /** 099 * This call is expensive as it copies out the weight matrix from the 100 * LibLinear model. 101 * <p> 102 * Prefer {@link LibLinearModel#getExcuses} to get multiple excuses. 103 * <p> 104 * @param e The example to excuse. 105 * @return An {@link Excuse} for this example. 106 */ 107 @Override 108 public Optional<Excuse<T>> getExcuse(Example<T> e) { 109 double[][] featureWeights = getFeatureWeights(); 110 return Optional.of(innerGetExcuse(e, featureWeights)); 111 } 112 113 @Override 114 public Optional<List<Excuse<T>>> getExcuses(Iterable<Example<T>> examples) { 115 //This call copies out the weights, so it's better to do it once 116 double[][] featureWeights = getFeatureWeights(); 117 List<Excuse<T>> excuses = new ArrayList<>(); 118 119 for (Example<T> e : examples) { 120 excuses.add(innerGetExcuse(e, featureWeights)); 121 } 122 123 return Optional.of(excuses); 124 } 125 126 /** 127 * Copies the model by writing it out to a String and loading it back in. 128 * <p> 129 * Unfortunately liblinear-java doesn't have a copy constructor on it's model. 130 * @param model The model to copy. 131 * @return A deep copy of the model. 132 */ 133 protected static de.bwaldvogel.liblinear.Model copyModel(de.bwaldvogel.liblinear.Model model) { 134 try { 135 StringWriter writer = new StringWriter(); 136 Linear.saveModel(writer,model); 137 String modelString = writer.toString(); 138 StringReader reader = new StringReader(modelString); 139 return Linear.loadModel(reader); 140 } catch (IOException e) { 141 throw new IllegalStateException("IOException found when copying the model in memory via a String.",e); 142 } 143 } 144 145 /** 146 * Extracts the feature weights from the models. 147 * The first dimension corresponds to the model index. 148 * @return The feature weights. 149 */ 150 protected abstract double[][] getFeatureWeights(); 151 152 /** 153 * The call to getFeatureWeights in the public methods copies the 154 * weights array so this inner method exists to save the copy in getExcuses. 155 * <p> 156 * If it becomes a problem then we could cache the feature weights in the 157 * model. 158 * <p> 159 * @param e The example. 160 * @param featureWeights The per dimension feature weights. 161 * @return An excuse for this example. 162 */ 163 protected abstract Excuse<T> innerGetExcuse(Example<T> e, double[][] featureWeights); 164}