未验证 提交 a4a641a6 编写于 作者: martianzhang's avatar martianzhang 提交者: GitHub

Merge pull request #119 from liipx/master

fix TestFindAllCols, add expression type
......@@ -135,7 +135,7 @@ func NewAdvisor(env *env.VirtualEnv, rEnv database.Connector, q Query4Audit) (*I
whereINEQ: ast.FindWhereINEQ(q.Stmt),
groupBy: ast.FindGroupByCols(q.Stmt),
orderBy: ast.FindOrderByCols(q.Stmt),
where: ast.FindAllCols(q.Stmt, "where"),
where: ast.FindAllCols(q.Stmt, ast.WhereExpression),
IndexMeta: make(map[string]map[string]*database.TableIndexInfo),
}, nil
}
......
......@@ -450,15 +450,15 @@ func FindJoinTable(node sqlparser.SQLNode, meta common.Meta) common.Meta {
switch expr := node.(type) {
case *sqlparser.JoinTableExpr:
switch expr.Join {
case "join", "natural join":
case sqlparser.JoinStr, sqlparser.NaturalJoinStr:
// 两边表都需要
findJoinTable(expr.LeftExpr, meta)
findJoinTable(expr.RightExpr, meta)
case "left join", "natural left join", "straight_join":
case sqlparser.LeftJoinStr, sqlparser.NaturalLeftJoinStr, sqlparser.StraightJoinStr:
// 只需要右表
findJoinTable(expr.RightExpr, meta)
case "right join", "natural right join":
case sqlparser.RightJoinStr, sqlparser.NaturalRightJoinStr:
// 只需要左表
findJoinTable(expr.LeftExpr, meta)
}
......@@ -673,8 +673,22 @@ func FindAllCondition(node sqlparser.SQLNode) []interface{} {
return conditions
}
// Expression describe sql expression type
type Expression string
const (
// WhereExpression 用于标记 where
WhereExpression Expression = "where"
// JoinExpression 用于标记 join
JoinExpression Expression = "join"
// GroupByExpression 用于标记 group by
GroupByExpression Expression = "group by"
// OrderByExpression 用于标记 order by
OrderByExpression Expression = "order by"
)
// FindAllCols 获取 AST 中某个节点下所有的 columns
func FindAllCols(node sqlparser.SQLNode, targets ...string) []*common.Column {
func FindAllCols(node sqlparser.SQLNode, targets ...Expression) []*common.Column {
var result []*common.Column
// 获取节点内所有的列
f := func(node sqlparser.SQLNode) {
......@@ -699,25 +713,24 @@ func FindAllCols(node sqlparser.SQLNode, targets ...string) []*common.Column {
} else {
// 根据target获取所有的节点
for _, target := range targets {
target = strings.Replace(strings.ToLower(target), " ", "", -1)
err := sqlparser.Walk(func(node sqlparser.SQLNode) (kontinue bool, err error) {
switch node := node.(type) {
case *sqlparser.Subquery:
// 忽略子查询
case *sqlparser.JoinTableExpr:
if target == "join" {
if target == JoinExpression {
f(node)
}
case *sqlparser.Where:
if target == "where" {
if target == WhereExpression {
f(node)
}
case *sqlparser.GroupBy:
if target == "groupby" {
case sqlparser.GroupBy:
if target == GroupByExpression {
f(node)
}
case sqlparser.OrderBy:
if target == "orderby" {
if target == OrderByExpression {
f(node)
}
}
......
......@@ -270,29 +270,31 @@ func TestFindColumn(t *testing.T) {
func TestFindAllCols(t *testing.T) {
sqlList := []string{
"SELECT * FROM t1 LEFT JOIN (t2 CROSS JOIN t3 CROSS JOIN t4) ON (t2.a = t1.a AND t3.b = t1.b AND t4.c = t1.c)",
"select t from a LEFT JOIN b USING (c1, c2, c3)",
"select ID,name from (select address from customer_list where SID=1 order by phone limit 50,10) a join customer_list l on (a.address=l.address) join city c on (c.city=l.city) order by phone desc;",
"SELECT * FROM t1 LEFT JOIN (t2, t3, t4) ON (t2.a = t1.a AND t3.b = t1.b AND t4.c = t1.c)",
"SELECT * FROM t1 RIGHT JOIN (t2, t3, t4) ON (t2.a = t1.a AND t3.b = t1.b AND t4.c = t1.c)",
"SELECT left_tbl.* FROM left_tbl LEFT JOIN right_tbl ON left_tbl.id = right_tbl.id WHERE right_tbl.id IS NULL;",
"SELECT left_tbl.* FROM left_tbl RIGHT JOIN right_tbl ON left_tbl.id = right_tbl.id WHERE right_tbl.id IS NULL;",
"SELECT * FROM t1 where a in ('a','b')",
"SELECT * FROM t1 where a BETWEEN 'bar' AND 'foo'",
"SELECT * FROM t1 where a = sum(a,b)",
"SELECT distinct a FROM t1 where a = '2001-01-01 01:01:01'",
"select * from tb where a = '1' order by c",
"select * from tb where a = '1' group by c",
"select * from tb where c = '1' group by a",
"select * from tb join tb2 on c = c where c = '1' group by a",
}
for _, sql := range sqlList {
fmt.Println(sql)
targets := []Expression{
OrderByExpression,
GroupByExpression,
WhereExpression,
JoinExpression,
}
for i, sql := range sqlList {
stmt, err := sqlparser.Parse(sql)
// pretty.Println(stmt)
if err != nil {
panic(err)
t.Error(err)
return
}
columns := FindAllCols(stmt, "order by")
pretty.Println(columns)
columns := FindAllCols(stmt, targets[i])
if columns[0].Name != "c" {
fmt.Println(sql)
t.Error(fmt.Errorf("want 'c' got %v", columns))
}
}
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册