提交 20e713b9 编写于 作者: T terrymanu

add InsertNamesAndValuesAssert for ShardingSQLStatementAssert

上级 d5ab72d3
......@@ -20,6 +20,7 @@ package org.apache.shardingsphere.sql.parser.integrate.asserts;
import com.google.common.base.Optional;
import org.apache.shardingsphere.sql.parser.integrate.asserts.groupby.GroupByAssert;
import org.apache.shardingsphere.sql.parser.integrate.asserts.index.IndexAssert;
import org.apache.shardingsphere.sql.parser.integrate.asserts.insert.InsertNamesAndValuesAssert;
import org.apache.shardingsphere.sql.parser.integrate.asserts.orderby.OrderByAssert;
import org.apache.shardingsphere.sql.parser.integrate.asserts.pagination.PaginationAssert;
import org.apache.shardingsphere.sql.parser.integrate.asserts.predicate.PredicateAssert;
......@@ -36,6 +37,7 @@ import org.apache.shardingsphere.sql.parser.sql.segment.dml.predicate.WhereSegme
import org.apache.shardingsphere.sql.parser.sql.segment.generic.TableSegment;
import org.apache.shardingsphere.sql.parser.sql.statement.SQLStatement;
import org.apache.shardingsphere.sql.parser.sql.statement.ddl.AlterTableStatement;
import org.apache.shardingsphere.sql.parser.sql.statement.dml.InsertStatement;
import org.apache.shardingsphere.sql.parser.sql.statement.dml.SelectStatement;
import org.apache.shardingsphere.sql.parser.sql.statement.tcl.SetAutoCommitStatement;
import org.apache.shardingsphere.sql.parser.sql.statement.tcl.TCLStatement;
......@@ -72,7 +74,11 @@ public final class ShardingSQLStatementAssert {
private final PredicateAssert predicateAssert;
public ShardingSQLStatementAssert(final SQLStatement actual, final String sqlCaseId, final SQLCaseType sqlCaseType) {
private final InsertNamesAndValuesAssert insertNamesAndValuesAssert;
private final String databaseType;
public ShardingSQLStatementAssert(final SQLStatement actual, final String sqlCaseId, final SQLCaseType sqlCaseType, final String databaseType) {
SQLStatementAssertMessage assertMessage = new SQLStatementAssertMessage(
ShardingSQLCasesRegistry.getInstance().getSqlCasesLoader(), ShardingParserResultSetRegistry.getInstance().getRegistry(), sqlCaseId, sqlCaseType);
this.actual = actual;
......@@ -85,6 +91,8 @@ public final class ShardingSQLStatementAssert {
alterTableAssert = new AlterTableAssert(assertMessage);
selectItemAssert = new SelectItemAssert(sqlCaseType, assertMessage);
predicateAssert = new PredicateAssert(sqlCaseType, assertMessage);
insertNamesAndValuesAssert = new InsertNamesAndValuesAssert(assertMessage, sqlCaseType);
this.databaseType = databaseType;
}
/**
......@@ -96,6 +104,9 @@ public final class ShardingSQLStatementAssert {
if (actual instanceof SelectStatement) {
assertSelectStatement((SelectStatement) actual);
}
if (actual instanceof InsertStatement) {
assertInsertStatement((InsertStatement) actual, databaseType);
}
if (actual instanceof AlterTableStatement) {
assertAlterTableStatement((AlterTableStatement) actual);
}
......@@ -128,6 +139,14 @@ public final class ShardingSQLStatementAssert {
}
}
private void assertInsertStatement(final InsertStatement actual, final String databaseType) {
// TODO remove it when oracle fix for column names extract
if ("oracle".equalsIgnoreCase(databaseType)) {
return;
}
insertNamesAndValuesAssert.assertInsertNamesAndValues(actual, expected.getInsertColumnsAndValues());
}
private void assertAlterTableStatement(final AlterTableStatement actual) {
if (null != expected.getAlterTable()) {
alterTableAssert.assertAlterTable(actual, expected.getAlterTable());
......
......@@ -44,6 +44,9 @@ public final class AssignmentAssert {
*/
public void assertAssignment(final ExpressionSegment actual, final ExpectedAssignment expected) {
if (SQLCaseType.Placeholder == sqlCaseType) {
if (null == expected.getTypeForPlaceholder()) {
return;
}
assertThat(assertMessage.getFullAssertMessage("SQL expression type for placeholder error: "), actual.getClass().getSimpleName(), is(expected.getTypeForPlaceholder()));
assertThat(assertMessage.getFullAssertMessage("SQL expression text for placeholder error: "), getText(actual), is(expected.getTextForPlaceholder()));
} else {
......
......@@ -60,6 +60,6 @@ public final class ShardingParameterizedParsingTest {
public void assertSupportedSQL() {
String sql = sqlCasesLoader.getSQL(sqlCaseId, sqlCaseType, parserResultSetRegistry.get(sqlCaseId).getParameters());
SQLStatement sqlStatement = SQLParseEngineFactory.getSQLParseEngine("H2".equals(databaseType) ? "MySQL" : databaseType).parse(sql, false);
new ShardingSQLStatementAssert(sqlStatement, sqlCaseId, sqlCaseType).assertSQLStatement();
new ShardingSQLStatementAssert(sqlStatement, sqlCaseId, sqlCaseType, databaseType).assertSQLStatement();
}
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册