001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo;
018
019import com.oracle.labs.mlrg.olcut.command.Command;
020import com.oracle.labs.mlrg.olcut.command.CommandGroup;
021import com.oracle.labs.mlrg.olcut.command.CommandInterpreter;
022import com.oracle.labs.mlrg.olcut.config.ConfigurationManager;
023import com.oracle.labs.mlrg.olcut.config.Option;
024import com.oracle.labs.mlrg.olcut.config.Options;
025import org.jline.builtins.Completers;
026import org.jline.reader.Completer;
027import org.jline.reader.impl.completer.NullCompleter;
028
029import java.io.BufferedInputStream;
030import java.io.File;
031import java.io.FileInputStream;
032import java.io.FileNotFoundException;
033import java.io.IOException;
034import java.io.ObjectInputStream;
035import java.util.logging.Level;
036import java.util.logging.Logger;
037
038/**
039 * A command line interface for loading in models and inspecting their feature and output spaces.
040 */
041public class ModelExplorer implements CommandGroup {
042    private static final Logger logger = Logger.getLogger(ModelExplorer.class.getName());
043
044    /**
045     * The command shell instance.
046     */
047    protected CommandInterpreter shell;
048
049    private Model<?> model;
050
051    /**
052     * Builds a new model explorer shell.
053     */
054    public ModelExplorer() {
055        shell = new CommandInterpreter();
056        shell.setPrompt("model sh% ");
057    }
058
059    @Override
060    public String getName() {
061        return "Model Explorer";
062    }
063
064    @Override
065    public String getDescription() {
066        return "Commands for inspecting a Model.";
067    }
068
069    /**
070     * Completers for files.
071     * @return The completers for file commands.
072     */
073    public Completer[] fileCompleter() {
074        return new Completer[]{
075                new Completers.FileNameCompleter(),
076                new NullCompleter()
077        };
078    }
079
080    /**
081     * Start the command shell
082     */
083    public void startShell() {
084        shell.add(this);
085        shell.start();
086    }
087
088    /**
089     * Loads a model.
090     * @param ci The shell instance.
091     * @param path The path to load.
092     * @return A status string.
093     */
094    @Command(usage = "<filename> - Load a model from disk.", completers="fileCompleter")
095    public String loadModel(CommandInterpreter ci, File path) {
096        try (ObjectInputStream ois = new ObjectInputStream(new BufferedInputStream(new FileInputStream(path)))) {
097            model = (Model<?>) ois.readObject();
098        } catch (ClassNotFoundException e) {
099            logger.log(Level.SEVERE,"Failed to load class from stream " + path.getAbsolutePath(),e);
100            return "Failed to load model";
101        } catch (FileNotFoundException e) {
102            logger.log(Level.SEVERE,"Failed to open file " + path.getAbsolutePath(),e);
103            return "Failed to load model";
104        } catch (IOException e) {
105            logger.log(Level.SEVERE,"IOException when reading from file " + path.getAbsolutePath(),e);
106            return "Failed to load model";
107        }
108
109        return "Loaded model from path " + path.toString();
110    }
111
112    /**
113     * Checks if the model generates probabilities.
114     * @param ci The command shell.
115     * @return A status string.
116     */
117    @Command(usage="Does the model generate probabilities")
118    public String generatesProbabilities(CommandInterpreter ci) {
119        return ""+model.generatesProbabilities();
120    }
121
122    /**
123     * Displays the model provenance.
124     * @param ci The command shell.
125     * @return A status string.
126     */
127    @Command(usage="Shows the model provenance")
128    public String modelProvenance(CommandInterpreter ci) {
129        return model.getProvenance().toString();
130    }
131
132    /**
133     * Shows a specific feature's information.
134     * @param ci The command shell.
135     * @param featureName The feature name.
136     * @return A status string.
137     */
138    @Command(usage="Shows the information on a particular feature")
139    public String featureInfo(CommandInterpreter ci, String featureName) {
140        VariableInfo f = model.getFeatureIDMap().get(featureName);
141        if (f != null) {
142            return "" + f.toString();
143        } else {
144            return "Feature " + featureName + " not found.";
145        }
146    }
147
148    /**
149     * Displays the output info.
150     * @param ci The command shell.
151     * @return A status string.
152     */
153    @Command(usage="Shows the output information.")
154    public String outputInfo(CommandInterpreter ci) {
155        return model.getOutputIDInfo().toReadableString();
156    }
157
158    /**
159     * Displays the top n features.
160     * @param ci The command shell
161     * @param numFeatures The number of features to display.
162     * @return A status string.
163     */
164    @Command(usage="<int> - Shows the top N features in the model")
165    public String topFeatures(CommandInterpreter ci, int numFeatures) {
166        return ""+ model.getTopFeatures(numFeatures);
167    }
168
169    /**
170     * Displays the number of features.
171     * @param ci The command shell.
172     * @return A status string.
173     */
174    @Command(usage="Shows the number of features in the model")
175    public String numFeatures(CommandInterpreter ci) {
176        return ""+ model.getFeatureIDMap().size();
177    }
178
179    /**
180     * Shows the number of features which occurred more than min count times.
181     * @param ci The command shell.
182     * @param minCount The minimum occurrence count.
183     * @return A status string.
184     */
185    @Command(usage="<min count> - Shows the number of features that occurred more than min count times.")
186    public String minCount(CommandInterpreter ci, int minCount) {
187        int counter = 0;
188        for (VariableInfo f : model.getFeatureIDMap()) {
189            if (f.getCount() > minCount) {
190                counter++;
191            }
192        }
193        return counter + " features occurred more than " + minCount + " times.";
194    }
195
196    /**
197     * CLI options for {@link ModelExplorer}.
198     */
199    public static class ModelExplorerOptions implements Options {
200        /**
201         * Model file to load.
202         */
203        @Option(charName='f',longName="filename",usage="Model file to load. Optional.")
204        public String modelFilename;
205    }
206
207    /**
208     * Entry point.
209     * @param args CLI args.
210     */
211    public static void main(String[] args) {
212        ModelExplorerOptions options = new ModelExplorerOptions();
213        ConfigurationManager cm = new ConfigurationManager(args,options,false);
214        ModelExplorer driver = new ModelExplorer();
215        if (options.modelFilename != null) {
216            logger.log(Level.INFO,driver.loadModel(driver.shell, new File(options.modelFilename)));
217        }
218        driver.startShell();
219    }
220}