001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.data.csv; 018 019import com.opencsv.CSVParserWriter; 020import com.opencsv.ICSVWriter; 021import com.opencsv.RFC4180ParserBuilder; 022import com.oracle.labs.mlrg.olcut.config.Config; 023import com.oracle.labs.mlrg.olcut.config.Configurable; 024import org.tribuo.Dataset; 025import org.tribuo.Example; 026import org.tribuo.Feature; 027import org.tribuo.ImmutableFeatureMap; 028import org.tribuo.Output; 029import org.tribuo.VariableIDInfo; 030import org.tribuo.VariableInfo; 031 032import java.io.IOException; 033import java.nio.charset.StandardCharsets; 034import java.nio.file.Files; 035import java.nio.file.Path; 036import java.util.Arrays; 037import java.util.Collections; 038import java.util.HashMap; 039import java.util.List; 040import java.util.Map; 041import java.util.Set; 042import java.util.logging.Logger; 043 044/** 045 * Saves a Dataset in CSV format suitable for loading by {@link CSVLoader}. 046 * <p> 047 * CSVSaver is thread safe and immutable. 048 */ 049public class CSVSaver implements Configurable { 050 051 public final static String DEFAULT_RESPONSE = "Response"; 052 private static final Logger logger = Logger.getLogger(CSVSaver.class.getName()); 053 054 @Config(description="The column separator.") 055 private char separator = CSVIterator.SEPARATOR; 056 @Config(description="The quote character.") 057 private char quote = CSVIterator.QUOTE; 058 059 /** 060 * Builds a CSV saver using the supplied separator and quote. 061 * @param separator The column separator. 062 * @param quote The quote character. 063 */ 064 public CSVSaver(char separator, char quote) { 065 if (separator == quote) { 066 throw new IllegalArgumentException("Quote and separator must be different characters."); 067 } 068 this.separator = separator; 069 this.quote = quote; 070 } 071 072 /** 073 * Builds a CSV saver using the default separator and quote from {@link CSVIterator}. 074 */ 075 public CSVSaver() { 076 this(CSVIterator.SEPARATOR, CSVIterator.QUOTE); 077 } 078 079 /** 080 * Saves the dataset to the specified path. 081 * @param csvPath The path to save to. 082 * @param dataset The dataset to save. 083 * @param responseName The name of the response variable. 084 * @param <T> The output type. 085 * @throws IOException If the disk write failed. 086 */ 087 public <T extends Output<T>> void save(Path csvPath, Dataset<T> dataset, String responseName) throws IOException { 088 save(csvPath, dataset, Collections.singleton(responseName)); 089 } 090 091 /** 092 * Saves the dataset to the specified path. 093 * @param csvPath The path to save to. 094 * @param dataset The dataset to save. 095 * @param responseNames The response names set. 096 * @param <T> The output type. 097 * @throws IOException If the disk write failed. 098 */ 099 public <T extends Output<T>> void save(Path csvPath, Dataset<T> dataset, Set<String> responseNames) throws IOException { 100 boolean isMultiOutput = responseNames.size() > 1; 101 ImmutableFeatureMap features = dataset.getFeatureIDMap(); 102 int ncols = features.size() + responseNames.size(); 103 // 104 // Initialize the CSV header row. 105 String[] headerLine = new String[ncols]; 106 Map<String, Integer> responseToColumn = new HashMap<>(); 107 int col = 0; 108 for (String response : responseNames) { 109 headerLine[col] = response; 110 responseToColumn.put(response, col); 111 col++; 112 } 113 for (VariableInfo feature : features) { 114 headerLine[col++] = feature.getName(); 115 } 116 // 117 // Write the CSV 118 try (ICSVWriter writer = new CSVParserWriter( 119 Files.newBufferedWriter(csvPath, StandardCharsets.UTF_8), 120 new RFC4180ParserBuilder() 121 .withSeparator(separator) 122 .withQuoteChar(quote) 123 .build(), "\n")) { 124 125 writer.writeNext(headerLine); 126 127 for (Example<T> e : dataset) { 128 String[] denseOutput = (isMultiOutput) ? 129 densifyMultiOutput(e, responseToColumn) : 130 densifySingleOutput(e); 131 String[] featureArr = generateFeatureArray(e, features); 132 if (featureArr.length != features.size()) { 133 throw new IllegalStateException(String.format("Invalid example: had %d features, expected %d.", featureArr.length, features.size())); 134 } 135 // 136 // Copy responses and features into a single array 137 String[] line = new String[ncols]; 138 System.arraycopy(denseOutput, 0, line, 0, denseOutput.length); 139 System.arraycopy(featureArr, 0, line, denseOutput.length, featureArr.length); 140 writer.writeNext(line); 141 } 142 } 143 } 144 145 private static <T extends Output<T>> String[] densifySingleOutput(Example<T> e) { 146 return new String[]{e.getOutput().getSerializableForm(false)}; 147 } 148 149 private static <T extends Output<T>> String[] densifyMultiOutput(Example<T> e, Map<String, Integer> responseToColumn) { 150 String[] denseOutput = new String[responseToColumn.size()]; 151 // 152 // Initialize to false/0 everywhere 153 // TODO ^^^ seems bad to me. maybe OutputFactory could give us a "zero-value" in addition to an "unknown-value". 154 //TODO: maybe this instead? outputFactory.getUnknownOutput().toString(); 155 Arrays.fill(denseOutput, "0"); 156 157 // 158 // Convert sparse output to a dense format 159 // Sparse output format: "a=true,b=true..." for classification or "a=0.0,b=2.0..." for regression 160 String csv = e.getOutput().getSerializableForm(false); 161 if (csv.isEmpty()) { 162 // 163 // If the string is empty, then the denseOutput will be false/0 everywhere. 164 return denseOutput; 165 } 166 String[] sparseOutput = csv.split(","); // TODO should comma be hard-coded into this 'split' call? 167 168 for (String elem : sparseOutput) { 169 String[] kv = elem.split("="); 170 if (kv.length != 2) { 171 throw new IllegalArgumentException("Bad serialized string element: '" + elem + "'"); 172 } 173 String responseName = kv[0]; 174 String responseValue = kv[1]; 175 int index = responseToColumn.getOrDefault(responseName,-1); 176 if (index == -1) { 177 // 178 // We have to check for a special-case here: 179 // In the multi-output case, we might have a CSV like the following: 180 // 181 // Feature1,Feature2,...,Label1,Label2 182 // 1.0,0.5,...,False,False 183 // 184 // In this case, where we have false for all labels, the multi-output label will just be the 185 // empty string. In the single label case, an empty-string label should be an error. 186 if (responseName.equals("")) { 187 continue; 188 } else { 189 throw new IllegalStateException(String.format("Invalid example: unknown response name '%s'. (known response names: %s)", responseName, responseToColumn.keySet())); 190 } 191 } 192 denseOutput[index] = responseValue; 193 } 194 195 return denseOutput; 196 } 197 198 /** 199 * Converts an Example's features into a dense row, filling in unobserved values with 0. 200 * @param example The example to convert. 201 * @param features The featureMap to use for the ids. 202 * @return A String array, one element per feature plus the output. 203 */ 204 private static <T extends Output<T>> String[] generateFeatureArray(Example<T> example, ImmutableFeatureMap features) { 205 String[] output = new String[features.size()]; 206 HashMap<Integer,Double> featureMap = new HashMap<>(); 207 for (Feature f : example) { 208 VariableIDInfo info = features.get(f.getName()); 209 if (info != null) { 210 featureMap.put(info.getID(), f.getValue()); 211 } 212 } 213 for (int i = 0; i < features.size(); i++) { 214 Double curFeature = featureMap.get(i); 215 if (curFeature == null) { 216 output[i] = "0"; 217 } else { 218 output[i] = curFeature.toString(); 219 } 220 } 221 return output; 222 } 223 224}