提交 b3333067 编写于 作者: martianzhang's avatar martianzhang

Fix ARG.008 cases

  col = 1 OR col IS NULL
  col1 = 1 OR col2 = 1
上级 fc7d57af
...@@ -2116,8 +2116,24 @@ func (q *Query4Audit) RuleORUsage() Rule { ...@@ -2116,8 +2116,24 @@ func (q *Query4Audit) RuleORUsage() Rule {
switch q.Stmt.(type) { switch q.Stmt.(type) {
case *sqlparser.Select: case *sqlparser.Select:
err := sqlparser.Walk(func(node sqlparser.SQLNode) (kontinue bool, err error) { err := sqlparser.Walk(func(node sqlparser.SQLNode) (kontinue bool, err error) {
switch node.(type) { switch n := node.(type) {
case *sqlparser.OrExpr: case *sqlparser.OrExpr:
switch n.Left.(type) {
case *sqlparser.IsExpr:
// IS TRUE|FALSE|NULL eg. a = 1 or a IS NULL 这种情况也需要考虑
return true, nil
}
switch n.Right.(type) {
case *sqlparser.IsExpr:
// IS TRUE|FALSE|NULL eg. a = 1 or a IS NULL 这种情况也需要考虑
return true, nil
}
if strings.Fields(sqlparser.String(n.Left))[0] != strings.Fields(sqlparser.String(n.Right))[0] {
// 不同字段需要区分开,不同字段的 OR 不能改写为 IN
return true, nil
}
rule = HeuristicRules["ARG.008"] rule = HeuristicRules["ARG.008"]
return false, nil return false, nil
} }
......
...@@ -1840,10 +1840,16 @@ func TestRuleMultiDBJoin(t *testing.T) { ...@@ -1840,10 +1840,16 @@ func TestRuleMultiDBJoin(t *testing.T) {
// ARG.008 // ARG.008
func TestRuleORUsage(t *testing.T) { func TestRuleORUsage(t *testing.T) {
common.Log.Debug("Entering function: %s", common.GetFunctionName()) common.Log.Debug("Entering function: %s", common.GetFunctionName())
sqls := []string{ sqls := [][]string{
`SELECT c1,c2,c3 FROM tab WHERE c1 = 14 OR c2 = 17;`, {
`SELECT c1,c2,c3 FROM tab WHERE c1 = 14 OR c1 = 14;`,
},
{
`SELECT c1,c2,c3 FROM tab WHERE c1 = 14 OR c2 = 17;`,
`SELECT c1,c2,c3 FROM tab WHERE c1 = 14 OR c1 IS NULL;`,
},
} }
for _, sql := range sqls { for _, sql := range sqls[0] {
q, err := NewQuery4Audit(sql) q, err := NewQuery4Audit(sql)
if err == nil { if err == nil {
rule := q.RuleORUsage() rule := q.RuleORUsage()
...@@ -1854,6 +1860,17 @@ func TestRuleORUsage(t *testing.T) { ...@@ -1854,6 +1860,17 @@ func TestRuleORUsage(t *testing.T) {
t.Error("sqlparser.Parse Error:", err) t.Error("sqlparser.Parse Error:", err)
} }
} }
for _, sql := range sqls[1] {
q, err := NewQuery4Audit(sql)
if err == nil {
rule := q.RuleORUsage()
if rule.Item != "OK" {
t.Error("Rule not match:", rule.Item, "Expect : OK")
}
} else {
t.Error("sqlparser.Parse Error:", err)
}
}
common.Log.Debug("Exiting function: %s", common.GetFunctionName()) common.Log.Debug("Exiting function: %s", common.GetFunctionName())
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册