001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.sequence;
018
019import com.oracle.labs.mlrg.olcut.command.Command;
020import com.oracle.labs.mlrg.olcut.command.CommandGroup;
021import com.oracle.labs.mlrg.olcut.command.CommandInterpreter;
022import com.oracle.labs.mlrg.olcut.config.ConfigurationManager;
023import com.oracle.labs.mlrg.olcut.config.Option;
024import com.oracle.labs.mlrg.olcut.config.Options;
025import org.tribuo.VariableInfo;
026import org.jline.builtins.Completers;
027import org.jline.reader.Completer;
028import org.jline.reader.impl.completer.NullCompleter;
029
030import java.io.BufferedInputStream;
031import java.io.File;
032import java.io.FileInputStream;
033import java.io.FileNotFoundException;
034import java.io.IOException;
035import java.io.ObjectInputStream;
036import java.util.logging.Level;
037import java.util.logging.Logger;
038
039/**
040 * A CLI for interacting with a {@link SequenceModel}.
041 */
042public class SequenceModelExplorer implements CommandGroup {
043    private static final Logger logger = Logger.getLogger(SequenceModelExplorer.class.getName());
044
045    protected CommandInterpreter shell;
046
047    private SequenceModel<?> model;
048
049    public SequenceModelExplorer() {
050        shell = new CommandInterpreter();
051        shell.setPrompt("model sh% ");
052    }
053
054    @Override
055    public String getName() {
056        return "Sequence Model Explorer";
057    }
058
059    @Override
060    public String getDescription() {
061        return "Commands for inspecting a SequenceModel.";
062    }
063
064    public Completer[] fileCompleter() {
065        return new Completer[]{
066                new Completers.FileNameCompleter(),
067                new NullCompleter()
068        };
069    }
070
071    /**
072     * Start the command shell
073     */
074    public void startShell() {
075        shell.add(this);
076        shell.start();
077    }
078
079    @Command(usage = "<filename> - Load a model from disk.", completers="fileCompleter")
080    public String loadModel(CommandInterpreter ci, File path) {
081        try (ObjectInputStream ois = new ObjectInputStream(new BufferedInputStream(new FileInputStream(path)))) {
082            model = (SequenceModel<?>) ois.readObject();
083        } catch (ClassNotFoundException e) {
084            logger.log(Level.SEVERE,"Failed to load class from stream " + path.getAbsolutePath(),e);
085            return "Failed to load model";
086        } catch (FileNotFoundException e) {
087            logger.log(Level.SEVERE,"Failed to open file " + path.getAbsolutePath(),e);
088            return "Failed to load model";
089        } catch (IOException e) {
090            logger.log(Level.SEVERE,"IOException when reading from file " + path.getAbsolutePath(),e);
091            return "Failed to load model";
092        }
093
094        return "Loaded model from path " + path.toString();
095    }
096
097    @Command(usage="Shows the model description")
098    public String modelDescription(CommandInterpreter ci) {
099        return model.toString();
100    }
101
102    @Command(usage="Shows the information on a particular feature")
103    public String featureInfo(CommandInterpreter ci, String featureName) {
104        VariableInfo f = model.getFeatureIDMap().get(featureName);
105        if (f != null) {
106            return "" + f.toString();
107        } else {
108            return "Feature " + featureName + " not found.";
109        }
110    }
111
112    @Command(usage="Shows the output information.")
113    public String outputInfo(CommandInterpreter ci) {
114        return model.getOutputIDInfo().toReadableString();
115    }
116
117    @Command(usage="<int> - Shows the top N features in the model")
118    public String topFeatures(CommandInterpreter ci, int numFeatures) {
119        return ""+ model.getTopFeatures(numFeatures);
120    }
121
122    @Command(usage="Shows the number of features in the model")
123    public String numFeatures(CommandInterpreter ci) {
124        return ""+ model.getFeatureIDMap().size();
125    }
126
127    @Command(usage="<min count> - Shows the number of features that occurred more than min count times.")
128    public String minCount(CommandInterpreter ci, int minCount) {
129        int counter = 0;
130        for (VariableInfo f : model.getFeatureIDMap()) {
131            if (f.getCount() > minCount) {
132                counter++;
133            }
134        }
135        return counter + " features occurred more than " + minCount + " times.";
136    }
137
138    public static String usage() {
139        StringBuilder string = new StringBuilder();
140        string.append("Usage: ModelExplorer\n");
141
142        string.append("Optional parameters\n");
143        string.append("     -f <model-filename>\n");
144        string.append("         Load in a model from file.\n");
145
146        return string.toString();
147    }
148
149    public static class SequenceModelExplorerOptions implements Options {
150        @Option(charName='f',longName="filename",usage="Model file to load. Optional.")
151        public String modelFilename;
152    }
153
154    public static void main(String[] args) {
155        SequenceModelExplorerOptions options = new SequenceModelExplorerOptions();
156        ConfigurationManager cm = new ConfigurationManager(args,options,false);
157        SequenceModelExplorer driver = new SequenceModelExplorer();
158        if (options.modelFilename != null) {
159            logger.log(Level.INFO,driver.loadModel(driver.shell, new File(options.modelFilename)));
160        }
161        driver.startShell();
162    }
163}