001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.data.sql;
018
019import com.oracle.labs.mlrg.olcut.config.Config;
020import com.oracle.labs.mlrg.olcut.provenance.ObjectProvenance;
021import com.oracle.labs.mlrg.olcut.provenance.PrimitiveProvenance;
022import com.oracle.labs.mlrg.olcut.provenance.Provenance;
023import com.oracle.labs.mlrg.olcut.provenance.impl.SkeletalConfiguredObjectProvenance;
024import com.oracle.labs.mlrg.olcut.provenance.primitives.DateTimeProvenance;
025import com.oracle.labs.mlrg.olcut.provenance.primitives.StringProvenance;
026import org.tribuo.DataSource;
027import org.tribuo.Output;
028import org.tribuo.OutputFactory;
029import org.tribuo.data.columnar.ColumnarDataSource;
030import org.tribuo.data.columnar.ColumnarIterator;
031import org.tribuo.data.columnar.FieldProcessor;
032import org.tribuo.data.columnar.RowProcessor;
033import org.tribuo.provenance.ConfiguredDataSourceProvenance;
034
035import java.sql.SQLException;
036import java.sql.Statement;
037import java.time.OffsetDateTime;
038import java.util.HashMap;
039import java.util.HashSet;
040import java.util.Map;
041import java.util.Objects;
042import java.util.Set;
043import java.util.logging.Level;
044import java.util.logging.Logger;
045
046/**
047 * A {@link org.tribuo.DataSource} for loading columnar data from a database
048 * and applying {@link org.tribuo.data.columnar.FieldProcessor}s to it.
049 * The {@link java.sql.Connection}s it creates are closed when the iterator is empty
050 * (ie. when hasNext is called and returns false). Calling close() on SQLDatasource itself closes all connections
051 * created since close was last called.
052 *
053 * <p>
054 *
055 * N.B. This class accepts raw SQL strings and executes them directly via JDBC. It DOES NOT perform
056 * any SQL escaping or other injection prevention. It is the user's responsibility to ensure that SQL passed to this
057 * class performs as desired.
058 */
059public class SQLDataSource<T extends Output<T>> extends ColumnarDataSource<T> implements AutoCloseable {
060
061    private static final Logger logger = Logger.getLogger(SQLDataSource.class.getName());
062
063    @Config(mandatory = true,description="Database configuration.")
064    private SQLDBConfig sqlConfig;
065
066    @Config(mandatory = true,description="SQL query to run.")
067    private String sqlString;
068
069    private final Set<Statement> statements = new HashSet<>();
070
071    private SQLDataSource() {}
072
073    public SQLDataSource(String sqlString, SQLDBConfig sqlConfig, OutputFactory<T> outputFactory, RowProcessor<T> rowProcessor, boolean outputRequired) throws SQLException {
074        super(outputFactory, rowProcessor, outputRequired);
075        this.sqlConfig = sqlConfig;
076        this.sqlString = sqlString;
077    }
078
079    @Override
080    public String toString() {
081        return "SQLDataSource(sqlString=\"" + sqlString + "\", sqlConfig=\"" + sqlConfig.toString() + "\", rowProcessor=" + rowProcessor.getDescription() +")";
082    }
083
084    @Override
085    public ColumnarIterator rowIterator() {
086        try {
087            Statement stmt = sqlConfig.getStatement();
088            statements.add(stmt);
089            return new ResultSetIterator(stmt.executeQuery(sqlString), stmt.getFetchSize());
090        } catch (SQLException e) {
091            throw new IllegalArgumentException("Error Processing SQL", e);
092        }
093    }
094
095    @Override
096    public void close() {
097        for (Statement statement: statements) {
098            try {
099                statement.close();
100            } catch (SQLException e) {
101                logger.log(Level.WARNING, "Error closing statement", e);
102            }
103        }
104        statements.clear();
105    }
106
107    @Override
108    public ConfiguredDataSourceProvenance getProvenance() {
109        return new SQLDataSourceProvenance(this);
110    }
111
112    /**
113     * Provenance for {@link SQLDataSource}.
114     */
115    public static class SQLDataSourceProvenance extends SkeletalConfiguredObjectProvenance implements ConfiguredDataSourceProvenance {
116        private static final long serialVersionUID = 1L;
117
118        private final DateTimeProvenance dataSourceCreationTime;
119
120        <T extends Output<T>> SQLDataSourceProvenance(SQLDataSource<T> host) {
121            super(host,"DataSource");
122            this.dataSourceCreationTime = new DateTimeProvenance(DATASOURCE_CREATION_TIME,OffsetDateTime.now());
123        }
124
125        public SQLDataSourceProvenance(Map<String,Provenance> map) {
126            this(extractProvenanceInfo(map));
127        }
128
129        private SQLDataSourceProvenance(ExtractedInfo info) {
130            super(info);
131            this.dataSourceCreationTime = (DateTimeProvenance) info.instanceValues.get(DATASOURCE_CREATION_TIME);
132        }
133
134        protected static ExtractedInfo extractProvenanceInfo(Map<String,Provenance> map) {
135            Map<String,Provenance> configuredParameters = new HashMap<>(map);
136            String className = ObjectProvenance.checkAndExtractProvenance(configuredParameters,CLASS_NAME, StringProvenance.class, SQLDataSourceProvenance.class.getSimpleName()).getValue();
137            String hostTypeStringName = ObjectProvenance.checkAndExtractProvenance(configuredParameters,HOST_SHORT_NAME, StringProvenance.class, SQLDataSourceProvenance.class.getSimpleName()).getValue();
138
139            Map<String,PrimitiveProvenance<?>> instanceParameters = new HashMap<>();
140            instanceParameters.put(DATASOURCE_CREATION_TIME,ObjectProvenance.checkAndExtractProvenance(configuredParameters,DATASOURCE_CREATION_TIME,DateTimeProvenance.class, SQLDataSourceProvenance.class.getSimpleName()));
141
142            return new ExtractedInfo(className,hostTypeStringName,configuredParameters,instanceParameters);
143        }
144
145        @Override
146        public boolean equals(Object o) {
147            if (this == o) return true;
148            if (!(o instanceof SQLDataSourceProvenance)) return false;
149            if (!super.equals(o)) return false;
150            SQLDataSourceProvenance pairs = (SQLDataSourceProvenance) o;
151            return dataSourceCreationTime.equals(pairs.dataSourceCreationTime);
152        }
153
154        @Override
155        public int hashCode() {
156            return Objects.hash(super.hashCode(), dataSourceCreationTime);
157        }
158
159        @Override
160        public Map<String, PrimitiveProvenance<?>> getInstanceValues() {
161            Map<String,PrimitiveProvenance<?>> map = super.getInstanceValues();
162
163            map.put(DATASOURCE_CREATION_TIME,dataSourceCreationTime);
164
165            return map;
166        }
167
168    }
169}