diff --git a/ast/testdata/TestQueryType.golden b/ast/testdata/TestQueryType.golden index 7b9b540aee5194d1a983c3f57393b4df8b1bc68d..8f016975030de5f801610d28d1e7c32e3b06f46b 100644 --- a/ast/testdata/TestQueryType.golden +++ b/ast/testdata/TestQueryType.golden @@ -1,5 +1,8 @@ SELECT SELECT +GRANT +REVOKE +SELECT SELECT SELECT SELECT diff --git a/ast/token.go b/ast/token.go index 9021661ba5bd9e5b57532ef058b7a69198ba6b9a..70afb522bde040a08a3db897bed70213d08e316b 100644 --- a/ast/token.go +++ b/ast/token.go @@ -446,6 +446,7 @@ var mySQLKeywords = map[string]string{ "geometry": "GEOMETRY", "geometrycollection": "GEOMETRYCOLLECTION", "global": "GLOBAL", + "grant": "GRANT", "group": "GROUP", "group_concat": "GROUP_CONCAT", "having": "HAVING", @@ -512,6 +513,7 @@ var mySQLKeywords = map[string]string{ "reorganize": "REORGANIZE", "repair": "REPAIR", "replace": "REPLACE", + "revoke": "REVOKE", "right": "RIGHT", "rlike": "REGEXP", "rollback": "ROLLBACK", @@ -989,13 +991,14 @@ func NewLines(buf []byte) int { // QueryType get query type such as SELECT/INSERT/DELETE/CREATE/ALTER func QueryType(sql string) string { - var typ string - tokens := Tokenizer(sql) + tokens := Tokenize(sql) for _, token := range tokens { - if val, ok := mySQLKeywords[token.Val]; ok { - typ = val - break + // use strings.Fields for 'ALTER TABLE' token split + for _, tk := range strings.Fields(strings.TrimSpace(token.Val)) { + if val, ok := mySQLKeywords[strings.ToLower(tk)]; ok { + return val + } } } - return typ + return "" } diff --git a/ast/token_test.go b/ast/token_test.go index d72e718d229197b3a371c759a68ddd66fd2159eb..c8f7ec998ba9792012ad365000433376390d9b2e 100644 --- a/ast/token_test.go +++ b/ast/token_test.go @@ -241,8 +241,14 @@ func TestNewLines(t *testing.T) { func TestQueryType(t *testing.T) { common.Log.Debug("Entering function: %s", common.GetFunctionName()) var testSQLs = []string{ + `/*comment*/ select 1`, `(select 1)`, + `grant select on *.* to user@'localhost'`, + `REVOKE INSERT ON *.* FROM 'jeffrey'@'localhost';`, } + // fmt.Println(testSQLs[len(testSQLs)-1]) + // fmt.Println(QueryType(testSQLs[len(testSQLs)-1])) + // return err := common.GoldenDiff(func() { for _, buf := range append(testSQLs, common.TestSQLs...) { fmt.Println(QueryType(buf))