001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.data.csv;
018
019import com.opencsv.CSVParserWriter;
020import com.opencsv.ICSVWriter;
021import com.opencsv.RFC4180ParserBuilder;
022import com.oracle.labs.mlrg.olcut.config.Config;
023import com.oracle.labs.mlrg.olcut.config.Configurable;
024import org.tribuo.Dataset;
025import org.tribuo.Example;
026import org.tribuo.Feature;
027import org.tribuo.ImmutableFeatureMap;
028import org.tribuo.Output;
029import org.tribuo.VariableIDInfo;
030import org.tribuo.VariableInfo;
031
032import java.io.IOException;
033import java.nio.charset.StandardCharsets;
034import java.nio.file.Files;
035import java.nio.file.Path;
036import java.util.Arrays;
037import java.util.Collections;
038import java.util.HashMap;
039import java.util.List;
040import java.util.Map;
041import java.util.Set;
042import java.util.logging.Logger;
043
044/**
045 * Saves a Dataset in CSV format suitable for loading by {@link CSVLoader}.
046 * <p>
047 * CSVSaver is thread safe and immutable.
048 */
049public class CSVSaver implements Configurable {
050
051    public final static String DEFAULT_RESPONSE = "Response";
052    private static final Logger logger = Logger.getLogger(CSVSaver.class.getName());
053
054    @Config(description="The column separator.")
055    private char separator = CSVIterator.SEPARATOR;
056    @Config(description="The quote character.")
057    private char quote = CSVIterator.QUOTE;
058
059    /**
060     * Builds a CSV saver using the supplied separator and quote.
061     * @param separator The column separator.
062     * @param quote The quote character.
063     */
064    public CSVSaver(char separator, char quote) {
065        if (separator == quote) {
066            throw new IllegalArgumentException("Quote and separator must be different characters.");
067        }
068        this.separator = separator;
069        this.quote = quote;
070    }
071
072    /**
073     * Builds a CSV saver using the default separator and quote from {@link CSVIterator}.
074     */
075    public CSVSaver() {
076        this(CSVIterator.SEPARATOR, CSVIterator.QUOTE);
077    }
078
079    /**
080     * Saves the dataset to the specified path.
081     * @param csvPath The path to save to.
082     * @param dataset The dataset to save.
083     * @param responseName The name of the response variable.
084     * @param <T> The output type.
085     * @throws IOException If the disk write failed.
086     */
087    public <T extends Output<T>> void save(Path csvPath, Dataset<T> dataset, String responseName) throws IOException {
088        save(csvPath, dataset, Collections.singleton(responseName));
089    }
090
091    /**
092     * Saves the dataset to the specified path.
093     * @param csvPath The path to save to.
094     * @param dataset The dataset to save.
095     * @param responseNames The response names set.
096     * @param <T> The output type.
097     * @throws IOException If the disk write failed.
098     */
099    public <T extends Output<T>> void save(Path csvPath, Dataset<T> dataset, Set<String> responseNames) throws IOException {
100        boolean isMultiOutput = responseNames.size() > 1;
101        ImmutableFeatureMap features = dataset.getFeatureIDMap();
102        int ncols = features.size() + responseNames.size();
103        //
104        // Initialize the CSV header row.
105        String[] headerLine = new String[ncols];
106        Map<String, Integer> responseToColumn = new HashMap<>();
107        int col = 0;
108        for (String response : responseNames) {
109            headerLine[col] = response;
110            responseToColumn.put(response, col);
111            col++;
112        }
113        for (VariableInfo feature : features) {
114            headerLine[col++] = feature.getName();
115        }
116        //
117        // Write the CSV
118        try (ICSVWriter writer = new CSVParserWriter(
119                Files.newBufferedWriter(csvPath, StandardCharsets.UTF_8),
120                new RFC4180ParserBuilder()
121                        .withSeparator(separator)
122                        .withQuoteChar(quote)
123                        .build(), "\n")) {
124
125            writer.writeNext(headerLine);
126
127            for (Example<T> e : dataset) {
128                String[] denseOutput = (isMultiOutput) ?
129                        densifyMultiOutput(e, responseToColumn) :
130                        densifySingleOutput(e);
131                String[] featureArr = generateFeatureArray(e, features);
132                if (featureArr.length != features.size()) {
133                    throw new IllegalStateException(String.format("Invalid example: had %d features, expected %d.", featureArr.length, features.size()));
134                }
135                //
136                // Copy responses and features into a single array
137                String[] line = new String[ncols];
138                System.arraycopy(denseOutput, 0, line, 0, denseOutput.length);
139                System.arraycopy(featureArr, 0, line, denseOutput.length, featureArr.length);
140                writer.writeNext(line);
141            }
142        }
143    }
144
145    private static <T extends Output<T>> String[] densifySingleOutput(Example<T> e) {
146        return new String[]{e.getOutput().getSerializableForm(false)};
147    }
148
149    private static <T extends Output<T>> String[] densifyMultiOutput(Example<T> e, Map<String, Integer> responseToColumn) {
150        String[] denseOutput = new String[responseToColumn.size()];
151        //
152        // Initialize to false/0 everywhere
153        // TODO ^^^ seems bad to me. maybe OutputFactory could give us a "zero-value" in addition to an "unknown-value".
154        //TODO: maybe this instead? outputFactory.getUnknownOutput().toString();
155        Arrays.fill(denseOutput, "0");
156
157        //
158        // Convert sparse output to a dense format
159        // Sparse output format: "a=true,b=true..." for classification or "a=0.0,b=2.0..." for regression
160        String csv = e.getOutput().getSerializableForm(false);
161        if (csv.isEmpty()) {
162            //
163            // If the string is empty, then the denseOutput will be false/0 everywhere.
164            return denseOutput;
165        }
166        String[] sparseOutput = csv.split(","); // TODO should comma be hard-coded into this 'split' call?
167
168        for (String elem : sparseOutput) {
169            String[] kv = elem.split("=");
170            if (kv.length != 2) {
171                throw new IllegalArgumentException("Bad serialized string element: '" + elem + "'");
172            }
173            String responseName = kv[0];
174            String responseValue = kv[1];
175            int index = responseToColumn.getOrDefault(responseName,-1);
176            if (index == -1) {
177                //
178                // We have to check for a special-case here:
179                // In the multi-output case, we might have a CSV like the following:
180                //
181                // Feature1,Feature2,...,Label1,Label2
182                // 1.0,0.5,...,False,False
183                //
184                // In this case, where we have false for all labels, the multi-output label will just be the
185                // empty string. In the single label case, an empty-string label should be an error.
186                if (responseName.equals("")) {
187                    continue;
188                } else {
189                    throw new IllegalStateException(String.format("Invalid example: unknown response name '%s'. (known response names: %s)", responseName, responseToColumn.keySet()));
190                }
191            }
192            denseOutput[index] = responseValue;
193        }
194
195        return denseOutput;
196    }
197
198    /**
199     * Converts an Example's features into a dense row, filling in unobserved values with 0.
200     * @param example The example to convert.
201     * @param features The featureMap to use for the ids.
202     * @return A String array, one element per feature plus the output.
203     */
204    private static <T extends Output<T>> String[] generateFeatureArray(Example<T> example, ImmutableFeatureMap features) {
205        String[] output = new String[features.size()];
206        HashMap<Integer,Double> featureMap = new HashMap<>();
207        for (Feature f : example) {
208            VariableIDInfo info = features.get(f.getName());
209            if (info != null) {
210                featureMap.put(info.getID(), f.getValue());
211            }
212        }
213        for (int i = 0; i < features.size(); i++) {
214            Double curFeature = featureMap.get(i);
215            if (curFeature == null) {
216                output[i] = "0";
217            } else {
218                output[i] = curFeature.toString();
219            }
220        }
221        return output;
222    }
223
224}