提交 fd016354 编写于 作者: martianzhang's avatar martianzhang

escape mysql database, table, column name

上级 431027ed
...@@ -169,38 +169,3 @@ func TestRemoveComments(t *testing.T) { ...@@ -169,38 +169,3 @@ func TestRemoveComments(t *testing.T) {
} }
common.Log.Debug("Exiting function: %s", common.GetFunctionName()) common.Log.Debug("Exiting function: %s", common.GetFunctionName())
} }
func TestMysqlEscapeString(t *testing.T) {
common.Log.Debug("Entering function: %s", common.GetFunctionName())
var strs = []map[string]string{
{
"input": "abc",
"output": "abc",
},
{
"input": "'abc",
"output": "\\'abc",
},
{
"input": `
abc`,
"output": `\
abc`,
},
{
"input": "\"abc",
"output": "\\\"abc",
},
}
for _, str := range strs {
output, err := MysqlEscapeString(str["input"])
if err != nil {
t.Error("TestMysqlEscapeString", err)
} else {
if output != str["output"] {
t.Error("TestMysqlEscapeString", output, str["output"])
}
}
}
common.Log.Debug("Exiting function: %s", common.GetFunctionName())
}
...@@ -17,7 +17,6 @@ ...@@ -17,7 +17,6 @@
package ast package ast
import ( import (
"errors"
"fmt" "fmt"
"regexp" "regexp"
"strings" "strings"
...@@ -614,51 +613,6 @@ func Tokenizer(sql string) []Token { ...@@ -614,51 +613,6 @@ func Tokenizer(sql string) []Token {
return tokens return tokens
} }
// MysqlEscapeString mysql_real_escape_string
// https://github.com/liule/golang_escape
func MysqlEscapeString(source string) (string, error) {
var j = 0
if len(source) == 0 {
return "", errors.New("source is null")
}
tempStr := source[:]
desc := make([]byte, len(tempStr)*2)
for i := 0; i < len(tempStr); i++ {
flag := false
var escape byte
switch tempStr[i] {
case '\r':
flag = true
escape = '\r'
case '\n':
flag = true
escape = '\n'
case '\\':
flag = true
escape = '\\'
case '\'':
flag = true
escape = '\''
case '"':
flag = true
escape = '"'
case '\032':
flag = true
escape = 'Z'
default:
}
if flag {
desc[j] = '\\'
desc[j+1] = escape
j = j + 2
} else {
desc[j] = tempStr[i]
j = j + 1
}
}
return string(desc[0:j]), nil
}
// IsMysqlKeyword 判断是否是关键字 // IsMysqlKeyword 判断是否是关键字
func IsMysqlKeyword(name string) bool { func IsMysqlKeyword(name string) bool {
_, ok := mySQLKeywords[strings.ToLower(strings.TrimSpace(name))] _, ok := mySQLKeywords[strings.ToLower(strings.TrimSpace(name))]
......
...@@ -209,7 +209,7 @@ func (db *Connector) ColumnCardinality(tb, col string) float64 { ...@@ -209,7 +209,7 @@ func (db *Connector) ColumnCardinality(tb, col string) float64 {
} }
// 计算该列散粒度 // 计算该列散粒度
res, err := db.Query(fmt.Sprintf("select count(distinct `%s`) from `%s`.`%s`", col, db.Database, tb)) res, err := db.Query(fmt.Sprintf("select count(distinct `%s`) from `%s`.`%s`", StringEscape(col), StringEscape(db.Database), StringEscape(tb)))
if err != nil { if err != nil {
common.Log.Warn("(db *Connector) ColumnCardinality() Query Error: %v", err) common.Log.Warn("(db *Connector) ColumnCardinality() Query Error: %v", err)
return 0 return 0
...@@ -319,3 +319,51 @@ func NullString(buf []byte) string { ...@@ -319,3 +319,51 @@ func NullString(buf []byte) string {
} }
return string(buf) return string(buf)
} }
// StringEscape like C API mysql_escape_string()
// https://github.com/liule/golang_escape
func StringEscape(source string) string {
var j int
if source == "" {
return source
}
tempStr := source[:]
desc := make([]byte, len(tempStr)*2)
for i := 0; i < len(tempStr); i++ {
flag := false
var escape byte
switch tempStr[i] {
case '\000':
flag = true
escape = '\000'
case '\r':
flag = true
escape = '\r'
case '\n':
flag = true
escape = '\n'
case '\\':
flag = true
escape = '\\'
case '\'':
flag = true
escape = '\''
case '"':
flag = true
escape = '"'
case '\032':
flag = true
escape = 'Z'
default:
}
if flag {
desc[j] = '\\'
desc[j+1] = escape
j = j + 2
} else {
desc[j] = tempStr[i]
j = j + 1
}
}
return string(desc[0:j])
}
...@@ -198,3 +198,27 @@ func TestNullString(t *testing.T) { ...@@ -198,3 +198,27 @@ func TestNullString(t *testing.T) {
} }
common.Log.Debug("Exiting function: %s", common.GetFunctionName()) common.Log.Debug("Exiting function: %s", common.GetFunctionName())
} }
func TestStringEscaple(t *testing.T) {
common.Log.Debug("Entering function: %s", common.GetFunctionName())
cases := []string{
"",
"hello world",
"hello' world",
`hello" world`,
"hello\000world",
`hello\ world`,
"hello\032world",
"hello\rworld",
"hello\nworld",
}
err := common.GoldenDiff(func() {
for _, str := range cases {
fmt.Println(StringEscape(str))
}
}, t.Name(), update)
if err != nil {
t.Error(err)
}
common.Log.Debug("Exiting function: %s", common.GetFunctionName())
}
...@@ -96,7 +96,7 @@ func (db *Connector) SamplingData(onlineConn *Connector, tables ...string) error ...@@ -96,7 +96,7 @@ func (db *Connector) SamplingData(onlineConn *Connector, tables ...string) error
// startSampling sampling data from OnlineDSN to TestDSN // startSampling sampling data from OnlineDSN to TestDSN
func (db *Connector) startSampling(onlineConn *sql.DB, database, table string, where string) error { func (db *Connector) startSampling(onlineConn *sql.DB, database, table string, where string) error {
samplingQuery := fmt.Sprintf("select * from `%s`.`%s` %s", database, table, where) samplingQuery := fmt.Sprintf("select * from `%s`.`%s` %s", StringEscape(database), StringEscape(table), StringEscape(where))
common.Log.Debug("startSampling with Query: %s", samplingQuery) common.Log.Debug("startSampling with Query: %s", samplingQuery)
res, err := onlineConn.Query(samplingQuery) res, err := onlineConn.Query(samplingQuery)
if err != nil { if err != nil {
...@@ -167,7 +167,7 @@ func (db *Connector) startSampling(onlineConn *sql.DB, database, table string, w ...@@ -167,7 +167,7 @@ func (db *Connector) startSampling(onlineConn *sql.DB, database, table string, w
// 将泵取的数据转换成 insert 语句并在 testConn 数据库中执行 // 将泵取的数据转换成 insert 语句并在 testConn 数据库中执行
func (db *Connector) doSampling(table, colDef, values string) error { func (db *Connector) doSampling(table, colDef, values string) error {
// db.Database is hashed database name // db.Database is hashed database name
query := fmt.Sprintf("insert into `%s`.`%s` (%s) values %s;", db.Database, table, colDef, values) query := fmt.Sprintf("insert into `%s`.`%s` (%s) values %s;", StringEscape(db.Database), StringEscape(table), StringEscape(colDef), values)
res, err := db.Query(query) res, err := db.Query(query)
if res.Rows != nil { if res.Rows != nil {
res.Rows.Close() res.Rows.Close()
......
...@@ -113,7 +113,7 @@ func (db *Connector) ShowTableStatus(tableName string) (*TableStatInfo, error) { ...@@ -113,7 +113,7 @@ func (db *Connector) ShowTableStatus(tableName string) (*TableStatInfo, error) {
tbStatus := newTableStat(tableName) tbStatus := newTableStat(tableName)
// 执行 show table status // 执行 show table status
res, err := db.Query(fmt.Sprintf("show table status where name = '%s'", tbStatus.Name)) res, err := db.Query(fmt.Sprintf("show table status where name = '%s'", StringEscape(tbStatus.Name)))
if err != nil { if err != nil {
return tbStatus, err return tbStatus, err
} }
...@@ -208,7 +208,7 @@ func (db *Connector) ShowIndex(tableName string) (*TableIndexInfo, error) { ...@@ -208,7 +208,7 @@ func (db *Connector) ShowIndex(tableName string) (*TableIndexInfo, error) {
} }
// 执行 show create table // 执行 show create table
res, err := db.Query(fmt.Sprintf("show index from `%s`.`%s`", db.Database, tableName)) res, err := db.Query(fmt.Sprintf("show index from `%s`.`%s`", StringEscape(db.Database), StringEscape(tableName)))
if err != nil { if err != nil {
return nil, err return nil, err
} }
...@@ -348,7 +348,7 @@ func (db *Connector) ShowColumns(tableName string) (*TableDesc, error) { ...@@ -348,7 +348,7 @@ func (db *Connector) ShowColumns(tableName string) (*TableDesc, error) {
tbDesc := NewTableDesc(tableName) tbDesc := NewTableDesc(tableName)
// 执行 show create table // 执行 show create table
res, err := db.Query(fmt.Sprintf("show full columns from `%s`.`%s`", db.Database, tableName)) res, err := db.Query(fmt.Sprintf("show full columns from `%s`.`%s`", StringEscape(db.Database), StringEscape(tableName)))
if err != nil { if err != nil {
return nil, err return nil, err
} }
...@@ -408,7 +408,7 @@ func (db *Connector) showCreate(createType, name string) (string, error) { ...@@ -408,7 +408,7 @@ func (db *Connector) showCreate(createType, name string) (string, error) {
// SHOW CREATE TABLE tbl_name // SHOW CREATE TABLE tbl_name
// SHOW CREATE TRIGGER trigger_name // SHOW CREATE TRIGGER trigger_name
// SHOW CREATE VIEW view_name // SHOW CREATE VIEW view_name
res, err := db.Query(fmt.Sprintf("SHOW CREATE %s `%s`", createType, name)) res, err := db.Query(fmt.Sprintf("SHOW CREATE %s `%s`", StringEscape(createType), StringEscape(name)))
if err != nil { if err != nil {
return "", err return "", err
} }
...@@ -500,10 +500,10 @@ func (db *Connector) FindColumn(name, dbName string, tables ...string) ([]*commo ...@@ -500,10 +500,10 @@ func (db *Connector) FindColumn(name, dbName string, tables ...string) ([]*commo
var columns []*common.Column var columns []*common.Column
sql := fmt.Sprintf("SELECT "+ sql := fmt.Sprintf("SELECT "+
"c.TABLE_NAME,c.TABLE_SCHEMA,c.COLUMN_TYPE,c.CHARACTER_SET_NAME, c.COLLATION_NAME "+ "c.TABLE_NAME,c.TABLE_SCHEMA,c.COLUMN_TYPE,c.CHARACTER_SET_NAME, c.COLLATION_NAME "+
"FROM `INFORMATION_SCHEMA`.`COLUMNS` as c where c.COLUMN_NAME = '%s' ", name) "FROM `INFORMATION_SCHEMA`.`COLUMNS` as c where c.COLUMN_NAME = '%s' ", StringEscape(name))
if dbName != "" { if dbName != "" {
sql += fmt.Sprintf(" and c.table_schema = '%s'", dbName) sql += fmt.Sprintf(" and c.table_schema = '%s'", StringEscape(dbName))
} }
if len(tables) > 0 { if len(tables) > 0 {
...@@ -511,7 +511,7 @@ func (db *Connector) FindColumn(name, dbName string, tables ...string) ([]*commo ...@@ -511,7 +511,7 @@ func (db *Connector) FindColumn(name, dbName string, tables ...string) ([]*commo
for _, table := range tables { for _, table := range tables {
tmp = append(tmp, "'"+table+"'") tmp = append(tmp, "'"+table+"'")
} }
sql += fmt.Sprintf(" and c.table_name in (%s)", strings.Join(tmp, ",")) sql += fmt.Sprintf(" and c.table_name in (%s)", StringEscape(strings.Join(tmp, ",")))
} }
common.Log.Debug("FindColumn, execute SQL: %s", sql) common.Log.Debug("FindColumn, execute SQL: %s", sql)
...@@ -538,7 +538,7 @@ func (db *Connector) FindColumn(name, dbName string, tables ...string) ([]*commo ...@@ -538,7 +538,7 @@ func (db *Connector) FindColumn(name, dbName string, tables ...string) ([]*commo
// 由于 `INFORMATION_SCHEMA`.`TABLES` 表中未找到表的 character,所以从按照 MySQL 中 collation 的规则从中截取 character // 由于 `INFORMATION_SCHEMA`.`TABLES` 表中未找到表的 character,所以从按照 MySQL 中 collation 的规则从中截取 character
sql = fmt.Sprintf("SELECT `t`.`TABLE_COLLATION` FROM `INFORMATION_SCHEMA`.`TABLES` AS `t` "+ sql = fmt.Sprintf("SELECT `t`.`TABLE_COLLATION` FROM `INFORMATION_SCHEMA`.`TABLES` AS `t` "+
"WHERE `t`.`TABLE_NAME`='%s' AND `t`.`TABLE_SCHEMA` = '%s'", col.Table, col.DB) "WHERE `t`.`TABLE_NAME`='%s' AND `t`.`TABLE_SCHEMA` = '%s'", StringEscape(col.Table), StringEscape(col.DB))
common.Log.Debug("FindColumn, execute SQL: %s", sql) common.Log.Debug("FindColumn, execute SQL: %s", sql)
var newRes QueryResult var newRes QueryResult
...@@ -573,7 +573,7 @@ func (db *Connector) IsForeignKey(dbName, tbName, column string) bool { ...@@ -573,7 +573,7 @@ func (db *Connector) IsForeignKey(dbName, tbName, column string) bool {
"WHERE REFERENCED_TABLE_SCHEMA <> 'NULL' AND"+ "WHERE REFERENCED_TABLE_SCHEMA <> 'NULL' AND"+
" TABLE_NAME='%s' AND"+ " TABLE_NAME='%s' AND"+
" TABLE_SCHEMA='%s' AND"+ " TABLE_SCHEMA='%s' AND"+
" COLUMN_NAME='%s'", tbName, dbName, column) " COLUMN_NAME='%s'", StringEscape(tbName), StringEscape(dbName), StringEscape(column))
common.Log.Debug("IsForeignKey, execute SQL: %s", sql) common.Log.Debug("IsForeignKey, execute SQL: %s", sql)
res, err := db.Query(sql) res, err := db.Query(sql)
...@@ -604,10 +604,10 @@ type ReferenceValue struct { ...@@ -604,10 +604,10 @@ type ReferenceValue struct {
func (db *Connector) ShowReference(dbName string, tbName ...string) ([]ReferenceValue, error) { func (db *Connector) ShowReference(dbName string, tbName ...string) ([]ReferenceValue, error) {
var referenceValues []ReferenceValue var referenceValues []ReferenceValue
sql := `SELECT DISTINCT C.REFERENCED_TABLE_SCHEMA,C.REFERENCED_TABLE_NAME,C.TABLE_SCHEMA,C.TABLE_NAME,C.CONSTRAINT_NAME FROM INFORMATION_SCHEMA.KEY_COLUMN_USAGE C JOIN INFORMATION_SCHEMA. TABLES T ON T.TABLE_NAME = C.TABLE_NAME WHERE C.REFERENCED_TABLE_NAME IS NOT NULL` sql := `SELECT DISTINCT C.REFERENCED_TABLE_SCHEMA,C.REFERENCED_TABLE_NAME,C.TABLE_SCHEMA,C.TABLE_NAME,C.CONSTRAINT_NAME FROM INFORMATION_SCHEMA.KEY_COLUMN_USAGE C JOIN INFORMATION_SCHEMA. TABLES T ON T.TABLE_NAME = C.TABLE_NAME WHERE C.REFERENCED_TABLE_NAME IS NOT NULL`
sql = sql + fmt.Sprintf(` AND C.TABLE_SCHEMA = "%s"`, dbName) sql = sql + fmt.Sprintf(` AND C.TABLE_SCHEMA = "%s"`, StringEscape(dbName))
if len(tbName) > 0 { if len(tbName) > 0 {
extra := fmt.Sprintf(` AND C.TABLE_NAME IN ("%s")`, strings.Join(tbName, `","`)) extra := fmt.Sprintf(` AND C.TABLE_NAME IN ("%s")`, StringEscape(strings.Join(tbName, `","`)))
sql = sql + extra sql = sql + extra
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册