diff --git a/shardingsphere-features/shardingsphere-sharding/shardingsphere-sharding-rewrite/src/main/java/org/apache/shardingsphere/sharding/rewrite/token/generator/impl/ProjectionsTokenGenerator.java b/shardingsphere-features/shardingsphere-sharding/shardingsphere-sharding-rewrite/src/main/java/org/apache/shardingsphere/sharding/rewrite/token/generator/impl/ProjectionsTokenGenerator.java index e635d26e7c0946abf4eacc2646715e85e71e8be9..426e32cfcef1f7b18c66a8513a3dbabf0e081348 100644 --- a/shardingsphere-features/shardingsphere-sharding/shardingsphere-sharding-rewrite/src/main/java/org/apache/shardingsphere/sharding/rewrite/token/generator/impl/ProjectionsTokenGenerator.java +++ b/shardingsphere-features/shardingsphere-sharding/shardingsphere-sharding-rewrite/src/main/java/org/apache/shardingsphere/sharding/rewrite/token/generator/impl/ProjectionsTokenGenerator.java @@ -67,17 +67,25 @@ public final class ProjectionsTokenGenerator implements OptionalSQLTokenGenerato private Map> getDerivedProjectionTexts(final SelectStatementContext selectStatementContext) { Map> result = new HashMap<>(); for (RouteUnit routeUnit : routeContext.getRouteResult().getRouteUnits()) { - result.put(routeUnit, new LinkedList<>()); - for (Projection each : selectStatementContext.getProjectionsContext().getProjections()) { - if (each instanceof AggregationProjection && !((AggregationProjection) each).getDerivedAggregationProjections().isEmpty()) { - result.get(routeUnit).addAll(((AggregationProjection) each).getDerivedAggregationProjections().stream().map(this::getDerivedProjectionText).collect(Collectors.toList())); - } else if (each instanceof DerivedProjection && ((DerivedProjection) each).getDerivedProjection() instanceof ColumnOrderByItemSegment) { - TableExtractor tableExtractor = new TableExtractor(); - tableExtractor.extractTablesFromSelect(selectStatementContext.getSqlStatement()); - result.get(routeUnit).add(getDerivedProjectionTextFromColumnOrderByItemSegment((DerivedProjection) each, tableExtractor, routeUnit)); - } else if (each instanceof DerivedProjection) { - result.get(routeUnit).add(getDerivedProjectionText(each)); - } + Collection projectionTexts = getDerivedProjectionTextsByRouteUnit(selectStatementContext, routeUnit); + if (!projectionTexts.isEmpty()) { + result.put(routeUnit, projectionTexts); + } + } + return result; + } + + private Collection getDerivedProjectionTextsByRouteUnit(final SelectStatementContext selectStatementContext, final RouteUnit routeUnit) { + Collection result = new LinkedList<>(); + for (Projection each : selectStatementContext.getProjectionsContext().getProjections()) { + if (each instanceof AggregationProjection && !((AggregationProjection) each).getDerivedAggregationProjections().isEmpty()) { + result.addAll(((AggregationProjection) each).getDerivedAggregationProjections().stream().map(this::getDerivedProjectionText).collect(Collectors.toList())); + } else if (each instanceof DerivedProjection && ((DerivedProjection) each).getDerivedProjection() instanceof ColumnOrderByItemSegment) { + TableExtractor tableExtractor = new TableExtractor(); + tableExtractor.extractTablesFromSelect(selectStatementContext.getSqlStatement()); + result.add(getDerivedProjectionTextFromColumnOrderByItemSegment((DerivedProjection) each, tableExtractor, routeUnit)); + } else if (each instanceof DerivedProjection) { + result.add(getDerivedProjectionText(each)); } } return result; diff --git a/shardingsphere-features/shardingsphere-sharding/shardingsphere-sharding-rewrite/src/test/resources/sharding/select.xml b/shardingsphere-features/shardingsphere-sharding/shardingsphere-sharding-rewrite/src/test/resources/sharding/select.xml index b617724f9c51c9c4805c51e05b756beafb46b5f2..c01a58d0b45be0f62ea7b492b90df8c89d568991 100644 --- a/shardingsphere-features/shardingsphere-sharding/shardingsphere-sharding-rewrite/src/test/resources/sharding/select.xml +++ b/shardingsphere-features/shardingsphere-sharding/shardingsphere-sharding-rewrite/src/test/resources/sharding/select.xml @@ -64,6 +64,12 @@ + + + + + + diff --git a/shardingsphere-sql-parser/shardingsphere-sql-parser-binder/src/main/java/org/apache/shardingsphere/sql/parser/binder/statement/dml/SelectStatementContext.java b/shardingsphere-sql-parser/shardingsphere-sql-parser-binder/src/main/java/org/apache/shardingsphere/sql/parser/binder/statement/dml/SelectStatementContext.java index 157509aad7f85b599e5d3ab5a6e2275786bfc9fb..90497400ae3b83944a79be8182d3d34ced72edc9 100644 --- a/shardingsphere-sql-parser/shardingsphere-sql-parser-binder/src/main/java/org/apache/shardingsphere/sql/parser/binder/statement/dml/SelectStatementContext.java +++ b/shardingsphere-sql-parser/shardingsphere-sql-parser-binder/src/main/java/org/apache/shardingsphere/sql/parser/binder/statement/dml/SelectStatementContext.java @@ -36,6 +36,7 @@ import org.apache.shardingsphere.sql.parser.binder.segment.table.TablesContext; import org.apache.shardingsphere.sql.parser.binder.statement.CommonSQLStatementContext; import org.apache.shardingsphere.sql.parser.binder.type.TableAvailable; import org.apache.shardingsphere.sql.parser.binder.type.WhereAvailable; +import org.apache.shardingsphere.sql.parser.sql.common.extractor.TableExtractor; import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.order.item.ColumnOrderByItemSegment; import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.order.item.ExpressionOrderByItemSegment; import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.order.item.IndexOrderByItemSegment; @@ -44,7 +45,6 @@ import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.predicate.Whe import org.apache.shardingsphere.sql.parser.sql.common.segment.generic.table.SimpleTableSegment; import org.apache.shardingsphere.sql.parser.sql.common.statement.dml.SelectStatement; import org.apache.shardingsphere.sql.parser.sql.common.util.SQLUtil; -import org.apache.shardingsphere.sql.parser.sql.common.extractor.TableExtractor; import org.apache.shardingsphere.sql.parser.sql.common.util.WhereSegmentExtractUtils; import java.util.Collection; diff --git a/shardingsphere-sql-parser/shardingsphere-sql-parser-dialect/shardingsphere-sql-parser-mysql/src/main/java/org/apache/shardingsphere/sql/parser/mysql/visitor/MySQLVisitor.java b/shardingsphere-sql-parser/shardingsphere-sql-parser-dialect/shardingsphere-sql-parser-mysql/src/main/java/org/apache/shardingsphere/sql/parser/mysql/visitor/MySQLVisitor.java index 73054929dfdfb5c55f5a6ba6cfb77ebe0ccf44c3..c6e3d0c18c8e9183ad1fc122181d48be23f14904 100644 --- a/shardingsphere-sql-parser/shardingsphere-sql-parser-dialect/shardingsphere-sql-parser-mysql/src/main/java/org/apache/shardingsphere/sql/parser/mysql/visitor/MySQLVisitor.java +++ b/shardingsphere-sql-parser/shardingsphere-sql-parser-dialect/shardingsphere-sql-parser-mysql/src/main/java/org/apache/shardingsphere/sql/parser/mysql/visitor/MySQLVisitor.java @@ -303,7 +303,7 @@ public abstract class MySQLVisitor extends MySQLStatementBaseVisitor { if (null != ctx.predicate()) { right = (ExpressionSegment) visit(ctx.predicate()); } else { - right = (ExpressionSegment) visit(ctx.subquery()); + right = new SubqueryExpressionSegment(new SubquerySegment(ctx.subquery().start.getStartIndex(), ctx.subquery().stop.getStopIndex(), (MySQLSelectStatement) visit(ctx.subquery()))); } String operator = null != ctx.SAFE_EQ_() ? ctx.SAFE_EQ_().getText() : ctx.comparisonOperator().getText(); String text = ctx.start.getInputStream().getText(new Interval(ctx.start.getStartIndex(), ctx.stop.getStopIndex()));