diff --git a/ast/pretty_test.go b/ast/pretty_test.go index 7bc531eda56d9fd8676601fe5eea2e90c7d3a040..15a4ef5657dd76aa12178df68b8f2a0e2c0904b0 100644 --- a/ast/pretty_test.go +++ b/ast/pretty_test.go @@ -169,38 +169,3 @@ func TestRemoveComments(t *testing.T) { } common.Log.Debug("Exiting function: %s", common.GetFunctionName()) } - -func TestMysqlEscapeString(t *testing.T) { - common.Log.Debug("Entering function: %s", common.GetFunctionName()) - var strs = []map[string]string{ - { - "input": "abc", - "output": "abc", - }, - { - "input": "'abc", - "output": "\\'abc", - }, - { - "input": ` -abc`, - "output": `\ -abc`, - }, - { - "input": "\"abc", - "output": "\\\"abc", - }, - } - for _, str := range strs { - output, err := MysqlEscapeString(str["input"]) - if err != nil { - t.Error("TestMysqlEscapeString", err) - } else { - if output != str["output"] { - t.Error("TestMysqlEscapeString", output, str["output"]) - } - } - } - common.Log.Debug("Exiting function: %s", common.GetFunctionName()) -} diff --git a/ast/token.go b/ast/token.go index 4d1d2ecc1ea1fc76a7e87b86e0edc4b797aa5563..3e1493761ad8188ec2bd0d21e585e912e4cd4edc 100644 --- a/ast/token.go +++ b/ast/token.go @@ -17,7 +17,6 @@ package ast import ( - "errors" "fmt" "regexp" "strings" @@ -614,51 +613,6 @@ func Tokenizer(sql string) []Token { return tokens } -// MysqlEscapeString mysql_real_escape_string -// https://github.com/liule/golang_escape -func MysqlEscapeString(source string) (string, error) { - var j = 0 - if len(source) == 0 { - return "", errors.New("source is null") - } - tempStr := source[:] - desc := make([]byte, len(tempStr)*2) - for i := 0; i < len(tempStr); i++ { - flag := false - var escape byte - switch tempStr[i] { - case '\r': - flag = true - escape = '\r' - case '\n': - flag = true - escape = '\n' - case '\\': - flag = true - escape = '\\' - case '\'': - flag = true - escape = '\'' - case '"': - flag = true - escape = '"' - case '\032': - flag = true - escape = 'Z' - default: - } - if flag { - desc[j] = '\\' - desc[j+1] = escape - j = j + 2 - } else { - desc[j] = tempStr[i] - j = j + 1 - } - } - return string(desc[0:j]), nil -} - // IsMysqlKeyword 判断是否是关键字 func IsMysqlKeyword(name string) bool { _, ok := mySQLKeywords[strings.ToLower(strings.TrimSpace(name))] diff --git a/database/mysql.go b/database/mysql.go index df2153b6c7adfc61fd6fde443ef988cd398849b9..738f1a9878ab13b8fb7c11da40752823e5c38c22 100644 --- a/database/mysql.go +++ b/database/mysql.go @@ -209,7 +209,7 @@ func (db *Connector) ColumnCardinality(tb, col string) float64 { } // 计算该列散粒度 - res, err := db.Query(fmt.Sprintf("select count(distinct `%s`) from `%s`.`%s`", col, db.Database, tb)) + res, err := db.Query(fmt.Sprintf("select count(distinct `%s`) from `%s`.`%s`", StringEscape(col), StringEscape(db.Database), StringEscape(tb))) if err != nil { common.Log.Warn("(db *Connector) ColumnCardinality() Query Error: %v", err) return 0 @@ -319,3 +319,51 @@ func NullString(buf []byte) string { } return string(buf) } + +// StringEscape like C API mysql_escape_string() +// https://github.com/liule/golang_escape +func StringEscape(source string) string { + var j int + if source == "" { + return source + } + tempStr := source[:] + desc := make([]byte, len(tempStr)*2) + for i := 0; i < len(tempStr); i++ { + flag := false + var escape byte + switch tempStr[i] { + case '\000': + flag = true + escape = '\000' + case '\r': + flag = true + escape = '\r' + case '\n': + flag = true + escape = '\n' + case '\\': + flag = true + escape = '\\' + case '\'': + flag = true + escape = '\'' + case '"': + flag = true + escape = '"' + case '\032': + flag = true + escape = 'Z' + default: + } + if flag { + desc[j] = '\\' + desc[j+1] = escape + j = j + 2 + } else { + desc[j] = tempStr[i] + j = j + 1 + } + } + return string(desc[0:j]) +} diff --git a/database/mysql_test.go b/database/mysql_test.go index d85949dca8e8930522cc6fc307fce8113c104de3..b69562847fe5275f1f08eb160aa15f8c543e5249 100644 --- a/database/mysql_test.go +++ b/database/mysql_test.go @@ -198,3 +198,27 @@ func TestNullString(t *testing.T) { } common.Log.Debug("Exiting function: %s", common.GetFunctionName()) } + +func TestStringEscaple(t *testing.T) { + common.Log.Debug("Entering function: %s", common.GetFunctionName()) + cases := []string{ + "", + "hello world", + "hello' world", + `hello" world`, + "hello\000world", + `hello\ world`, + "hello\032world", + "hello\rworld", + "hello\nworld", + } + err := common.GoldenDiff(func() { + for _, str := range cases { + fmt.Println(StringEscape(str)) + } + }, t.Name(), update) + if err != nil { + t.Error(err) + } + common.Log.Debug("Exiting function: %s", common.GetFunctionName()) +} diff --git a/database/sampling.go b/database/sampling.go index bb0daef915c3d6e06acc051d7b7cd6b996fccfe3..707ad03d43af9fb51c2bffd2746dcf422767daf2 100644 --- a/database/sampling.go +++ b/database/sampling.go @@ -96,7 +96,7 @@ func (db *Connector) SamplingData(onlineConn *Connector, tables ...string) error // startSampling sampling data from OnlineDSN to TestDSN func (db *Connector) startSampling(onlineConn *sql.DB, database, table string, where string) error { - samplingQuery := fmt.Sprintf("select * from `%s`.`%s` %s", database, table, where) + samplingQuery := fmt.Sprintf("select * from `%s`.`%s` %s", StringEscape(database), StringEscape(table), StringEscape(where)) common.Log.Debug("startSampling with Query: %s", samplingQuery) res, err := onlineConn.Query(samplingQuery) if err != nil { @@ -167,7 +167,7 @@ func (db *Connector) startSampling(onlineConn *sql.DB, database, table string, w // 将泵取的数据转换成 insert 语句并在 testConn 数据库中执行 func (db *Connector) doSampling(table, colDef, values string) error { // db.Database is hashed database name - query := fmt.Sprintf("insert into `%s`.`%s` (%s) values %s;", db.Database, table, colDef, values) + query := fmt.Sprintf("insert into `%s`.`%s` (%s) values %s;", StringEscape(db.Database), StringEscape(table), StringEscape(colDef), values) res, err := db.Query(query) if res.Rows != nil { res.Rows.Close() diff --git a/database/show.go b/database/show.go index b0a47e41cad8eb180520e14f950b6d249045b75b..3396a88954aa4e50af44d32f4edfdbba4c7aea8e 100644 --- a/database/show.go +++ b/database/show.go @@ -113,7 +113,7 @@ func (db *Connector) ShowTableStatus(tableName string) (*TableStatInfo, error) { tbStatus := newTableStat(tableName) // 执行 show table status - res, err := db.Query(fmt.Sprintf("show table status where name = '%s'", tbStatus.Name)) + res, err := db.Query(fmt.Sprintf("show table status where name = '%s'", StringEscape(tbStatus.Name))) if err != nil { return tbStatus, err } @@ -208,7 +208,7 @@ func (db *Connector) ShowIndex(tableName string) (*TableIndexInfo, error) { } // 执行 show create table - res, err := db.Query(fmt.Sprintf("show index from `%s`.`%s`", db.Database, tableName)) + res, err := db.Query(fmt.Sprintf("show index from `%s`.`%s`", StringEscape(db.Database), StringEscape(tableName))) if err != nil { return nil, err } @@ -348,7 +348,7 @@ func (db *Connector) ShowColumns(tableName string) (*TableDesc, error) { tbDesc := NewTableDesc(tableName) // 执行 show create table - res, err := db.Query(fmt.Sprintf("show full columns from `%s`.`%s`", db.Database, tableName)) + res, err := db.Query(fmt.Sprintf("show full columns from `%s`.`%s`", StringEscape(db.Database), StringEscape(tableName))) if err != nil { return nil, err } @@ -408,7 +408,7 @@ func (db *Connector) showCreate(createType, name string) (string, error) { // SHOW CREATE TABLE tbl_name // SHOW CREATE TRIGGER trigger_name // SHOW CREATE VIEW view_name - res, err := db.Query(fmt.Sprintf("SHOW CREATE %s `%s`", createType, name)) + res, err := db.Query(fmt.Sprintf("SHOW CREATE %s `%s`", StringEscape(createType), StringEscape(name))) if err != nil { return "", err } @@ -500,10 +500,10 @@ func (db *Connector) FindColumn(name, dbName string, tables ...string) ([]*commo var columns []*common.Column sql := fmt.Sprintf("SELECT "+ "c.TABLE_NAME,c.TABLE_SCHEMA,c.COLUMN_TYPE,c.CHARACTER_SET_NAME, c.COLLATION_NAME "+ - "FROM `INFORMATION_SCHEMA`.`COLUMNS` as c where c.COLUMN_NAME = '%s' ", name) + "FROM `INFORMATION_SCHEMA`.`COLUMNS` as c where c.COLUMN_NAME = '%s' ", StringEscape(name)) if dbName != "" { - sql += fmt.Sprintf(" and c.table_schema = '%s'", dbName) + sql += fmt.Sprintf(" and c.table_schema = '%s'", StringEscape(dbName)) } if len(tables) > 0 { @@ -511,7 +511,7 @@ func (db *Connector) FindColumn(name, dbName string, tables ...string) ([]*commo for _, table := range tables { tmp = append(tmp, "'"+table+"'") } - sql += fmt.Sprintf(" and c.table_name in (%s)", strings.Join(tmp, ",")) + sql += fmt.Sprintf(" and c.table_name in (%s)", StringEscape(strings.Join(tmp, ","))) } common.Log.Debug("FindColumn, execute SQL: %s", sql) @@ -538,7 +538,7 @@ func (db *Connector) FindColumn(name, dbName string, tables ...string) ([]*commo // 由于 `INFORMATION_SCHEMA`.`TABLES` 表中未找到表的 character,所以从按照 MySQL 中 collation 的规则从中截取 character sql = fmt.Sprintf("SELECT `t`.`TABLE_COLLATION` FROM `INFORMATION_SCHEMA`.`TABLES` AS `t` "+ - "WHERE `t`.`TABLE_NAME`='%s' AND `t`.`TABLE_SCHEMA` = '%s'", col.Table, col.DB) + "WHERE `t`.`TABLE_NAME`='%s' AND `t`.`TABLE_SCHEMA` = '%s'", StringEscape(col.Table), StringEscape(col.DB)) common.Log.Debug("FindColumn, execute SQL: %s", sql) var newRes QueryResult @@ -573,7 +573,7 @@ func (db *Connector) IsForeignKey(dbName, tbName, column string) bool { "WHERE REFERENCED_TABLE_SCHEMA <> 'NULL' AND"+ " TABLE_NAME='%s' AND"+ " TABLE_SCHEMA='%s' AND"+ - " COLUMN_NAME='%s'", tbName, dbName, column) + " COLUMN_NAME='%s'", StringEscape(tbName), StringEscape(dbName), StringEscape(column)) common.Log.Debug("IsForeignKey, execute SQL: %s", sql) res, err := db.Query(sql) @@ -604,10 +604,10 @@ type ReferenceValue struct { func (db *Connector) ShowReference(dbName string, tbName ...string) ([]ReferenceValue, error) { var referenceValues []ReferenceValue sql := `SELECT DISTINCT C.REFERENCED_TABLE_SCHEMA,C.REFERENCED_TABLE_NAME,C.TABLE_SCHEMA,C.TABLE_NAME,C.CONSTRAINT_NAME FROM INFORMATION_SCHEMA.KEY_COLUMN_USAGE C JOIN INFORMATION_SCHEMA. TABLES T ON T.TABLE_NAME = C.TABLE_NAME WHERE C.REFERENCED_TABLE_NAME IS NOT NULL` - sql = sql + fmt.Sprintf(` AND C.TABLE_SCHEMA = "%s"`, dbName) + sql = sql + fmt.Sprintf(` AND C.TABLE_SCHEMA = "%s"`, StringEscape(dbName)) if len(tbName) > 0 { - extra := fmt.Sprintf(` AND C.TABLE_NAME IN ("%s")`, strings.Join(tbName, `","`)) + extra := fmt.Sprintf(` AND C.TABLE_NAME IN ("%s")`, StringEscape(strings.Join(tbName, `","`))) sql = sql + extra } diff --git a/database/testdata/TestStringEscaple.golden b/database/testdata/TestStringEscaple.golden new file mode 100644 index 0000000000000000000000000000000000000000..bb8f853c5b655d9476c1d2be848e5f5c7f87f566 Binary files /dev/null and b/database/testdata/TestStringEscaple.golden differ