提交 abee4c7d 编写于 作者: T terrymanu

split generate id form parsing to router, will move to rewrite in future

上级 c1d91655
......@@ -48,7 +48,6 @@ public abstract class AbstractSQLContext implements SQLContext {
private ConditionContext conditionContext = new ConditionContext();
@Getter(AccessLevel.NONE)
private SQLBuilderContext sqlBuilderContext;
@Setter(AccessLevel.NONE)
......
......@@ -19,6 +19,10 @@ package com.dangdang.ddframe.rdb.sharding.parsing.parser.context;
import com.dangdang.ddframe.rdb.sharding.constant.SQLType;
import lombok.Getter;
import lombok.Setter;
import java.util.Collection;
import java.util.LinkedList;
/**
* Insert SQL上下文.
......@@ -26,10 +30,17 @@ import lombok.Getter;
* @author zhangliang
*/
@Getter
@Setter
public final class InsertSQLContext extends AbstractSQLContext {
private final GeneratedKeyContext generatedKeyContext = new GeneratedKeyContext();
private final Collection<ShardingColumnContext> shardingColumnContexts = new LinkedList<>();
private int columnsListLastPosition;
private int valuesListLastPosition;
public InsertSQLContext() {
super(SQLType.INSERT);
}
......
......@@ -66,6 +66,13 @@ public interface SQLContext {
*/
Optional<ShardingColumnContext> findColumn(SQLExpr expr);
/**
* 获取SQL构建器上下文.
*
* @return SQL构建器上下文
*/
SQLBuilderContext getSqlBuilderContext();
/**
* 获取SQL构建器.
*
......
......@@ -33,9 +33,9 @@ import com.dangdang.ddframe.rdb.sharding.parsing.parser.expr.SQLNumberExpr;
import com.dangdang.ddframe.rdb.sharding.parsing.parser.expr.SQLPlaceholderExpr;
import com.dangdang.ddframe.rdb.sharding.parsing.parser.expr.SQLTextExpr;
import com.dangdang.ddframe.rdb.sharding.parsing.parser.statement.insert.AbstractInsertParser;
import com.dangdang.ddframe.rdb.sharding.util.SQLUtil;
import com.google.common.collect.Sets;
import java.util.Collection;
import java.util.Set;
/**
......@@ -55,11 +55,11 @@ public final class MySQLInsertParser extends AbstractInsertParser {
}
private void parseInsertSet() {
Collection<String> autoIncrementColumns = getShardingRule().getAutoIncrementColumns(getSqlContext().getTables().get(0).getName());
ConditionContext conditionContext = new ConditionContext();
do {
getSqlParser().getLexer().nextToken();
ShardingColumnContext shardingColumnContext = getColumn(autoIncrementColumns);
ShardingColumnContext shardingColumnContext = new ShardingColumnContext(
SQLUtil.getExactlyValue(getSqlParser().getLexer().getCurrentToken().getLiterals()), getSqlContext().getTables().get(0).getName());
getSqlParser().getLexer().nextToken();
getSqlParser().accept(Symbol.EQ);
SQLExpr sqlExpr;
......
......@@ -18,19 +18,16 @@
package com.dangdang.ddframe.rdb.sharding.parsing.parser.statement.insert;
import com.dangdang.ddframe.rdb.sharding.api.rule.ShardingRule;
import com.dangdang.ddframe.rdb.sharding.parsing.parser.context.ConditionContext;
import com.dangdang.ddframe.rdb.sharding.parsing.parser.context.InsertSQLContext;
import com.dangdang.ddframe.rdb.sharding.parsing.parser.context.ItemsToken;
import com.dangdang.ddframe.rdb.sharding.parsing.parser.context.ShardingColumnContext;
import com.dangdang.ddframe.rdb.sharding.parsing.parser.expr.SQLExpr;
import com.dangdang.ddframe.rdb.sharding.parsing.parser.expr.SQLNumberExpr;
import com.dangdang.ddframe.rdb.sharding.parsing.parser.expr.SQLPlaceholderExpr;
import com.dangdang.ddframe.rdb.sharding.parsing.lexer.token.Assist;
import com.dangdang.ddframe.rdb.sharding.parsing.lexer.token.DefaultKeyword;
import com.dangdang.ddframe.rdb.sharding.parsing.lexer.token.Symbol;
import com.dangdang.ddframe.rdb.sharding.parsing.lexer.token.TokenType;
import com.dangdang.ddframe.rdb.sharding.parsing.parser.exception.SQLParsingUnsupportedException;
import com.dangdang.ddframe.rdb.sharding.parsing.parser.SQLParser;
import com.dangdang.ddframe.rdb.sharding.parsing.parser.context.ConditionContext;
import com.dangdang.ddframe.rdb.sharding.parsing.parser.context.InsertSQLContext;
import com.dangdang.ddframe.rdb.sharding.parsing.parser.context.ShardingColumnContext;
import com.dangdang.ddframe.rdb.sharding.parsing.parser.exception.SQLParsingUnsupportedException;
import com.dangdang.ddframe.rdb.sharding.parsing.parser.expr.SQLExpr;
import com.dangdang.ddframe.rdb.sharding.parsing.parser.statement.SQLStatementParser;
import com.dangdang.ddframe.rdb.sharding.util.SQLUtil;
import com.google.common.collect.Sets;
......@@ -68,12 +65,12 @@ public abstract class AbstractInsertParser implements SQLStatementParser {
public final InsertSQLContext parse() {
sqlParser.getLexer().nextToken();
parseInto();
Collection<ShardingColumnContext> shardingColumnContexts = parseColumns();
parseColumns();
if (sqlParser.equalAny(DefaultKeyword.SELECT, Symbol.LEFT_PAREN)) {
throw new UnsupportedOperationException("Cannot support subquery");
}
if (getValuesKeywords().contains(sqlParser.getLexer().getCurrentToken().getType())) {
parseValues(shardingColumnContexts);
parseValues();
} else if (getCustomizedInsertKeywords().contains(sqlParser.getLexer().getCurrentToken().getType())) {
parseCustomizedInsert();
}
......@@ -107,41 +104,25 @@ public abstract class AbstractInsertParser implements SQLStatementParser {
return Collections.emptySet();
}
private Collection<ShardingColumnContext> parseColumns() {
private void parseColumns() {
Collection<ShardingColumnContext> result = new LinkedList<>();
Collection<String> autoIncrementColumns = shardingRule.getAutoIncrementColumns(sqlContext.getTables().get(0).getName());
if (sqlParser.equalAny(Symbol.LEFT_PAREN)) {
do {
sqlParser.getLexer().nextToken();
result.add(getColumn(autoIncrementColumns));
result.add(new ShardingColumnContext(SQLUtil.getExactlyValue(sqlParser.getLexer().getCurrentToken().getLiterals()), sqlContext.getTables().get(0).getName()));
sqlParser.getLexer().nextToken();
} while (!sqlParser.equalAny(Symbol.RIGHT_PAREN) && !sqlParser.equalAny(Assist.END));
ItemsToken itemsToken = new ItemsToken(sqlParser.getLexer().getCurrentToken().getEndPosition() - sqlParser.getLexer().getCurrentToken().getLiterals().length());
for (String each : autoIncrementColumns) {
itemsToken.getItems().add(each);
result.add(new ShardingColumnContext(each, sqlContext.getTables().get(0).getName(), true));
}
if (!itemsToken.getItems().isEmpty()) {
sqlParser.getSqlBuilderContext().getSqlTokens().add(itemsToken);
}
sqlContext.setColumnsListLastPosition(sqlParser.getLexer().getCurrentToken().getEndPosition() - sqlParser.getLexer().getCurrentToken().getLiterals().length());
sqlParser.getLexer().nextToken();
}
return result;
}
protected final ShardingColumnContext getColumn(final Collection<String> autoIncrementColumns) {
String columnName = SQLUtil.getExactlyValue(sqlParser.getLexer().getCurrentToken().getLiterals());
if (autoIncrementColumns.contains(columnName)) {
autoIncrementColumns.remove(columnName);
}
return new ShardingColumnContext(columnName, sqlContext.getTables().get(0).getName());
sqlContext.getShardingColumnContexts().addAll(result);
}
protected Set<TokenType> getValuesKeywords() {
return Sets.<TokenType>newHashSet(DefaultKeyword.VALUES);
}
private void parseValues(final Collection<ShardingColumnContext> shardingColumnContexts) {
private void parseValues() {
boolean parsed = false;
do {
if (parsed) {
......@@ -154,31 +135,14 @@ public abstract class AbstractInsertParser implements SQLStatementParser {
do {
sqlExprs.add(sqlParser.parseExpression());
} while (sqlParser.skipIfEqual(Symbol.COMMA));
ItemsToken itemsToken = new ItemsToken(sqlParser.getLexer().getCurrentToken().getEndPosition() - sqlParser.getLexer().getCurrentToken().getLiterals().length());
sqlContext.setValuesListLastPosition(sqlParser.getLexer().getCurrentToken().getEndPosition() - sqlParser.getLexer().getCurrentToken().getLiterals().length());
int count = 0;
int offset = 0;
for (ShardingColumnContext each : shardingColumnContexts) {
if (each.isAutoIncrement()) {
Number generatedId = getShardingRule().findTableRule(sqlContext.getTables().get(0).getName()).generateId(each.getColumnName());
if (0 == sqlParser.getParametersIndex()) {
itemsToken.getItems().add(generatedId.toString());
sqlExprs.add(new SQLNumberExpr(generatedId));
} else {
itemsToken.getItems().add("?");
offset++;
sqlExprs.add(new SQLPlaceholderExpr(sqlParser.getParametersIndex() + offset - 1));
}
sqlContext.getGeneratedKeyContext().getColumns().add(each.getColumnName());
sqlContext.getGeneratedKeyContext().putValue(each.getColumnName(), generatedId);
}
for (ShardingColumnContext each : sqlContext.getShardingColumnContexts()) {
if (getShardingRule().isShardingColumn(each)) {
conditionContext.add(new ConditionContext.Condition(each, sqlExprs.get(count)));
}
count++;
}
if (!itemsToken.getItems().isEmpty()) {
sqlParser.getSqlBuilderContext().getSqlTokens().add(itemsToken);
}
sqlParser.accept(Symbol.RIGHT_PAREN);
parsed = true;
sqlContext.setConditionContext(conditionContext);
......
......@@ -30,7 +30,6 @@ import lombok.RequiredArgsConstructor;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.TreeSet;
/**
* 预解析功能的SQL路由器.
......@@ -58,8 +57,6 @@ public class PreparedSQLRouter {
public SQLRouteResult route(final List<Object> parameters) {
if (null == sqlContext) {
sqlContext = engine.parseSQL(logicSql, parameters);
// TODO 提炼至rewrite模块
fillGeneratedId(parameters);
} else {
List<Number> generatedIds = generateId();
parameters.addAll(generatedIds);
......@@ -69,15 +66,6 @@ public class PreparedSQLRouter {
return engine.routeSQL(sqlContext, parameters);
}
private void fillGeneratedId(final List<Object> parameters) {
if (sqlContext instanceof InsertSQLContext) {
InsertSQLContext insertSQLContext = (InsertSQLContext) sqlContext;
for (Integer each : new TreeSet<>(insertSQLContext.getGeneratedKeyContext().getColumnNameToIndexMap().values())) {
parameters.add(insertSQLContext.getGeneratedKeyContext().getValueTable().get(0, each));
}
}
}
private void setLimit(final List<Object> parameters) {
if (null == sqlContext.getLimitContext()) {
return;
......
......@@ -23,17 +23,21 @@ import com.dangdang.ddframe.rdb.sharding.constant.DatabaseType;
import com.dangdang.ddframe.rdb.sharding.hint.HintManagerHolder;
import com.dangdang.ddframe.rdb.sharding.metrics.MetricsContext;
import com.dangdang.ddframe.rdb.sharding.parsing.SQLParsingEngine;
import com.dangdang.ddframe.rdb.sharding.parsing.parser.exception.SQLParsingException;
import com.dangdang.ddframe.rdb.sharding.parsing.parser.context.ConditionContext;
import com.dangdang.ddframe.rdb.sharding.parsing.parser.context.DeleteSQLContext;
import com.dangdang.ddframe.rdb.sharding.parsing.parser.context.InsertSQLContext;
import com.dangdang.ddframe.rdb.sharding.parsing.parser.context.ItemsToken;
import com.dangdang.ddframe.rdb.sharding.parsing.parser.context.LimitContext;
import com.dangdang.ddframe.rdb.sharding.parsing.parser.context.SQLBuilder;
import com.dangdang.ddframe.rdb.sharding.parsing.parser.context.SQLBuilderContext;
import com.dangdang.ddframe.rdb.sharding.parsing.parser.context.SQLContext;
import com.dangdang.ddframe.rdb.sharding.parsing.parser.context.SelectSQLContext;
import com.dangdang.ddframe.rdb.sharding.parsing.parser.context.ShardingColumnContext;
import com.dangdang.ddframe.rdb.sharding.parsing.parser.context.TableContext;
import com.dangdang.ddframe.rdb.sharding.parsing.parser.context.UpdateSQLContext;
import com.dangdang.ddframe.rdb.sharding.parsing.parser.exception.SQLParsingException;
import com.dangdang.ddframe.rdb.sharding.parsing.parser.expr.SQLNumberExpr;
import com.dangdang.ddframe.rdb.sharding.parsing.parser.expr.SQLPlaceholderExpr;
import com.dangdang.ddframe.rdb.sharding.router.binding.BindingTablesRouter;
import com.dangdang.ddframe.rdb.sharding.router.database.DatabaseRouter;
import com.dangdang.ddframe.rdb.sharding.router.mixed.MixedTablesRouter;
......@@ -45,6 +49,7 @@ import com.google.common.collect.Sets;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Set;
......@@ -92,9 +97,59 @@ public final class SQLRouteEngine {
log.debug("Logic SQL: {}, {}", logicSql, parameters);
SQLContext result = new SQLParsingEngine(databaseType, logicSql, shardingRule).parseStatement();
MetricsContext.stop(context);
// TODO 提炼至rewrite模块
if (result instanceof InsertSQLContext) {
String tableName = result.getTables().get(0).getName();
ItemsToken columnsToken = new ItemsToken(((InsertSQLContext) result).getColumnsListLastPosition());
Collection<String> autoIncrementColumns = shardingRule.getAutoIncrementColumns(tableName);
for (String each : autoIncrementColumns) {
if (!isIncluded((InsertSQLContext) result, each)) {
columnsToken.getItems().add(each);
}
}
if (!columnsToken.getItems().isEmpty()) {
result.getSqlBuilderContext().getSqlTokens().add(columnsToken);
}
ItemsToken valuesToken = new ItemsToken(((InsertSQLContext) result).getValuesListLastPosition());
int offset = parameters.size() - 1;
for (String each : autoIncrementColumns) {
if (isIncluded((InsertSQLContext) result, each)) {
continue;
}
Number generatedId = shardingRule.findTableRule(tableName).generateId(each);
ShardingColumnContext shardingColumnContext = new ShardingColumnContext(each, tableName, true);
if (parameters.isEmpty()) {
valuesToken.getItems().add(generatedId.toString());
if (shardingRule.isShardingColumn(shardingColumnContext)) {
result.getConditionContext().add(new ConditionContext.Condition(shardingColumnContext, new SQLNumberExpr(generatedId)));
}
} else {
valuesToken.getItems().add("?");
parameters.add(generatedId);
offset++;
if (shardingRule.isShardingColumn(shardingColumnContext)) {
result.getConditionContext().add(new ConditionContext.Condition(shardingColumnContext, new SQLPlaceholderExpr(offset)));
}
}
((InsertSQLContext) result).getGeneratedKeyContext().getColumns().add(each);
((InsertSQLContext) result).getGeneratedKeyContext().putValue(each, generatedId);
}
if (!valuesToken.getItems().isEmpty()) {
result.getSqlBuilderContext().getSqlTokens().add(valuesToken);
}
}
return result;
}
private boolean isIncluded(final InsertSQLContext insertSQLContext, final String autoIncrementColumn) {
for (ShardingColumnContext shardingColumnContext : insertSQLContext.getShardingColumnContexts()) {
if (shardingColumnContext.getColumnName().equalsIgnoreCase(autoIncrementColumn)) {
return true;
}
}
return false;
}
private SQLContext buildHintParsedResult(final String logicSql) {
SQLContext result;
switch (SQLUtil.getTypeByStart(logicSql)) {
......
......@@ -73,7 +73,7 @@ public final class InsertStatementParserTest extends AbstractStatementParserTest
SQLParsingEngine statementParser = new SQLParsingEngine(DatabaseType.MySQL, "INSERT INTO `TABLE_XXX` (`field1`) VALUES (10)", shardingRule);
InsertSQLContext sqlContext = (InsertSQLContext) statementParser.parseStatement();
assertInsertStatementWithoutParameter(sqlContext);
assertThat(sqlContext.getSqlBuilder().toString(), is("INSERT INTO [Token(TABLE_XXX)] (`field1`, field2) VALUES (10, 1)"));
assertThat(sqlContext.getSqlBuilder().toString(), is("INSERT INTO [Token(TABLE_XXX)] (`field1`) VALUES (10)"));
}
@Test
......@@ -82,37 +82,24 @@ public final class InsertStatementParserTest extends AbstractStatementParserTest
SQLParsingEngine statementParser = new SQLParsingEngine(DatabaseType.MySQL, "INSERT INTO `TABLE_XXX` (`field1`) VALUES (?)", shardingRule);
InsertSQLContext sqlContext = (InsertSQLContext) statementParser.parseStatement();
assertInsertStatementWithParameter(sqlContext);
assertThat(sqlContext.getSqlBuilder().toString(), is("INSERT INTO [Token(TABLE_XXX)] (`field1`, field2) VALUES (?, ?)"));
assertThat(sqlContext.getSqlBuilder().toString(), is("INSERT INTO [Token(TABLE_XXX)] (`field1`) VALUES (?)"));
}
private void assertInsertStatementWithoutParameter(final InsertSQLContext sqlContext) {
assertThat(sqlContext.getTables().get(0).getName(), is("TABLE_XXX"));
ConditionContext.Condition condition1 = sqlContext.getConditionContext().find("TABLE_XXX", "field1").get();
assertThat(condition1.getOperator(), is(ShardingOperator.EQUAL));
assertThat(condition1.getValues().size(), is(1));
assertThat(condition1.getValues().get(0), is((Comparable) 10));
ConditionContext.Condition condition2 = sqlContext.getConditionContext().find("TABLE_XXX", "field2").get();
assertThat(condition2.getShardingColumnContext().getColumnName(), is("field2"));
assertThat(condition2.getShardingColumnContext().getTableName(), is("TABLE_XXX"));
assertThat(condition2.getOperator(), is(ShardingOperator.EQUAL));
assertThat(condition2.getValues().size(), is(1));
assertThat(condition2.getValues().get(0), is((Comparable) 1));
ConditionContext.Condition condition = sqlContext.getConditionContext().find("TABLE_XXX", "field1").get();
assertThat(condition.getOperator(), is(ShardingOperator.EQUAL));
assertThat(condition.getValues().size(), is(1));
assertThat(condition.getValues().get(0), is((Comparable) 10));
}
private void assertInsertStatementWithParameter(final InsertSQLContext sqlContext) {
assertThat(sqlContext.getTables().get(0).getName(), is("TABLE_XXX"));
ConditionContext.Condition condition1 = sqlContext.getConditionContext().find("TABLE_XXX", "field1").get();
assertThat(condition1.getOperator(), is(ShardingOperator.EQUAL));
assertTrue(condition1.getValues().isEmpty());
assertThat(condition1.getValueIndices().size(), is(1));
assertThat(condition1.getValueIndices().get(0), is(0));
ConditionContext.Condition condition2 = sqlContext.getConditionContext().find("TABLE_XXX", "field2").get();
assertThat(condition2.getShardingColumnContext().getColumnName(), is("field2"));
assertThat(condition2.getShardingColumnContext().getTableName(), is("TABLE_XXX"));
assertThat(condition2.getOperator(), is(ShardingOperator.EQUAL));
assertTrue(condition2.getValues().isEmpty());
assertThat(condition2.getValueIndices().size(), is(1));
assertThat(condition2.getValueIndices().get(0), is(1));
ConditionContext.Condition condition = sqlContext.getConditionContext().find("TABLE_XXX", "field1").get();
assertThat(condition.getOperator(), is(ShardingOperator.EQUAL));
assertTrue(condition.getValues().isEmpty());
assertThat(condition.getValueIndices().size(), is(1));
assertThat(condition.getValueIndices().get(0), is(0));
}
private ShardingRule createShardingRuleWithAutoIncrementColumns() {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册