001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.data.sql; 018 019import com.oracle.labs.mlrg.olcut.config.Config; 020import com.oracle.labs.mlrg.olcut.provenance.ObjectProvenance; 021import com.oracle.labs.mlrg.olcut.provenance.PrimitiveProvenance; 022import com.oracle.labs.mlrg.olcut.provenance.Provenance; 023import com.oracle.labs.mlrg.olcut.provenance.impl.SkeletalConfiguredObjectProvenance; 024import com.oracle.labs.mlrg.olcut.provenance.primitives.DateTimeProvenance; 025import com.oracle.labs.mlrg.olcut.provenance.primitives.StringProvenance; 026import org.tribuo.DataSource; 027import org.tribuo.Output; 028import org.tribuo.OutputFactory; 029import org.tribuo.data.columnar.ColumnarDataSource; 030import org.tribuo.data.columnar.ColumnarIterator; 031import org.tribuo.data.columnar.FieldProcessor; 032import org.tribuo.data.columnar.RowProcessor; 033import org.tribuo.provenance.ConfiguredDataSourceProvenance; 034 035import java.sql.SQLException; 036import java.sql.Statement; 037import java.time.OffsetDateTime; 038import java.util.HashMap; 039import java.util.HashSet; 040import java.util.Map; 041import java.util.Objects; 042import java.util.Set; 043import java.util.logging.Level; 044import java.util.logging.Logger; 045 046/** 047 * A {@link org.tribuo.DataSource} for loading columnar data from a database 048 * and applying {@link org.tribuo.data.columnar.FieldProcessor}s to it. 049 * The {@link java.sql.Connection}s it creates are closed when the iterator is empty 050 * (ie. when hasNext is called and returns false). Calling close() on SQLDatasource itself closes all connections 051 * created since close was last called. 052 * 053 * <p> 054 * 055 * N.B. This class accepts raw SQL strings and executes them directly via JDBC. It DOES NOT perform 056 * any SQL escaping or other injection prevention. It is the user's responsibility to ensure that SQL passed to this 057 * class performs as desired. 058 */ 059public class SQLDataSource<T extends Output<T>> extends ColumnarDataSource<T> implements AutoCloseable { 060 061 private static final Logger logger = Logger.getLogger(SQLDataSource.class.getName()); 062 063 @Config(mandatory = true,description="Database configuration.") 064 private SQLDBConfig sqlConfig; 065 066 @Config(mandatory = true,description="SQL query to run.") 067 private String sqlString; 068 069 private final Set<Statement> statements = new HashSet<>(); 070 071 private SQLDataSource() {} 072 073 public SQLDataSource(String sqlString, SQLDBConfig sqlConfig, OutputFactory<T> outputFactory, RowProcessor<T> rowProcessor, boolean outputRequired) throws SQLException { 074 super(outputFactory, rowProcessor, outputRequired); 075 this.sqlConfig = sqlConfig; 076 this.sqlString = sqlString; 077 } 078 079 @Override 080 public String toString() { 081 return "SQLDataSource(sqlString=\"" + sqlString + "\", sqlConfig=\"" + sqlConfig.toString() + "\", rowProcessor=" + rowProcessor.getDescription() +")"; 082 } 083 084 @Override 085 public ColumnarIterator rowIterator() { 086 try { 087 Statement stmt = sqlConfig.getStatement(); 088 statements.add(stmt); 089 return new ResultSetIterator(stmt.executeQuery(sqlString), stmt.getFetchSize()); 090 } catch (SQLException e) { 091 throw new IllegalArgumentException("Error Processing SQL", e); 092 } 093 } 094 095 @Override 096 public void close() { 097 for (Statement statement: statements) { 098 try { 099 statement.close(); 100 } catch (SQLException e) { 101 logger.log(Level.WARNING, "Error closing statement", e); 102 } 103 } 104 statements.clear(); 105 } 106 107 @Override 108 public ConfiguredDataSourceProvenance getProvenance() { 109 return new SQLDataSourceProvenance(this); 110 } 111 112 /** 113 * Provenance for {@link SQLDataSource}. 114 */ 115 public static class SQLDataSourceProvenance extends SkeletalConfiguredObjectProvenance implements ConfiguredDataSourceProvenance { 116 private static final long serialVersionUID = 1L; 117 118 private final DateTimeProvenance dataSourceCreationTime; 119 120 <T extends Output<T>> SQLDataSourceProvenance(SQLDataSource<T> host) { 121 super(host,"DataSource"); 122 this.dataSourceCreationTime = new DateTimeProvenance(DATASOURCE_CREATION_TIME,OffsetDateTime.now()); 123 } 124 125 public SQLDataSourceProvenance(Map<String,Provenance> map) { 126 this(extractProvenanceInfo(map)); 127 } 128 129 private SQLDataSourceProvenance(ExtractedInfo info) { 130 super(info); 131 this.dataSourceCreationTime = (DateTimeProvenance) info.instanceValues.get(DATASOURCE_CREATION_TIME); 132 } 133 134 protected static ExtractedInfo extractProvenanceInfo(Map<String,Provenance> map) { 135 Map<String,Provenance> configuredParameters = new HashMap<>(map); 136 String className = ObjectProvenance.checkAndExtractProvenance(configuredParameters,CLASS_NAME, StringProvenance.class, SQLDataSourceProvenance.class.getSimpleName()).getValue(); 137 String hostTypeStringName = ObjectProvenance.checkAndExtractProvenance(configuredParameters,HOST_SHORT_NAME, StringProvenance.class, SQLDataSourceProvenance.class.getSimpleName()).getValue(); 138 139 Map<String,PrimitiveProvenance<?>> instanceParameters = new HashMap<>(); 140 instanceParameters.put(DATASOURCE_CREATION_TIME,ObjectProvenance.checkAndExtractProvenance(configuredParameters,DATASOURCE_CREATION_TIME,DateTimeProvenance.class, SQLDataSourceProvenance.class.getSimpleName())); 141 142 return new ExtractedInfo(className,hostTypeStringName,configuredParameters,instanceParameters); 143 } 144 145 @Override 146 public boolean equals(Object o) { 147 if (this == o) return true; 148 if (!(o instanceof SQLDataSourceProvenance)) return false; 149 if (!super.equals(o)) return false; 150 SQLDataSourceProvenance pairs = (SQLDataSourceProvenance) o; 151 return dataSourceCreationTime.equals(pairs.dataSourceCreationTime); 152 } 153 154 @Override 155 public int hashCode() { 156 return Objects.hash(super.hashCode(), dataSourceCreationTime); 157 } 158 159 @Override 160 public Map<String, PrimitiveProvenance<?>> getInstanceValues() { 161 Map<String,PrimitiveProvenance<?>> map = super.getInstanceValues(); 162 163 map.put(DATASOURCE_CREATION_TIME,dataSourceCreationTime); 164 165 return map; 166 } 167 168 } 169}