提交 edee3928 编写于 作者: T terrymanu

refactor generateKeys to rewrite module

上级 a3d4961f
package com.dangdang.ddframe.rdb.sharding.rewrite;
import com.dangdang.ddframe.rdb.sharding.api.rule.ShardingRule;
import com.dangdang.ddframe.rdb.sharding.parsing.parser.context.ConditionContext;
import com.dangdang.ddframe.rdb.sharding.parsing.parser.context.InsertSQLContext;
import com.dangdang.ddframe.rdb.sharding.parsing.parser.context.ItemsToken;
import com.dangdang.ddframe.rdb.sharding.parsing.parser.context.ShardingColumnContext;
import com.dangdang.ddframe.rdb.sharding.parsing.parser.expr.SQLNumberExpr;
import com.dangdang.ddframe.rdb.sharding.parsing.parser.expr.SQLPlaceholderExpr;
import java.util.Collection;
import java.util.List;
/**
* .
*
* @author zhangliang
*/
public class GenerateKeysUtils {
/**
* 追加自增主键.
*
* @param shardingRule 分片规则
* @param parameters 参数
* @param insertSQLContext 解析结果
*/
public static void appendGenerateKeys(final ShardingRule shardingRule, final List<Object> parameters, final InsertSQLContext insertSQLContext) {
String tableName = insertSQLContext.getTables().get(0).getName();
ItemsToken columnsToken = new ItemsToken(insertSQLContext.getColumnsListLastPosition());
Collection<String> autoIncrementColumns = shardingRule.getAutoIncrementColumns(tableName);
for (String each : autoIncrementColumns) {
if (!isIncluded(insertSQLContext, each)) {
columnsToken.getItems().add(each);
}
}
if (!columnsToken.getItems().isEmpty()) {
insertSQLContext.getSqlBuilderContext().getSqlTokens().add(columnsToken);
}
ItemsToken valuesToken = new ItemsToken(insertSQLContext.getValuesListLastPosition());
int offset = parameters.size() - 1;
for (String each : autoIncrementColumns) {
if (isIncluded(insertSQLContext, each)) {
continue;
}
Number generatedId = shardingRule.findTableRule(tableName).generateId(each);
ShardingColumnContext shardingColumnContext = new ShardingColumnContext(each, tableName, true);
if (parameters.isEmpty()) {
valuesToken.getItems().add(generatedId.toString());
if (shardingRule.isShardingColumn(shardingColumnContext)) {
insertSQLContext.getConditionContext().add(new ConditionContext.Condition(shardingColumnContext, new SQLNumberExpr(generatedId)));
}
} else {
valuesToken.getItems().add("?");
parameters.add(generatedId);
offset++;
if (shardingRule.isShardingColumn(shardingColumnContext)) {
insertSQLContext.getConditionContext().add(new ConditionContext.Condition(shardingColumnContext, new SQLPlaceholderExpr(offset)));
}
}
insertSQLContext.getGeneratedKeyContext().getColumns().add(each);
insertSQLContext.getGeneratedKeyContext().putValue(each, generatedId);
}
if (!valuesToken.getItems().isEmpty()) {
insertSQLContext.getSqlBuilderContext().getSqlTokens().add(valuesToken);
}
}
private static boolean isIncluded(final InsertSQLContext insertSQLContext, final String autoIncrementColumn) {
for (ShardingColumnContext shardingColumnContext : insertSQLContext.getShardingColumnContexts()) {
if (shardingColumnContext.getColumnName().equalsIgnoreCase(autoIncrementColumn)) {
return true;
}
}
return false;
}
}
......@@ -26,18 +26,15 @@ import com.dangdang.ddframe.rdb.sharding.parsing.SQLParsingEngine;
import com.dangdang.ddframe.rdb.sharding.parsing.parser.context.ConditionContext;
import com.dangdang.ddframe.rdb.sharding.parsing.parser.context.DeleteSQLContext;
import com.dangdang.ddframe.rdb.sharding.parsing.parser.context.InsertSQLContext;
import com.dangdang.ddframe.rdb.sharding.parsing.parser.context.ItemsToken;
import com.dangdang.ddframe.rdb.sharding.parsing.parser.context.LimitContext;
import com.dangdang.ddframe.rdb.sharding.parsing.parser.context.OffsetLimitToken;
import com.dangdang.ddframe.rdb.sharding.parsing.parser.context.RowCountLimitToken;
import com.dangdang.ddframe.rdb.sharding.parsing.parser.context.SQLContext;
import com.dangdang.ddframe.rdb.sharding.parsing.parser.context.SelectSQLContext;
import com.dangdang.ddframe.rdb.sharding.parsing.parser.context.ShardingColumnContext;
import com.dangdang.ddframe.rdb.sharding.parsing.parser.context.TableContext;
import com.dangdang.ddframe.rdb.sharding.parsing.parser.context.UpdateSQLContext;
import com.dangdang.ddframe.rdb.sharding.parsing.parser.exception.SQLParsingException;
import com.dangdang.ddframe.rdb.sharding.parsing.parser.expr.SQLNumberExpr;
import com.dangdang.ddframe.rdb.sharding.parsing.parser.expr.SQLPlaceholderExpr;
import com.dangdang.ddframe.rdb.sharding.rewrite.GenerateKeysUtils;
import com.dangdang.ddframe.rdb.sharding.rewrite.SQLBuilder;
import com.dangdang.ddframe.rdb.sharding.rewrite.SQLBuilderContext;
import com.dangdang.ddframe.rdb.sharding.rewrite.SQLRewriteEngine;
......@@ -52,7 +49,6 @@ import com.google.common.collect.Sets;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Set;
......@@ -100,59 +96,12 @@ public final class SQLRouteEngine {
log.debug("Logic SQL: {}, {}", logicSql, parameters);
SQLContext result = new SQLParsingEngine(databaseType, logicSql, shardingRule).parseStatement();
MetricsContext.stop(context);
// TODO 提炼至rewrite模块
if (result instanceof InsertSQLContext) {
String tableName = result.getTables().get(0).getName();
ItemsToken columnsToken = new ItemsToken(((InsertSQLContext) result).getColumnsListLastPosition());
Collection<String> autoIncrementColumns = shardingRule.getAutoIncrementColumns(tableName);
for (String each : autoIncrementColumns) {
if (!isIncluded((InsertSQLContext) result, each)) {
columnsToken.getItems().add(each);
}
}
if (!columnsToken.getItems().isEmpty()) {
result.getSqlBuilderContext().getSqlTokens().add(columnsToken);
}
ItemsToken valuesToken = new ItemsToken(((InsertSQLContext) result).getValuesListLastPosition());
int offset = parameters.size() - 1;
for (String each : autoIncrementColumns) {
if (isIncluded((InsertSQLContext) result, each)) {
continue;
}
Number generatedId = shardingRule.findTableRule(tableName).generateId(each);
ShardingColumnContext shardingColumnContext = new ShardingColumnContext(each, tableName, true);
if (parameters.isEmpty()) {
valuesToken.getItems().add(generatedId.toString());
if (shardingRule.isShardingColumn(shardingColumnContext)) {
result.getConditionContext().add(new ConditionContext.Condition(shardingColumnContext, new SQLNumberExpr(generatedId)));
}
} else {
valuesToken.getItems().add("?");
parameters.add(generatedId);
offset++;
if (shardingRule.isShardingColumn(shardingColumnContext)) {
result.getConditionContext().add(new ConditionContext.Condition(shardingColumnContext, new SQLPlaceholderExpr(offset)));
}
}
((InsertSQLContext) result).getGeneratedKeyContext().getColumns().add(each);
((InsertSQLContext) result).getGeneratedKeyContext().putValue(each, generatedId);
}
if (!valuesToken.getItems().isEmpty()) {
result.getSqlBuilderContext().getSqlTokens().add(valuesToken);
}
GenerateKeysUtils.appendGenerateKeys(shardingRule, parameters, (InsertSQLContext) result);
}
return result;
}
private boolean isIncluded(final InsertSQLContext insertSQLContext, final String autoIncrementColumn) {
for (ShardingColumnContext shardingColumnContext : insertSQLContext.getShardingColumnContexts()) {
if (shardingColumnContext.getColumnName().equalsIgnoreCase(autoIncrementColumn)) {
return true;
}
}
return false;
}
private SQLContext buildHintParsedResult(final String logicSql) {
SQLContext result;
switch (SQLUtil.getTypeByStart(logicSql)) {
......
......@@ -29,6 +29,7 @@ import com.dangdang.ddframe.rdb.sharding.parsing.SQLParsingEngine;
import com.dangdang.ddframe.rdb.sharding.parsing.parser.context.ConditionContext;
import com.dangdang.ddframe.rdb.sharding.parsing.parser.context.InsertSQLContext;
import com.dangdang.ddframe.rdb.sharding.parsing.parser.exception.SQLParsingUnsupportedException;
import com.dangdang.ddframe.rdb.sharding.rewrite.GenerateKeysUtils;
import com.dangdang.ddframe.rdb.sharding.rewrite.SQLRewriteEngine;
import org.junit.Test;
......@@ -77,7 +78,8 @@ public final class InsertStatementParserTest extends AbstractStatementParserTest
InsertSQLContext sqlContext = (InsertSQLContext) statementParser.parseStatement();
assertInsertStatementWithoutParameter(sqlContext);
// TODO 放入rewrite模块断言
assertThat(new SQLRewriteEngine(sqlContext.getSqlBuilderContext()).rewrite().toString(), is("INSERT INTO [Token(TABLE_XXX)] (`field1`) VALUES (10)"));
GenerateKeysUtils.appendGenerateKeys(shardingRule, Collections.emptyList(), sqlContext);
assertThat(new SQLRewriteEngine(sqlContext.getSqlBuilderContext()).rewrite().toString(), is("INSERT INTO [Token(TABLE_XXX)] (`field1`, field2) VALUES (10, 1)"));
}
@Test
......@@ -87,7 +89,8 @@ public final class InsertStatementParserTest extends AbstractStatementParserTest
InsertSQLContext sqlContext = (InsertSQLContext) statementParser.parseStatement();
assertInsertStatementWithParameter(sqlContext);
// TODO 放入rewrite模块断言
assertThat(new SQLRewriteEngine(sqlContext.getSqlBuilderContext()).rewrite().toString(), is("INSERT INTO [Token(TABLE_XXX)] (`field1`) VALUES (?)"));
GenerateKeysUtils.appendGenerateKeys(shardingRule, Collections.emptyList(), sqlContext);
assertThat(new SQLRewriteEngine(sqlContext.getSqlBuilderContext()).rewrite().toString(), is("INSERT INTO [Token(TABLE_XXX)] (`field1`, field2) VALUES (?, 1)"));
}
private void assertInsertStatementWithoutParameter(final InsertSQLContext sqlContext) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册