提交 97d6d3bb 编写于 作者: T terrymanu

for #2441, remove database type dependency of rewrite module

上级 c48e56c4
...@@ -19,7 +19,6 @@ package org.apache.shardingsphere.core; ...@@ -19,7 +19,6 @@ package org.apache.shardingsphere.core;
import lombok.RequiredArgsConstructor; import lombok.RequiredArgsConstructor;
import org.apache.shardingsphere.api.hint.HintManager; import org.apache.shardingsphere.api.hint.HintManager;
import org.apache.shardingsphere.core.constant.DatabaseType;
import org.apache.shardingsphere.core.constant.properties.ShardingProperties; import org.apache.shardingsphere.core.constant.properties.ShardingProperties;
import org.apache.shardingsphere.core.constant.properties.ShardingPropertiesConstant; import org.apache.shardingsphere.core.constant.properties.ShardingPropertiesConstant;
import org.apache.shardingsphere.core.metadata.ShardingMetaData; import org.apache.shardingsphere.core.metadata.ShardingMetaData;
...@@ -57,8 +56,6 @@ public abstract class BaseShardingEngine { ...@@ -57,8 +56,6 @@ public abstract class BaseShardingEngine {
private final ShardingMetaData metaData; private final ShardingMetaData metaData;
private final DatabaseType databaseType;
private final SPIRoutingHook routingHook = new SPIRoutingHook(); private final SPIRoutingHook routingHook = new SPIRoutingHook();
/** /**
...@@ -107,8 +104,8 @@ public abstract class BaseShardingEngine { ...@@ -107,8 +104,8 @@ public abstract class BaseShardingEngine {
private Collection<RouteUnit> rewriteAndConvert(final List<Object> parameters, final SQLRouteResult sqlRouteResult) { private Collection<RouteUnit> rewriteAndConvert(final List<Object> parameters, final SQLRouteResult sqlRouteResult) {
SQLRewriteEngine rewriteEngine = new SQLRewriteEngine(shardingRule, sqlRouteResult.getSqlStatement(), parameters, sqlRouteResult.getRoutingResult().isSingleRouting()); SQLRewriteEngine rewriteEngine = new SQLRewriteEngine(shardingRule, sqlRouteResult.getSqlStatement(), parameters, sqlRouteResult.getRoutingResult().isSingleRouting());
ShardingParameterRewriter shardingParameterRewriter = new ShardingParameterRewriter(databaseType, sqlRouteResult); ShardingParameterRewriter shardingParameterRewriter = new ShardingParameterRewriter(sqlRouteResult);
ShardingSQLRewriter shardingSQLRewriter = new ShardingSQLRewriter(shardingRule, databaseType, sqlRouteResult, sqlRouteResult.getOptimizeResult()); ShardingSQLRewriter shardingSQLRewriter = new ShardingSQLRewriter(shardingRule, sqlRouteResult, sqlRouteResult.getOptimizeResult());
EncryptSQLRewriter encryptSQLRewriter = new EncryptSQLRewriter(shardingRule.getEncryptRule().getEncryptorEngine(), sqlRouteResult.getSqlStatement(), sqlRouteResult.getOptimizeResult()); EncryptSQLRewriter encryptSQLRewriter = new EncryptSQLRewriter(shardingRule.getEncryptRule().getEncryptorEngine(), sqlRouteResult.getSqlStatement(), sqlRouteResult.getOptimizeResult());
rewriteEngine.init(Collections.<ParameterRewriter>singletonList(shardingParameterRewriter), Arrays.asList(shardingSQLRewriter, encryptSQLRewriter)); rewriteEngine.init(Collections.<ParameterRewriter>singletonList(shardingParameterRewriter), Arrays.asList(shardingSQLRewriter, encryptSQLRewriter));
Collection<RouteUnit> result = new LinkedHashSet<>(); Collection<RouteUnit> result = new LinkedHashSet<>();
......
...@@ -46,7 +46,7 @@ public final class PreparedQueryShardingEngine extends BaseShardingEngine { ...@@ -46,7 +46,7 @@ public final class PreparedQueryShardingEngine extends BaseShardingEngine {
public PreparedQueryShardingEngine(final String sql, final ShardingRule shardingRule, final ShardingProperties shardingProperties, public PreparedQueryShardingEngine(final String sql, final ShardingRule shardingRule, final ShardingProperties shardingProperties,
final ShardingMetaData metaData, final DatabaseType databaseType, final ParsingResultCache cache) { final ShardingMetaData metaData, final DatabaseType databaseType, final ParsingResultCache cache) {
super(shardingRule, shardingProperties, metaData, databaseType); super(shardingRule, shardingProperties, metaData);
routingEngine = new PreparedStatementRoutingEngine(sql, shardingRule, metaData, databaseType, cache); routingEngine = new PreparedStatementRoutingEngine(sql, shardingRule, metaData, databaseType, cache);
} }
......
...@@ -46,7 +46,7 @@ public final class SimpleQueryShardingEngine extends BaseShardingEngine { ...@@ -46,7 +46,7 @@ public final class SimpleQueryShardingEngine extends BaseShardingEngine {
public SimpleQueryShardingEngine(final ShardingRule shardingRule, public SimpleQueryShardingEngine(final ShardingRule shardingRule,
final ShardingProperties shardingProperties, final ShardingMetaData metaData, final DatabaseType databaseType, final ParsingResultCache cache) { final ShardingProperties shardingProperties, final ShardingMetaData metaData, final DatabaseType databaseType, final ParsingResultCache cache) {
super(shardingRule, shardingProperties, metaData, databaseType); super(shardingRule, shardingProperties, metaData);
routingEngine = new StatementRoutingEngine(shardingRule, metaData, databaseType, cache); routingEngine = new StatementRoutingEngine(shardingRule, metaData, databaseType, cache);
} }
......
...@@ -27,9 +27,9 @@ import org.apache.shardingsphere.core.parse.extractor.util.RuleName; ...@@ -27,9 +27,9 @@ import org.apache.shardingsphere.core.parse.extractor.util.RuleName;
import org.apache.shardingsphere.core.parse.sql.segment.dml.expr.ExpressionSegment; import org.apache.shardingsphere.core.parse.sql.segment.dml.expr.ExpressionSegment;
import org.apache.shardingsphere.core.parse.sql.segment.dml.expr.simple.LiteralExpressionSegment; import org.apache.shardingsphere.core.parse.sql.segment.dml.expr.simple.LiteralExpressionSegment;
import org.apache.shardingsphere.core.parse.sql.segment.dml.expr.simple.ParameterMarkerExpressionSegment; import org.apache.shardingsphere.core.parse.sql.segment.dml.expr.simple.ParameterMarkerExpressionSegment;
import org.apache.shardingsphere.core.parse.sql.segment.dml.pagination.limit.LimitValueSegment; import org.apache.shardingsphere.core.parse.sql.segment.dml.pagination.rownum.NumberLiteralRowNumberValueSegment;
import org.apache.shardingsphere.core.parse.sql.segment.dml.pagination.limit.NumberLiteralLimitValueSegment; import org.apache.shardingsphere.core.parse.sql.segment.dml.pagination.rownum.ParameterMarkerRowNumberValueSegment;
import org.apache.shardingsphere.core.parse.sql.segment.dml.pagination.limit.ParameterMarkerLimitValueSegment; import org.apache.shardingsphere.core.parse.sql.segment.dml.pagination.rownum.RowNumberValueSegment;
import org.apache.shardingsphere.core.parse.sql.segment.dml.pagination.top.TopSegment; import org.apache.shardingsphere.core.parse.sql.segment.dml.pagination.top.TopSegment;
import java.util.Map; import java.util.Map;
...@@ -52,21 +52,21 @@ public final class TopSelectItemExtractor implements OptionalSQLSegmentExtractor ...@@ -52,21 +52,21 @@ public final class TopSelectItemExtractor implements OptionalSQLSegmentExtractor
ParserRuleContext topExprNode = ExtractorUtils.getFirstChildNode(topNode.get(), RuleName.EXPR); ParserRuleContext topExprNode = ExtractorUtils.getFirstChildNode(topNode.get(), RuleName.EXPR);
Optional<? extends ExpressionSegment> topExpr = expressionExtractor.extract(topExprNode, parameterMarkerIndexes); Optional<? extends ExpressionSegment> topExpr = expressionExtractor.extract(topExprNode, parameterMarkerIndexes);
Preconditions.checkState(topExpr.isPresent()); Preconditions.checkState(topExpr.isPresent());
Optional<LimitValueSegment> limitValueSegment = createLimitValueSegment(topExpr.get()); Optional<RowNumberValueSegment> rowNumberValueSegment = createRowNumberValueSegment(topExpr.get());
Preconditions.checkState(limitValueSegment.isPresent()); Preconditions.checkState(rowNumberValueSegment.isPresent());
ParserRuleContext rowNumberAliasNode = ExtractorUtils.getFirstChildNode(topNode.get().getParent(), RuleName.ALIAS); ParserRuleContext rowNumberAliasNode = ExtractorUtils.getFirstChildNode(topNode.get().getParent(), RuleName.ALIAS);
return Optional.of( return Optional.of(
new TopSegment(topNode.get().getStart().getStartIndex(), topNode.get().getStop().getStopIndex(), topNode.get().getText(), limitValueSegment.get(), rowNumberAliasNode.getText())); new TopSegment(topNode.get().getStart().getStartIndex(), topNode.get().getStop().getStopIndex(), topNode.get().getText(), rowNumberValueSegment.get(), rowNumberAliasNode.getText()));
} }
private Optional<LimitValueSegment> createLimitValueSegment(final ExpressionSegment topExpr) { private Optional<RowNumberValueSegment> createRowNumberValueSegment(final ExpressionSegment topExpr) {
if (topExpr instanceof ParameterMarkerExpressionSegment) { if (topExpr instanceof ParameterMarkerExpressionSegment) {
return Optional.<LimitValueSegment>of( return Optional.<RowNumberValueSegment>of(
new ParameterMarkerLimitValueSegment(topExpr.getStartIndex(), topExpr.getStopIndex(), ((ParameterMarkerExpressionSegment) topExpr).getParameterMarkerIndex())); new ParameterMarkerRowNumberValueSegment(topExpr.getStartIndex(), topExpr.getStopIndex(), ((ParameterMarkerExpressionSegment) topExpr).getParameterMarkerIndex(), false));
} }
if (topExpr instanceof LiteralExpressionSegment && ((LiteralExpressionSegment) topExpr).getLiterals() instanceof Number) { if (topExpr instanceof LiteralExpressionSegment && ((LiteralExpressionSegment) topExpr).getLiterals() instanceof Number) {
return Optional.<LimitValueSegment>of( return Optional.<RowNumberValueSegment>of(
new NumberLiteralLimitValueSegment(topExpr.getStartIndex(), topExpr.getStopIndex(), ((Number) ((LiteralExpressionSegment) topExpr).getLiterals()).intValue())); new NumberLiteralRowNumberValueSegment(topExpr.getStartIndex(), topExpr.getStopIndex(), ((Number) ((LiteralExpressionSegment) topExpr).getLiterals()).intValue(), false));
} }
return Optional.absent(); return Optional.absent();
} }
......
...@@ -20,7 +20,7 @@ package org.apache.shardingsphere.core.parse.sql.segment.dml.pagination.top; ...@@ -20,7 +20,7 @@ package org.apache.shardingsphere.core.parse.sql.segment.dml.pagination.top;
import lombok.Getter; import lombok.Getter;
import lombok.RequiredArgsConstructor; import lombok.RequiredArgsConstructor;
import org.apache.shardingsphere.core.parse.sql.segment.dml.item.SelectItemSegment; import org.apache.shardingsphere.core.parse.sql.segment.dml.item.SelectItemSegment;
import org.apache.shardingsphere.core.parse.sql.segment.dml.pagination.limit.LimitValueSegment; import org.apache.shardingsphere.core.parse.sql.segment.dml.pagination.rownum.RowNumberValueSegment;
/** /**
* Top segment. * Top segment.
...@@ -37,7 +37,7 @@ public final class TopSegment implements SelectItemSegment { ...@@ -37,7 +37,7 @@ public final class TopSegment implements SelectItemSegment {
private final String text; private final String text;
private final LimitValueSegment top; private final RowNumberValueSegment top;
private final String rowNumberAlias; private final String rowNumberAlias;
} }
...@@ -18,7 +18,6 @@ ...@@ -18,7 +18,6 @@
package org.apache.shardingsphere.core.rewrite.rewriter.parameter; package org.apache.shardingsphere.core.rewrite.rewriter.parameter;
import lombok.RequiredArgsConstructor; import lombok.RequiredArgsConstructor;
import org.apache.shardingsphere.core.constant.DatabaseType;
import org.apache.shardingsphere.core.parse.sql.statement.dml.SelectStatement; import org.apache.shardingsphere.core.parse.sql.statement.dml.SelectStatement;
import org.apache.shardingsphere.core.rewrite.builder.ParameterBuilder; import org.apache.shardingsphere.core.rewrite.builder.ParameterBuilder;
import org.apache.shardingsphere.core.route.SQLRouteResult; import org.apache.shardingsphere.core.route.SQLRouteResult;
...@@ -31,8 +30,6 @@ import org.apache.shardingsphere.core.route.SQLRouteResult; ...@@ -31,8 +30,6 @@ import org.apache.shardingsphere.core.route.SQLRouteResult;
@RequiredArgsConstructor @RequiredArgsConstructor
public final class ShardingParameterRewriter implements ParameterRewriter { public final class ShardingParameterRewriter implements ParameterRewriter {
private final DatabaseType databaseType;
private final SQLRouteResult sqlRouteResult; private final SQLRouteResult sqlRouteResult;
@Override @Override
...@@ -43,7 +40,7 @@ public final class ShardingParameterRewriter implements ParameterRewriter { ...@@ -43,7 +40,7 @@ public final class ShardingParameterRewriter implements ParameterRewriter {
} }
private void rewriteLimit(final SelectStatement selectStatement, final ParameterBuilder parameterBuilder) { private void rewriteLimit(final SelectStatement selectStatement, final ParameterBuilder parameterBuilder) {
boolean isNeedFetchAll = (!selectStatement.getGroupByItems().isEmpty() || !selectStatement.getAggregationSelectItems().isEmpty()) && !selectStatement.isSameGroupByAndOrderByItems(); boolean isMaxRowCount = (!selectStatement.getGroupByItems().isEmpty() || !selectStatement.getAggregationSelectItems().isEmpty()) && !selectStatement.isSameGroupByAndOrderByItems();
parameterBuilder.getReplacedIndexAndParameters().putAll(sqlRouteResult.getPagination().getRevisedParameters(isNeedFetchAll, databaseType.name())); parameterBuilder.getReplacedIndexAndParameters().putAll(sqlRouteResult.getPagination().getRevisedParameters(isMaxRowCount));
} }
} }
...@@ -19,7 +19,6 @@ package org.apache.shardingsphere.core.rewrite.rewriter.sql; ...@@ -19,7 +19,6 @@ package org.apache.shardingsphere.core.rewrite.rewriter.sql;
import com.google.common.base.Optional; import com.google.common.base.Optional;
import com.google.common.base.Strings; import com.google.common.base.Strings;
import org.apache.shardingsphere.core.constant.DatabaseType;
import org.apache.shardingsphere.core.optimize.result.OptimizeResult; import org.apache.shardingsphere.core.optimize.result.OptimizeResult;
import org.apache.shardingsphere.core.optimize.result.insert.InsertOptimizeResult; import org.apache.shardingsphere.core.optimize.result.insert.InsertOptimizeResult;
import org.apache.shardingsphere.core.parse.sql.segment.dml.order.item.OrderByItemSegment; import org.apache.shardingsphere.core.parse.sql.segment.dml.order.item.OrderByItemSegment;
...@@ -63,15 +62,12 @@ public final class ShardingSQLRewriter implements SQLRewriter { ...@@ -63,15 +62,12 @@ public final class ShardingSQLRewriter implements SQLRewriter {
private final ShardingRule shardingRule; private final ShardingRule shardingRule;
private final DatabaseType databaseType;
private final SQLRouteResult sqlRouteResult; private final SQLRouteResult sqlRouteResult;
private final InsertOptimizeResult insertOptimizeResult; private final InsertOptimizeResult insertOptimizeResult;
public ShardingSQLRewriter(final ShardingRule shardingRule, final DatabaseType databaseType, final SQLRouteResult sqlRouteResult, final OptimizeResult optimizeResult) { public ShardingSQLRewriter(final ShardingRule shardingRule, final SQLRouteResult sqlRouteResult, final OptimizeResult optimizeResult) {
this.shardingRule = shardingRule; this.shardingRule = shardingRule;
this.databaseType = databaseType;
this.sqlRouteResult = sqlRouteResult; this.sqlRouteResult = sqlRouteResult;
this.insertOptimizeResult = getInsertOptimizeResult(optimizeResult); this.insertOptimizeResult = getInsertOptimizeResult(optimizeResult);
} }
...@@ -147,10 +143,7 @@ public final class ShardingSQLRewriter implements SQLRewriter { ...@@ -147,10 +143,7 @@ public final class ShardingSQLRewriter implements SQLRewriter {
if (!isRewrite) { if (!isRewrite) {
return rowCountToken.getRowCount(); return rowCountToken.getRowCount();
} }
if (isMaxRowCount(selectStatement)) { return isMaxRowCount(selectStatement) ? Integer.MAX_VALUE : pagination.getRevisedRowCount();
return Integer.MAX_VALUE;
}
return pagination.isNeedRewriteRowCount(databaseType.name()) ? rowCountToken.getRowCount() + pagination.getOffsetValue() : rowCountToken.getRowCount();
} }
private boolean isMaxRowCount(final SelectStatement selectStatement) { private boolean isMaxRowCount(final SelectStatement selectStatement) {
......
...@@ -22,6 +22,7 @@ import org.apache.shardingsphere.core.parse.sql.segment.dml.pagination.NumberLit ...@@ -22,6 +22,7 @@ import org.apache.shardingsphere.core.parse.sql.segment.dml.pagination.NumberLit
import org.apache.shardingsphere.core.parse.sql.segment.dml.pagination.PaginationSegment; import org.apache.shardingsphere.core.parse.sql.segment.dml.pagination.PaginationSegment;
import org.apache.shardingsphere.core.parse.sql.segment.dml.pagination.PaginationValueSegment; import org.apache.shardingsphere.core.parse.sql.segment.dml.pagination.PaginationValueSegment;
import org.apache.shardingsphere.core.parse.sql.segment.dml.pagination.ParameterMarkerPaginationValueSegment; import org.apache.shardingsphere.core.parse.sql.segment.dml.pagination.ParameterMarkerPaginationValueSegment;
import org.apache.shardingsphere.core.parse.sql.segment.dml.pagination.limit.LimitValueSegment;
import java.util.HashMap; import java.util.HashMap;
import java.util.List; import java.util.List;
...@@ -74,35 +75,26 @@ public final class Pagination { ...@@ -74,35 +75,26 @@ public final class Pagination {
/** /**
* Get revise parameters. * Get revise parameters.
* *
* @param isFetchAll is fetch all data or not * @param isMaxRowCount is max row count
* @param databaseType database type
* @return revised parameters and parameters' indexes * @return revised parameters and parameters' indexes
*/ */
public Map<Integer, Object> getRevisedParameters(final boolean isFetchAll, final String databaseType) { public Map<Integer, Object> getRevisedParameters(final boolean isMaxRowCount) {
Map<Integer, Object> result = new HashMap<>(2, 1); Map<Integer, Object> result = new HashMap<>(2, 1);
if (null != offset && offset.getSegment() instanceof ParameterMarkerPaginationValueSegment) { if (null != offset && offset.getSegment() instanceof ParameterMarkerPaginationValueSegment) {
result.put(((ParameterMarkerPaginationValueSegment) offset.getSegment()).getParameterIndex(), 0); result.put(((ParameterMarkerPaginationValueSegment) offset.getSegment()).getParameterIndex(), 0);
} }
if (null != rowCount && rowCount.getSegment() instanceof ParameterMarkerPaginationValueSegment) { if (null != rowCount && rowCount.getSegment() instanceof ParameterMarkerPaginationValueSegment) {
result.put(((ParameterMarkerPaginationValueSegment) rowCount.getSegment()).getParameterIndex(), getRewriteRowCount(isFetchAll, databaseType)); result.put(((ParameterMarkerPaginationValueSegment) rowCount.getSegment()).getParameterIndex(), isMaxRowCount ? Integer.MAX_VALUE : getRevisedRowCount());
} }
return result; return result;
} }
private int getRewriteRowCount(final boolean isFetchAll, final String databaseType) {
if (isFetchAll) {
return Integer.MAX_VALUE;
}
return isNeedRewriteRowCount(databaseType) ? getOffsetValue() + rowCount.getValue() : rowCount.getValue();
}
/** /**
* Judge is need rewrite row count or not. * Get revised row count.
* *
* @param databaseType database type * @return revised row count
* @return is need rewrite row count or not
*/ */
public boolean isNeedRewriteRowCount(final String databaseType) { public int getRevisedRowCount() {
return "MySQL".equals(databaseType) || "PostgreSQL".equals(databaseType) || "H2".equals(databaseType); return rowCount.getSegment() instanceof LimitValueSegment ? getOffsetValue() + rowCount.getValue() : rowCount.getValue();
} }
} }
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册