提交 fe0765d1 编写于 作者: T terrymanu

refactor api for generate key, multiple keys => single key 2nd version

上级 3f496c7a
...@@ -226,13 +226,11 @@ public class ShardingStatement extends AbstractStatementAdapter { ...@@ -226,13 +226,11 @@ public class ShardingStatement extends AbstractStatementAdapter {
} }
Optional<GeneratedKeyContext> generatedKeyContext = getGeneratedKeyContext(); Optional<GeneratedKeyContext> generatedKeyContext = getGeneratedKeyContext();
if (!generatedKeyContext.isPresent()) { if (!generatedKeyContext.isPresent()) {
return generatedKeysResultSet = new GeneratedKeysResultSet();
}
if (generatedKeyContext.get().getColumnNameToIndexMap().isEmpty()) {
Collection<? extends Statement> routedStatements = getRoutedStatements(); Collection<? extends Statement> routedStatements = getRoutedStatements();
if (1 == routedStatements.size()) { if (1 == routedStatements.size()) {
return generatedKeysResultSet = routedStatements.iterator().next().getGeneratedKeys(); return generatedKeysResultSet = routedStatements.iterator().next().getGeneratedKeys();
} }
return generatedKeysResultSet = new GeneratedKeysResultSet();
} }
if (Statement.RETURN_GENERATED_KEYS != generatedKeyContext.get().getAutoGeneratedKeys() if (Statement.RETURN_GENERATED_KEYS != generatedKeyContext.get().getAutoGeneratedKeys()
&& null == generatedKeyContext.get().getColumnIndexes() && null == generatedKeyContext.get().getColumnNames()) { && null == generatedKeyContext.get().getColumnIndexes() && null == generatedKeyContext.get().getColumnNames()) {
...@@ -274,7 +272,7 @@ public class ShardingStatement extends AbstractStatementAdapter { ...@@ -274,7 +272,7 @@ public class ShardingStatement extends AbstractStatementAdapter {
protected final Optional<GeneratedKeyContext> getGeneratedKeyContext() { protected final Optional<GeneratedKeyContext> getGeneratedKeyContext() {
if (null != sqlContext && sqlContext instanceof InsertSQLContext) { if (null != sqlContext && sqlContext instanceof InsertSQLContext) {
return Optional.of(((InsertSQLContext) sqlContext).getGeneratedKeyContext()); return Optional.fromNullable(((InsertSQLContext) sqlContext).getGeneratedKeyContext());
} }
return Optional.absent(); return Optional.absent();
} }
......
...@@ -20,12 +20,11 @@ package com.dangdang.ddframe.rdb.sharding.parsing.parser.context; ...@@ -20,12 +20,11 @@ package com.dangdang.ddframe.rdb.sharding.parsing.parser.context;
import com.google.common.collect.Table; import com.google.common.collect.Table;
import com.google.common.collect.TreeBasedTable; import com.google.common.collect.TreeBasedTable;
import lombok.Getter; import lombok.Getter;
import lombok.RequiredArgsConstructor;
import lombok.Setter; import lombok.Setter;
import lombok.ToString; import lombok.ToString;
import java.util.HashMap; import java.util.HashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Map; import java.util.Map;
/** /**
...@@ -33,11 +32,12 @@ import java.util.Map; ...@@ -33,11 +32,12 @@ import java.util.Map;
* *
* @author gaohongtao * @author gaohongtao
*/ */
@RequiredArgsConstructor
@ToString @ToString
public final class GeneratedKeyContext { public final class GeneratedKeyContext {
@Getter @Getter
private final List<String> columns = new LinkedList<>(); private final String column;
@Getter @Getter
private final Map<String, Integer> columnNameToIndexMap = new HashMap<>(); private final Map<String, Integer> columnNameToIndexMap = new HashMap<>();
......
...@@ -45,10 +45,10 @@ import java.util.Map; ...@@ -45,10 +45,10 @@ import java.util.Map;
@Setter @Setter
public final class InsertSQLContext extends AbstractSQLContext { public final class InsertSQLContext extends AbstractSQLContext {
private final GeneratedKeyContext generatedKeyContext = new GeneratedKeyContext();
private final Collection<ShardingColumnContext> shardingColumnContexts = new LinkedList<>(); private final Collection<ShardingColumnContext> shardingColumnContexts = new LinkedList<>();
private GeneratedKeyContext generatedKeyContext;
private int columnsListLastPosition; private int columnsListLastPosition;
private int valuesListLastPosition; private int valuesListLastPosition;
...@@ -79,25 +79,26 @@ public final class InsertSQLContext extends AbstractSQLContext { ...@@ -79,25 +79,26 @@ public final class InsertSQLContext extends AbstractSQLContext {
* @return 自增列与主键映射表 * @return 自增列与主键映射表
*/ */
public Map<String, Number> generateKeys(final ShardingRule shardingRule) { public Map<String, Number> generateKeys(final ShardingRule shardingRule) {
if (null == generatedKeyContext) {
return Collections.emptyMap();
}
Optional<TableRule> tableRuleOptional = shardingRule.tryFindTableRule(getTables().iterator().next().getName()); Optional<TableRule> tableRuleOptional = shardingRule.tryFindTableRule(getTables().iterator().next().getName());
if (!tableRuleOptional.isPresent()) { if (!tableRuleOptional.isPresent()) {
return Collections.emptyMap(); return Collections.emptyMap();
} }
TableRule tableRule = tableRuleOptional.get(); TableRule tableRule = tableRuleOptional.get();
Map<String, Number> result = new LinkedHashMap<>(generatedKeyContext.getColumns().size()); Map<String, Number> result = new LinkedHashMap<>(1, 1);
for (String each : generatedKeyContext.getColumns()) { Number generatedKey;
Number generatedKey; if (null != tableRule.getIdGenerator()) {
if (null != tableRule.getIdGenerator()) { generatedKey = tableRule.getIdGenerator().generateId();
generatedKey = tableRule.getIdGenerator().generateId(); } else if (null != shardingRule.getIdGenerator()) {
} else if (null != shardingRule.getIdGenerator()) { generatedKey = shardingRule.getIdGenerator().generateId();
generatedKey = shardingRule.getIdGenerator().generateId(); } else {
} else { // TODO 使用default id生成器
// TODO 使用default id生成器 generatedKey = null;
generatedKey = null;
}
result.put(each, generatedKey);
generatedKeyContext.putValue(each, generatedKey);
} }
result.put(generatedKeyContext.getColumn(), generatedKey);
generatedKeyContext.putValue(generatedKeyContext.getColumn(), generatedKey);
return result; return result;
} }
...@@ -129,18 +130,17 @@ public final class InsertSQLContext extends AbstractSQLContext { ...@@ -129,18 +130,17 @@ public final class InsertSQLContext extends AbstractSQLContext {
* @param parametersSize 参数个数 * @param parametersSize 参数个数
*/ */
public void appendGenerateKeysToken(final ShardingRule shardingRule, final int parametersSize) { public void appendGenerateKeysToken(final ShardingRule shardingRule, final int parametersSize) {
if (null == generatedKeyContext) {
return;
}
Optional<AutoGeneratedKeysToken> autoGeneratedKeysToken = findAutoGeneratedKeysToken(); Optional<AutoGeneratedKeysToken> autoGeneratedKeysToken = findAutoGeneratedKeysToken();
if (!autoGeneratedKeysToken.isPresent()) { if (!autoGeneratedKeysToken.isPresent()) {
return; return;
} }
String tableName = getTables().get(0).getName(); String tableName = getTables().get(0).getName();
ItemsToken valuesToken = new ItemsToken(autoGeneratedKeysToken.get().getBeginPosition()); ItemsToken valuesToken = new ItemsToken(autoGeneratedKeysToken.get().getBeginPosition());
int offset = 0; valuesToken.getItems().add("?");
for (String each : generatedKeyContext.getColumns()) { addCondition(shardingRule, new ShardingColumnContext(generatedKeyContext.getColumn(), tableName, true), new SQLPlaceholderExpr(parametersSize));
valuesToken.getItems().add("?");
addCondition(shardingRule, new ShardingColumnContext(each, tableName, true), new SQLPlaceholderExpr(parametersSize + offset));
offset++;
}
getSqlTokens().remove(autoGeneratedKeysToken.get()); getSqlTokens().remove(autoGeneratedKeysToken.get());
getSqlTokens().add(valuesToken); getSqlTokens().add(valuesToken);
} }
......
...@@ -23,6 +23,7 @@ import com.dangdang.ddframe.rdb.sharding.parsing.lexer.token.DefaultKeyword; ...@@ -23,6 +23,7 @@ import com.dangdang.ddframe.rdb.sharding.parsing.lexer.token.DefaultKeyword;
import com.dangdang.ddframe.rdb.sharding.parsing.lexer.token.Symbol; import com.dangdang.ddframe.rdb.sharding.parsing.lexer.token.Symbol;
import com.dangdang.ddframe.rdb.sharding.parsing.lexer.token.TokenType; import com.dangdang.ddframe.rdb.sharding.parsing.lexer.token.TokenType;
import com.dangdang.ddframe.rdb.sharding.parsing.parser.context.ConditionContext; import com.dangdang.ddframe.rdb.sharding.parsing.parser.context.ConditionContext;
import com.dangdang.ddframe.rdb.sharding.parsing.parser.context.GeneratedKeyContext;
import com.dangdang.ddframe.rdb.sharding.parsing.parser.context.InsertSQLContext; import com.dangdang.ddframe.rdb.sharding.parsing.parser.context.InsertSQLContext;
import com.dangdang.ddframe.rdb.sharding.parsing.parser.context.ShardingColumnContext; import com.dangdang.ddframe.rdb.sharding.parsing.parser.context.ShardingColumnContext;
import com.dangdang.ddframe.rdb.sharding.parsing.parser.exception.SQLParsingUnsupportedException; import com.dangdang.ddframe.rdb.sharding.parsing.parser.exception.SQLParsingUnsupportedException;
...@@ -167,7 +168,7 @@ public abstract class AbstractInsertParser implements SQLParser { ...@@ -167,7 +168,7 @@ public abstract class AbstractInsertParser implements SQLParser {
if (autoGeneratedKeyColumn.isPresent()) { if (autoGeneratedKeyColumn.isPresent()) {
if (!sqlContext.hasColumn(autoGeneratedKeyColumn.get())) { if (!sqlContext.hasColumn(autoGeneratedKeyColumn.get())) {
columnsToken.getItems().add(autoGeneratedKeyColumn.get()); columnsToken.getItems().add(autoGeneratedKeyColumn.get());
sqlContext.getGeneratedKeyContext().getColumns().add(autoGeneratedKeyColumn.get()); sqlContext.setGeneratedKeyContext(new GeneratedKeyContext(autoGeneratedKeyColumn.get()));
} }
} }
if (!columnsToken.getItems().isEmpty()) { if (!columnsToken.getItems().isEmpty()) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册