001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.data;
018
019import com.oracle.labs.mlrg.olcut.command.Command;
020import com.oracle.labs.mlrg.olcut.command.CommandGroup;
021import com.oracle.labs.mlrg.olcut.command.CommandInterpreter;
022import com.oracle.labs.mlrg.olcut.config.ConfigurationManager;
023import com.oracle.labs.mlrg.olcut.config.Option;
024import com.oracle.labs.mlrg.olcut.config.Options;
025import org.tribuo.Dataset;
026import org.tribuo.VariableInfo;
027import org.tribuo.data.csv.CSVSaver;
028import org.jline.builtins.Completers;
029import org.jline.reader.Completer;
030import org.jline.reader.impl.completer.NullCompleter;
031
032import java.io.BufferedInputStream;
033import java.io.File;
034import java.io.FileInputStream;
035import java.io.FileNotFoundException;
036import java.io.IOException;
037import java.io.ObjectInputStream;
038import java.nio.file.Paths;
039import java.util.logging.Level;
040import java.util.logging.Logger;
041
042/**
043 * A CLI for exploring a serialised {@link Dataset}.
044 */
045public final class DatasetExplorer implements CommandGroup {
046    private static final Logger logger = Logger.getLogger(DatasetExplorer.class.getName());
047
048    protected CommandInterpreter shell;
049
050    private Dataset<?> dataset;
051
052    public DatasetExplorer() {
053        shell = new CommandInterpreter();
054        shell.setPrompt("dataset sh% ");
055    }
056
057    @Override
058    public String getName() {
059        return "Dataset Explorer";
060    }
061
062    @Override
063    public String getDescription() {
064        return "Commands for inspecting a Dataset.";
065    }
066
067    public Completer[] fileCompleter() {
068        return new Completer[]{
069                new Completers.FileNameCompleter(),
070                new NullCompleter()
071        };
072    }
073
074    /**
075     * Start the command shell
076     */
077    public void startShell() {
078        shell.add(this);
079        shell.start();
080    }
081
082    @Command(usage = "<filename> - Load a dataset from disk.", completers="fileCompleter")
083    public String loadDataset(CommandInterpreter ci, File path) {
084        try (ObjectInputStream ois = new ObjectInputStream(new BufferedInputStream(new FileInputStream(path)))) {
085            dataset = (Dataset<?>) ois.readObject();
086        } catch (ClassNotFoundException e) {
087            logger.log(Level.SEVERE,"Failed to load class from stream " + path.getAbsolutePath(),e);
088            return "Failed to load dataset";
089        } catch (FileNotFoundException e) {
090            logger.log(Level.SEVERE,"Failed to open file " + path.getAbsolutePath(),e);
091            return "Failed to load dataset";
092        } catch (IOException e) {
093            logger.log(Level.SEVERE,"IOException when reading from file " + path.getAbsolutePath(),e);
094            return "Failed to load dataset";
095        }
096
097        return "Loaded dataset from path " + path.toString();
098    }
099
100    @Command(usage="Shows the information on a particular feature")
101    public String featureInfo(CommandInterpreter ci, String featureName) {
102        VariableInfo f = dataset.getFeatureMap().get(featureName);
103        if (f != null) {
104            return "" + f.toString();
105        } else {
106            return "Feature " + featureName + " not found.";
107        }
108    }
109
110    @Command(usage="Shows the output information.")
111    public String outputInfo(CommandInterpreter ci) {
112        return dataset.getOutputInfo().toReadableString();
113    }
114
115    @Command(usage="Shows the number of rows in the dataset")
116    public String numExamples(CommandInterpreter ci) {
117        return ""+dataset.getData().size();
118    }
119
120    @Command(usage="Shows the number of features in the dataset")
121    public String numFeatures(CommandInterpreter ci) {
122        return ""+dataset.getFeatureMap().size();
123    }
124
125    @Command(usage="<min count> - Shows the number of features that occurred more than min count times.")
126    public String minCount(CommandInterpreter ci, int minCount) {
127        int counter = 0;
128        for (VariableInfo f : dataset.getFeatureMap()) {
129            if (f.getCount() > minCount) {
130                counter++;
131            }
132        }
133        return counter + " features occurred more than " + minCount + " times.";
134    }
135
136    @Command(usage="Shows the output statistics")
137    public String showLabelStats(CommandInterpreter ci) {
138        return "Label histogram : \n" + dataset.getOutputInfo().toReadableString();
139    }
140
141    @Command(usage="Saves out the data as a CSV.")
142    public String saveCSV(CommandInterpreter ci, String path) {
143        CSVSaver saver = new CSVSaver();
144        try {
145            saver.save(Paths.get(path),dataset,CSVSaver.DEFAULT_RESPONSE);
146            return "Saved";
147        } catch (IOException e) {
148            e.printStackTrace(ci.out);
149            return "Failed to save to CSV.";
150        }
151    }
152
153    @Command(usage="Shows the dataset provenance")
154    public String showProvenance(CommandInterpreter ci) {
155        return dataset.getProvenance().toString();
156    }
157
158    public static class DatasetExplorerOptions implements Options {
159        @Option(charName='f',longName="filename",usage="Dataset file to load. Optional.")
160        public String modelFilename;
161    }
162
163    public static void main(String[] args) {
164        DatasetExplorerOptions options = new DatasetExplorerOptions();
165        ConfigurationManager cm = new ConfigurationManager(args,options,false);
166        DatasetExplorer driver = new DatasetExplorer();
167        if (options.modelFilename != null) {
168            logger.log(Level.INFO,driver.loadDataset(driver.shell, new File(options.modelFilename)));
169        }
170        driver.startShell();
171    }
172}