diff --git a/advisor/heuristic.go b/advisor/heuristic.go index 55a978ddd832885281b483cb266d906d8dc5e7ec..6707073abdae9ca2a4906325b489e27b476c5281 100644 --- a/advisor/heuristic.go +++ b/advisor/heuristic.go @@ -2460,6 +2460,8 @@ func (q *Query4Audit) RuleInjection() Rule { // RuleCompareWithFunction FUN.001 func (q *Query4Audit) RuleCompareWithFunction() Rule { var rule = q.RuleOK() + + // `select id from t where num/2 = 100`, err := sqlparser.Walk(func(node sqlparser.SQLNode) (kontinue bool, err error) { // Vitess 中有些函数进行了单独定义不在 FuncExpr 中,如: substring。所以不能直接用 FuncExpr 判断。 switch n := node.(type) { @@ -2470,38 +2472,26 @@ func (q *Query4Audit) RuleCompareWithFunction() Rule { rule = HeuristicRules["FUN.001"] return false, nil } - /* - // func always has bracket - if strings.HasSuffix(sqlparser.String(n.Left), ")") { - rule = HeuristicRules["FUN.001"] - return false, nil - } - */ - - case *sqlparser.RangeCond: - // func(a) between func(c) and func(d) - switch n.Left.(type) { - case *sqlparser.SQLVal, *sqlparser.ColName: - default: - rule = HeuristicRules["FUN.001"] - return false, nil - } - switch n.From.(type) { - case *sqlparser.SQLVal, *sqlparser.ColName: - default: - rule = HeuristicRules["FUN.001"] - return false, nil - } - switch n.To.(type) { - case *sqlparser.SQLVal, *sqlparser.ColName: - default: - rule = HeuristicRules["FUN.001"] - return false, nil - } } return true, nil }, q.Stmt) common.LogIfError(err, "") + + // select id from t where substring(name,1,3)='abc'; + for _, tiStmt := range q.TiStmt { + switch tiStmt.(type) { + case *tidb.SelectStmt, *tidb.UpdateStmt, *tidb.DeleteStmt: + json := ast.StmtNode2JSON(q.Query, "", "") + whereJSON := common.JSONFind(json, "Where") + for _, where := range whereJSON { + if len(common.JSONFind(where, "FnName")) > 0 { + rule = HeuristicRules["FUN.001"] + } + break + } + } + } + return rule } diff --git a/advisor/heuristic_test.go b/advisor/heuristic_test.go index 89e5a384494cd3c1540a8e8ab62f235e3a31ee1a..f0579107621cd0034b6c78cd79bf266f068d4a49 100644 --- a/advisor/heuristic_test.go +++ b/advisor/heuristic_test.go @@ -2455,19 +2455,29 @@ func TestCompareWithFunction(t *testing.T) { `select id from t where substring(name,1,3)='abc';`, `SELECT * FROM tbl WHERE UNIX_TIMESTAMP(loginTime) BETWEEN UNIX_TIMESTAMP('2018-11-16 09:46:00 +0800 CST') AND UNIX_TIMESTAMP('2018-11-22 00:00:00 +0800 CST')`, `select id from t where num/2 = 100`, + `select id from t where num/2 < 100`, + // 时间 builtin 函数 + `SELECT * FROM tb WHERE DATE '2020-01-01'`, + `DELETE FROM tb WHERE DATE '2020-01-01'`, + `UPDATE tb SET col = 1 WHERE DATE '2020-01-01'`, + `SELECT * FROM tb WHERE TIME '10:01:01'`, + `SELECT * FROM tb WHERE TIMESTAMP '1587181360'`, + `select * from mysql.user where user = "root" and date '2020-02-01'`, + // 右侧使用函数比较 + `select id from t where 'abc'=substring(name,1,3);`, }, - // TODO: 右侧使用函数比较 + // 正常 SQL { - `select id from t where 'abc'=substring(name,1,3);`, `select id from t where col = (select 1)`, + `select id from t where col = 1`, }, } - for _, sql := range sqls[0] { + for i, sql := range sqls[0] { q, err := NewQuery4Audit(sql) if err == nil { rule := q.RuleCompareWithFunction() if rule.Item != "FUN.001" { - t.Error("Rule not match:", rule.Item, "Expect : FUN.001") + t.Errorf("SQL: %d, Rule not match: %s Expect : FUN.001", i, rule.Item) } } else { t.Error("sqlparser.Parse Error:", err)