001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.data; 018 019import com.oracle.labs.mlrg.olcut.command.Command; 020import com.oracle.labs.mlrg.olcut.command.CommandGroup; 021import com.oracle.labs.mlrg.olcut.command.CommandInterpreter; 022import com.oracle.labs.mlrg.olcut.config.ConfigurationManager; 023import com.oracle.labs.mlrg.olcut.config.Option; 024import com.oracle.labs.mlrg.olcut.config.Options; 025import org.tribuo.Dataset; 026import org.tribuo.VariableInfo; 027import org.tribuo.data.csv.CSVSaver; 028import org.jline.builtins.Completers; 029import org.jline.reader.Completer; 030import org.jline.reader.impl.completer.NullCompleter; 031 032import java.io.BufferedInputStream; 033import java.io.File; 034import java.io.FileInputStream; 035import java.io.FileNotFoundException; 036import java.io.IOException; 037import java.io.ObjectInputStream; 038import java.nio.file.Paths; 039import java.util.logging.Level; 040import java.util.logging.Logger; 041 042/** 043 * A CLI for exploring a serialised {@link Dataset}. 044 */ 045public final class DatasetExplorer implements CommandGroup { 046 private static final Logger logger = Logger.getLogger(DatasetExplorer.class.getName()); 047 048 protected CommandInterpreter shell; 049 050 private Dataset<?> dataset; 051 052 public DatasetExplorer() { 053 shell = new CommandInterpreter(); 054 shell.setPrompt("dataset sh% "); 055 } 056 057 @Override 058 public String getName() { 059 return "Dataset Explorer"; 060 } 061 062 @Override 063 public String getDescription() { 064 return "Commands for inspecting a Dataset."; 065 } 066 067 public Completer[] fileCompleter() { 068 return new Completer[]{ 069 new Completers.FileNameCompleter(), 070 new NullCompleter() 071 }; 072 } 073 074 /** 075 * Start the command shell 076 */ 077 public void startShell() { 078 shell.add(this); 079 shell.start(); 080 } 081 082 @Command(usage = "<filename> - Load a dataset from disk.", completers="fileCompleter") 083 public String loadDataset(CommandInterpreter ci, File path) { 084 try (ObjectInputStream ois = new ObjectInputStream(new BufferedInputStream(new FileInputStream(path)))) { 085 dataset = (Dataset<?>) ois.readObject(); 086 } catch (ClassNotFoundException e) { 087 logger.log(Level.SEVERE,"Failed to load class from stream " + path.getAbsolutePath(),e); 088 return "Failed to load dataset"; 089 } catch (FileNotFoundException e) { 090 logger.log(Level.SEVERE,"Failed to open file " + path.getAbsolutePath(),e); 091 return "Failed to load dataset"; 092 } catch (IOException e) { 093 logger.log(Level.SEVERE,"IOException when reading from file " + path.getAbsolutePath(),e); 094 return "Failed to load dataset"; 095 } 096 097 return "Loaded dataset from path " + path.toString(); 098 } 099 100 @Command(usage="Shows the information on a particular feature") 101 public String featureInfo(CommandInterpreter ci, String featureName) { 102 VariableInfo f = dataset.getFeatureMap().get(featureName); 103 if (f != null) { 104 return "" + f.toString(); 105 } else { 106 return "Feature " + featureName + " not found."; 107 } 108 } 109 110 @Command(usage="Shows the output information.") 111 public String outputInfo(CommandInterpreter ci) { 112 return dataset.getOutputInfo().toReadableString(); 113 } 114 115 @Command(usage="Shows the number of rows in the dataset") 116 public String numExamples(CommandInterpreter ci) { 117 return ""+dataset.getData().size(); 118 } 119 120 @Command(usage="Shows the number of features in the dataset") 121 public String numFeatures(CommandInterpreter ci) { 122 return ""+dataset.getFeatureMap().size(); 123 } 124 125 @Command(usage="<min count> - Shows the number of features that occurred more than min count times.") 126 public String minCount(CommandInterpreter ci, int minCount) { 127 int counter = 0; 128 for (VariableInfo f : dataset.getFeatureMap()) { 129 if (f.getCount() > minCount) { 130 counter++; 131 } 132 } 133 return counter + " features occurred more than " + minCount + " times."; 134 } 135 136 @Command(usage="Shows the output statistics") 137 public String showLabelStats(CommandInterpreter ci) { 138 return "Label histogram : \n" + dataset.getOutputInfo().toReadableString(); 139 } 140 141 @Command(usage="Saves out the data as a CSV.") 142 public String saveCSV(CommandInterpreter ci, String path) { 143 CSVSaver saver = new CSVSaver(); 144 try { 145 saver.save(Paths.get(path),dataset,CSVSaver.DEFAULT_RESPONSE); 146 return "Saved"; 147 } catch (IOException e) { 148 e.printStackTrace(ci.out); 149 return "Failed to save to CSV."; 150 } 151 } 152 153 @Command(usage="Shows the dataset provenance") 154 public String showProvenance(CommandInterpreter ci) { 155 return dataset.getProvenance().toString(); 156 } 157 158 public static class DatasetExplorerOptions implements Options { 159 @Option(charName='f',longName="filename",usage="Dataset file to load. Optional.") 160 public String modelFilename; 161 } 162 163 public static void main(String[] args) { 164 DatasetExplorerOptions options = new DatasetExplorerOptions(); 165 ConfigurationManager cm = new ConfigurationManager(args,options,false); 166 DatasetExplorer driver = new DatasetExplorer(); 167 if (options.modelFilename != null) { 168 logger.log(Level.INFO,driver.loadDataset(driver.shell, new File(options.modelFilename))); 169 } 170 driver.startShell(); 171 } 172}