diff --git a/advisor/heuristic.go b/advisor/heuristic.go index 413808d2f6d88e9717f101146b5a3968ab0f160b..243fa059c3b320ccb6a70f2ec5c419c0391c905a 100644 --- a/advisor/heuristic.go +++ b/advisor/heuristic.go @@ -2116,8 +2116,24 @@ func (q *Query4Audit) RuleORUsage() Rule { switch q.Stmt.(type) { case *sqlparser.Select: err := sqlparser.Walk(func(node sqlparser.SQLNode) (kontinue bool, err error) { - switch node.(type) { + switch n := node.(type) { case *sqlparser.OrExpr: + switch n.Left.(type) { + case *sqlparser.IsExpr: + // IS TRUE|FALSE|NULL eg. a = 1 or a IS NULL 这种情况也需要考虑 + return true, nil + } + switch n.Right.(type) { + case *sqlparser.IsExpr: + // IS TRUE|FALSE|NULL eg. a = 1 or a IS NULL 这种情况也需要考虑 + return true, nil + } + + if strings.Fields(sqlparser.String(n.Left))[0] != strings.Fields(sqlparser.String(n.Right))[0] { + // 不同字段需要区分开,不同字段的 OR 不能改写为 IN + return true, nil + } + rule = HeuristicRules["ARG.008"] return false, nil } diff --git a/advisor/heuristic_test.go b/advisor/heuristic_test.go index c69e63050176e4ef07f245e6ac7aa24007bc1a01..2ae6326abbdcc8230a4dfbf918ff29754f229020 100644 --- a/advisor/heuristic_test.go +++ b/advisor/heuristic_test.go @@ -1840,10 +1840,16 @@ func TestRuleMultiDBJoin(t *testing.T) { // ARG.008 func TestRuleORUsage(t *testing.T) { common.Log.Debug("Entering function: %s", common.GetFunctionName()) - sqls := []string{ - `SELECT c1,c2,c3 FROM tab WHERE c1 = 14 OR c2 = 17;`, + sqls := [][]string{ + { + `SELECT c1,c2,c3 FROM tab WHERE c1 = 14 OR c1 = 14;`, + }, + { + `SELECT c1,c2,c3 FROM tab WHERE c1 = 14 OR c2 = 17;`, + `SELECT c1,c2,c3 FROM tab WHERE c1 = 14 OR c1 IS NULL;`, + }, } - for _, sql := range sqls { + for _, sql := range sqls[0] { q, err := NewQuery4Audit(sql) if err == nil { rule := q.RuleORUsage() @@ -1854,6 +1860,17 @@ func TestRuleORUsage(t *testing.T) { t.Error("sqlparser.Parse Error:", err) } } + for _, sql := range sqls[1] { + q, err := NewQuery4Audit(sql) + if err == nil { + rule := q.RuleORUsage() + if rule.Item != "OK" { + t.Error("Rule not match:", rule.Item, "Expect : OK") + } + } else { + t.Error("sqlparser.Parse Error:", err) + } + } common.Log.Debug("Exiting function: %s", common.GetFunctionName()) }