提交 fe0765d1 编写于 作者: T terrymanu

refactor api for generate key, multiple keys => single key 2nd version

上级 3f496c7a
......@@ -226,13 +226,11 @@ public class ShardingStatement extends AbstractStatementAdapter {
}
Optional<GeneratedKeyContext> generatedKeyContext = getGeneratedKeyContext();
if (!generatedKeyContext.isPresent()) {
return generatedKeysResultSet = new GeneratedKeysResultSet();
}
if (generatedKeyContext.get().getColumnNameToIndexMap().isEmpty()) {
Collection<? extends Statement> routedStatements = getRoutedStatements();
if (1 == routedStatements.size()) {
return generatedKeysResultSet = routedStatements.iterator().next().getGeneratedKeys();
}
return generatedKeysResultSet = new GeneratedKeysResultSet();
}
if (Statement.RETURN_GENERATED_KEYS != generatedKeyContext.get().getAutoGeneratedKeys()
&& null == generatedKeyContext.get().getColumnIndexes() && null == generatedKeyContext.get().getColumnNames()) {
......@@ -274,7 +272,7 @@ public class ShardingStatement extends AbstractStatementAdapter {
protected final Optional<GeneratedKeyContext> getGeneratedKeyContext() {
if (null != sqlContext && sqlContext instanceof InsertSQLContext) {
return Optional.of(((InsertSQLContext) sqlContext).getGeneratedKeyContext());
return Optional.fromNullable(((InsertSQLContext) sqlContext).getGeneratedKeyContext());
}
return Optional.absent();
}
......
......@@ -20,12 +20,11 @@ package com.dangdang.ddframe.rdb.sharding.parsing.parser.context;
import com.google.common.collect.Table;
import com.google.common.collect.TreeBasedTable;
import lombok.Getter;
import lombok.RequiredArgsConstructor;
import lombok.Setter;
import lombok.ToString;
import java.util.HashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
/**
......@@ -33,11 +32,12 @@ import java.util.Map;
*
* @author gaohongtao
*/
@RequiredArgsConstructor
@ToString
public final class GeneratedKeyContext {
@Getter
private final List<String> columns = new LinkedList<>();
private final String column;
@Getter
private final Map<String, Integer> columnNameToIndexMap = new HashMap<>();
......
......@@ -45,10 +45,10 @@ import java.util.Map;
@Setter
public final class InsertSQLContext extends AbstractSQLContext {
private final GeneratedKeyContext generatedKeyContext = new GeneratedKeyContext();
private final Collection<ShardingColumnContext> shardingColumnContexts = new LinkedList<>();
private GeneratedKeyContext generatedKeyContext;
private int columnsListLastPosition;
private int valuesListLastPosition;
......@@ -79,25 +79,26 @@ public final class InsertSQLContext extends AbstractSQLContext {
* @return 自增列与主键映射表
*/
public Map<String, Number> generateKeys(final ShardingRule shardingRule) {
if (null == generatedKeyContext) {
return Collections.emptyMap();
}
Optional<TableRule> tableRuleOptional = shardingRule.tryFindTableRule(getTables().iterator().next().getName());
if (!tableRuleOptional.isPresent()) {
return Collections.emptyMap();
}
TableRule tableRule = tableRuleOptional.get();
Map<String, Number> result = new LinkedHashMap<>(generatedKeyContext.getColumns().size());
for (String each : generatedKeyContext.getColumns()) {
Number generatedKey;
if (null != tableRule.getIdGenerator()) {
generatedKey = tableRule.getIdGenerator().generateId();
} else if (null != shardingRule.getIdGenerator()) {
generatedKey = shardingRule.getIdGenerator().generateId();
} else {
// TODO 使用default id生成器
generatedKey = null;
}
result.put(each, generatedKey);
generatedKeyContext.putValue(each, generatedKey);
Map<String, Number> result = new LinkedHashMap<>(1, 1);
Number generatedKey;
if (null != tableRule.getIdGenerator()) {
generatedKey = tableRule.getIdGenerator().generateId();
} else if (null != shardingRule.getIdGenerator()) {
generatedKey = shardingRule.getIdGenerator().generateId();
} else {
// TODO 使用default id生成器
generatedKey = null;
}
result.put(generatedKeyContext.getColumn(), generatedKey);
generatedKeyContext.putValue(generatedKeyContext.getColumn(), generatedKey);
return result;
}
......@@ -129,18 +130,17 @@ public final class InsertSQLContext extends AbstractSQLContext {
* @param parametersSize 参数个数
*/
public void appendGenerateKeysToken(final ShardingRule shardingRule, final int parametersSize) {
if (null == generatedKeyContext) {
return;
}
Optional<AutoGeneratedKeysToken> autoGeneratedKeysToken = findAutoGeneratedKeysToken();
if (!autoGeneratedKeysToken.isPresent()) {
return;
}
String tableName = getTables().get(0).getName();
ItemsToken valuesToken = new ItemsToken(autoGeneratedKeysToken.get().getBeginPosition());
int offset = 0;
for (String each : generatedKeyContext.getColumns()) {
valuesToken.getItems().add("?");
addCondition(shardingRule, new ShardingColumnContext(each, tableName, true), new SQLPlaceholderExpr(parametersSize + offset));
offset++;
}
valuesToken.getItems().add("?");
addCondition(shardingRule, new ShardingColumnContext(generatedKeyContext.getColumn(), tableName, true), new SQLPlaceholderExpr(parametersSize));
getSqlTokens().remove(autoGeneratedKeysToken.get());
getSqlTokens().add(valuesToken);
}
......
......@@ -23,6 +23,7 @@ import com.dangdang.ddframe.rdb.sharding.parsing.lexer.token.DefaultKeyword;
import com.dangdang.ddframe.rdb.sharding.parsing.lexer.token.Symbol;
import com.dangdang.ddframe.rdb.sharding.parsing.lexer.token.TokenType;
import com.dangdang.ddframe.rdb.sharding.parsing.parser.context.ConditionContext;
import com.dangdang.ddframe.rdb.sharding.parsing.parser.context.GeneratedKeyContext;
import com.dangdang.ddframe.rdb.sharding.parsing.parser.context.InsertSQLContext;
import com.dangdang.ddframe.rdb.sharding.parsing.parser.context.ShardingColumnContext;
import com.dangdang.ddframe.rdb.sharding.parsing.parser.exception.SQLParsingUnsupportedException;
......@@ -167,7 +168,7 @@ public abstract class AbstractInsertParser implements SQLParser {
if (autoGeneratedKeyColumn.isPresent()) {
if (!sqlContext.hasColumn(autoGeneratedKeyColumn.get())) {
columnsToken.getItems().add(autoGeneratedKeyColumn.get());
sqlContext.getGeneratedKeyContext().getColumns().add(autoGeneratedKeyColumn.get());
sqlContext.setGeneratedKeyContext(new GeneratedKeyContext(autoGeneratedKeyColumn.get()));
}
}
if (!columnsToken.getItems().isEmpty()) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册