提交 97d6d3bb 编写于 作者: T terrymanu

for #2441, remove database type dependency of rewrite module

上级 c48e56c4
......@@ -19,7 +19,6 @@ package org.apache.shardingsphere.core;
import lombok.RequiredArgsConstructor;
import org.apache.shardingsphere.api.hint.HintManager;
import org.apache.shardingsphere.core.constant.DatabaseType;
import org.apache.shardingsphere.core.constant.properties.ShardingProperties;
import org.apache.shardingsphere.core.constant.properties.ShardingPropertiesConstant;
import org.apache.shardingsphere.core.metadata.ShardingMetaData;
......@@ -57,8 +56,6 @@ public abstract class BaseShardingEngine {
private final ShardingMetaData metaData;
private final DatabaseType databaseType;
private final SPIRoutingHook routingHook = new SPIRoutingHook();
/**
......@@ -107,8 +104,8 @@ public abstract class BaseShardingEngine {
private Collection<RouteUnit> rewriteAndConvert(final List<Object> parameters, final SQLRouteResult sqlRouteResult) {
SQLRewriteEngine rewriteEngine = new SQLRewriteEngine(shardingRule, sqlRouteResult.getSqlStatement(), parameters, sqlRouteResult.getRoutingResult().isSingleRouting());
ShardingParameterRewriter shardingParameterRewriter = new ShardingParameterRewriter(databaseType, sqlRouteResult);
ShardingSQLRewriter shardingSQLRewriter = new ShardingSQLRewriter(shardingRule, databaseType, sqlRouteResult, sqlRouteResult.getOptimizeResult());
ShardingParameterRewriter shardingParameterRewriter = new ShardingParameterRewriter(sqlRouteResult);
ShardingSQLRewriter shardingSQLRewriter = new ShardingSQLRewriter(shardingRule, sqlRouteResult, sqlRouteResult.getOptimizeResult());
EncryptSQLRewriter encryptSQLRewriter = new EncryptSQLRewriter(shardingRule.getEncryptRule().getEncryptorEngine(), sqlRouteResult.getSqlStatement(), sqlRouteResult.getOptimizeResult());
rewriteEngine.init(Collections.<ParameterRewriter>singletonList(shardingParameterRewriter), Arrays.asList(shardingSQLRewriter, encryptSQLRewriter));
Collection<RouteUnit> result = new LinkedHashSet<>();
......
......@@ -46,7 +46,7 @@ public final class PreparedQueryShardingEngine extends BaseShardingEngine {
public PreparedQueryShardingEngine(final String sql, final ShardingRule shardingRule, final ShardingProperties shardingProperties,
final ShardingMetaData metaData, final DatabaseType databaseType, final ParsingResultCache cache) {
super(shardingRule, shardingProperties, metaData, databaseType);
super(shardingRule, shardingProperties, metaData);
routingEngine = new PreparedStatementRoutingEngine(sql, shardingRule, metaData, databaseType, cache);
}
......
......@@ -46,7 +46,7 @@ public final class SimpleQueryShardingEngine extends BaseShardingEngine {
public SimpleQueryShardingEngine(final ShardingRule shardingRule,
final ShardingProperties shardingProperties, final ShardingMetaData metaData, final DatabaseType databaseType, final ParsingResultCache cache) {
super(shardingRule, shardingProperties, metaData, databaseType);
super(shardingRule, shardingProperties, metaData);
routingEngine = new StatementRoutingEngine(shardingRule, metaData, databaseType, cache);
}
......
......@@ -27,9 +27,9 @@ import org.apache.shardingsphere.core.parse.extractor.util.RuleName;
import org.apache.shardingsphere.core.parse.sql.segment.dml.expr.ExpressionSegment;
import org.apache.shardingsphere.core.parse.sql.segment.dml.expr.simple.LiteralExpressionSegment;
import org.apache.shardingsphere.core.parse.sql.segment.dml.expr.simple.ParameterMarkerExpressionSegment;
import org.apache.shardingsphere.core.parse.sql.segment.dml.pagination.limit.LimitValueSegment;
import org.apache.shardingsphere.core.parse.sql.segment.dml.pagination.limit.NumberLiteralLimitValueSegment;
import org.apache.shardingsphere.core.parse.sql.segment.dml.pagination.limit.ParameterMarkerLimitValueSegment;
import org.apache.shardingsphere.core.parse.sql.segment.dml.pagination.rownum.NumberLiteralRowNumberValueSegment;
import org.apache.shardingsphere.core.parse.sql.segment.dml.pagination.rownum.ParameterMarkerRowNumberValueSegment;
import org.apache.shardingsphere.core.parse.sql.segment.dml.pagination.rownum.RowNumberValueSegment;
import org.apache.shardingsphere.core.parse.sql.segment.dml.pagination.top.TopSegment;
import java.util.Map;
......@@ -52,21 +52,21 @@ public final class TopSelectItemExtractor implements OptionalSQLSegmentExtractor
ParserRuleContext topExprNode = ExtractorUtils.getFirstChildNode(topNode.get(), RuleName.EXPR);
Optional<? extends ExpressionSegment> topExpr = expressionExtractor.extract(topExprNode, parameterMarkerIndexes);
Preconditions.checkState(topExpr.isPresent());
Optional<LimitValueSegment> limitValueSegment = createLimitValueSegment(topExpr.get());
Preconditions.checkState(limitValueSegment.isPresent());
Optional<RowNumberValueSegment> rowNumberValueSegment = createRowNumberValueSegment(topExpr.get());
Preconditions.checkState(rowNumberValueSegment.isPresent());
ParserRuleContext rowNumberAliasNode = ExtractorUtils.getFirstChildNode(topNode.get().getParent(), RuleName.ALIAS);
return Optional.of(
new TopSegment(topNode.get().getStart().getStartIndex(), topNode.get().getStop().getStopIndex(), topNode.get().getText(), limitValueSegment.get(), rowNumberAliasNode.getText()));
new TopSegment(topNode.get().getStart().getStartIndex(), topNode.get().getStop().getStopIndex(), topNode.get().getText(), rowNumberValueSegment.get(), rowNumberAliasNode.getText()));
}
private Optional<LimitValueSegment> createLimitValueSegment(final ExpressionSegment topExpr) {
private Optional<RowNumberValueSegment> createRowNumberValueSegment(final ExpressionSegment topExpr) {
if (topExpr instanceof ParameterMarkerExpressionSegment) {
return Optional.<LimitValueSegment>of(
new ParameterMarkerLimitValueSegment(topExpr.getStartIndex(), topExpr.getStopIndex(), ((ParameterMarkerExpressionSegment) topExpr).getParameterMarkerIndex()));
return Optional.<RowNumberValueSegment>of(
new ParameterMarkerRowNumberValueSegment(topExpr.getStartIndex(), topExpr.getStopIndex(), ((ParameterMarkerExpressionSegment) topExpr).getParameterMarkerIndex(), false));
}
if (topExpr instanceof LiteralExpressionSegment && ((LiteralExpressionSegment) topExpr).getLiterals() instanceof Number) {
return Optional.<LimitValueSegment>of(
new NumberLiteralLimitValueSegment(topExpr.getStartIndex(), topExpr.getStopIndex(), ((Number) ((LiteralExpressionSegment) topExpr).getLiterals()).intValue()));
return Optional.<RowNumberValueSegment>of(
new NumberLiteralRowNumberValueSegment(topExpr.getStartIndex(), topExpr.getStopIndex(), ((Number) ((LiteralExpressionSegment) topExpr).getLiterals()).intValue(), false));
}
return Optional.absent();
}
......
......@@ -20,7 +20,7 @@ package org.apache.shardingsphere.core.parse.sql.segment.dml.pagination.top;
import lombok.Getter;
import lombok.RequiredArgsConstructor;
import org.apache.shardingsphere.core.parse.sql.segment.dml.item.SelectItemSegment;
import org.apache.shardingsphere.core.parse.sql.segment.dml.pagination.limit.LimitValueSegment;
import org.apache.shardingsphere.core.parse.sql.segment.dml.pagination.rownum.RowNumberValueSegment;
/**
* Top segment.
......@@ -37,7 +37,7 @@ public final class TopSegment implements SelectItemSegment {
private final String text;
private final LimitValueSegment top;
private final RowNumberValueSegment top;
private final String rowNumberAlias;
}
......@@ -18,7 +18,6 @@
package org.apache.shardingsphere.core.rewrite.rewriter.parameter;
import lombok.RequiredArgsConstructor;
import org.apache.shardingsphere.core.constant.DatabaseType;
import org.apache.shardingsphere.core.parse.sql.statement.dml.SelectStatement;
import org.apache.shardingsphere.core.rewrite.builder.ParameterBuilder;
import org.apache.shardingsphere.core.route.SQLRouteResult;
......@@ -31,8 +30,6 @@ import org.apache.shardingsphere.core.route.SQLRouteResult;
@RequiredArgsConstructor
public final class ShardingParameterRewriter implements ParameterRewriter {
private final DatabaseType databaseType;
private final SQLRouteResult sqlRouteResult;
@Override
......@@ -43,7 +40,7 @@ public final class ShardingParameterRewriter implements ParameterRewriter {
}
private void rewriteLimit(final SelectStatement selectStatement, final ParameterBuilder parameterBuilder) {
boolean isNeedFetchAll = (!selectStatement.getGroupByItems().isEmpty() || !selectStatement.getAggregationSelectItems().isEmpty()) && !selectStatement.isSameGroupByAndOrderByItems();
parameterBuilder.getReplacedIndexAndParameters().putAll(sqlRouteResult.getPagination().getRevisedParameters(isNeedFetchAll, databaseType.name()));
boolean isMaxRowCount = (!selectStatement.getGroupByItems().isEmpty() || !selectStatement.getAggregationSelectItems().isEmpty()) && !selectStatement.isSameGroupByAndOrderByItems();
parameterBuilder.getReplacedIndexAndParameters().putAll(sqlRouteResult.getPagination().getRevisedParameters(isMaxRowCount));
}
}
......@@ -19,7 +19,6 @@ package org.apache.shardingsphere.core.rewrite.rewriter.sql;
import com.google.common.base.Optional;
import com.google.common.base.Strings;
import org.apache.shardingsphere.core.constant.DatabaseType;
import org.apache.shardingsphere.core.optimize.result.OptimizeResult;
import org.apache.shardingsphere.core.optimize.result.insert.InsertOptimizeResult;
import org.apache.shardingsphere.core.parse.sql.segment.dml.order.item.OrderByItemSegment;
......@@ -63,15 +62,12 @@ public final class ShardingSQLRewriter implements SQLRewriter {
private final ShardingRule shardingRule;
private final DatabaseType databaseType;
private final SQLRouteResult sqlRouteResult;
private final InsertOptimizeResult insertOptimizeResult;
public ShardingSQLRewriter(final ShardingRule shardingRule, final DatabaseType databaseType, final SQLRouteResult sqlRouteResult, final OptimizeResult optimizeResult) {
public ShardingSQLRewriter(final ShardingRule shardingRule, final SQLRouteResult sqlRouteResult, final OptimizeResult optimizeResult) {
this.shardingRule = shardingRule;
this.databaseType = databaseType;
this.sqlRouteResult = sqlRouteResult;
this.insertOptimizeResult = getInsertOptimizeResult(optimizeResult);
}
......@@ -147,10 +143,7 @@ public final class ShardingSQLRewriter implements SQLRewriter {
if (!isRewrite) {
return rowCountToken.getRowCount();
}
if (isMaxRowCount(selectStatement)) {
return Integer.MAX_VALUE;
}
return pagination.isNeedRewriteRowCount(databaseType.name()) ? rowCountToken.getRowCount() + pagination.getOffsetValue() : rowCountToken.getRowCount();
return isMaxRowCount(selectStatement) ? Integer.MAX_VALUE : pagination.getRevisedRowCount();
}
private boolean isMaxRowCount(final SelectStatement selectStatement) {
......
......@@ -22,6 +22,7 @@ import org.apache.shardingsphere.core.parse.sql.segment.dml.pagination.NumberLit
import org.apache.shardingsphere.core.parse.sql.segment.dml.pagination.PaginationSegment;
import org.apache.shardingsphere.core.parse.sql.segment.dml.pagination.PaginationValueSegment;
import org.apache.shardingsphere.core.parse.sql.segment.dml.pagination.ParameterMarkerPaginationValueSegment;
import org.apache.shardingsphere.core.parse.sql.segment.dml.pagination.limit.LimitValueSegment;
import java.util.HashMap;
import java.util.List;
......@@ -74,35 +75,26 @@ public final class Pagination {
/**
* Get revise parameters.
*
* @param isFetchAll is fetch all data or not
* @param databaseType database type
* @param isMaxRowCount is max row count
* @return revised parameters and parameters' indexes
*/
public Map<Integer, Object> getRevisedParameters(final boolean isFetchAll, final String databaseType) {
public Map<Integer, Object> getRevisedParameters(final boolean isMaxRowCount) {
Map<Integer, Object> result = new HashMap<>(2, 1);
if (null != offset && offset.getSegment() instanceof ParameterMarkerPaginationValueSegment) {
result.put(((ParameterMarkerPaginationValueSegment) offset.getSegment()).getParameterIndex(), 0);
}
if (null != rowCount && rowCount.getSegment() instanceof ParameterMarkerPaginationValueSegment) {
result.put(((ParameterMarkerPaginationValueSegment) rowCount.getSegment()).getParameterIndex(), getRewriteRowCount(isFetchAll, databaseType));
result.put(((ParameterMarkerPaginationValueSegment) rowCount.getSegment()).getParameterIndex(), isMaxRowCount ? Integer.MAX_VALUE : getRevisedRowCount());
}
return result;
}
private int getRewriteRowCount(final boolean isFetchAll, final String databaseType) {
if (isFetchAll) {
return Integer.MAX_VALUE;
}
return isNeedRewriteRowCount(databaseType) ? getOffsetValue() + rowCount.getValue() : rowCount.getValue();
}
/**
* Judge is need rewrite row count or not.
* Get revised row count.
*
* @param databaseType database type
* @return is need rewrite row count or not
* @return revised row count
*/
public boolean isNeedRewriteRowCount(final String databaseType) {
return "MySQL".equals(databaseType) || "PostgreSQL".equals(databaseType) || "H2".equals(databaseType);
public int getRevisedRowCount() {
return rowCount.getSegment() instanceof LimitValueSegment ? getOffsetValue() + rowCount.getValue() : rowCount.getValue();
}
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册