/*
* Copyright 1999-2015 dangdang.com.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package com.dangdang.ddframe.rdb.sharding.jdbc.core.statement.prepared;
import com.dangdang.ddframe.rdb.sharding.executor.PreparedStatementExecutor;
import com.dangdang.ddframe.rdb.sharding.executor.wrapper.PreparedStatementExecutorWrapper;
import com.dangdang.ddframe.rdb.sharding.jdbc.adapter.AbstractPreparedStatementAdapter;
import com.dangdang.ddframe.rdb.sharding.jdbc.core.connection.ShardingConnection;
import com.dangdang.ddframe.rdb.sharding.merger.ResultSetFactory;
import com.dangdang.ddframe.rdb.sharding.parsing.parser.context.GeneratedKey;
import com.dangdang.ddframe.rdb.sharding.routing.PreparedStatementRoutingEngine;
import com.dangdang.ddframe.rdb.sharding.routing.SQLExecutionUnit;
import com.dangdang.ddframe.rdb.sharding.routing.SQLRouteResult;
import com.google.common.base.Optional;
import com.google.common.base.Predicate;
import com.google.common.collect.Iterators;
import java.sql.Connection;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.util.ArrayList;
import java.util.LinkedList;
import java.util.List;
import java.util.Objects;
/**
* 支持分片的预编译语句对象.
*
* @author zhangliang
* @author caohao
*/
public final class ShardingPreparedStatement extends AbstractPreparedStatementAdapter {
private final PreparedStatementRoutingEngine preparedStatementRoutingEngine;
private final List cachedRoutedPreparedStatements = new LinkedList<>();
private final List cachedPreparedStatementWrappers = new LinkedList<>();
private int batchIndex;
public ShardingPreparedStatement(final ShardingConnection shardingConnection, final String sql) {
this(shardingConnection, sql, ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY, ResultSet.HOLD_CURSORS_OVER_COMMIT);
}
public ShardingPreparedStatement(final ShardingConnection shardingConnection, final String sql, final int resultSetType, final int resultSetConcurrency) {
this(shardingConnection, sql, resultSetType, resultSetConcurrency, ResultSet.HOLD_CURSORS_OVER_COMMIT);
}
public ShardingPreparedStatement(final ShardingConnection shardingConnection, final String sql, final int resultSetType, final int resultSetConcurrency, final int resultSetHoldability) {
super(shardingConnection, resultSetType, resultSetConcurrency, resultSetHoldability);
preparedStatementRoutingEngine = new PreparedStatementRoutingEngine(sql, shardingConnection.getShardingContext());
}
public ShardingPreparedStatement(final ShardingConnection shardingConnection, final String sql, final int autoGeneratedKeys) {
this(shardingConnection, sql);
if (RETURN_GENERATED_KEYS == autoGeneratedKeys) {
markReturnGeneratedKeys();
}
}
@Override
public ResultSet executeQuery() throws SQLException {
ResultSet result;
try {
result = ResultSetFactory.getResultSet(
new PreparedStatementExecutor(getShardingConnection().getShardingContext().getExecutorEngine(), routeSQL()).executeQuery(), getRouteResult().getSqlStatement());
} finally {
clearRouteContext();
}
setCurrentResultSet(result);
return result;
}
@Override
public int executeUpdate() throws SQLException {
try {
return new PreparedStatementExecutor(getShardingConnection().getShardingContext().getExecutorEngine(), routeSQL()).executeUpdate();
} finally {
clearRouteContext();
}
}
@Override
public boolean execute() throws SQLException {
try {
return new PreparedStatementExecutor(getShardingConnection().getShardingContext().getExecutorEngine(), routeSQL()).execute();
} finally {
clearRouteContext();
}
}
protected void clearRouteContext() throws SQLException {
resetBatch();
cachedPreparedStatementWrappers.clear();
batchIndex = 0;
}
@Override
public void clearBatch() throws SQLException {
clearRouteContext();
}
@Override
public void addBatch() throws SQLException {
try {
for (PreparedStatementExecutorWrapper each : routeSQLForBatch()) {
each.getPreparedStatement().addBatch();
each.mapBatchIndex(batchIndex);
}
batchIndex++;
} finally {
resetBatch();
}
}
private void resetBatch() throws SQLException {
setCurrentResultSet(null);
clearParameters();
}
@Override
public int[] executeBatch() throws SQLException {
try {
return new PreparedStatementExecutor(getShardingConnection().getShardingContext().getExecutorEngine(), cachedPreparedStatementWrappers).executeBatch(batchIndex);
} finally {
clearRouteContext();
}
}
private List routeSQL() throws SQLException {
List result = new ArrayList<>();
SQLRouteResult sqlRouteResult = preparedStatementRoutingEngine.route(getParameters());
setRouteResult(sqlRouteResult);
for (SQLExecutionUnit each : sqlRouteResult.getExecutionUnits()) {
BackendPreparedStatementWrapper backendPreparedStatementWrapper = generateStatement(
getShardingConnection().getConnection(each.getDataSource(), sqlRouteResult.getSqlStatement().getType()), each.getSql());
getRoutedStatements().add(backendPreparedStatementWrapper.getPreparedStatement());
replayMethodsInvocation(backendPreparedStatementWrapper.getPreparedStatement());
getParameters().replayMethodsInvocation(backendPreparedStatementWrapper.getPreparedStatement());
result.add(wrap(backendPreparedStatementWrapper.getPreparedStatement(), each));
}
return result;
}
private List routeSQLForBatch() throws SQLException {
List result = new ArrayList<>();
SQLRouteResult sqlRouteResult = preparedStatementRoutingEngine.route(getParameters());
setRouteResult(sqlRouteResult);
for (SQLExecutionUnit each : sqlRouteResult.getExecutionUnits()) {
PreparedStatement preparedStatement = getStatementForBatch(
getShardingConnection().getConnection(each.getDataSource(), sqlRouteResult.getSqlStatement().getType()), each.getSql());
replayMethodsInvocation(preparedStatement);
getParameters().replayMethodsInvocation(preparedStatement);
result.add(wrap(preparedStatement, each));
}
return result;
}
private PreparedStatement getStatementForBatch(final Connection connection, final String sql) throws SQLException {
for (BackendPreparedStatementWrapper each : cachedRoutedPreparedStatements) {
if (each.isBelongTo(connection, sql)) {
return each.getPreparedStatement();
}
}
BackendPreparedStatementWrapper statement = generateStatement(connection, sql);
getRoutedStatements().add(statement.getPreparedStatement());
cachedRoutedPreparedStatements.add(statement);
return statement.getPreparedStatement();
}
private BackendPreparedStatementWrapper generateStatement(final Connection connection, final String sql) throws SQLException {
Optional generatedKey = getGeneratedKey();
if (isReturnGeneratedKeys() && generatedKey.isPresent()) {
return new BackendPreparedStatementWrapper(connection.prepareStatement(sql, RETURN_GENERATED_KEYS), sql);
}
return new BackendPreparedStatementWrapper(connection.prepareStatement(sql, getResultSetType(), getResultSetConcurrency(), getResultSetHoldability()), sql);
}
private PreparedStatementExecutorWrapper wrap(final PreparedStatement preparedStatement, final SQLExecutionUnit sqlExecutionUnit) {
Optional wrapperOptional = Iterators.tryFind(cachedPreparedStatementWrappers.iterator(), new Predicate() {
@Override
public boolean apply(final PreparedStatementExecutorWrapper input) {
return Objects.equals(input.getPreparedStatement(), preparedStatement);
}
});
if (wrapperOptional.isPresent()) {
wrapperOptional.get().addBatchParameters(getParameters());
return wrapperOptional.get();
}
PreparedStatementExecutorWrapper result = new PreparedStatementExecutorWrapper(getRouteResult().getSqlStatement().getType(), preparedStatement, sqlExecutionUnit);
cachedPreparedStatementWrappers.add(result);
return result;
}
}