未验证 提交 79c34582 编写于 作者: J JingShang Lu 提交者: GitHub

ref support of subquery (#7583)

* ref issue#6497

* delete useless import

* fix
上级 60aea13c
......@@ -67,17 +67,25 @@ public final class ProjectionsTokenGenerator implements OptionalSQLTokenGenerato
private Map<RouteUnit, Collection<String>> getDerivedProjectionTexts(final SelectStatementContext selectStatementContext) {
Map<RouteUnit, Collection<String>> result = new HashMap<>();
for (RouteUnit routeUnit : routeContext.getRouteResult().getRouteUnits()) {
result.put(routeUnit, new LinkedList<>());
for (Projection each : selectStatementContext.getProjectionsContext().getProjections()) {
if (each instanceof AggregationProjection && !((AggregationProjection) each).getDerivedAggregationProjections().isEmpty()) {
result.get(routeUnit).addAll(((AggregationProjection) each).getDerivedAggregationProjections().stream().map(this::getDerivedProjectionText).collect(Collectors.toList()));
} else if (each instanceof DerivedProjection && ((DerivedProjection) each).getDerivedProjection() instanceof ColumnOrderByItemSegment) {
TableExtractor tableExtractor = new TableExtractor();
tableExtractor.extractTablesFromSelect(selectStatementContext.getSqlStatement());
result.get(routeUnit).add(getDerivedProjectionTextFromColumnOrderByItemSegment((DerivedProjection) each, tableExtractor, routeUnit));
} else if (each instanceof DerivedProjection) {
result.get(routeUnit).add(getDerivedProjectionText(each));
}
Collection<String> projectionTexts = getDerivedProjectionTextsByRouteUnit(selectStatementContext, routeUnit);
if (!projectionTexts.isEmpty()) {
result.put(routeUnit, projectionTexts);
}
}
return result;
}
private Collection<String> getDerivedProjectionTextsByRouteUnit(final SelectStatementContext selectStatementContext, final RouteUnit routeUnit) {
Collection<String> result = new LinkedList<>();
for (Projection each : selectStatementContext.getProjectionsContext().getProjections()) {
if (each instanceof AggregationProjection && !((AggregationProjection) each).getDerivedAggregationProjections().isEmpty()) {
result.addAll(((AggregationProjection) each).getDerivedAggregationProjections().stream().map(this::getDerivedProjectionText).collect(Collectors.toList()));
} else if (each instanceof DerivedProjection && ((DerivedProjection) each).getDerivedProjection() instanceof ColumnOrderByItemSegment) {
TableExtractor tableExtractor = new TableExtractor();
tableExtractor.extractTablesFromSelect(selectStatementContext.getSqlStatement());
result.add(getDerivedProjectionTextFromColumnOrderByItemSegment((DerivedProjection) each, tableExtractor, routeUnit));
} else if (each instanceof DerivedProjection) {
result.add(getDerivedProjectionText(each));
}
}
return result;
......
......@@ -64,6 +64,12 @@
<output sql="SELECT (select id from t_account_0 limit 1) as myid FROM (select b.account_id from (select t_account_0.account_id from t_account_0) b where b.account_id=?) a WHERE account_id >= (select account_id from t_account_0 limit 1)" parameters="100"/>
</rewrite-assertion>
<rewrite-assertion id="select_with_subquery_only_in_projection" db-type="MySQL">
<input sql="SELECT (select id from t_account)"/>
<output sql="SELECT (select id from t_account_0)"/>
<output sql="SELECT (select id from t_account_1)"/>
</rewrite-assertion>
<rewrite-assertion id="select_with_subquery_for_where_in_predicate" db-type="MySQL">
<input sql="SELECT * FROM t_account WHERE account_id = ? AND amount IN (SELECT amount FROM t_account WHERE account_id = ?)" parameters="100, 100"/>
<output sql="SELECT * FROM t_account_0 WHERE account_id = ? AND amount IN (SELECT amount FROM t_account_0 WHERE account_id = ?)" parameters="100, 100"/>
......
......@@ -36,6 +36,7 @@ import org.apache.shardingsphere.sql.parser.binder.segment.table.TablesContext;
import org.apache.shardingsphere.sql.parser.binder.statement.CommonSQLStatementContext;
import org.apache.shardingsphere.sql.parser.binder.type.TableAvailable;
import org.apache.shardingsphere.sql.parser.binder.type.WhereAvailable;
import org.apache.shardingsphere.sql.parser.sql.common.extractor.TableExtractor;
import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.order.item.ColumnOrderByItemSegment;
import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.order.item.ExpressionOrderByItemSegment;
import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.order.item.IndexOrderByItemSegment;
......@@ -44,7 +45,6 @@ import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.predicate.Whe
import org.apache.shardingsphere.sql.parser.sql.common.segment.generic.table.SimpleTableSegment;
import org.apache.shardingsphere.sql.parser.sql.common.statement.dml.SelectStatement;
import org.apache.shardingsphere.sql.parser.sql.common.util.SQLUtil;
import org.apache.shardingsphere.sql.parser.sql.common.extractor.TableExtractor;
import org.apache.shardingsphere.sql.parser.sql.common.util.WhereSegmentExtractUtils;
import java.util.Collection;
......
......@@ -303,7 +303,7 @@ public abstract class MySQLVisitor extends MySQLStatementBaseVisitor<ASTNode> {
if (null != ctx.predicate()) {
right = (ExpressionSegment) visit(ctx.predicate());
} else {
right = (ExpressionSegment) visit(ctx.subquery());
right = new SubqueryExpressionSegment(new SubquerySegment(ctx.subquery().start.getStartIndex(), ctx.subquery().stop.getStopIndex(), (MySQLSelectStatement) visit(ctx.subquery())));
}
String operator = null != ctx.SAFE_EQ_() ? ctx.SAFE_EQ_().getText() : ctx.comparisonOperator().getText();
String text = ctx.start.getInputStream().getText(new Interval(ctx.start.getStartIndex(), ctx.stop.getStopIndex()));
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册