001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.data.csv;
018
019import com.opencsv.CSVReader;
020import com.opencsv.CSVReaderBuilder;
021import com.opencsv.RFC4180ParserBuilder;
022import com.opencsv.exceptions.CsvValidationException;
023import org.tribuo.data.columnar.ColumnarIterator;
024
025import java.io.BufferedReader;
026import java.io.IOException;
027import java.io.InputStreamReader;
028import java.io.Reader;
029import java.net.URI;
030import java.nio.charset.StandardCharsets;
031import java.nio.file.Files;
032import java.nio.file.Paths;
033import java.util.Arrays;
034import java.util.Collections;
035import java.util.HashMap;
036import java.util.List;
037import java.util.Map;
038import java.util.Optional;
039import java.util.logging.Level;
040import java.util.logging.Logger;
041
042/**
043 * An iterator over a CSV file.
044 */
045public class CSVIterator extends ColumnarIterator implements AutoCloseable {
046    private static final Logger logger = Logger.getLogger(CSVIterator.class.getName());
047
048    /**
049     * Default separator character.
050     */
051    public final static char SEPARATOR = ',';
052
053    /**
054     * Default quote character.
055     */
056    public final static char QUOTE = '"';
057
058    private final CSVReader reader;
059
060    // Used as the row index
061    private int recordNum = 0;
062
063    /**
064     * Builds a CSVIterator for the supplied Reader. Defaults to {@link CSVIterator#SEPARATOR} for the separator
065     * and {@link CSVIterator#QUOTE} for the quote.
066     * @param rdr The source to read.
067     */
068    public CSVIterator(Reader rdr) {
069        this(rdr, SEPARATOR, QUOTE);
070    }
071
072    /**
073     * Builds a CSVIterator for the supplied Reader.
074     * @param rdr The source to read.
075     * @param separator The separator character to use.
076     * @param quote The quote character to use.
077     */
078    public CSVIterator(Reader rdr, char separator, char quote) {
079        this(rdr, separator, quote, Collections.emptyList());
080    }
081
082    /**
083     * Builds a CSVIterator for the supplied URI. Defaults to {@link CSVIterator#SEPARATOR} for the separator
084     * and {@link CSVIterator#QUOTE} for the quote.
085     * @param dataFile The source to read.
086     * @throws IOException thrown if the file is not readable in some way.
087     */
088    public CSVIterator(URI dataFile) throws IOException {
089        this(new InputStreamReader(Files.newInputStream(Paths.get(dataFile)), StandardCharsets.UTF_8));
090    }
091
092    /**
093     * Builds a CSVIterator for the supplied URI.
094     * @param dataFile The source to read.
095     * @param separator The separator character to use.
096     * @param quote The quote character to use.
097     * @throws IOException thrown if the file is not readable in some way.
098     */
099    public CSVIterator(URI dataFile, char separator, char quote) throws IOException {
100        this(new InputStreamReader(Files.newInputStream(Paths.get(dataFile)), StandardCharsets.UTF_8), separator, quote);
101    }
102
103    /**
104     * Builds a CSVIterator for the supplied URI.
105     * @param dataFile The source to read.
106     * @param separator The separator character to use.
107     * @param quote The quote character to use.
108     * @param fields The headers to use.
109     * @throws IOException thrown if the file is not readable in some way.
110     */
111    public CSVIterator(URI dataFile, char separator, char quote, String[] fields) throws IOException {
112        this(new InputStreamReader(Files.newInputStream(Paths.get(dataFile)), StandardCharsets.UTF_8), separator, quote, Arrays.asList(fields));
113    }
114
115    /**
116     * Builds a CSVIterator for the supplied URI.
117     * @param dataFile The source to read.
118     * @param separator The separator character to use.
119     * @param quote The quote character to use.
120     * @param fields The headers to use.
121     * @throws IOException thrown if the file is not readable in some way.
122     */
123    public CSVIterator(URI dataFile, char separator, char quote, List<String> fields) throws IOException {
124        this(new InputStreamReader(Files.newInputStream(Paths.get(dataFile)), StandardCharsets.UTF_8), separator, quote, fields);
125    }
126
127    /**
128     * Builds a CSVIterator for the supplied Reader. If headers is null, read the headers from the csv file.
129     * @param rdr The source to read.
130     * @param separator The separator character to use.
131     * @param quote The quote character to use.
132     * @param fields The headers to use.
133     */
134    public CSVIterator(Reader rdr, char separator, char quote, String[] fields) {
135        this(rdr, separator, quote, fields == null ? null : Arrays.asList(fields));
136    }
137
138    /**
139     * Builds a CSVIterator for the supplied Reader. If headers is null, read the headers from the csv file.
140     * @param rdr The source to read.
141     * @param separator The separator character to use.
142     * @param quote The quote character to use.
143     * @param fields The headers to use.
144     */
145    public CSVIterator(Reader rdr, char separator, char quote, List<String> fields) {
146        try {
147            // If someone hands us a BufferedReader, then we'll double buffer it here.
148            Reader bomRemoved = new BufferedReader(CSVDataSource.removeBOM(rdr));
149            reader = new CSVReaderBuilder(bomRemoved).withCSVParser(new RFC4180ParserBuilder().withSeparator(separator).withQuoteChar(quote).build()).build();
150            try {
151                if (fields == null || fields.isEmpty()) {
152                    String[] inducedHeader = reader.readNext();
153                    if(inducedHeader == null) {
154                        logger.warning("Given an empty CSV");
155                    } else {
156                        this.fields = Collections.unmodifiableList(Arrays.asList(inducedHeader));
157                    }
158                } else {
159                    this.fields = Collections.unmodifiableList(fields);
160                }
161            } catch (CsvValidationException | IOException e) {
162                try {
163                    reader.close();
164                } catch (IOException e2) {
165                    logger.log(Level.WARNING, "Error closing reader in another error", e2);
166                }
167                throw new IllegalArgumentException("Error reading file caused by: " + e.getMessage());
168            }
169        } catch (IOException e) {
170            throw new IllegalArgumentException("Error reading file caused by: " + e.getMessage());
171        }
172    }
173
174    /**
175     * Zips together the headers and the line into a Map.
176     * @param headers The field headers.
177     * @param line The extracted line.
178     * @param rowNum The row number used for error messages.
179     * @return A map from header to value.
180     */
181    private static Map<String,String> zip(List<String> headers, String[] line, long rowNum) {
182        if (headers.size() != line.length) {
183            throw new IllegalArgumentException("On row " + rowNum + " headers has " + headers.size() + " elements, current line has " + line.length + " elements.");
184        }
185
186        Map<String,String> map = new HashMap<>();
187        for (int i = 0; i < headers.size(); i++) {
188            map.put(headers.get(i),line[i]);
189        }
190        return map;
191    }
192
193    @Override
194    protected Optional<Row> getRow() {
195        try {
196            String[] rawRow = reader.readNext();
197            if(rawRow != null) {
198                if(reader.getRecordsRead() % 50_000 == 0) {
199                    logger.info(String.format("Read %d records on %d lines", reader.getRecordsRead(), reader.getLinesRead()));
200                }
201                while (rawRow != null && rawRow.length == 1 && rawRow[0].isEmpty()) {
202                    // Found an extraneous newline in the csv file
203                    logger.warning("Ignoring extra newline at line " + reader.getLinesRead());
204                    rawRow = reader.readNext();
205                }
206                if (rawRow == null) {
207                    try {
208                        reader.close();
209                    } catch (IOException e) {
210                        logger.log(Level.WARNING, "Error closing reader at end of file", e);
211                    }
212                    return Optional.empty();
213                }
214
215                // Note this is intentionally recordNum++ as we count records from 0.
216                return Optional.of(new Row(recordNum++,
217                        fields,
218                        zip(fields, rawRow, reader.getRecordsRead())));
219            } else {
220                try {
221                    reader.close();
222                } catch (IOException e) {
223                    logger.log(Level.WARNING, "Error closing reader at end of file", e);
224                }
225                return Optional.empty();
226            }
227        } catch (CsvValidationException | IOException e) {
228            long linesRead = reader.getLinesRead();
229            long recordsRead = reader.getRecordsRead();
230            try {
231                reader.close();
232            } catch (IOException e2) {
233                logger.log(Level.WARNING, "Error closing reader in another error", e2);
234            }
235            throw new IllegalArgumentException(String.format("Error reading CSV on record %d, row %d", recordsRead, linesRead), e);
236        }
237    }
238
239    @Override
240    public void close() throws IOException{
241        if(reader != null) {
242            reader.close();
243        }
244    }
245}