未验证 提交 79c34582 编写于 作者: J JingShang Lu 提交者: GitHub

ref support of subquery (#7583)

* ref issue#6497

* delete useless import

* fix
上级 60aea13c
...@@ -67,17 +67,25 @@ public final class ProjectionsTokenGenerator implements OptionalSQLTokenGenerato ...@@ -67,17 +67,25 @@ public final class ProjectionsTokenGenerator implements OptionalSQLTokenGenerato
private Map<RouteUnit, Collection<String>> getDerivedProjectionTexts(final SelectStatementContext selectStatementContext) { private Map<RouteUnit, Collection<String>> getDerivedProjectionTexts(final SelectStatementContext selectStatementContext) {
Map<RouteUnit, Collection<String>> result = new HashMap<>(); Map<RouteUnit, Collection<String>> result = new HashMap<>();
for (RouteUnit routeUnit : routeContext.getRouteResult().getRouteUnits()) { for (RouteUnit routeUnit : routeContext.getRouteResult().getRouteUnits()) {
result.put(routeUnit, new LinkedList<>()); Collection<String> projectionTexts = getDerivedProjectionTextsByRouteUnit(selectStatementContext, routeUnit);
for (Projection each : selectStatementContext.getProjectionsContext().getProjections()) { if (!projectionTexts.isEmpty()) {
if (each instanceof AggregationProjection && !((AggregationProjection) each).getDerivedAggregationProjections().isEmpty()) { result.put(routeUnit, projectionTexts);
result.get(routeUnit).addAll(((AggregationProjection) each).getDerivedAggregationProjections().stream().map(this::getDerivedProjectionText).collect(Collectors.toList())); }
} else if (each instanceof DerivedProjection && ((DerivedProjection) each).getDerivedProjection() instanceof ColumnOrderByItemSegment) { }
TableExtractor tableExtractor = new TableExtractor(); return result;
tableExtractor.extractTablesFromSelect(selectStatementContext.getSqlStatement()); }
result.get(routeUnit).add(getDerivedProjectionTextFromColumnOrderByItemSegment((DerivedProjection) each, tableExtractor, routeUnit));
} else if (each instanceof DerivedProjection) { private Collection<String> getDerivedProjectionTextsByRouteUnit(final SelectStatementContext selectStatementContext, final RouteUnit routeUnit) {
result.get(routeUnit).add(getDerivedProjectionText(each)); Collection<String> result = new LinkedList<>();
} for (Projection each : selectStatementContext.getProjectionsContext().getProjections()) {
if (each instanceof AggregationProjection && !((AggregationProjection) each).getDerivedAggregationProjections().isEmpty()) {
result.addAll(((AggregationProjection) each).getDerivedAggregationProjections().stream().map(this::getDerivedProjectionText).collect(Collectors.toList()));
} else if (each instanceof DerivedProjection && ((DerivedProjection) each).getDerivedProjection() instanceof ColumnOrderByItemSegment) {
TableExtractor tableExtractor = new TableExtractor();
tableExtractor.extractTablesFromSelect(selectStatementContext.getSqlStatement());
result.add(getDerivedProjectionTextFromColumnOrderByItemSegment((DerivedProjection) each, tableExtractor, routeUnit));
} else if (each instanceof DerivedProjection) {
result.add(getDerivedProjectionText(each));
} }
} }
return result; return result;
......
...@@ -64,6 +64,12 @@ ...@@ -64,6 +64,12 @@
<output sql="SELECT (select id from t_account_0 limit 1) as myid FROM (select b.account_id from (select t_account_0.account_id from t_account_0) b where b.account_id=?) a WHERE account_id >= (select account_id from t_account_0 limit 1)" parameters="100"/> <output sql="SELECT (select id from t_account_0 limit 1) as myid FROM (select b.account_id from (select t_account_0.account_id from t_account_0) b where b.account_id=?) a WHERE account_id >= (select account_id from t_account_0 limit 1)" parameters="100"/>
</rewrite-assertion> </rewrite-assertion>
<rewrite-assertion id="select_with_subquery_only_in_projection" db-type="MySQL">
<input sql="SELECT (select id from t_account)"/>
<output sql="SELECT (select id from t_account_0)"/>
<output sql="SELECT (select id from t_account_1)"/>
</rewrite-assertion>
<rewrite-assertion id="select_with_subquery_for_where_in_predicate" db-type="MySQL"> <rewrite-assertion id="select_with_subquery_for_where_in_predicate" db-type="MySQL">
<input sql="SELECT * FROM t_account WHERE account_id = ? AND amount IN (SELECT amount FROM t_account WHERE account_id = ?)" parameters="100, 100"/> <input sql="SELECT * FROM t_account WHERE account_id = ? AND amount IN (SELECT amount FROM t_account WHERE account_id = ?)" parameters="100, 100"/>
<output sql="SELECT * FROM t_account_0 WHERE account_id = ? AND amount IN (SELECT amount FROM t_account_0 WHERE account_id = ?)" parameters="100, 100"/> <output sql="SELECT * FROM t_account_0 WHERE account_id = ? AND amount IN (SELECT amount FROM t_account_0 WHERE account_id = ?)" parameters="100, 100"/>
......
...@@ -36,6 +36,7 @@ import org.apache.shardingsphere.sql.parser.binder.segment.table.TablesContext; ...@@ -36,6 +36,7 @@ import org.apache.shardingsphere.sql.parser.binder.segment.table.TablesContext;
import org.apache.shardingsphere.sql.parser.binder.statement.CommonSQLStatementContext; import org.apache.shardingsphere.sql.parser.binder.statement.CommonSQLStatementContext;
import org.apache.shardingsphere.sql.parser.binder.type.TableAvailable; import org.apache.shardingsphere.sql.parser.binder.type.TableAvailable;
import org.apache.shardingsphere.sql.parser.binder.type.WhereAvailable; import org.apache.shardingsphere.sql.parser.binder.type.WhereAvailable;
import org.apache.shardingsphere.sql.parser.sql.common.extractor.TableExtractor;
import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.order.item.ColumnOrderByItemSegment; import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.order.item.ColumnOrderByItemSegment;
import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.order.item.ExpressionOrderByItemSegment; import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.order.item.ExpressionOrderByItemSegment;
import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.order.item.IndexOrderByItemSegment; import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.order.item.IndexOrderByItemSegment;
...@@ -44,7 +45,6 @@ import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.predicate.Whe ...@@ -44,7 +45,6 @@ import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.predicate.Whe
import org.apache.shardingsphere.sql.parser.sql.common.segment.generic.table.SimpleTableSegment; import org.apache.shardingsphere.sql.parser.sql.common.segment.generic.table.SimpleTableSegment;
import org.apache.shardingsphere.sql.parser.sql.common.statement.dml.SelectStatement; import org.apache.shardingsphere.sql.parser.sql.common.statement.dml.SelectStatement;
import org.apache.shardingsphere.sql.parser.sql.common.util.SQLUtil; import org.apache.shardingsphere.sql.parser.sql.common.util.SQLUtil;
import org.apache.shardingsphere.sql.parser.sql.common.extractor.TableExtractor;
import org.apache.shardingsphere.sql.parser.sql.common.util.WhereSegmentExtractUtils; import org.apache.shardingsphere.sql.parser.sql.common.util.WhereSegmentExtractUtils;
import java.util.Collection; import java.util.Collection;
......
...@@ -303,7 +303,7 @@ public abstract class MySQLVisitor extends MySQLStatementBaseVisitor<ASTNode> { ...@@ -303,7 +303,7 @@ public abstract class MySQLVisitor extends MySQLStatementBaseVisitor<ASTNode> {
if (null != ctx.predicate()) { if (null != ctx.predicate()) {
right = (ExpressionSegment) visit(ctx.predicate()); right = (ExpressionSegment) visit(ctx.predicate());
} else { } else {
right = (ExpressionSegment) visit(ctx.subquery()); right = new SubqueryExpressionSegment(new SubquerySegment(ctx.subquery().start.getStartIndex(), ctx.subquery().stop.getStopIndex(), (MySQLSelectStatement) visit(ctx.subquery())));
} }
String operator = null != ctx.SAFE_EQ_() ? ctx.SAFE_EQ_().getText() : ctx.comparisonOperator().getText(); String operator = null != ctx.SAFE_EQ_() ? ctx.SAFE_EQ_().getText() : ctx.comparisonOperator().getText();
String text = ctx.start.getInputStream().getText(new Interval(ctx.start.getStartIndex(), ctx.stop.getStopIndex())); String text = ctx.start.getInputStream().getText(new Interval(ctx.start.getStartIndex(), ctx.stop.getStopIndex()));
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册