提交 156a0248 编写于 作者: T terrymanu

add generateId to parameter: parser => router

上级 78aaf333
......@@ -226,9 +226,7 @@ public class ShardingStatement extends AbstractStatementAdapter {
return generatedKeyResultSet = routedStatements.iterator().next().getGeneratedKeys();
}
}
if (Statement.RETURN_GENERATED_KEYS != generatedKeyContext.getAutoGeneratedKeys() && null == generatedKeyContext.getColumnIndexes()
&& null == generatedKeyContext.getColumnNames()) {
if (Statement.RETURN_GENERATED_KEYS != generatedKeyContext.getAutoGeneratedKeys() && null == generatedKeyContext.getColumnIndexes() && null == generatedKeyContext.getColumnNames()) {
return generatedKeyResultSet = new GeneratedKeysResultSet();
}
return generatedKeyResultSet = new GeneratedKeysResultSet(generateAutoIncrementTable(), generatedKeyContext.getColumnNameToIndexMap(), this);
......
......@@ -67,7 +67,7 @@ public class GeneratedKeyContext {
* @param columnName 列名称
* @param value 键值
*/
public void putValue(final String columnName, final Object value) {
public void putValue(final String columnName, final Number value) {
valueTable.put(rowIndex, columnIndex, value);
columnNameToIndexMap.put(columnName, columnIndex);
columnIndex++;
......
......@@ -159,19 +159,21 @@ public abstract class AbstractInsertParser implements SQLStatementParser {
} while (sqlParser.skipIfEqual(Symbol.COMMA));
ItemsToken itemsToken = new ItemsToken(sqlParser.getLexer().getCurrentToken().getEndPosition() - sqlParser.getLexer().getCurrentToken().getLiterals().length());
int count = 0;
int offset = 0;
int parametersSize = parameters.size();
for (ShardingColumnContext each : shardingColumnContexts) {
if (each.isAutoIncrement()) {
Number autoIncrementedValue = getShardingRule().findTableRule(sqlContext.getTables().get(0).getName()).generateId(each.getColumnName());
Number generatedId = getShardingRule().findTableRule(sqlContext.getTables().get(0).getName()).generateId(each.getColumnName());
if (parameters.isEmpty()) {
itemsToken.getItems().add(autoIncrementedValue.toString());
sqlExprs.add(new SQLNumberExpr(autoIncrementedValue));
itemsToken.getItems().add(generatedId.toString());
sqlExprs.add(new SQLNumberExpr(generatedId));
} else {
itemsToken.getItems().add("?");
parameters.add(autoIncrementedValue);
sqlExprs.add(new SQLPlaceholderExpr(parameters.size() - 1));
offset++;
sqlExprs.add(new SQLPlaceholderExpr(parametersSize + offset - 1));
}
sqlContext.getGeneratedKeyContext().getColumns().add(each.getColumnName());
sqlContext.getGeneratedKeyContext().putValue(each.getColumnName(), autoIncrementedValue);
sqlContext.getGeneratedKeyContext().putValue(each.getColumnName(), generatedId);
}
if (getShardingRule().isShardingColumn(each)) {
conditionContext.add(new ConditionContext.Condition(each, sqlExprs.get(count)));
......
......@@ -28,6 +28,7 @@ import lombok.RequiredArgsConstructor;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.TreeSet;
/**
* 预解析功能的SQL路由器.
......@@ -55,6 +56,12 @@ public class PreparedSQLRouter {
public SQLRouteResult route(final List<Object> parameters) {
if (null == sqlContext) {
sqlContext = engine.parseSQL(logicSql, parameters);
if (sqlContext instanceof InsertSQLContext) {
InsertSQLContext insertSQLContext = (InsertSQLContext) sqlContext;
for (Integer each : new TreeSet<>(insertSQLContext.getGeneratedKeyContext().getColumnNameToIndexMap().values())) {
parameters.add(insertSQLContext.getGeneratedKeyContext().getValueTable().get(0, each));
}
}
} else {
List<Number> generatedIds = generateId();
parameters.addAll(generatedIds);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册