提交 812cd956 编写于 作者: T terrymanu

Merge remote-tracking branch 'origin/master'

...@@ -17,12 +17,6 @@ ...@@ -17,12 +17,6 @@
package com.dangdang.ddframe.rdb.sharding.parser.visitor; package com.dangdang.ddframe.rdb.sharding.parser.visitor;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import com.alibaba.druid.sql.ast.SQLExpr; import com.alibaba.druid.sql.ast.SQLExpr;
import com.alibaba.druid.sql.ast.SQLObject; import com.alibaba.druid.sql.ast.SQLObject;
import com.alibaba.druid.sql.ast.expr.SQLBinaryOpExpr; import com.alibaba.druid.sql.ast.expr.SQLBinaryOpExpr;
...@@ -48,6 +42,13 @@ import com.google.common.base.Optional; ...@@ -48,6 +42,13 @@ import com.google.common.base.Optional;
import lombok.Getter; import lombok.Getter;
import lombok.Setter; import lombok.Setter;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashSet;
import java.util.LinkedList;
import java.util.List;
/** /**
* 解析过程的上下文对象. * 解析过程的上下文对象.
* *
...@@ -56,8 +57,12 @@ import lombok.Setter; ...@@ -56,8 +57,12 @@ import lombok.Setter;
@Getter @Getter
public final class ParseContext { public final class ParseContext {
private static final String AUTO_GEN_TOKE_KEY_TEMPLATE = "sharding_auto_gen_%d";
private static final String SHARDING_GEN_ALIAS = "sharding_gen_%s"; private static final String SHARDING_GEN_ALIAS = "sharding_gen_%s";
private final String autoGenTokenKey;
private final SQLParsedResult parsedResult = new SQLParsedResult(); private final SQLParsedResult parsedResult = new SQLParsedResult();
@Setter @Setter
...@@ -76,6 +81,24 @@ public final class ParseContext { ...@@ -76,6 +81,24 @@ public final class ParseContext {
private boolean hasAllColumn; private boolean hasAllColumn;
@Setter
private ParseContext parentParseContext;
private List<ParseContext> subParseContext = new LinkedList<>();
private int itemIndex;
public ParseContext(final int parseContextIndex) {
this.autoGenTokenKey = String.format(AUTO_GEN_TOKE_KEY_TEMPLATE, parseContextIndex);
}
/**
* 增加查询投射项数量.
*/
public void increaseItemIndex() {
itemIndex++;
}
/** /**
* 设置当前正在访问的表. * 设置当前正在访问的表.
* *
...@@ -329,5 +352,4 @@ public final class ParseContext { ...@@ -329,5 +352,4 @@ public final class ParseContext {
} }
selectItems.add(rawItemExpr); selectItems.add(rawItemExpr);
} }
} }
...@@ -17,9 +17,6 @@ ...@@ -17,9 +17,6 @@
package com.dangdang.ddframe.rdb.sharding.parser.visitor.basic.mysql; package com.dangdang.ddframe.rdb.sharding.parser.visitor.basic.mysql;
import java.util.Arrays;
import java.util.Collections;
import com.alibaba.druid.sql.ast.SQLHint; import com.alibaba.druid.sql.ast.SQLHint;
import com.alibaba.druid.sql.ast.expr.SQLBetweenExpr; import com.alibaba.druid.sql.ast.expr.SQLBetweenExpr;
import com.alibaba.druid.sql.ast.expr.SQLBinaryOpExpr; import com.alibaba.druid.sql.ast.expr.SQLBinaryOpExpr;
...@@ -38,6 +35,9 @@ import com.dangdang.ddframe.rdb.sharding.parser.visitor.ParseContext; ...@@ -38,6 +35,9 @@ import com.dangdang.ddframe.rdb.sharding.parser.visitor.ParseContext;
import com.dangdang.ddframe.rdb.sharding.parser.visitor.SQLVisitor; import com.dangdang.ddframe.rdb.sharding.parser.visitor.SQLVisitor;
import com.dangdang.ddframe.rdb.sharding.util.SQLUtil; import com.dangdang.ddframe.rdb.sharding.util.SQLUtil;
import java.util.Arrays;
import java.util.Collections;
/** /**
* MySQL解析基础访问器. * MySQL解析基础访问器.
* *
...@@ -45,11 +45,14 @@ import com.dangdang.ddframe.rdb.sharding.util.SQLUtil; ...@@ -45,11 +45,14 @@ import com.dangdang.ddframe.rdb.sharding.util.SQLUtil;
*/ */
public abstract class AbstractMySQLVisitor extends MySqlOutputVisitor implements SQLVisitor { public abstract class AbstractMySQLVisitor extends MySqlOutputVisitor implements SQLVisitor {
private final ParseContext parseContext = new ParseContext(); private ParseContext parseContext;
private int parseContextIndex;
protected AbstractMySQLVisitor() { protected AbstractMySQLVisitor() {
super(new SQLBuilder()); super(new SQLBuilder());
setPrettyFormat(false); setPrettyFormat(false);
parseContext = new ParseContext(parseContextIndex);
} }
@Override @Override
...@@ -62,6 +65,25 @@ public abstract class AbstractMySQLVisitor extends MySqlOutputVisitor implements ...@@ -62,6 +65,25 @@ public abstract class AbstractMySQLVisitor extends MySqlOutputVisitor implements
return parseContext; return parseContext;
} }
protected final void stepInQuery() {
if (0 == parseContextIndex) {
parseContextIndex++;
return;
}
ParseContext parseContext = new ParseContext(parseContextIndex++);
parseContext.setShardingColumns(this.parseContext.getShardingColumns());
parseContext.setParentParseContext(this.parseContext);
this.parseContext.getSubParseContext().add(parseContext);
this.parseContext = parseContext;
}
protected final void stepOutQuery() {
if (null == parseContext.getParentParseContext()) {
return;
}
parseContext = parseContext.getParentParseContext();
}
@Override @Override
public final SQLBuilder getSQLBuilder() { public final SQLBuilder getSQLBuilder() {
return (SQLBuilder) appender; return (SQLBuilder) appender;
...@@ -86,7 +108,7 @@ public abstract class AbstractMySQLVisitor extends MySqlOutputVisitor implements ...@@ -86,7 +108,7 @@ public abstract class AbstractMySQLVisitor extends MySqlOutputVisitor implements
@Override @Override
public final boolean visit(final SQLExprTableSource x) { public final boolean visit(final SQLExprTableSource x) {
return visit(x, parseContext.addTable(x)); return visit(x, getParseContext().addTable(x));
} }
private boolean visit(final SQLExprTableSource x, final Table table) { private boolean visit(final SQLExprTableSource x, final Table table) {
...@@ -128,7 +150,7 @@ public abstract class AbstractMySQLVisitor extends MySqlOutputVisitor implements ...@@ -128,7 +150,7 @@ public abstract class AbstractMySQLVisitor extends MySqlOutputVisitor implements
return super.visit(x); return super.visit(x);
} }
String tableOrAliasName = ((SQLIdentifierExpr) x.getOwner()).getLowerName(); String tableOrAliasName = ((SQLIdentifierExpr) x.getOwner()).getLowerName();
if (parseContext.isBinaryOperateWithAlias(x, tableOrAliasName)) { if (getParseContext().isBinaryOperateWithAlias(x, tableOrAliasName)) {
return super.visit(x); return super.visit(x);
} }
printToken(tableOrAliasName); printToken(tableOrAliasName);
......
...@@ -17,8 +17,6 @@ ...@@ -17,8 +17,6 @@
package com.dangdang.ddframe.rdb.sharding.parser.visitor.basic.mysql; package com.dangdang.ddframe.rdb.sharding.parser.visitor.basic.mysql;
import java.util.List;
import com.alibaba.druid.sql.ast.SQLExpr; import com.alibaba.druid.sql.ast.SQLExpr;
import com.alibaba.druid.sql.ast.SQLOrderBy; import com.alibaba.druid.sql.ast.SQLOrderBy;
import com.alibaba.druid.sql.ast.expr.SQLAggregateExpr; import com.alibaba.druid.sql.ast.expr.SQLAggregateExpr;
...@@ -31,7 +29,6 @@ import com.alibaba.druid.sql.ast.expr.SQLVariantRefExpr; ...@@ -31,7 +29,6 @@ import com.alibaba.druid.sql.ast.expr.SQLVariantRefExpr;
import com.alibaba.druid.sql.ast.statement.SQLExprTableSource; import com.alibaba.druid.sql.ast.statement.SQLExprTableSource;
import com.alibaba.druid.sql.ast.statement.SQLSelectItem; import com.alibaba.druid.sql.ast.statement.SQLSelectItem;
import com.alibaba.druid.sql.ast.statement.SQLSelectOrderByItem; import com.alibaba.druid.sql.ast.statement.SQLSelectOrderByItem;
import com.alibaba.druid.sql.ast.statement.SQLSelectStatement;
import com.alibaba.druid.sql.dialect.mysql.ast.expr.MySqlSelectGroupByExpr; import com.alibaba.druid.sql.dialect.mysql.ast.expr.MySqlSelectGroupByExpr;
import com.alibaba.druid.sql.dialect.mysql.ast.statement.MySqlSelectQueryBlock; import com.alibaba.druid.sql.dialect.mysql.ast.statement.MySqlSelectQueryBlock;
import com.alibaba.druid.sql.dialect.mysql.visitor.MySqlOutputVisitor; import com.alibaba.druid.sql.dialect.mysql.visitor.MySqlOutputVisitor;
...@@ -44,6 +41,8 @@ import com.dangdang.ddframe.rdb.sharding.parser.result.merger.OrderByColumn.Orde ...@@ -44,6 +41,8 @@ import com.dangdang.ddframe.rdb.sharding.parser.result.merger.OrderByColumn.Orde
import com.google.common.base.Optional; import com.google.common.base.Optional;
import com.google.common.base.Strings; import com.google.common.base.Strings;
import java.util.List;
/** /**
* MySQL的SELECT语句访问器. * MySQL的SELECT语句访问器.
* *
...@@ -51,20 +50,16 @@ import com.google.common.base.Strings; ...@@ -51,20 +50,16 @@ import com.google.common.base.Strings;
*/ */
public class MySQLSelectVisitor extends AbstractMySQLVisitor { public class MySQLSelectVisitor extends AbstractMySQLVisitor {
private static final String AUTO_GEN_TOKE_KEY = "sharding_auto_gen";
// TODO 封装到方法内部
private int itemIndex;
@Override @Override
protected void printSelectList(final List<SQLSelectItem> selectList) { protected void printSelectList(final List<SQLSelectItem> selectList) {
super.printSelectList(selectList); super.printSelectList(selectList);
// TODO 提炼成print,或者是否不应该由token的方式替换? // TODO 提炼成print,或者是否不应该由token的方式替换?
getSQLBuilder().appendToken(AUTO_GEN_TOKE_KEY, false); getSQLBuilder().appendToken(getParseContext().getAutoGenTokenKey(), false);
} }
@Override @Override
public boolean visit(final MySqlSelectQueryBlock x) { public boolean visit(final MySqlSelectQueryBlock x) {
stepInQuery();
if (x.getFrom() instanceof SQLExprTableSource) { if (x.getFrom() instanceof SQLExprTableSource) {
SQLExprTableSource tableExpr = (SQLExprTableSource) x.getFrom(); SQLExprTableSource tableExpr = (SQLExprTableSource) x.getFrom();
getParseContext().setCurrentTable(tableExpr.getExpr().toString(), Optional.fromNullable(tableExpr.getAlias())); getParseContext().setCurrentTable(tableExpr.getExpr().toString(), Optional.fromNullable(tableExpr.getAlias()));
...@@ -80,7 +75,7 @@ public class MySQLSelectVisitor extends AbstractMySQLVisitor { ...@@ -80,7 +75,7 @@ public class MySQLSelectVisitor extends AbstractMySQLVisitor {
*/ */
// TODO SELECT * 导致index不准,不支持SELECT *,且生产环境不建议使用SELECT * // TODO SELECT * 导致index不准,不支持SELECT *,且生产环境不建议使用SELECT *
public boolean visit(final SQLSelectItem x) { public boolean visit(final SQLSelectItem x) {
itemIndex++; getParseContext().increaseItemIndex();
if (Strings.isNullOrEmpty(x.getAlias())) { if (Strings.isNullOrEmpty(x.getAlias())) {
SQLExpr expr = x.getExpr(); SQLExpr expr = x.getExpr();
if (expr instanceof SQLIdentifierExpr) { if (expr instanceof SQLIdentifierExpr) {
...@@ -111,7 +106,7 @@ public class MySQLSelectVisitor extends AbstractMySQLVisitor { ...@@ -111,7 +106,7 @@ public class MySQLSelectVisitor extends AbstractMySQLVisitor {
x.accept(new MySqlOutputVisitor(expression)); x.accept(new MySqlOutputVisitor(expression));
// TODO index获取不准,考虑使用别名替换 // TODO index获取不准,考虑使用别名替换
AggregationColumn column = new AggregationColumn(expression.toString(), aggregationType, Optional.fromNullable(((SQLSelectItem) x.getParent()).getAlias()), AggregationColumn column = new AggregationColumn(expression.toString(), aggregationType, Optional.fromNullable(((SQLSelectItem) x.getParent()).getAlias()),
null == x.getOption() ? Optional.<String>absent() : Optional.of(x.getOption().toString()), itemIndex); null == x.getOption() ? Optional.<String>absent() : Optional.of(x.getOption().toString()), getParseContext().getItemIndex());
getParseContext().getParsedResult().getMergeContext().getAggregationColumns().add(column); getParseContext().getParsedResult().getMergeContext().getAggregationColumns().add(column);
if (AggregationType.AVG.equals(aggregationType)) { if (AggregationType.AVG.equals(aggregationType)) {
getParseContext().addDerivedColumnsForAvgColumn(column); getParseContext().addDerivedColumnsForAvgColumn(column);
...@@ -190,7 +185,7 @@ public class MySQLSelectVisitor extends AbstractMySQLVisitor { ...@@ -190,7 +185,7 @@ public class MySQLSelectVisitor extends AbstractMySQLVisitor {
} }
@Override @Override
public void endVisit(final SQLSelectStatement x) { public void endVisit(final MySqlSelectQueryBlock x) {
StringBuilder derivedSelectItems = new StringBuilder(); StringBuilder derivedSelectItems = new StringBuilder();
for (AggregationColumn aggregationColumn : getParseContext().getParsedResult().getMergeContext().getAggregationColumns()) { for (AggregationColumn aggregationColumn : getParseContext().getParsedResult().getMergeContext().getAggregationColumns()) {
for (AggregationColumn derivedColumn : aggregationColumn.getDerivedColumns()) { for (AggregationColumn derivedColumn : aggregationColumn.getDerivedColumns()) {
...@@ -206,8 +201,9 @@ public class MySQLSelectVisitor extends AbstractMySQLVisitor { ...@@ -206,8 +201,9 @@ public class MySQLSelectVisitor extends AbstractMySQLVisitor {
} }
} }
if (0 != derivedSelectItems.length()) { if (0 != derivedSelectItems.length()) {
getSQLBuilder().buildSQL(AUTO_GEN_TOKE_KEY, derivedSelectItems.toString()); getSQLBuilder().buildSQL(getParseContext().getAutoGenTokenKey(), derivedSelectItems.toString());
} }
super.endVisit(x); super.endVisit(x);
stepOutQuery();
} }
} }
<?xml version="1.0" encoding="UTF-8"?>
<asserts>
<assert id="assertSelectWithOrderByForAliasAndSubQuery" sql="SELECT price FROM (SELECT o.user_id,o.price FROM order o WHERE o.order_id = 1 ORDER BY o.order_id) order by user_id" expected-sql="SELECT price[Token(, user_id AS sharding_gen_1)] FROM (SELECT [Token(o)].user_id, [Token(o)].price[Token(, order_id AS sharding_gen_1)] FROM [Token(order)] o WHERE o.order_id = 1 ORDER BY o.order_id ) ORDER BY user_id">
<tables>
</tables>
<condition-contexts>
<condition-context>
</condition-context>
</condition-contexts>
<order-by-columns>
<order-by-column name="user_id" alias="sharding_gen_1" order-by-type="ASC" />
</order-by-columns>
</assert>
<assert id="assertSelectWithGroupByAndSubQuery" sql="SELECT AVG(i.SUM_PRICE) avg FROM (SELECT o.order_id,SUM(o.price) AS SUM_PRICE FROM order o WHERE o.order_id = 1 GROUP BY o.order_id) i" expected-sql="SELECT AVG(i.SUM_PRICE) AS avg[Token(, COUNT(i.SUM_PRICE) AS sharding_gen_1, SUM(i.SUM_PRICE) AS sharding_gen_2)] FROM (SELECT [Token(o)].order_id, SUM(o.price) AS SUM_PRICE[Token(, o.order_id AS sharding_gen_1)] FROM [Token(order)] o WHERE o.order_id = 1 GROUP BY o.order_id ) i">
<tables>
</tables>
<condition-contexts>
<condition-context>
</condition-context>
</condition-contexts>
<aggregation-columns>
<aggregation-column expression="AVG(i.SUM_PRICE)" aggregation-type="AVG" alias="avg" index="1">
<derived-column expression="COUNT(i.SUM_PRICE)" aggregation-type="COUNT" alias="sharding_gen_1"/>
<derived-column expression="SUM(i.SUM_PRICE)" aggregation-type="SUM" alias="sharding_gen_2" />
</aggregation-column>
<aggregation-column expression="COUNT(i.SUM_PRICE)" aggregation-type="COUNT" alias="sharding_gen_1" />
<aggregation-column expression="SUM(i.SUM_PRICE)" aggregation-type="SUM" alias="sharding_gen_2" />
</aggregation-columns>
</assert>
<assert id="assertSelectWithWhereSubQuery" sql="SELECT * FROM order o WHERE o.order_id = 2 and exists (select 1 from t_user u where u.user_id = o.user_id and u.user_id = 1)" expected-sql="SELECT * FROM [Token(order)] o WHERE o.order_id = 2 AND EXISTS (SELECT 1 FROM [Token(t_user)] u WHERE u.user_id = [Token(o)].user_id AND u.user_id = 1)">
<tables>
<table name="order" alias="o" />
</tables>
<condition-contexts>
<condition-context>
<condition column-name="order_id" table-name="order" operator="EQUAL">
<value value="2" type="java.lang.Integer" />
</condition>
</condition-context>
</condition-contexts>
</assert>
</asserts>
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册