001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.data.sql; 018 019import com.oracle.labs.mlrg.olcut.util.IOSpliterator; 020import org.tribuo.data.columnar.ColumnarIterator; 021 022import java.sql.ResultSet; 023import java.sql.ResultSetMetaData; 024import java.sql.SQLException; 025import java.util.ArrayList; 026import java.util.HashMap; 027import java.util.Map; 028import java.util.Optional; 029import java.util.logging.Level; 030import java.util.logging.Logger; 031 032/** 033 * An iterator over a ResultSet returned from JDBC. 034 */ 035public class ResultSetIterator extends ColumnarIterator { 036 private static final Logger logger = Logger.getLogger(ResultSetIterator.class.getName()); 037 038 private final ResultSet resultSet; 039 040 private int rowNum = 0; 041 042 public ResultSetIterator(ResultSet rs) throws SQLException { 043 resultSet = rs; 044 ResultSetMetaData rsm = resultSet.getMetaData(); 045 fields = new ArrayList<>(); 046 for(int i=1; i <= rsm.getColumnCount(); i++) { 047 fields.add(rsm.getColumnName(i)); 048 } 049 } 050 051 public ResultSetIterator(ResultSet rs, int fetchSize) throws SQLException { 052 super(IOSpliterator.DEFAULT_CHARACTERISTICS, fetchSize, Long.MAX_VALUE); 053 resultSet = rs; 054 ResultSetMetaData rsm = resultSet.getMetaData(); 055 fields = new ArrayList<>(); 056 for(int i=1; i <= rsm.getColumnCount(); i++) { 057 fields.add(rsm.getColumnName(i)); 058 } 059 } 060 061 @Override 062 protected Optional<Row> getRow() { 063 try { 064 if(!resultSet.isClosed() && resultSet.next()) { 065 Map<String, String> rowMap = new HashMap<>(); 066 for(int i=0; i < fields.size(); i++) { 067 Object obj = null; 068 try { 069 obj = resultSet.getObject(i + 1); 070 } catch (SQLException e) { 071 logger.log(Level.SEVERE, "Missing object at index: " + (i + 1), e); 072 } 073 rowMap.put(fields.get(i), obj == null ? "" : obj.toString()); 074 } 075 rowNum++; 076 if (rowNum % 50_000 == 0) { 077 logger.info(String.format("Iterated over %d rows", rowNum)); 078 } 079 return Optional.of(new Row(rowNum, fields, rowMap)); 080 } else { 081 if(!resultSet.isClosed()) { 082 resultSet.close(); 083 } 084 return Optional.empty(); 085 } 086 } catch (SQLException e) { 087 try { 088 resultSet.close(); 089 } catch (SQLException e2) { 090 logger.log(Level.WARNING, "Error closing ResultSet inside another error", e2); 091 } 092 throw new IllegalStateException("Error while reading from SQL at row " + rowNum, e); 093 } 094 } 095}