diff --git a/advisor/heuristic.go b/advisor/heuristic.go index 02699f831ac679aea4a8d4c26945c97d4d5d304e..a21088eb0ac42a243bdda62c832af8eb9a66df97 100644 --- a/advisor/heuristic.go +++ b/advisor/heuristic.go @@ -1275,7 +1275,7 @@ func (q *Query4Audit) RuleMeaninglessWhere() Rule { func (q *Query4Audit) RuleLoadFile() Rule { var rule = q.RuleOK() // 去除注释 - sql := string(database.RemoveSQLComments([]byte(q.Query))) + sql := database.RemoveSQLComments(q.Query) // 去除多余的空格和回车 sql = strings.Join(strings.Fields(sql), " ") tks := ast.Tokenize(sql) diff --git a/cmd/soar/soar.go b/cmd/soar/soar.go index ac0c97e25ac0516d452dac612d6919a6dd7df2dd..cb7172c5267cb3295b95f7373da182c125de2507 100644 --- a/cmd/soar/soar.go +++ b/cmd/soar/soar.go @@ -120,8 +120,7 @@ func main() { buf = string(bufBytes) // 去除无用的备注和空格 - sql = strings.TrimSpace(sql) - sql = string(database.RemoveSQLComments([]byte(sql))) + sql = database.RemoveSQLComments(sql) if sql == "" { common.Log.Debug("empty query or comment, buf: %s", buf) continue diff --git a/cmd/soar/tool.go b/cmd/soar/tool.go index ccad724ffa6b2c7096578445aeffce3044ac40f9..7e550b577fd49b0c32fc321e415206193c158cbe 100644 --- a/cmd/soar/tool.go +++ b/cmd/soar/tool.go @@ -175,7 +175,7 @@ func reportTool(sql string, bom []byte) (isContinue bool, exitCode int) { fmt.Println(charset) return false, 0 case "remove-comment": - fmt.Println(string(database.RemoveSQLComments([]byte(sql)))) + fmt.Println(database.RemoveSQLComments(sql)) return false, 0 default: return true, 0 diff --git a/database/explain.go b/database/explain.go index 2924740f84a30a0d0bbedc9eb07bf9980c524688..1f5880bf0ecea89be99d1ad8f196524c2335f9e5 100644 --- a/database/explain.go +++ b/database/explain.go @@ -384,7 +384,7 @@ var ExplainExtra = map[string]string{ func findTablesInJSON(explainJSON string, depth int) { common.Log.Debug("findTablesInJSON Enter: depth(%d), json(%s)", depth, explainJSON) // 去除注释,语法检查 - explainJSON = string(RemoveSQLComments([]byte(explainJSON))) + explainJSON = RemoveSQLComments(explainJSON) if !gjson.Valid(explainJSON) { return } @@ -923,7 +923,7 @@ func parseVerticalExplainText(content string) (explainRows []*ExplainRow, err er // 解析文本形式JSON Explain信息 func parseJSONExplainText(content string) (*ExplainJSON, error) { explainJSON := new(ExplainJSON) - err := json.Unmarshal(RemoveSQLComments([]byte(content)), explainJSON) + err := json.Unmarshal([]byte(RemoveSQLComments(content)), explainJSON) return explainJSON, err } diff --git a/database/explain_test.go b/database/explain_test.go index 1aa5715db6fd7bd29e03db9e4b0d9b9216da4f51..01ef50f923322d55c2fb48bdf1d12207e39537b8 100644 --- a/database/explain_test.go +++ b/database/explain_test.go @@ -2366,7 +2366,7 @@ func TestExplain(t *testing.T) { func TestParseExplainText(t *testing.T) { for _, content := range exp { - pretty.Println(string(RemoveSQLComments([]byte(content)))) + pretty.Println(RemoveSQLComments(content)) pretty.Println(ParseExplainText(content)) } /* diff --git a/database/mysql.go b/database/mysql.go index 6f1ce4527d036e14efa1dd7547b82bdc1312e009..fa6927d46d17252829fc5e594f7b35d3beb0c843 100644 --- a/database/mysql.go +++ b/database/mysql.go @@ -259,10 +259,11 @@ func (db *Connector) IsView(tbName string) bool { } // RemoveSQLComments 去除SQL中的注释 -func RemoveSQLComments(sql []byte) []byte { +func RemoveSQLComments(sql string) string { + buf := []byte(sql) cmtReg := regexp.MustCompile(`("(""|[^"])*")|('(''|[^'])*')|(--[^\n\r]*)|(#.*)|(/\*([^*]|[\r\n]|(\*+([^*/]|[\r\n])))*\*+/)`) - return cmtReg.ReplaceAllFunc(sql, func(s []byte) []byte { + res := cmtReg.ReplaceAllFunc(buf, func(s []byte) []byte { if (s[0] == '"' && s[len(s)-1] == '"') || (s[0] == '\'' && s[len(s)-1] == '\'') || (string(s[:3]) == "/*!") { @@ -270,6 +271,7 @@ func RemoveSQLComments(sql []byte) []byte { } return []byte("") }) + return strings.TrimSpace(string(res)) } // 为了防止在 Online 环境进行误操作,通过 dangerousQuery 来判断能否在 Online 执行 diff --git a/database/mysql_test.go b/database/mysql_test.go index 2914c8896cc0e7c2c3aa9eae2c4ac7c1bc31083b..79ea7bdd2a84c6254086dbdeb665d08bd528391e 100644 --- a/database/mysql_test.go +++ b/database/mysql_test.go @@ -88,3 +88,24 @@ func TestSource(t *testing.T) { t.Error("Source result not match, expect 1, 1") } } + +func TestRemoveSQLComments(t *testing.T) { + SQLs := []string{ + `-- comment`, + `--`, + `# comment`, + `/* multi-line +comment*/`, + `-- +-- comment`, + } + + err := common.GoldenDiff(func() { + for _, sql := range SQLs { + fmt.Println(RemoveSQLComments(sql)) + } + }, t.Name(), update) + if err != nil { + t.Error(err) + } +} diff --git a/database/testdata/TestRemoveSQLComments.golden b/database/testdata/TestRemoveSQLComments.golden new file mode 100644 index 0000000000000000000000000000000000000000..3f2ff2d6cc8f257ffcade7ead1ca4042c0e884b9 --- /dev/null +++ b/database/testdata/TestRemoveSQLComments.golden @@ -0,0 +1,5 @@ + + + + +