未验证 提交 70c5428a 编写于 作者: L Liang Zhang 提交者: GitHub

decouple find InsertColumnsSegment with insert statement (#4346)

上级 75121668
......@@ -19,13 +19,12 @@ package org.apache.shardingsphere.encrypt.rewrite.token.generator.impl;
import com.google.common.base.Optional;
import com.google.common.base.Preconditions;
import org.apache.shardingsphere.encrypt.rewrite.token.generator.BaseEncryptSQLTokenGenerator;
import org.apache.shardingsphere.encrypt.strategy.EncryptTable;
import org.apache.shardingsphere.sql.parser.relation.statement.SQLStatementContext;
import org.apache.shardingsphere.sql.parser.relation.statement.impl.InsertSQLStatementContext;
import org.apache.shardingsphere.sql.parser.sql.segment.dml.column.ColumnSegment;
import org.apache.shardingsphere.sql.parser.sql.segment.dml.column.InsertColumnsSegment;
import org.apache.shardingsphere.sql.parser.sql.statement.dml.InsertStatement;
import org.apache.shardingsphere.encrypt.rewrite.token.generator.BaseEncryptSQLTokenGenerator;
import org.apache.shardingsphere.underlying.rewrite.sql.token.generator.CollectionSQLTokenGenerator;
import org.apache.shardingsphere.underlying.rewrite.sql.token.pojo.generic.InsertColumnsToken;
......@@ -40,8 +39,10 @@ public final class AssistQueryAndPlainInsertColumnsTokenGenerator extends BaseEn
@Override
protected boolean isGenerateSQLTokenForEncrypt(final SQLStatementContext sqlStatementContext) {
return sqlStatementContext instanceof InsertSQLStatementContext && sqlStatementContext.getSqlStatement().findSQLSegment(InsertColumnsSegment.class).isPresent()
&& !((InsertStatement) sqlStatementContext.getSqlStatement()).useDefaultColumns();
if (!(sqlStatementContext instanceof InsertSQLStatementContext)) {
return false;
}
return ((InsertStatement) sqlStatementContext.getSqlStatement()).getInsertColumns().isPresent() && !((InsertStatement) sqlStatementContext.getSqlStatement()).useDefaultColumns();
}
@Override
......
......@@ -77,7 +77,7 @@ public final class EncryptForUseDefaultInsertColumnsTokenGenerator extends BaseE
}
private UseDefaultInsertColumnsToken generateNewSQLToken(final InsertSQLStatementContext sqlStatementContext, final String tableName) {
Optional<InsertColumnsSegment> insertColumnsSegment = sqlStatementContext.getSqlStatement().findSQLSegment(InsertColumnsSegment.class);
Optional<InsertColumnsSegment> insertColumnsSegment = ((InsertStatement) sqlStatementContext.getSqlStatement()).getInsertColumns();
Preconditions.checkState(insertColumnsSegment.isPresent());
Optional<EncryptTable> encryptTable = getEncryptRule().findEncryptTable(tableName);
Preconditions.checkState(encryptTable.isPresent());
......
......@@ -24,6 +24,7 @@ import org.apache.shardingsphere.sql.parser.relation.statement.impl.InsertSQLSta
import org.apache.shardingsphere.sql.parser.sql.segment.dml.column.ColumnSegment;
import org.apache.shardingsphere.sql.parser.sql.segment.dml.column.InsertColumnsSegment;
import org.apache.shardingsphere.encrypt.rewrite.token.generator.BaseEncryptSQLTokenGenerator;
import org.apache.shardingsphere.sql.parser.sql.statement.dml.InsertStatement;
import org.apache.shardingsphere.underlying.rewrite.sql.token.generator.CollectionSQLTokenGenerator;
import org.apache.shardingsphere.underlying.rewrite.sql.token.pojo.generic.SubstitutableColumnNameToken;
......@@ -38,13 +39,16 @@ public final class InsertCipherNameTokenGenerator extends BaseEncryptSQLTokenGen
@Override
protected boolean isGenerateSQLTokenForEncrypt(final SQLStatementContext sqlStatementContext) {
Optional<InsertColumnsSegment> insertColumnsSegment = sqlStatementContext.getSqlStatement().findSQLSegment(InsertColumnsSegment.class);
return sqlStatementContext instanceof InsertSQLStatementContext && insertColumnsSegment.isPresent() && !insertColumnsSegment.get().getColumns().isEmpty();
if (!(sqlStatementContext instanceof InsertSQLStatementContext)) {
return false;
}
Optional<InsertColumnsSegment> insertColumnsSegment = ((InsertStatement) sqlStatementContext.getSqlStatement()).getInsertColumns();
return insertColumnsSegment.isPresent() && !insertColumnsSegment.get().getColumns().isEmpty();
}
@Override
public Collection<SubstitutableColumnNameToken> generateSQLTokens(final SQLStatementContext sqlStatementContext) {
Optional<InsertColumnsSegment> sqlSegment = sqlStatementContext.getSqlStatement().findSQLSegment(InsertColumnsSegment.class);
Optional<InsertColumnsSegment> sqlSegment = ((InsertStatement) sqlStatementContext.getSqlStatement()).getInsertColumns();
Preconditions.checkState(sqlSegment.isPresent());
Map<String, String> logicAndCipherColumns = getEncryptRule().getLogicAndCipherColumns(sqlStatementContext.getTablesContext().getSingleTableName());
Collection<SubstitutableColumnNameToken> result = new LinkedList<>();
......
......@@ -24,6 +24,7 @@ import org.apache.shardingsphere.sql.parser.relation.statement.SQLStatementConte
import org.apache.shardingsphere.sql.parser.relation.statement.impl.InsertSQLStatementContext;
import org.apache.shardingsphere.sql.parser.sql.segment.dml.column.ColumnSegment;
import org.apache.shardingsphere.sql.parser.sql.segment.dml.column.InsertColumnsSegment;
import org.apache.shardingsphere.sql.parser.sql.statement.dml.InsertStatement;
import org.apache.shardingsphere.underlying.rewrite.sql.token.generator.CollectionSQLTokenGenerator;
import org.apache.shardingsphere.underlying.rewrite.sql.token.pojo.generic.RemoveToken;
......@@ -37,13 +38,16 @@ public final class RemoveShadowColumnTokenGenerator extends BaseShadowSQLTokenGe
@Override
protected boolean isGenerateSQLTokenForShadow(final SQLStatementContext sqlStatementContext) {
Optional<InsertColumnsSegment> insertColumnsSegment = sqlStatementContext.getSqlStatement().findSQLSegment(InsertColumnsSegment.class);
return sqlStatementContext instanceof InsertSQLStatementContext && insertColumnsSegment.isPresent() && !insertColumnsSegment.get().getColumns().isEmpty();
if (!(sqlStatementContext instanceof InsertSQLStatementContext)) {
return false;
}
Optional<InsertColumnsSegment> insertColumnsSegment = ((InsertStatement) sqlStatementContext.getSqlStatement()).getInsertColumns();
return insertColumnsSegment.isPresent() && !insertColumnsSegment.get().getColumns().isEmpty();
}
@Override
public Collection<RemoveToken> generateSQLTokens(final SQLStatementContext sqlStatementContext) {
Optional<InsertColumnsSegment> sqlSegment = sqlStatementContext.getSqlStatement().findSQLSegment(InsertColumnsSegment.class);
Optional<InsertColumnsSegment> sqlSegment = ((InsertStatement) sqlStatementContext.getSqlStatement()).getInsertColumns();
Preconditions.checkState(sqlSegment.isPresent());
Collection<RemoveToken> result = new LinkedList<>();
LinkedList<ColumnSegment> columns = (LinkedList<ColumnSegment>) sqlSegment.get().getColumns();
......
......@@ -45,7 +45,7 @@ public final class PreparedJudgementEngineTest {
InsertStatement insertStatement = new InsertStatement();
InsertColumnsSegment insertColumnsSegment = new InsertColumnsSegment(0, 0,
Arrays.asList(new ColumnSegment(0, 0, new IdentifierValue("id")), new ColumnSegment(0, 0, new IdentifierValue("name")), new ColumnSegment(0, 0, new IdentifierValue("shadow"))));
insertStatement.setColumns(insertColumnsSegment);
insertStatement.setInsertColumns(insertColumnsSegment);
InsertSQLStatementContext insertSQLStatementContext = new InsertSQLStatementContext(relationMetas, Arrays.<Object>asList(1, "Tom", 2, "Jerry", 3, true), insertStatement);
PreparedJudgementEngine preparedJudgementEngine = new PreparedJudgementEngine(shadowRule, insertSQLStatementContext, Arrays.<Object>asList(1, "Tom", true));
Assert.assertTrue("should be shadow", preparedJudgementEngine.isShadowSQL());
......
......@@ -67,7 +67,7 @@ public final class SimpleJudgementEngineTest {
InsertStatement insertStatement = new InsertStatement();
InsertColumnsSegment insertColumnsSegment = new InsertColumnsSegment(0, 0,
Arrays.asList(new ColumnSegment(0, 0, new IdentifierValue("id")), new ColumnSegment(0, 0, new IdentifierValue("name")), new ColumnSegment(0, 0, new IdentifierValue("shadow"))));
insertStatement.setColumns(insertColumnsSegment);
insertStatement.setInsertColumns(insertColumnsSegment);
insertStatement.getValues()
.addAll(Collections.singletonList(new InsertValuesSegment(0, 0, new ArrayList<ExpressionSegment>() {
{
......
......@@ -70,7 +70,7 @@ public final class ShardingResultMergerEngineTest {
InsertStatement insertStatement = new InsertStatement();
insertStatement.getAllSQLSegments().add(new TableSegment(0, 0, new IdentifierValue("tbl")));
InsertColumnsSegment insertColumnsSegment = new InsertColumnsSegment(0, 0, Collections.singletonList(new ColumnSegment(0, 0, new IdentifierValue("col"))));
insertStatement.setColumns(insertColumnsSegment);
insertStatement.setInsertColumns(insertColumnsSegment);
SQLStatementContext sqlStatementContext = new InsertSQLStatementContext(null, Collections.emptyList(), insertStatement);
ShardingSphereProperties properties = new ShardingSphereProperties(new Properties());
assertThat(new ShardingResultMergerEngine().newInstance(DatabaseTypes.getActualDatabaseType("MySQL"), null, properties, sqlStatementContext), instanceOf(TransparentResultMerger.class));
......
......@@ -41,7 +41,7 @@ public final class GeneratedKeyForUseDefaultInsertColumnsTokenGenerator extends
@Override
protected UseDefaultInsertColumnsToken generateSQLToken(final SQLStatementContext sqlStatementContext, final GeneratedKey generatedKey) {
Optional<InsertColumnsSegment> insertColumnsSegment = sqlStatementContext.getSqlStatement().findSQLSegment(InsertColumnsSegment.class);
Optional<InsertColumnsSegment> insertColumnsSegment = ((InsertStatement) sqlStatementContext.getSqlStatement()).getInsertColumns();
Preconditions.checkState(insertColumnsSegment.isPresent());
return new UseDefaultInsertColumnsToken(insertColumnsSegment.get().getStopIndex(), getColumnNames((InsertSQLStatementContext) sqlStatementContext, generatedKey));
}
......
......@@ -32,13 +32,13 @@ public final class GeneratedKeyInsertColumnTokenGenerator extends BaseGeneratedK
@Override
protected boolean isGenerateSQLToken(final InsertStatement insertStatement) {
Optional<InsertColumnsSegment> sqlSegment = insertStatement.findSQLSegment(InsertColumnsSegment.class);
Optional<InsertColumnsSegment> sqlSegment = insertStatement.getInsertColumns();
return sqlSegment.isPresent() && !sqlSegment.get().getColumns().isEmpty();
}
@Override
protected GeneratedKeyInsertColumnToken generateSQLToken(final SQLStatementContext sqlStatementContext, final GeneratedKey generatedKey) {
Optional<InsertColumnsSegment> sqlSegment = sqlStatementContext.getSqlStatement().findSQLSegment(InsertColumnsSegment.class);
Optional<InsertColumnsSegment> sqlSegment = ((InsertStatement) sqlStatementContext.getSqlStatement()).getInsertColumns();
Preconditions.checkState(sqlSegment.isPresent());
return new GeneratedKeyInsertColumnToken(sqlSegment.get().getStopIndex(), generatedKey.getColumnName());
}
......
......@@ -58,7 +58,7 @@ public final class GeneratedKeyTest {
@Before
public void setUp() {
insertStatement.setTable(new TableSegment(0, 0, new IdentifierValue("tbl")));
insertStatement.setColumns(new InsertColumnsSegment(0, 0, Collections.singletonList(new ColumnSegment(0, 0, new IdentifierValue("id")))));
insertStatement.setInsertColumns(new InsertColumnsSegment(0, 0, Collections.singletonList(new ColumnSegment(0, 0, new IdentifierValue("id")))));
}
@Test
......
......@@ -29,6 +29,6 @@ public final class InsertColumnsFiller implements SQLSegmentFiller<InsertColumns
@Override
public void fill(final InsertColumnsSegment sqlSegment, final SQLStatement sqlStatement) {
((InsertStatement) sqlStatement).setColumns(sqlSegment);
((InsertStatement) sqlStatement).setInsertColumns(sqlSegment);
}
}
......@@ -45,7 +45,7 @@ public final class InsertStatement extends DMLStatement implements TableSegmentA
private TableSegment table;
private InsertColumnsSegment columns;
private InsertColumnsSegment insertColumns;
private SetAssignmentSegment setAssignment;
......@@ -59,7 +59,7 @@ public final class InsertStatement extends DMLStatement implements TableSegmentA
* @return insert columns segment
*/
public Optional<InsertColumnsSegment> getInsertColumns() {
return Optional.fromNullable(columns);
return Optional.fromNullable(insertColumns);
}
/**
......@@ -68,7 +68,7 @@ public final class InsertStatement extends DMLStatement implements TableSegmentA
* @return columns
*/
public Collection<ColumnSegment> getColumns() {
return null == columns ? Collections.<ColumnSegment>emptyList() : columns.getColumns();
return null == insertColumns ? Collections.<ColumnSegment>emptyList() : insertColumns.getColumns();
}
/**
......
......@@ -45,7 +45,7 @@ public final class InsertStatementTest {
public void assertNotUseDefaultColumnsWithColumns() {
InsertStatement insertStatement = new InsertStatement();
InsertColumnsSegment insertColumnsSegment = new InsertColumnsSegment(0, 0, Collections.singletonList(new ColumnSegment(0, 0, new IdentifierValue("col"))));
insertStatement.setColumns(insertColumnsSegment);
insertStatement.setInsertColumns(insertColumnsSegment);
assertFalse(insertStatement.useDefaultColumns());
}
......@@ -53,7 +53,7 @@ public final class InsertStatementTest {
public void assertNotUseDefaultColumnsWithSetAssignment() {
InsertStatement insertStatement = new InsertStatement();
insertStatement.setSetAssignment(new SetAssignmentSegment(0, 0, Collections.<AssignmentSegment>emptyList()));
insertStatement.setColumns(new InsertColumnsSegment(0, 0, Collections.<ColumnSegment>emptyList()));
insertStatement.setInsertColumns(new InsertColumnsSegment(0, 0, Collections.<ColumnSegment>emptyList()));
assertFalse(insertStatement.useDefaultColumns());
}
......@@ -61,7 +61,7 @@ public final class InsertStatementTest {
public void assertGetColumnNamesForInsertColumns() {
InsertStatement insertStatement = new InsertStatement();
InsertColumnsSegment insertColumnsSegment = new InsertColumnsSegment(0, 0, Collections.singletonList(new ColumnSegment(0, 0, new IdentifierValue("col"))));
insertStatement.setColumns(insertColumnsSegment);
insertStatement.setInsertColumns(insertColumnsSegment);
assertThat(insertStatement.getColumnNames().size(), is(1));
assertThat(insertStatement.getColumnNames().iterator().next(), is("col"));
}
......
......@@ -132,7 +132,7 @@ public final class MySQLDMLVisitor extends MySQLVisitor {
InsertStatement result = new InsertStatement();
if (null != ctx.columnNames()) {
InsertColumnsSegment insertColumnsSegment = (InsertColumnsSegment) visit(ctx.columnNames());
result.setColumns(insertColumnsSegment);
result.setInsertColumns(insertColumnsSegment);
result.getAllSQLSegments().add(insertColumnsSegment);
}
Collection<InsertValuesSegment> insertValuesSegments = createInsertValuesSegments(ctx.assignmentValues());
......
......@@ -107,7 +107,7 @@ public final class OracleDMLVisitor extends OracleVisitor {
InsertStatement result = new InsertStatement();
if (null != ctx.columnNames()) {
InsertColumnsSegment insertColumnsSegment = (InsertColumnsSegment) visit(ctx.columnNames());
result.setColumns(insertColumnsSegment);
result.setInsertColumns(insertColumnsSegment);
result.getAllSQLSegments().add(insertColumnsSegment);
}
Collection<InsertValuesSegment> insertValuesSegments = createInsertValuesSegments(ctx.assignmentValues());
......
......@@ -113,7 +113,7 @@ public final class PostgreSQLDMLVisitor extends PostgreSQLVisitor {
InsertStatement result = new InsertStatement();
if (null != ctx.columnNames()) {
InsertColumnsSegment insertColumnsSegment = (InsertColumnsSegment) visit(ctx.columnNames());
result.setColumns(insertColumnsSegment);
result.setInsertColumns(insertColumnsSegment);
result.getAllSQLSegments().add(insertColumnsSegment);
}
Collection<InsertValuesSegment> insertValuesSegments = createInsertValuesSegments(ctx.assignmentValues());
......
......@@ -44,7 +44,7 @@ public final class InsertSQLStatementContextTest {
insertStatement.setTable(new TableSegment(0, 0, new IdentifierValue("tbl")));
InsertColumnsSegment insertColumnsSegment = new InsertColumnsSegment(0, 0, Arrays.asList(
new ColumnSegment(0, 0, new IdentifierValue("id")), new ColumnSegment(0, 0, new IdentifierValue("name")), new ColumnSegment(0, 0, new IdentifierValue("status"))));
insertStatement.setColumns(insertColumnsSegment);
insertStatement.setInsertColumns(insertColumnsSegment);
setUpInsertValues(insertStatement);
InsertSQLStatementContext actual = new InsertSQLStatementContext(mock(RelationMetas.class), Arrays.<Object>asList(1, "Tom", 2, "Jerry"), insertStatement);
assertInsertSQLStatementContext(actual);
......
......@@ -101,7 +101,7 @@ public final class SQL92DMLVisitor extends SQL92Visitor {
InsertStatement result = new InsertStatement();
if (null != ctx.columnNames()) {
InsertColumnsSegment insertColumnsSegment = (InsertColumnsSegment) visit(ctx.columnNames());
result.setColumns(insertColumnsSegment);
result.setInsertColumns(insertColumnsSegment);
result.getAllSQLSegments().add(insertColumnsSegment);
}
Collection<InsertValuesSegment> insertValuesSegments = createInsertValuesSegments(ctx.assignmentValues());
......
......@@ -102,7 +102,7 @@ public final class SQLServerDMLVisitor extends SQLServerVisitor {
InsertStatement result = new InsertStatement();
if (null != ctx.columnNames()) {
InsertColumnsSegment insertColumnsSegment = (InsertColumnsSegment) visit(ctx.columnNames());
result.setColumns(insertColumnsSegment);
result.setInsertColumns(insertColumnsSegment);
result.getAllSQLSegments().add(insertColumnsSegment);
}
Collection<InsertValuesSegment> insertValuesSegments = createInsertValuesSegments(ctx.assignmentValues());
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册