diff --git a/advisor/index.go b/advisor/index.go index 03581b86fb1ae35e7e1c75811bbcd32beb959b8c..0c00818eddc5ef96b2e0dc6018f028a4fc75a129 100644 --- a/advisor/index.go +++ b/advisor/index.go @@ -135,7 +135,7 @@ func NewAdvisor(env *env.VirtualEnv, rEnv database.Connector, q Query4Audit) (*I whereINEQ: ast.FindWhereINEQ(q.Stmt), groupBy: ast.FindGroupByCols(q.Stmt), orderBy: ast.FindOrderByCols(q.Stmt), - where: ast.FindAllCols(q.Stmt, "where"), + where: ast.FindAllCols(q.Stmt, ast.WhereExpression), IndexMeta: make(map[string]map[string]*database.TableIndexInfo), }, nil } diff --git a/ast/meta.go b/ast/meta.go index 25e04a9da4f1991d838ba8e3110bc9ac37022e81..48323ffd86c1c75c9adb8ea281aeb71770e91c77 100644 --- a/ast/meta.go +++ b/ast/meta.go @@ -450,15 +450,15 @@ func FindJoinTable(node sqlparser.SQLNode, meta common.Meta) common.Meta { switch expr := node.(type) { case *sqlparser.JoinTableExpr: switch expr.Join { - case "join", "natural join": + case sqlparser.JoinStr, sqlparser.NaturalJoinStr: // 两边表都需要 findJoinTable(expr.LeftExpr, meta) findJoinTable(expr.RightExpr, meta) - case "left join", "natural left join", "straight_join": + case sqlparser.LeftJoinStr, sqlparser.NaturalLeftJoinStr, sqlparser.StraightJoinStr: // 只需要右表 findJoinTable(expr.RightExpr, meta) - case "right join", "natural right join": + case sqlparser.RightJoinStr, sqlparser.NaturalRightJoinStr: // 只需要左表 findJoinTable(expr.LeftExpr, meta) } @@ -673,8 +673,22 @@ func FindAllCondition(node sqlparser.SQLNode) []interface{} { return conditions } +// Expression describe sql expression type +type Expression string + +const ( + // WhereExpression 用于标记 where + WhereExpression Expression = "where" + // JoinExpression 用于标记 join + JoinExpression Expression = "join" + // GroupByExpression 用于标记 group by + GroupByExpression Expression = "group by" + // OrderByExpression 用于标记 order by + OrderByExpression Expression = "order by" +) + // FindAllCols 获取 AST 中某个节点下所有的 columns -func FindAllCols(node sqlparser.SQLNode, targets ...string) []*common.Column { +func FindAllCols(node sqlparser.SQLNode, targets ...Expression) []*common.Column { var result []*common.Column // 获取节点内所有的列 f := func(node sqlparser.SQLNode) { @@ -699,25 +713,24 @@ func FindAllCols(node sqlparser.SQLNode, targets ...string) []*common.Column { } else { // 根据target获取所有的节点 for _, target := range targets { - target = strings.Replace(strings.ToLower(target), " ", "", -1) err := sqlparser.Walk(func(node sqlparser.SQLNode) (kontinue bool, err error) { switch node := node.(type) { case *sqlparser.Subquery: // 忽略子查询 case *sqlparser.JoinTableExpr: - if target == "join" { + if target == JoinExpression { f(node) } case *sqlparser.Where: - if target == "where" { + if target == WhereExpression { f(node) } - case *sqlparser.GroupBy: - if target == "groupby" { + case sqlparser.GroupBy: + if target == GroupByExpression { f(node) } case sqlparser.OrderBy: - if target == "orderby" { + if target == OrderByExpression { f(node) } } diff --git a/ast/meta_test.go b/ast/meta_test.go index 60b7c54768eede4784dd3179cb323c0a15e63470..f072978ec88b4cc376b8f8e486bae1e98a15a964 100644 --- a/ast/meta_test.go +++ b/ast/meta_test.go @@ -270,29 +270,31 @@ func TestFindColumn(t *testing.T) { func TestFindAllCols(t *testing.T) { sqlList := []string{ - "SELECT * FROM t1 LEFT JOIN (t2 CROSS JOIN t3 CROSS JOIN t4) ON (t2.a = t1.a AND t3.b = t1.b AND t4.c = t1.c)", - "select t from a LEFT JOIN b USING (c1, c2, c3)", - "select ID,name from (select address from customer_list where SID=1 order by phone limit 50,10) a join customer_list l on (a.address=l.address) join city c on (c.city=l.city) order by phone desc;", - "SELECT * FROM t1 LEFT JOIN (t2, t3, t4) ON (t2.a = t1.a AND t3.b = t1.b AND t4.c = t1.c)", - "SELECT * FROM t1 RIGHT JOIN (t2, t3, t4) ON (t2.a = t1.a AND t3.b = t1.b AND t4.c = t1.c)", - "SELECT left_tbl.* FROM left_tbl LEFT JOIN right_tbl ON left_tbl.id = right_tbl.id WHERE right_tbl.id IS NULL;", - "SELECT left_tbl.* FROM left_tbl RIGHT JOIN right_tbl ON left_tbl.id = right_tbl.id WHERE right_tbl.id IS NULL;", - "SELECT * FROM t1 where a in ('a','b')", - "SELECT * FROM t1 where a BETWEEN 'bar' AND 'foo'", - "SELECT * FROM t1 where a = sum(a,b)", - "SELECT distinct a FROM t1 where a = '2001-01-01 01:01:01'", + "select * from tb where a = '1' order by c", + "select * from tb where a = '1' group by c", + "select * from tb where c = '1' group by a", + "select * from tb join tb2 on c = c where c = '1' group by a", } - for _, sql := range sqlList { - fmt.Println(sql) + targets := []Expression{ + OrderByExpression, + GroupByExpression, + WhereExpression, + JoinExpression, + } + + for i, sql := range sqlList { stmt, err := sqlparser.Parse(sql) - // pretty.Println(stmt) if err != nil { - panic(err) + t.Error(err) + return } - columns := FindAllCols(stmt, "order by") - pretty.Println(columns) + columns := FindAllCols(stmt, targets[i]) + if columns[0].Name != "c" { + fmt.Println(sql) + t.Error(fmt.Errorf("want 'c' got %v", columns)) + } } }