/*
* Copyright 1999-2015 dangdang.com.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package com.dangdang.ddframe.rdb.common.sql.base;
import com.dangdang.ddframe.rdb.common.jaxb.SqlAssertData;
import com.dangdang.ddframe.rdb.common.jaxb.SqlShardingRule;
import com.dangdang.ddframe.rdb.common.sql.common.ShardingTestStrategy;
import com.dangdang.ddframe.rdb.integrate.util.DBUnitUtil;
import com.dangdang.ddframe.rdb.integrate.util.DataBaseEnvironment;
import com.dangdang.ddframe.rdb.sharding.constant.DatabaseType;
import com.dangdang.ddframe.rdb.sharding.constant.SQLType;
import com.dangdang.ddframe.rdb.sharding.jdbc.core.datasource.ShardingDataSource;
import com.google.common.base.Strings;
import com.google.common.collect.Lists;
import org.dbunit.dataset.IDataSet;
import org.dbunit.dataset.ITable;
import org.dbunit.dataset.ITableIterator;
import org.dbunit.dataset.xml.FlatXmlDataSetBuilder;
import org.junit.Test;
import java.io.File;
import java.net.URL;
import java.sql.Connection;
import java.sql.PreparedStatement;
import java.sql.SQLException;
import java.sql.Statement;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Set;
import static com.dangdang.ddframe.rdb.common.sql.common.ShardingTestStrategy.masterslave;
import static com.dangdang.ddframe.rdb.integrate.util.SqlPlaceholderUtil.replacePreparedStatement;
import static com.dangdang.ddframe.rdb.integrate.util.SqlPlaceholderUtil.replaceStatement;
import static org.dbunit.Assertion.assertEquals;
public abstract class AbstractSqlAssertTest extends AbstractBaseSqlTest {
private final String sql;
private final Set types;
private final List shardingRules;
protected AbstractSqlAssertTest(final String testCaseName, final String sql, final Set types, final List shardingRules) {
this.sql = sql;
this.types = types;
this.shardingRules = shardingRules;
}
protected abstract ShardingTestStrategy getShardingStrategy();
protected abstract Map getShardingDataSources();
@Test
public void assertWithPreparedStatement() throws Exception {
execute(true);
}
@Test
public void assertWithStatement() throws Exception {
execute(false);
}
private void execute(final boolean isPreparedStatement) throws Exception {
for (Map.Entry each : getShardingDataSources().entrySet()) {
if (types.size() == 0 || types.contains(each.getKey())) {
try {
executeAndAssertSql(isPreparedStatement, each.getValue());
//CHECKSTYLE:OFF
} catch (final Exception ex) {
//CHECKSTYLE:ON
if (ex.getMessage().startsWith("Dynamic table")) {
continue;
}
throw new RuntimeException(ex);
}
}
}
}
private void executeAndAssertSql(final boolean isPreparedStatement, final ShardingDataSource shardingDataSource) throws Exception {
for (SqlShardingRule sqlShardingRule : shardingRules) {
if (needAssert(sqlShardingRule)) {
for (SqlAssertData each : sqlShardingRule.getData()) {
String expected = each.getExpected() == null ? "integrate/dataset/EmptyTable.xml"
: String.format("integrate/dataset/%s/expect/" + each.getExpected(), getShardingStrategy().name(), getShardingStrategy().name());
URL url = AbstractSqlAssertTest.class.getClassLoader().getResource(expected);
if (null == url) {
throw new Exception("Wrong expected file:" + expected);
}
File expectedDataSetFile = new File(url.getPath());
if (sql.toUpperCase().startsWith("SELECT")) {
assertSelectSql(isPreparedStatement, shardingDataSource, each, expectedDataSetFile);
} else {
assertDmlSql(isPreparedStatement, shardingDataSource, each, expectedDataSetFile);
}
}
}
}
}
private boolean needAssert(final SqlShardingRule sqlShardingRule) {
String shardingRules = sqlShardingRule.getValue();
if (null == shardingRules) {
return true;
}
for (String each : shardingRules.split(",")) {
if (getShardingStrategy().name().equals(each)) {
return true;
}
}
return false;
}
private void assertSelectSql(final boolean isPreparedStatement, final ShardingDataSource shardingDataSource, final SqlAssertData data, final File expectedDataSetFile) throws Exception {
if (isPreparedStatement) {
executeQueryWithPreparedStatement(shardingDataSource, getParameters(data), expectedDataSetFile);
} else {
executeQueryWithStatement(shardingDataSource, getParameters(data), expectedDataSetFile);
}
}
private void assertDmlSql(final boolean isPreparedStatement, final ShardingDataSource shardingDataSource, final SqlAssertData data, final File expectedDataSetFile) throws Exception {
if (isPreparedStatement) {
executeWithPreparedStatement(shardingDataSource, getParameters(data));
} else {
executeWithStatement(shardingDataSource, getParameters(data));
}
String dataSourceName = getDataSourceName(data.getExpected());
SQLType sqlType = getSqlType();
try (Connection conn = shardingDataSource.getConnection().getConnection(dataSourceName, sqlType)) {
assertResult(conn, expectedDataSetFile);
}
}
private SQLType getSqlType() {
return masterslave == getShardingStrategy() ? SQLType.INSERT : SQLType.SELECT;
}
private String getDataSourceName(final String expected) {
String result = String.format(expected.split("/")[1].split(".xml")[0], getShardingStrategy().name());
if (!result.contains("_")) {
result = result + "_0";
}
if (result.startsWith("tbl")) {
result = "tbl";
}
if (result.contains("masterslave")) {
result = result.replace("masterslave", "ms");
} else {
result = "dataSource_" + result;
}
return result;
}
private List getParameters(final SqlAssertData data) {
return Strings.isNullOrEmpty(data.getParameter()) ? Collections.emptyList() : Lists.newArrayList(data.getParameter().split(","));
}
private void executeWithPreparedStatement(final ShardingDataSource dataSource, final List parameters) throws SQLException {
try (Connection connection = dataSource.getConnection();
PreparedStatement preparedStatement = connection.prepareStatement(replacePreparedStatement(sql))) {
int index = 1;
for (String each : parameters) {
if (each.contains("'")) {
preparedStatement.setString(index++, each.replace("'", ""));
} else {
preparedStatement.setInt(index++, Integer.valueOf(each));
}
}
preparedStatement.execute();
}
}
private void executeWithStatement(final ShardingDataSource dataSource, final List parameters) throws SQLException {
try (Connection connection = dataSource.getConnection();
Statement statement = connection.createStatement()) {
statement.execute(replaceStatement(sql, parameters.toArray()));
}
}
private void executeQueryWithPreparedStatement(final ShardingDataSource dataSource, final List parameters, final File file) throws Exception {
try (Connection conn = dataSource.getConnection();
PreparedStatement preparedStatement = conn.prepareStatement(replacePreparedStatement(sql))) {
int index = 1;
for (String param : parameters) {
if (param.contains("'")) {
preparedStatement.setString(index++, param.replace("'", ""));
} else {
preparedStatement.setInt(index++, Integer.valueOf(param));
}
}
ITableIterator expectedTableIterator = new FlatXmlDataSetBuilder().build(file).iterator();
while (expectedTableIterator.next()) {
ITable expectedTable = expectedTableIterator.getTable();
String actualTableName = expectedTable.getTableMetaData().getTableName();
ITable actualTable = DBUnitUtil.getConnection(new DataBaseEnvironment(DatabaseType.valueFrom(conn.getMetaData().getDatabaseProductName())), conn)
.createTable(actualTableName, preparedStatement);
IDataSet expectedDataSet = new FlatXmlDataSetBuilder().build(file);
assertEquals(expectedDataSet.getTable(actualTableName), actualTable);
}
}
}
private void executeQueryWithStatement(final ShardingDataSource dataSource, final List parameters, final File file) throws Exception {
try (Connection conn = dataSource.getConnection()) {
String querySql = replaceStatement(sql, parameters.toArray());
ITableIterator expectedTableIterator = new FlatXmlDataSetBuilder().build(file).iterator();
while (expectedTableIterator.next()) {
ITable expectedTable = expectedTableIterator.getTable();
String actualTableName = expectedTable.getTableMetaData().getTableName();
ITable actualTable = DBUnitUtil.getConnection(new DataBaseEnvironment(DatabaseType.valueFrom(conn.getMetaData().getDatabaseProductName())), conn)
.createQueryTable(actualTableName, querySql);
IDataSet expectedDataSet = new FlatXmlDataSetBuilder().build(file);
assertEquals(expectedDataSet.getTable(actualTableName), actualTable);
}
}
}
private void assertResult(final Connection connection, final File file) throws Exception {
ITableIterator expectedTableIterator = new FlatXmlDataSetBuilder().build(file).iterator();
try (Connection conn = connection) {
while (expectedTableIterator.next()) {
ITable expectedTable = expectedTableIterator.getTable();
String actualTableName = expectedTable.getTableMetaData().getTableName();
String verifySql = "SELECT * FROM " + actualTableName + " WHERE status = '" + getStatus(file) + "'";
ITable actualTable = DBUnitUtil.getConnection(new DataBaseEnvironment(DatabaseType.valueFrom(conn.getMetaData().getDatabaseProductName())), conn)
.createQueryTable(actualTableName, verifySql);
assertEquals(expectedTable, actualTable);
}
}
}
private String getStatus(final File file) {
if (sql.toUpperCase().startsWith("DELETE")) {
return masterslave == getShardingStrategy() ? "init_master" : "init";
}
return file.getParentFile().getName();
}
}