001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo; 018 019import com.oracle.labs.mlrg.olcut.command.Command; 020import com.oracle.labs.mlrg.olcut.command.CommandGroup; 021import com.oracle.labs.mlrg.olcut.command.CommandInterpreter; 022import com.oracle.labs.mlrg.olcut.config.ConfigurationManager; 023import com.oracle.labs.mlrg.olcut.config.Option; 024import com.oracle.labs.mlrg.olcut.config.Options; 025import org.jline.builtins.Completers; 026import org.jline.reader.Completer; 027import org.jline.reader.impl.completer.NullCompleter; 028 029import java.io.BufferedInputStream; 030import java.io.File; 031import java.io.FileInputStream; 032import java.io.FileNotFoundException; 033import java.io.IOException; 034import java.io.ObjectInputStream; 035import java.util.logging.Level; 036import java.util.logging.Logger; 037 038/** 039 * A command line interface for loading in models and inspecting their feature and output spaces. 040 */ 041public class ModelExplorer implements CommandGroup { 042 private static final Logger logger = Logger.getLogger(ModelExplorer.class.getName()); 043 044 /** 045 * The command shell instance. 046 */ 047 protected CommandInterpreter shell; 048 049 private Model<?> model; 050 051 /** 052 * Builds a new model explorer shell. 053 */ 054 public ModelExplorer() { 055 shell = new CommandInterpreter(); 056 shell.setPrompt("model sh% "); 057 } 058 059 @Override 060 public String getName() { 061 return "Model Explorer"; 062 } 063 064 @Override 065 public String getDescription() { 066 return "Commands for inspecting a Model."; 067 } 068 069 /** 070 * Completers for files. 071 * @return The completers for file commands. 072 */ 073 public Completer[] fileCompleter() { 074 return new Completer[]{ 075 new Completers.FileNameCompleter(), 076 new NullCompleter() 077 }; 078 } 079 080 /** 081 * Start the command shell 082 */ 083 public void startShell() { 084 shell.add(this); 085 shell.start(); 086 } 087 088 /** 089 * Loads a model. 090 * @param ci The shell instance. 091 * @param path The path to load. 092 * @return A status string. 093 */ 094 @Command(usage = "<filename> - Load a model from disk.", completers="fileCompleter") 095 public String loadModel(CommandInterpreter ci, File path) { 096 try (ObjectInputStream ois = new ObjectInputStream(new BufferedInputStream(new FileInputStream(path)))) { 097 model = (Model<?>) ois.readObject(); 098 } catch (ClassNotFoundException e) { 099 logger.log(Level.SEVERE,"Failed to load class from stream " + path.getAbsolutePath(),e); 100 return "Failed to load model"; 101 } catch (FileNotFoundException e) { 102 logger.log(Level.SEVERE,"Failed to open file " + path.getAbsolutePath(),e); 103 return "Failed to load model"; 104 } catch (IOException e) { 105 logger.log(Level.SEVERE,"IOException when reading from file " + path.getAbsolutePath(),e); 106 return "Failed to load model"; 107 } 108 109 return "Loaded model from path " + path.toString(); 110 } 111 112 /** 113 * Checks if the model generates probabilities. 114 * @param ci The command shell. 115 * @return A status string. 116 */ 117 @Command(usage="Does the model generate probabilities") 118 public String generatesProbabilities(CommandInterpreter ci) { 119 return ""+model.generatesProbabilities(); 120 } 121 122 /** 123 * Displays the model provenance. 124 * @param ci The command shell. 125 * @return A status string. 126 */ 127 @Command(usage="Shows the model provenance") 128 public String modelProvenance(CommandInterpreter ci) { 129 return model.getProvenance().toString(); 130 } 131 132 /** 133 * Shows a specific feature's information. 134 * @param ci The command shell. 135 * @param featureName The feature name. 136 * @return A status string. 137 */ 138 @Command(usage="Shows the information on a particular feature") 139 public String featureInfo(CommandInterpreter ci, String featureName) { 140 VariableInfo f = model.getFeatureIDMap().get(featureName); 141 if (f != null) { 142 return "" + f.toString(); 143 } else { 144 return "Feature " + featureName + " not found."; 145 } 146 } 147 148 /** 149 * Displays the output info. 150 * @param ci The command shell. 151 * @return A status string. 152 */ 153 @Command(usage="Shows the output information.") 154 public String outputInfo(CommandInterpreter ci) { 155 return model.getOutputIDInfo().toReadableString(); 156 } 157 158 /** 159 * Displays the top n features. 160 * @param ci The command shell 161 * @param numFeatures The number of features to display. 162 * @return A status string. 163 */ 164 @Command(usage="<int> - Shows the top N features in the model") 165 public String topFeatures(CommandInterpreter ci, int numFeatures) { 166 return ""+ model.getTopFeatures(numFeatures); 167 } 168 169 /** 170 * Displays the number of features. 171 * @param ci The command shell. 172 * @return A status string. 173 */ 174 @Command(usage="Shows the number of features in the model") 175 public String numFeatures(CommandInterpreter ci) { 176 return ""+ model.getFeatureIDMap().size(); 177 } 178 179 /** 180 * Shows the number of features which occurred more than min count times. 181 * @param ci The command shell. 182 * @param minCount The minimum occurrence count. 183 * @return A status string. 184 */ 185 @Command(usage="<min count> - Shows the number of features that occurred more than min count times.") 186 public String minCount(CommandInterpreter ci, int minCount) { 187 int counter = 0; 188 for (VariableInfo f : model.getFeatureIDMap()) { 189 if (f.getCount() > minCount) { 190 counter++; 191 } 192 } 193 return counter + " features occurred more than " + minCount + " times."; 194 } 195 196 /** 197 * CLI options for {@link ModelExplorer}. 198 */ 199 public static class ModelExplorerOptions implements Options { 200 /** 201 * Model file to load. 202 */ 203 @Option(charName='f',longName="filename",usage="Model file to load. Optional.") 204 public String modelFilename; 205 } 206 207 /** 208 * Entry point. 209 * @param args CLI args. 210 */ 211 public static void main(String[] args) { 212 ModelExplorerOptions options = new ModelExplorerOptions(); 213 ConfigurationManager cm = new ConfigurationManager(args,options,false); 214 ModelExplorer driver = new ModelExplorer(); 215 if (options.modelFilename != null) { 216 logger.log(Level.INFO,driver.loadModel(driver.shell, new File(options.modelFilename))); 217 } 218 driver.startShell(); 219 } 220}