提交 b0f79f0e 编写于 作者: T terrymanu

fixed # 292

上级 ef14a227
......@@ -8,6 +8,12 @@ next = "/00-overview/contribution/"
+++
## 1.5.0.M3
### 缺陷修正
1. [ISSUE #292](https://github.com/dangdangdotcom/sharding-jdbc/issues/292) 内存方式处理GROUP BY语句如有分页信息则需改写
## 1.5.0.M2
### 功能提升
......
......@@ -39,6 +39,7 @@ import com.dangdang.ddframe.rdb.sharding.parsing.parser.expression.SQLTextExpres
import com.dangdang.ddframe.rdb.sharding.parsing.parser.statement.SQLStatement;
import com.dangdang.ddframe.rdb.sharding.parsing.parser.statement.select.SelectStatement;
import com.dangdang.ddframe.rdb.sharding.parsing.parser.token.OffsetToken;
import com.dangdang.ddframe.rdb.sharding.parsing.parser.token.RowCountToken;
import com.dangdang.ddframe.rdb.sharding.parsing.parser.token.TableToken;
import com.dangdang.ddframe.rdb.sharding.util.SQLUtil;
import com.google.common.base.Optional;
......@@ -350,7 +351,10 @@ public class SQLParser extends AbstractParser {
}
if (Symbol.LT == symbol || Symbol.LT_EQ == symbol) {
if (sqlExpression instanceof SQLNumberExpression) {
selectStatement.getLimit().setRowCount(new LimitValue(((SQLNumberExpression) sqlExpression).getNumber().intValue(), -1));
int rowCount = ((SQLNumberExpression) sqlExpression).getNumber().intValue();
selectStatement.getLimit().setRowCount(new LimitValue(rowCount, -1));
selectStatement.getSqlTokens().add(
new RowCountToken(getLexer().getCurrentToken().getEndPosition() - String.valueOf(rowCount).length() - getLexer().getCurrentToken().getLiterals().length(), rowCount));
} else if (sqlExpression instanceof SQLPlaceholderExpression) {
selectStatement.getLimit().setRowCount(new LimitValue(-1, ((SQLPlaceholderExpression) sqlExpression).getIndex()));
}
......
......@@ -68,11 +68,12 @@ public final class Limit {
*
* @param parameters 参数
* @param isRewrite 是否重写参数
* @param isFetchAll 是否获取所有数据
*/
public void processParameters(final List<Object> parameters, final boolean isRewrite) {
public void processParameters(final List<Object> parameters, final boolean isRewrite, final boolean isFetchAll) {
fill(parameters);
if (isRewrite) {
rewrite(parameters);
rewrite(parameters, isFetchAll);
}
}
......@@ -92,10 +93,12 @@ public final class Limit {
}
}
private void rewrite(final List<Object> parameters) {
private void rewrite(final List<Object> parameters, final boolean isFetchAll) {
int rewriteOffset = 0;
int rewriteRowCount;
if (rowCountRewriteFlag) {
if (isFetchAll) {
rewriteRowCount = Integer.MAX_VALUE;
} else if (rowCountRewriteFlag) {
rewriteRowCount = null == rowCount ? -1 : getOffsetValue() + rowCount.getValue();
} else {
rewriteRowCount = rowCount.getValue();
......
......@@ -33,6 +33,7 @@ import com.dangdang.ddframe.rdb.sharding.parsing.parser.expression.SQLExpression
import com.dangdang.ddframe.rdb.sharding.parsing.parser.expression.SQLNumberExpression;
import com.dangdang.ddframe.rdb.sharding.parsing.parser.expression.SQLPlaceholderExpression;
import com.dangdang.ddframe.rdb.sharding.parsing.parser.statement.select.SelectStatement;
import com.dangdang.ddframe.rdb.sharding.parsing.parser.token.RowCountToken;
import com.google.common.base.Optional;
/**
......@@ -68,11 +69,14 @@ public final class SQLServerParser extends SQLParser {
skipIfEqual(Symbol.LEFT_PAREN);
SQLExpression sqlExpression = parseExpression();
skipIfEqual(Symbol.RIGHT_PAREN);
LimitValue rowCount;
LimitValue rowCountValue;
if (sqlExpression instanceof SQLNumberExpression) {
rowCount = new LimitValue(((SQLNumberExpression) sqlExpression).getNumber().intValue(), -1);
int rowCount = ((SQLNumberExpression) sqlExpression).getNumber().intValue();
rowCountValue = new LimitValue(rowCount, -1);
selectStatement.getSqlTokens().add(
new RowCountToken(getLexer().getCurrentToken().getEndPosition() - String.valueOf(rowCount).length() - getLexer().getCurrentToken().getLiterals().length(), rowCount));
} else if (sqlExpression instanceof SQLPlaceholderExpression) {
rowCount = new LimitValue(-1, ((SQLPlaceholderExpression) sqlExpression).getIndex());
rowCountValue = new LimitValue(-1, ((SQLPlaceholderExpression) sqlExpression).getIndex());
} else {
throw new SQLParsingException(getLexer());
}
......@@ -81,10 +85,10 @@ public final class SQLServerParser extends SQLParser {
}
if (null == selectStatement.getLimit()) {
Limit limit = new Limit(false);
limit.setRowCount(rowCount);
limit.setRowCount(rowCountValue);
selectStatement.setLimit(limit);
} else {
selectStatement.getLimit().setRowCount(rowCount);
selectStatement.getLimit().setRowCount(rowCountValue);
}
}
}
......
......@@ -32,7 +32,6 @@ import com.dangdang.ddframe.rdb.sharding.routing.type.TableUnit;
import com.dangdang.ddframe.rdb.sharding.routing.type.complex.CartesianTableReference;
import com.google.common.base.Optional;
import java.util.Collection;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
......@@ -53,20 +52,13 @@ public final class SQLRewriteEngine {
private final List<SQLToken> sqlTokens = new LinkedList<>();
private final Collection<String> tableNames;
private final Limit limit;
private final SQLStatement sqlStatement;
public SQLRewriteEngine(final ShardingRule shardingRule, final String originalSQL, final SQLStatement sqlStatement) {
this.shardingRule = shardingRule;
this.originalSQL = originalSQL;
this.sqlStatement = sqlStatement;
sqlTokens.addAll(sqlStatement.getSqlTokens());
tableNames = sqlStatement.getTables().getTableNames();
if (sqlStatement instanceof SelectStatement) {
limit = ((SelectStatement) sqlStatement).getLimit();
} else {
limit = null;
}
}
/**
......@@ -112,7 +104,7 @@ public final class SQLRewriteEngine {
}
private void appendTableToken(final SQLBuilder sqlBuilder, final TableToken tableToken, final int count, final List<SQLToken> sqlTokens) {
String tableName = tableNames.contains(tableToken.getTableName()) ? tableToken.getTableName() : tableToken.getOriginalLiterals();
String tableName = sqlStatement.getTables().getTableNames().contains(tableToken.getTableName()) ? tableToken.getTableName() : tableToken.getOriginalLiterals();
sqlBuilder.appendTable(tableName);
int beginPosition = tableToken.getBeginPosition() + tableToken.getOriginalLiterals().length();
int endPosition = sqlTokens.size() - 1 == count ? originalSQL.length() : sqlTokens.get(count + 1).getBeginPosition();
......@@ -130,7 +122,15 @@ public final class SQLRewriteEngine {
}
private void appendLimitRowCount(final SQLBuilder sqlBuilder, final RowCountToken rowCountToken, final int count, final List<SQLToken> sqlTokens, final boolean isRewrite) {
sqlBuilder.appendLiterals(isRewrite ? String.valueOf(rowCountToken.getRowCount() + limit.getOffsetValue()) : String.valueOf(rowCountToken.getRowCount()));
SelectStatement selectStatement = (SelectStatement) sqlStatement;
Limit limit = selectStatement.getLimit();
if (!isRewrite) {
sqlBuilder.appendLiterals(String.valueOf(rowCountToken.getRowCount()));
} else if ((!selectStatement.getGroupByItems().isEmpty() || !selectStatement.getAggregationSelectItems().isEmpty()) && !selectStatement.isSameGroupByAndOrderByItems()) {
sqlBuilder.appendLiterals(String.valueOf(Integer.MAX_VALUE));
} else {
sqlBuilder.appendLiterals(String.valueOf(limit.isRowCountRewriteFlag() ? rowCountToken.getRowCount() + limit.getOffsetValue() : rowCountToken.getRowCount()));
}
int beginPosition = rowCountToken.getBeginPosition() + String.valueOf(rowCountToken.getRowCount()).length();
int endPosition = sqlTokens.size() - 1 == count ? originalSQL.length() : sqlTokens.get(count + 1).getBeginPosition();
sqlBuilder.appendLiterals(originalSQL.substring(beginPosition, endPosition));
......@@ -189,7 +189,7 @@ public final class SQLRewriteEngine {
private Map<String, String> getBindingTableTokens(final TableUnit tableUnit, final BindingTableRule bindingTableRule) {
Map<String, String> result = new HashMap<>();
for (String eachTable : tableNames) {
for (String eachTable : sqlStatement.getTables().getTableNames()) {
if (!eachTable.equalsIgnoreCase(tableUnit.getLogicTableName()) && bindingTableRule.hasLogicTable(eachTable)) {
result.put(eachTable, bindingTableRule.getBindingActualTable(tableUnit.getDataSourceName(), eachTable, tableUnit.getActualTableName()));
}
......
......@@ -88,7 +88,9 @@ public final class ParsingSQLRouter implements SQLRouter {
SQLRewriteEngine rewriteEngine = new SQLRewriteEngine(shardingRule, logicSQL, sqlStatement);
boolean isSingleRouting = routingResult.isSingleRouting();
if (sqlStatement instanceof SelectStatement && null != ((SelectStatement) sqlStatement).getLimit()) {
((SelectStatement) sqlStatement).getLimit().processParameters(parameters, !isSingleRouting);
SelectStatement selectStatement = (SelectStatement) sqlStatement;
boolean isNeedFetchAll = (!selectStatement.getGroupByItems().isEmpty() || !selectStatement.getAggregationSelectItems().isEmpty()) && !selectStatement.isSameGroupByAndOrderByItems();
selectStatement.getLimit().processParameters(parameters, !isSingleRouting, isNeedFetchAll);
}
SQLBuilder sqlBuilder = rewriteEngine.rewrite(!isSingleRouting);
if (routingResult instanceof CartesianRoutingResult) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册