From 59094bf39bcaf83908fcabb86c7563296fa95516 Mon Sep 17 00:00:00 2001 From: Leon Zhang Date: Fri, 28 Dec 2018 10:33:52 +0800 Subject: [PATCH] string escape no_backslash_escapes --- database/mysql.go | 45 +++++++++++++++--- database/mysql_test.go | 5 +- database/sampling.go | 17 +++++-- database/show.go | 22 ++++----- ...StringEscaple.golden => TestEscape.golden} | Bin 107 -> 208 bytes 5 files changed, 66 insertions(+), 23 deletions(-) rename database/testdata/{TestStringEscaple.golden => TestEscape.golden} (51%) diff --git a/database/mysql.go b/database/mysql.go index a623c55..06388a6 100644 --- a/database/mysql.go +++ b/database/mysql.go @@ -17,9 +17,11 @@ package database import ( + "bytes" "database/sql" "errors" "fmt" + "io" "regexp" "strconv" "strings" @@ -39,7 +41,6 @@ type Connector struct { Pass string Database string Charset string - Net string Conn *sql.DB } @@ -202,7 +203,11 @@ func (db *Connector) ColumnCardinality(tb, col string) float64 { } // 计算该列散粒度 - res, err := db.Query(fmt.Sprintf("select count(distinct `%s`) from `%s`.`%s`", StringEscape(col), StringEscape(db.Database), StringEscape(tb))) + db.Conn.Stats() + res, err := db.Query(fmt.Sprintf("select count(distinct `%s`) from `%s`.`%s`", + Escape(col, false), + Escape(db.Database, false), + Escape(tb, false))) if err != nil { common.Log.Warn("(db *Connector) ColumnCardinality() Query Error: %v", err) return 0 @@ -313,19 +318,37 @@ func NullString(buf []byte) string { return string(buf) } -// StringEscape like C API mysql_escape_string() +// quoteEscape sql_mode=no_backslash_escapes +func quoteEscape(source string) string { + var buf bytes.Buffer + last := 0 + for ii, bb := range source { + if bb == '\'' { + _, err := io.WriteString(&buf, source[last:ii]) + common.LogIfWarn(err, "") + _, err = io.WriteString(&buf, `''`) + common.LogIfWarn(err, "") + last = ii + 1 + } + } + _, err := io.WriteString(&buf, source[last:]) + common.LogIfWarn(err, "") + return buf.String() +} + +// stringEscape mysql_escape_string // https://github.com/liule/golang_escape -func StringEscape(source string) string { +func stringEscape(source string) string { var j int if source == "" { return source } tempStr := source[:] desc := make([]byte, len(tempStr)*2) - for i := 0; i < len(tempStr); i++ { + for i, b := range tempStr { flag := false var escape byte - switch tempStr[i] { + switch b { case '\000': flag = true escape = '\000' @@ -360,3 +383,13 @@ func StringEscape(source string) string { } return string(desc[0:j]) } + +// Escape like C API mysql_escape_string() +func Escape(source string, NoBackslashEscapes bool) string { + // NoBackslashEscapes https://dev.mysql.com/doc/refman/8.0/en/sql-mode.html#sqlmode_no_backslash_escapes + // TODO: NoBackslashEscapes always false + if NoBackslashEscapes { + return quoteEscape(source) + } + return stringEscape(source) +} diff --git a/database/mysql_test.go b/database/mysql_test.go index b695628..18f20fc 100644 --- a/database/mysql_test.go +++ b/database/mysql_test.go @@ -199,7 +199,7 @@ func TestNullString(t *testing.T) { common.Log.Debug("Exiting function: %s", common.GetFunctionName()) } -func TestStringEscaple(t *testing.T) { +func TestEscape(t *testing.T) { common.Log.Debug("Entering function: %s", common.GetFunctionName()) cases := []string{ "", @@ -214,7 +214,8 @@ func TestStringEscaple(t *testing.T) { } err := common.GoldenDiff(func() { for _, str := range cases { - fmt.Println(StringEscape(str)) + fmt.Println(Escape(str, false)) + fmt.Println(Escape(str, true)) } }, t.Name(), update) if err != nil { diff --git a/database/sampling.go b/database/sampling.go index 707ad03..46bfdef 100644 --- a/database/sampling.go +++ b/database/sampling.go @@ -96,7 +96,10 @@ func (db *Connector) SamplingData(onlineConn *Connector, tables ...string) error // startSampling sampling data from OnlineDSN to TestDSN func (db *Connector) startSampling(onlineConn *sql.DB, database, table string, where string) error { - samplingQuery := fmt.Sprintf("select * from `%s`.`%s` %s", StringEscape(database), StringEscape(table), StringEscape(where)) + samplingQuery := fmt.Sprintf("select * from `%s`.`%s` %s", + Escape(database, false), + Escape(table, false), + Escape(where, false)) common.Log.Debug("startSampling with Query: %s", samplingQuery) res, err := onlineConn.Query(samplingQuery) if err != nil { @@ -136,8 +139,11 @@ func (db *Connector) startSampling(onlineConn *sql.DB, database, table string, w switch columnTypes[i].DatabaseTypeName() { case "TIMESTAMP", "DATETIME": t, err := time.Parse(time.RFC3339, string(val)) - common.LogIfWarn(err, "") - values = append(values, fmt.Sprintf(`"%s"`, TimeString(t))) + if err != nil { + values = append(values, fmt.Sprintf(`"%s"`, string(val))) + } else { + values = append(values, fmt.Sprintf(`"%s"`, TimeString(t))) + } default: values = append(values, fmt.Sprintf(`unhex("%s")`, fmt.Sprintf("%x", val))) } @@ -167,7 +173,10 @@ func (db *Connector) startSampling(onlineConn *sql.DB, database, table string, w // 将泵取的数据转换成 insert 语句并在 testConn 数据库中执行 func (db *Connector) doSampling(table, colDef, values string) error { // db.Database is hashed database name - query := fmt.Sprintf("insert into `%s`.`%s` (%s) values %s;", StringEscape(db.Database), StringEscape(table), StringEscape(colDef), values) + query := fmt.Sprintf("insert into `%s`.`%s` (%s) values %s;", + Escape(db.Database, false), + Escape(table, false), + Escape(colDef, false), values) res, err := db.Query(query) if res.Rows != nil { res.Rows.Close() diff --git a/database/show.go b/database/show.go index fbe9620..905b069 100644 --- a/database/show.go +++ b/database/show.go @@ -113,7 +113,7 @@ func (db *Connector) ShowTableStatus(tableName string) (*TableStatInfo, error) { tbStatus := newTableStat(tableName) // 执行 show table status - res, err := db.Query(fmt.Sprintf("show table status where name = '%s'", StringEscape(tbStatus.Name))) + res, err := db.Query(fmt.Sprintf("show table status where name = '%s'", Escape(tbStatus.Name, false))) if err != nil { return tbStatus, err } @@ -208,7 +208,7 @@ func (db *Connector) ShowIndex(tableName string) (*TableIndexInfo, error) { } // 执行 show create table - res, err := db.Query(fmt.Sprintf("show index from `%s`.`%s`", StringEscape(db.Database), StringEscape(tableName))) + res, err := db.Query(fmt.Sprintf("show index from `%s`.`%s`", Escape(db.Database, false), Escape(tableName, false))) if err != nil { return nil, err } @@ -348,7 +348,7 @@ func (db *Connector) ShowColumns(tableName string) (*TableDesc, error) { tbDesc := NewTableDesc(tableName) // 执行 show create table - res, err := db.Query(fmt.Sprintf("show full columns from `%s`.`%s`", StringEscape(db.Database), StringEscape(tableName))) + res, err := db.Query(fmt.Sprintf("show full columns from `%s`.`%s`", Escape(db.Database, false), Escape(tableName, false))) if err != nil { return nil, err } @@ -408,7 +408,7 @@ func (db *Connector) showCreate(createType, name string) (string, error) { // SHOW CREATE TABLE tbl_name // SHOW CREATE TRIGGER trigger_name // SHOW CREATE VIEW view_name - res, err := db.Query(fmt.Sprintf("SHOW CREATE %s `%s`", createType, StringEscape(name))) + res, err := db.Query(fmt.Sprintf("SHOW CREATE %s `%s`", createType, Escape(name, false))) if err != nil { return "", err } @@ -500,16 +500,16 @@ func (db *Connector) FindColumn(name, dbName string, tables ...string) ([]*commo var columns []*common.Column sql := fmt.Sprintf("SELECT "+ "c.TABLE_NAME,c.TABLE_SCHEMA,c.COLUMN_TYPE,c.CHARACTER_SET_NAME, c.COLLATION_NAME "+ - "FROM `INFORMATION_SCHEMA`.`COLUMNS` as c where c.COLUMN_NAME = '%s' ", StringEscape(name)) + "FROM `INFORMATION_SCHEMA`.`COLUMNS` as c where c.COLUMN_NAME = '%s' ", Escape(name, false)) if dbName != "" { - sql += fmt.Sprintf(" and c.table_schema = '%s'", StringEscape(dbName)) + sql += fmt.Sprintf(" and c.table_schema = '%s'", Escape(dbName, false)) } if len(tables) > 0 { var tmp []string for _, table := range tables { - tmp = append(tmp, "'"+StringEscape(table)+"'") + tmp = append(tmp, "'"+Escape(table, false)+"'") } sql += fmt.Sprintf(" and c.table_name in (%s)", strings.Join(tmp, ",")) } @@ -538,7 +538,7 @@ func (db *Connector) FindColumn(name, dbName string, tables ...string) ([]*commo // 由于 `INFORMATION_SCHEMA`.`TABLES` 表中未找到表的 character,所以从按照 MySQL 中 collation 的规则从中截取 character sql = fmt.Sprintf("SELECT `t`.`TABLE_COLLATION` FROM `INFORMATION_SCHEMA`.`TABLES` AS `t` "+ - "WHERE `t`.`TABLE_NAME`='%s' AND `t`.`TABLE_SCHEMA` = '%s'", StringEscape(col.Table), StringEscape(col.DB)) + "WHERE `t`.`TABLE_NAME`='%s' AND `t`.`TABLE_SCHEMA` = '%s'", Escape(col.Table, false), Escape(col.DB, false)) common.Log.Debug("FindColumn, execute SQL: %s", sql) var newRes QueryResult @@ -573,7 +573,7 @@ func (db *Connector) IsForeignKey(dbName, tbName, column string) bool { "WHERE REFERENCED_TABLE_SCHEMA <> 'NULL' AND"+ " TABLE_NAME='%s' AND"+ " TABLE_SCHEMA='%s' AND"+ - " COLUMN_NAME='%s'", StringEscape(tbName), StringEscape(dbName), StringEscape(column)) + " COLUMN_NAME='%s'", Escape(tbName, false), Escape(dbName, false), Escape(column, false)) common.Log.Debug("IsForeignKey, execute SQL: %s", sql) res, err := db.Query(sql) @@ -604,11 +604,11 @@ type ReferenceValue struct { func (db *Connector) ShowReference(dbName string, tbName ...string) ([]ReferenceValue, error) { var referenceValues []ReferenceValue sql := `SELECT DISTINCT C.REFERENCED_TABLE_SCHEMA,C.REFERENCED_TABLE_NAME,C.TABLE_SCHEMA,C.TABLE_NAME,C.CONSTRAINT_NAME FROM INFORMATION_SCHEMA.KEY_COLUMN_USAGE C JOIN INFORMATION_SCHEMA. TABLES T ON T.TABLE_NAME = C.TABLE_NAME WHERE C.REFERENCED_TABLE_NAME IS NOT NULL` - sql = sql + fmt.Sprintf(` AND C.TABLE_SCHEMA = "%s"`, StringEscape(dbName)) + sql = sql + fmt.Sprintf(` AND C.TABLE_SCHEMA = "%s"`, Escape(dbName, false)) var tables []string for _, tb := range tbName { - tables = append(tables, "'"+StringEscape(tb)+"'") + tables = append(tables, "'"+Escape(tb, false)+"'") } if len(tbName) > 0 { extra := fmt.Sprintf(` AND C.TABLE_NAME IN ("%s")`, strings.Join(tables, ",")) diff --git a/database/testdata/TestStringEscaple.golden b/database/testdata/TestEscape.golden similarity index 51% rename from database/testdata/TestStringEscaple.golden rename to database/testdata/TestEscape.golden index bb8f853c5b655d9476c1d2be848e5f5c7f87f566..3e14bff8db03ab878d54a87a50f790464ba9af60 100644 GIT binary patch literal 208 zcmd<$%1F)0$yX@PFUm>5WXGr@%c!I9Vw8~OP&hFRNai84V`7l?pmL&+^hhDGV|bA` M$m|#{Bnc=R0Pq${3;+NC literal 107 rcmd