提交 13f026fc 编写于 作者: J Juan Pan(Trista) 提交者: Liang Zhang

add visitors for `insert values` (#4023)

* add visitors for `insert values`

* modify InsertColumnsSegment.java

* check style
上级 46f60b90
......@@ -22,6 +22,7 @@ import org.apache.shardingsphere.core.rule.ShadowRule;
import org.apache.shardingsphere.sql.parser.relation.metadata.RelationMetas;
import org.apache.shardingsphere.sql.parser.relation.statement.impl.InsertSQLStatementContext;
import org.apache.shardingsphere.sql.parser.sql.segment.dml.column.ColumnSegment;
import org.apache.shardingsphere.sql.parser.sql.segment.dml.column.InsertColumnsSegment;
import org.apache.shardingsphere.sql.parser.sql.statement.dml.InsertStatement;
import org.junit.Assert;
import org.junit.Test;
......@@ -37,18 +38,17 @@ public class PreparedJudgementEngineTest {
public void isShadowSql() {
RelationMetas relationMetas = mock(RelationMetas.class);
when(relationMetas.getAllColumnNames("tbl")).thenReturn(Arrays.asList("id", "name", "shadow"));
ShadowRuleConfiguration shadowRuleConfiguration = new ShadowRuleConfiguration();
shadowRuleConfiguration.setColumn("shadow");
ShadowRule shadowRule = new ShadowRule(shadowRuleConfiguration);
InsertStatement insertStatement = new InsertStatement();
insertStatement.getColumns()
.addAll(Arrays.asList(new ColumnSegment(0, 0, "id"),
new ColumnSegment(0, 0, "name"),
new ColumnSegment(0, 0, "shadow")));
InsertColumnsSegment insertColumnsSegment = new InsertColumnsSegment(0, 0);
insertColumnsSegment.getColumns().addAll(Arrays.asList(new ColumnSegment(0, 0, "id"),
new ColumnSegment(0, 0, "name"),
new ColumnSegment(0, 0, "shadow")));
insertStatement.setColumns(insertColumnsSegment);
InsertSQLStatementContext insertSQLStatementContext = new InsertSQLStatementContext(relationMetas, Arrays.<Object>asList(1, "Tom", 2, "Jerry", 3, true), insertStatement);
PreparedJudgementEngine preparedJudgementEngine = new PreparedJudgementEngine(shadowRule, insertSQLStatementContext, Arrays.<Object>asList(1, "Tom", true));
Assert.assertTrue("should be shadow", preparedJudgementEngine.isShadowSQL());
}
}
......@@ -24,6 +24,7 @@ import org.apache.shardingsphere.sql.parser.relation.statement.impl.InsertSQLSta
import org.apache.shardingsphere.sql.parser.relation.statement.impl.SelectSQLStatementContext;
import org.apache.shardingsphere.sql.parser.sql.segment.dml.assignment.InsertValuesSegment;
import org.apache.shardingsphere.sql.parser.sql.segment.dml.column.ColumnSegment;
import org.apache.shardingsphere.sql.parser.sql.segment.dml.column.InsertColumnsSegment;
import org.apache.shardingsphere.sql.parser.sql.segment.dml.expr.ExpressionSegment;
import org.apache.shardingsphere.sql.parser.sql.segment.dml.expr.simple.LiteralExpressionSegment;
import org.apache.shardingsphere.sql.parser.sql.segment.dml.item.ExpressionProjectionSegment;
......@@ -63,10 +64,11 @@ public class SimpleJudgementEngineTest {
@Test
public void judgeForInsert() {
InsertStatement insertStatement = new InsertStatement();
insertStatement.getColumns()
.addAll(Arrays.asList(new ColumnSegment(0, 0, "id"),
new ColumnSegment(0, 0, "name"),
new ColumnSegment(0, 0, "shadow")));
InsertColumnsSegment insertColumnsSegment = new InsertColumnsSegment(0, 0);
insertColumnsSegment.getColumns().addAll(Arrays.asList(new ColumnSegment(0, 0, "id"),
new ColumnSegment(0, 0, "name"),
new ColumnSegment(0, 0, "shadow")));
insertStatement.setColumns(insertColumnsSegment);
insertStatement.getValues()
.addAll(Collections.singletonList(new InsertValuesSegment(0, 0, new ArrayList<ExpressionSegment>() {
{
......
......@@ -30,6 +30,7 @@ import org.apache.shardingsphere.sql.parser.relation.statement.impl.CommonSQLSta
import org.apache.shardingsphere.sql.parser.relation.statement.impl.InsertSQLStatementContext;
import org.apache.shardingsphere.sql.parser.relation.statement.impl.SelectSQLStatementContext;
import org.apache.shardingsphere.sql.parser.sql.segment.dml.column.ColumnSegment;
import org.apache.shardingsphere.sql.parser.sql.segment.dml.column.InsertColumnsSegment;
import org.apache.shardingsphere.sql.parser.sql.segment.generic.TableSegment;
import org.apache.shardingsphere.sql.parser.sql.statement.dal.DALStatement;
import org.apache.shardingsphere.sql.parser.sql.statement.dml.InsertStatement;
......@@ -65,11 +66,13 @@ public final class ShardingResultMergerEngineTest {
@Test
public void assertNewInstanceWithOtherStatement() {
ShardingSphereProperties properties = new ShardingSphereProperties(new Properties());
InsertStatement insertStatement = new InsertStatement();
insertStatement.getAllSQLSegments().add(new TableSegment(0, 0, "tbl"));
insertStatement.getColumns().add(new ColumnSegment(0, 0, "col"));
InsertColumnsSegment insertColumnsSegment = new InsertColumnsSegment(0, 0);
insertColumnsSegment.getColumns().add(new ColumnSegment(0, 0, "col"));
insertStatement.setColumns(insertColumnsSegment);
SQLStatementContext sqlStatementContext = new InsertSQLStatementContext(null, Collections.emptyList(), insertStatement);
ShardingSphereProperties properties = new ShardingSphereProperties(new Properties());
assertThat(new ShardingResultMergerEngine().newInstance(DatabaseTypes.getActualDatabaseType("MySQL"), null, properties, sqlStatementContext), instanceOf(TransparentResultMerger.class));
}
}
......@@ -18,16 +18,17 @@
package org.apache.shardingsphere.sharding.route.engine.keygen;
import com.google.common.base.Optional;
import org.apache.shardingsphere.underlying.common.metadata.table.TableMetas;
import org.apache.shardingsphere.core.rule.ShardingRule;
import org.apache.shardingsphere.sql.parser.sql.segment.dml.assignment.InsertValuesSegment;
import org.apache.shardingsphere.sql.parser.sql.segment.dml.column.ColumnSegment;
import org.apache.shardingsphere.sql.parser.sql.segment.dml.column.InsertColumnsSegment;
import org.apache.shardingsphere.sql.parser.sql.segment.dml.expr.ExpressionSegment;
import org.apache.shardingsphere.sql.parser.sql.segment.dml.expr.complex.CommonExpressionSegment;
import org.apache.shardingsphere.sql.parser.sql.segment.dml.expr.simple.LiteralExpressionSegment;
import org.apache.shardingsphere.sql.parser.sql.segment.dml.expr.simple.ParameterMarkerExpressionSegment;
import org.apache.shardingsphere.sql.parser.sql.segment.generic.TableSegment;
import org.apache.shardingsphere.sql.parser.sql.statement.dml.InsertStatement;
import org.apache.shardingsphere.core.rule.ShardingRule;
import org.apache.shardingsphere.underlying.common.metadata.table.TableMetas;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
......@@ -56,7 +57,9 @@ public final class GeneratedKeyTest {
@Before
public void setUp() {
insertStatement.setTable(new TableSegment(0, 0, "tbl"));
insertStatement.getColumns().add(new ColumnSegment(0, 0, "id"));
InsertColumnsSegment insertColumnsSegment = new InsertColumnsSegment(0, 0);
insertColumnsSegment.getColumns().add(new ColumnSegment(0, 0, "id"));
insertStatement.setColumns(insertColumnsSegment);
}
@Test
......
......@@ -42,8 +42,12 @@ public final class InsertColumnsExtractor implements OptionalSQLSegmentExtractor
@Override
public Optional<InsertColumnsSegment> extract(final ParserRuleContext ancestorNode, final Map<ParserRuleContext, Integer> parameterMarkerIndexes) {
Optional<ParserRuleContext> insertValuesClause = ExtractorUtils.findFirstChildNode(ancestorNode, RuleName.INSERT_VALUES_CLAUSE);
return insertValuesClause.isPresent() ? Optional.of(new InsertColumnsSegment(insertValuesClause.get().getStart().getStartIndex(),
extractStopIndex(insertValuesClause.get()), extractColumns(insertValuesClause.get(), parameterMarkerIndexes))) : Optional.<InsertColumnsSegment>absent();
if (!insertValuesClause.isPresent()) {
return Optional.absent();
}
InsertColumnsSegment insertColumnsSegment = new InsertColumnsSegment(insertValuesClause.get().getStart().getStartIndex(), extractStopIndex(insertValuesClause.get()));
insertColumnsSegment.getColumns().addAll(extractColumns(insertValuesClause.get(), parameterMarkerIndexes));
return Optional.of(insertColumnsSegment);
}
private Collection<ColumnSegment> extractColumns(final ParserRuleContext ancestorNode, final Map<ParserRuleContext, Integer> parameterMarkerIndexes) {
......
......@@ -32,6 +32,6 @@ public final class InsertColumnsFiller implements SQLSegmentFiller<InsertColumns
@Override
public void fill(final InsertColumnsSegment sqlSegment, final SQLStatement sqlStatement) {
((InsertStatement) sqlStatement).getColumns().addAll(sqlSegment.getColumns());
((InsertStatement) sqlStatement).setColumns(sqlSegment);
}
}
......@@ -22,6 +22,7 @@ import lombok.RequiredArgsConstructor;
import org.apache.shardingsphere.sql.parser.sql.segment.SQLSegment;
import java.util.Collection;
import java.util.LinkedList;
/**
* Insert columns segment.
......@@ -37,5 +38,5 @@ public final class InsertColumnsSegment implements SQLSegment {
private final int stopIndex;
private final Collection<ColumnSegment> columns;
private final Collection<ColumnSegment> columns = new LinkedList<>();
}
......@@ -24,6 +24,7 @@ import org.apache.shardingsphere.sql.parser.sql.segment.dml.assignment.Assignmen
import org.apache.shardingsphere.sql.parser.sql.segment.dml.assignment.InsertValuesSegment;
import org.apache.shardingsphere.sql.parser.sql.segment.dml.assignment.SetAssignmentsSegment;
import org.apache.shardingsphere.sql.parser.sql.segment.dml.column.ColumnSegment;
import org.apache.shardingsphere.sql.parser.sql.segment.dml.column.InsertColumnsSegment;
import org.apache.shardingsphere.sql.parser.sql.segment.dml.expr.ExpressionSegment;
import org.apache.shardingsphere.sql.parser.sql.segment.generic.TableSegment;
import org.apache.shardingsphere.sql.parser.sql.statement.generic.TableSegmentAvailable;
......@@ -49,10 +50,19 @@ public final class InsertStatement extends DMLStatement implements TableSegmentA
private SetAssignmentsSegment setAssignment;
private final Collection<ColumnSegment> columns = new LinkedList<>();
private InsertColumnsSegment columns;
private final Collection<InsertValuesSegment> values = new LinkedList<>();
/**
* Get columns.
*
* @return columns
*/
public Collection<ColumnSegment> getColumns() {
return null == columns ? Collections.<ColumnSegment>emptyList() : columns.getColumns();
}
/**
* Get set assignment segment.
*
......@@ -68,7 +78,7 @@ public final class InsertStatement extends DMLStatement implements TableSegmentA
* @return is use default columns or not
*/
public boolean useDefaultColumns() {
return columns.isEmpty() && null == setAssignment;
return getColumns().isEmpty() && null == setAssignment;
}
/**
......@@ -82,7 +92,7 @@ public final class InsertStatement extends DMLStatement implements TableSegmentA
private List<String> getColumnNamesForInsertColumns() {
List<String> result = new LinkedList<>();
for (ColumnSegment each : columns) {
for (ColumnSegment each : getColumns()) {
result.add(each.getName().toLowerCase());
}
return result;
......
......@@ -21,6 +21,7 @@ import org.apache.shardingsphere.sql.parser.sql.segment.dml.assignment.Assignmen
import org.apache.shardingsphere.sql.parser.sql.segment.dml.assignment.InsertValuesSegment;
import org.apache.shardingsphere.sql.parser.sql.segment.dml.assignment.SetAssignmentsSegment;
import org.apache.shardingsphere.sql.parser.sql.segment.dml.column.ColumnSegment;
import org.apache.shardingsphere.sql.parser.sql.segment.dml.column.InsertColumnsSegment;
import org.apache.shardingsphere.sql.parser.sql.segment.dml.expr.ExpressionSegment;
import org.apache.shardingsphere.sql.parser.sql.segment.dml.expr.simple.LiteralExpressionSegment;
import org.junit.Test;
......@@ -42,7 +43,9 @@ public final class InsertStatementTest {
@Test
public void assertNotUseDefaultColumnsWithColumns() {
InsertStatement insertStatement = new InsertStatement();
insertStatement.getColumns().add(new ColumnSegment(0, 0, "col"));
InsertColumnsSegment insertColumnsSegment = new InsertColumnsSegment(0, 0);
insertColumnsSegment.getColumns().add(new ColumnSegment(0, 0, "col"));
insertStatement.setColumns(insertColumnsSegment);
assertFalse(insertStatement.useDefaultColumns());
}
......@@ -50,13 +53,16 @@ public final class InsertStatementTest {
public void assertNotUseDefaultColumnsWithSetAssignment() {
InsertStatement insertStatement = new InsertStatement();
insertStatement.setSetAssignment(new SetAssignmentsSegment(0, 0, Collections.<AssignmentSegment>emptyList()));
insertStatement.setColumns(new InsertColumnsSegment(0, 0));
assertFalse(insertStatement.useDefaultColumns());
}
@Test
public void assertGetColumnNamesForInsertColumns() {
InsertStatement insertStatement = new InsertStatement();
insertStatement.getColumns().add(new ColumnSegment(0, 0, "col"));
InsertColumnsSegment insertColumnsSegment = new InsertColumnsSegment(0, 0);
insertColumnsSegment.getColumns().add(new ColumnSegment(0, 0, "col"));
insertStatement.setColumns(insertColumnsSegment);
assertThat(insertStatement.getColumnNames().size(), is(1));
assertThat(insertStatement.getColumnNames().iterator().next(), is("col"));
}
......
......@@ -21,15 +21,18 @@ import org.apache.shardingsphere.sql.parser.api.SQLVisitor;
import org.apache.shardingsphere.sql.parser.autogen.MySQLStatementBaseVisitor;
import org.apache.shardingsphere.sql.parser.autogen.MySQLStatementParser.AssignmentContext;
import org.apache.shardingsphere.sql.parser.autogen.MySQLStatementParser.AssignmentValueContext;
import org.apache.shardingsphere.sql.parser.autogen.MySQLStatementParser.AssignmentValuesContext;
import org.apache.shardingsphere.sql.parser.autogen.MySQLStatementParser.BitExprContext;
import org.apache.shardingsphere.sql.parser.autogen.MySQLStatementParser.BlobValueContext;
import org.apache.shardingsphere.sql.parser.autogen.MySQLStatementParser.BooleanLiteralsContext;
import org.apache.shardingsphere.sql.parser.autogen.MySQLStatementParser.BooleanPrimaryContext;
import org.apache.shardingsphere.sql.parser.autogen.MySQLStatementParser.ColumnNameContext;
import org.apache.shardingsphere.sql.parser.autogen.MySQLStatementParser.ColumnNamesContext;
import org.apache.shardingsphere.sql.parser.autogen.MySQLStatementParser.ExprContext;
import org.apache.shardingsphere.sql.parser.autogen.MySQLStatementParser.FromSchemaContext;
import org.apache.shardingsphere.sql.parser.autogen.MySQLStatementParser.IdentifierContext;
import org.apache.shardingsphere.sql.parser.autogen.MySQLStatementParser.InsertContext;
import org.apache.shardingsphere.sql.parser.autogen.MySQLStatementParser.InsertValuesClauseContext;
import org.apache.shardingsphere.sql.parser.autogen.MySQLStatementParser.LiteralsContext;
import org.apache.shardingsphere.sql.parser.autogen.MySQLStatementParser.NumberLiteralsContext;
import org.apache.shardingsphere.sql.parser.autogen.MySQLStatementParser.OwnerContext;
......@@ -48,8 +51,10 @@ import org.apache.shardingsphere.sql.parser.sql.ASTNode;
import org.apache.shardingsphere.sql.parser.sql.segment.dal.FromSchemaSegment;
import org.apache.shardingsphere.sql.parser.sql.segment.dal.ShowLikeSegment;
import org.apache.shardingsphere.sql.parser.sql.segment.dml.assignment.AssignmentSegment;
import org.apache.shardingsphere.sql.parser.sql.segment.dml.assignment.InsertValuesSegment;
import org.apache.shardingsphere.sql.parser.sql.segment.dml.assignment.SetAssignmentsSegment;
import org.apache.shardingsphere.sql.parser.sql.segment.dml.column.ColumnSegment;
import org.apache.shardingsphere.sql.parser.sql.segment.dml.column.InsertColumnsSegment;
import org.apache.shardingsphere.sql.parser.sql.segment.dml.expr.ExpressionSegment;
import org.apache.shardingsphere.sql.parser.sql.segment.dml.expr.complex.CommonExpressionSegment;
import org.apache.shardingsphere.sql.parser.sql.segment.dml.expr.complex.SubquerySegment;
......@@ -67,6 +72,7 @@ import org.apache.shardingsphere.sql.parser.sql.value.ParameterValue;
import java.util.Collection;
import java.util.LinkedList;
import java.util.List;
/**
* MySQL visitor.
......@@ -119,26 +125,47 @@ public final class MySQLVisitor extends MySQLStatementBaseVisitor<ASTNode> imple
@Override
public ASTNode visitInsert(final InsertContext ctx) {
InsertStatement result = new InsertStatement();
// TODO :Since there is no segment for insertValuesClause, InsertStatement is created by sub rule.
InsertStatement result = null != ctx.insertValuesClause() ? (InsertStatement) visit(ctx.insertValuesClause()) : (InsertStatement) visit(ctx.setAssignmentsClause());
TableSegment table = (TableSegment) visit(ctx.tableName());
result.setTable(table);
result.getAllSQLSegments().add(table);
if (null != ctx.setAssignmentsClause()) {
SetAssignmentsSegment segment = (SetAssignmentsSegment) visit(ctx.setAssignmentsClause());
result.setSetAssignment(segment);
result.getAllSQLSegments().add(segment);
}
result.setParametersCount(currentParameterIndex);
return result;
}
@Override
public ASTNode visitInsertValuesClause(final InsertValuesClauseContext ctx) {
InsertStatement result = new InsertStatement();
InsertColumnsSegment insertColumnsSegment = (InsertColumnsSegment) visit(ctx.columnNames());
Collection<InsertValuesSegment> insertValuesSegments = createInsertValuesSegments(ctx.assignmentValues());
result.setColumns(insertColumnsSegment);
result.getValues().addAll(insertValuesSegments);
result.getAllSQLSegments().add(insertColumnsSegment);
result.getAllSQLSegments().addAll(insertValuesSegments);
return result;
}
@Override
public ASTNode visitSetAssignmentsClause(final SetAssignmentsClauseContext ctx) {
InsertStatement result = new InsertStatement();
Collection<AssignmentSegment> assignments = new LinkedList<>();
for (AssignmentContext each : ctx.assignment()) {
assignments.add((AssignmentSegment) visit(each));
}
return new SetAssignmentsSegment(ctx.start.getStartIndex(), ctx.stop.getStopIndex(), assignments);
SetAssignmentsSegment segment = new SetAssignmentsSegment(ctx.start.getStartIndex(), ctx.stop.getStopIndex(), assignments);
result.setSetAssignment(segment);
result.getAllSQLSegments().add(segment);
return result;
}
@Override
public ASTNode visitAssignmentValues(final AssignmentValuesContext ctx) {
List<ExpressionSegment> segments = new LinkedList<>();
for (AssignmentValueContext each : ctx.assignmentValue()) {
segments.add((ExpressionSegment) visit(each));
}
return new InsertValuesSegment(ctx.start.getStartIndex(), ctx.stop.getStopIndex(), segments);
}
@Override
......@@ -183,6 +210,17 @@ public final class MySQLVisitor extends MySQLStatementBaseVisitor<ASTNode> imple
return result;
}
@Override
public ASTNode visitColumnNames(final ColumnNamesContext ctx) {
Collection<ColumnSegment> segments = new LinkedList<>();
for (ColumnNameContext each : ctx.columnName()) {
segments.add((ColumnSegment) visit(each));
}
InsertColumnsSegment result = new InsertColumnsSegment(ctx.start.getStartIndex(), ctx.stop.getStopIndex());
result.getColumns().addAll(segments);
return result;
}
@Override
public ASTNode visitColumnName(final ColumnNameContext ctx) {
LiteralValue columnName = (LiteralValue) visit(ctx.name());
......@@ -320,4 +358,12 @@ public final class MySQLVisitor extends MySQLStatementBaseVisitor<ASTNode> imple
}
return new ParameterMarkerExpressionSegment(expr.start.getStartIndex(), expr.stop.getStopIndex(), ((ParameterValue) astNode).getParameterIndex());
}
private Collection<InsertValuesSegment> createInsertValuesSegments(final Collection<AssignmentValuesContext> assignmentValuesContexts) {
Collection<InsertValuesSegment> result = new LinkedList<>();
for (AssignmentValuesContext each : assignmentValuesContexts) {
result.add((InsertValuesSegment) visit(each));
}
return result;
}
}
......@@ -20,6 +20,7 @@ package org.apache.shardingsphere.sql.parser.relation.statement.impl;
import org.apache.shardingsphere.sql.parser.relation.metadata.RelationMetas;
import org.apache.shardingsphere.sql.parser.sql.segment.dml.assignment.InsertValuesSegment;
import org.apache.shardingsphere.sql.parser.sql.segment.dml.column.ColumnSegment;
import org.apache.shardingsphere.sql.parser.sql.segment.dml.column.InsertColumnsSegment;
import org.apache.shardingsphere.sql.parser.sql.segment.dml.expr.ExpressionSegment;
import org.apache.shardingsphere.sql.parser.sql.segment.dml.expr.simple.LiteralExpressionSegment;
import org.apache.shardingsphere.sql.parser.sql.segment.dml.expr.simple.ParameterMarkerExpressionSegment;
......@@ -40,7 +41,9 @@ public final class InsertSQLStatementContextTest {
public void assertInsertSQLStatementContextWithColumnNames() {
InsertStatement insertStatement = new InsertStatement();
insertStatement.getAllSQLSegments().add(new TableSegment(0, 0, "tbl"));
insertStatement.getColumns().addAll(Arrays.asList(new ColumnSegment(0, 0, "id"), new ColumnSegment(0, 0, "name"), new ColumnSegment(0, 0, "status")));
InsertColumnsSegment insertColumnsSegment = new InsertColumnsSegment(0, 0);
insertColumnsSegment.getColumns().addAll(Arrays.asList(new ColumnSegment(0, 0, "id"), new ColumnSegment(0, 0, "name"), new ColumnSegment(0, 0, "status")));
insertStatement.setColumns(insertColumnsSegment);
setUpInsertValues(insertStatement);
InsertSQLStatementContext actual = new InsertSQLStatementContext(mock(RelationMetas.class), Arrays.<Object>asList(1, "Tom", 2, "Jerry"), insertStatement);
assertInsertSQLStatementContext(actual);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册