From b33330671eac0a99dcad7d36d7e4fdcf90d65675 Mon Sep 17 00:00:00 2001 From: Leon Zhang Date: Fri, 12 Apr 2019 18:50:47 +0800 Subject: [PATCH] Fix ARG.008 cases col = 1 OR col IS NULL col1 = 1 OR col2 = 1 --- advisor/heuristic.go | 18 +++++++++++++++++- advisor/heuristic_test.go | 23 ++++++++++++++++++++--- 2 files changed, 37 insertions(+), 4 deletions(-) diff --git a/advisor/heuristic.go b/advisor/heuristic.go index 413808d..243fa05 100644 --- a/advisor/heuristic.go +++ b/advisor/heuristic.go @@ -2116,8 +2116,24 @@ func (q *Query4Audit) RuleORUsage() Rule { switch q.Stmt.(type) { case *sqlparser.Select: err := sqlparser.Walk(func(node sqlparser.SQLNode) (kontinue bool, err error) { - switch node.(type) { + switch n := node.(type) { case *sqlparser.OrExpr: + switch n.Left.(type) { + case *sqlparser.IsExpr: + // IS TRUE|FALSE|NULL eg. a = 1 or a IS NULL 这种情况也需要考虑 + return true, nil + } + switch n.Right.(type) { + case *sqlparser.IsExpr: + // IS TRUE|FALSE|NULL eg. a = 1 or a IS NULL 这种情况也需要考虑 + return true, nil + } + + if strings.Fields(sqlparser.String(n.Left))[0] != strings.Fields(sqlparser.String(n.Right))[0] { + // 不同字段需要区分开,不同字段的 OR 不能改写为 IN + return true, nil + } + rule = HeuristicRules["ARG.008"] return false, nil } diff --git a/advisor/heuristic_test.go b/advisor/heuristic_test.go index c69e630..2ae6326 100644 --- a/advisor/heuristic_test.go +++ b/advisor/heuristic_test.go @@ -1840,10 +1840,16 @@ func TestRuleMultiDBJoin(t *testing.T) { // ARG.008 func TestRuleORUsage(t *testing.T) { common.Log.Debug("Entering function: %s", common.GetFunctionName()) - sqls := []string{ - `SELECT c1,c2,c3 FROM tab WHERE c1 = 14 OR c2 = 17;`, + sqls := [][]string{ + { + `SELECT c1,c2,c3 FROM tab WHERE c1 = 14 OR c1 = 14;`, + }, + { + `SELECT c1,c2,c3 FROM tab WHERE c1 = 14 OR c2 = 17;`, + `SELECT c1,c2,c3 FROM tab WHERE c1 = 14 OR c1 IS NULL;`, + }, } - for _, sql := range sqls { + for _, sql := range sqls[0] { q, err := NewQuery4Audit(sql) if err == nil { rule := q.RuleORUsage() @@ -1854,6 +1860,17 @@ func TestRuleORUsage(t *testing.T) { t.Error("sqlparser.Parse Error:", err) } } + for _, sql := range sqls[1] { + q, err := NewQuery4Audit(sql) + if err == nil { + rule := q.RuleORUsage() + if rule.Item != "OK" { + t.Error("Rule not match:", rule.Item, "Expect : OK") + } + } else { + t.Error("sqlparser.Parse Error:", err) + } + } common.Log.Debug("Exiting function: %s", common.GetFunctionName()) } -- GitLab