001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.data.sql;
018
019import com.oracle.labs.mlrg.olcut.util.IOSpliterator;
020import org.tribuo.data.columnar.ColumnarIterator;
021
022import java.sql.ResultSet;
023import java.sql.ResultSetMetaData;
024import java.sql.SQLException;
025import java.util.ArrayList;
026import java.util.HashMap;
027import java.util.Map;
028import java.util.Optional;
029import java.util.logging.Level;
030import java.util.logging.Logger;
031
032/**
033 * An iterator over a ResultSet returned from JDBC.
034 */
035public class ResultSetIterator extends ColumnarIterator {
036    private static final Logger logger = Logger.getLogger(ResultSetIterator.class.getName());
037
038    private final ResultSet resultSet;
039
040    private int rowNum = 0;
041
042    public ResultSetIterator(ResultSet rs) throws SQLException {
043        resultSet = rs;
044        ResultSetMetaData rsm = resultSet.getMetaData();
045        fields = new ArrayList<>();
046        for(int i=1; i <= rsm.getColumnCount(); i++) {
047            fields.add(rsm.getColumnName(i));
048        }
049    }
050
051    public ResultSetIterator(ResultSet rs, int fetchSize) throws SQLException {
052        super(IOSpliterator.DEFAULT_CHARACTERISTICS, fetchSize, Long.MAX_VALUE);
053        resultSet = rs;
054        ResultSetMetaData rsm = resultSet.getMetaData();
055        fields = new ArrayList<>();
056        for(int i=1; i <= rsm.getColumnCount(); i++) {
057            fields.add(rsm.getColumnName(i));
058        }
059    }
060
061    @Override
062    protected Optional<Row> getRow() {
063        try {
064            if(!resultSet.isClosed() && resultSet.next()) {
065                Map<String, String> rowMap = new HashMap<>();
066                for(int i=0; i < fields.size(); i++) {
067                    Object obj = null;
068                    try {
069                        obj = resultSet.getObject(i + 1);
070                    } catch (SQLException e) {
071                        logger.log(Level.SEVERE, "Missing object at index: " + (i + 1), e);
072                    }
073                    rowMap.put(fields.get(i), obj == null ? "" : obj.toString());
074                }
075                rowNum++;
076                if (rowNum % 50_000 == 0) {
077                    logger.info(String.format("Iterated over %d rows", rowNum));
078                }
079                return Optional.of(new Row(rowNum, fields, rowMap));
080            } else {
081                if(!resultSet.isClosed()) {
082                    resultSet.close();
083                }
084                return Optional.empty();
085            }
086        } catch (SQLException e) {
087            try {
088                resultSet.close();
089            } catch (SQLException e2) {
090                logger.log(Level.WARNING, "Error closing ResultSet inside another error", e2);
091            }
092            throw new IllegalStateException("Error while reading from SQL at row " + rowNum, e);
093        }
094    }
095}