提交 59094bf3 编写于 作者: martianzhang's avatar martianzhang

string escape no_backslash_escapes

上级 84a67026
......@@ -17,9 +17,11 @@
package database
import (
"bytes"
"database/sql"
"errors"
"fmt"
"io"
"regexp"
"strconv"
"strings"
......@@ -39,7 +41,6 @@ type Connector struct {
Pass string
Database string
Charset string
Net string
Conn *sql.DB
}
......@@ -202,7 +203,11 @@ func (db *Connector) ColumnCardinality(tb, col string) float64 {
}
// 计算该列散粒度
res, err := db.Query(fmt.Sprintf("select count(distinct `%s`) from `%s`.`%s`", StringEscape(col), StringEscape(db.Database), StringEscape(tb)))
db.Conn.Stats()
res, err := db.Query(fmt.Sprintf("select count(distinct `%s`) from `%s`.`%s`",
Escape(col, false),
Escape(db.Database, false),
Escape(tb, false)))
if err != nil {
common.Log.Warn("(db *Connector) ColumnCardinality() Query Error: %v", err)
return 0
......@@ -313,19 +318,37 @@ func NullString(buf []byte) string {
return string(buf)
}
// StringEscape like C API mysql_escape_string()
// quoteEscape sql_mode=no_backslash_escapes
func quoteEscape(source string) string {
var buf bytes.Buffer
last := 0
for ii, bb := range source {
if bb == '\'' {
_, err := io.WriteString(&buf, source[last:ii])
common.LogIfWarn(err, "")
_, err = io.WriteString(&buf, `''`)
common.LogIfWarn(err, "")
last = ii + 1
}
}
_, err := io.WriteString(&buf, source[last:])
common.LogIfWarn(err, "")
return buf.String()
}
// stringEscape mysql_escape_string
// https://github.com/liule/golang_escape
func StringEscape(source string) string {
func stringEscape(source string) string {
var j int
if source == "" {
return source
}
tempStr := source[:]
desc := make([]byte, len(tempStr)*2)
for i := 0; i < len(tempStr); i++ {
for i, b := range tempStr {
flag := false
var escape byte
switch tempStr[i] {
switch b {
case '\000':
flag = true
escape = '\000'
......@@ -360,3 +383,13 @@ func StringEscape(source string) string {
}
return string(desc[0:j])
}
// Escape like C API mysql_escape_string()
func Escape(source string, NoBackslashEscapes bool) string {
// NoBackslashEscapes https://dev.mysql.com/doc/refman/8.0/en/sql-mode.html#sqlmode_no_backslash_escapes
// TODO: NoBackslashEscapes always false
if NoBackslashEscapes {
return quoteEscape(source)
}
return stringEscape(source)
}
......@@ -199,7 +199,7 @@ func TestNullString(t *testing.T) {
common.Log.Debug("Exiting function: %s", common.GetFunctionName())
}
func TestStringEscaple(t *testing.T) {
func TestEscape(t *testing.T) {
common.Log.Debug("Entering function: %s", common.GetFunctionName())
cases := []string{
"",
......@@ -214,7 +214,8 @@ func TestStringEscaple(t *testing.T) {
}
err := common.GoldenDiff(func() {
for _, str := range cases {
fmt.Println(StringEscape(str))
fmt.Println(Escape(str, false))
fmt.Println(Escape(str, true))
}
}, t.Name(), update)
if err != nil {
......
......@@ -96,7 +96,10 @@ func (db *Connector) SamplingData(onlineConn *Connector, tables ...string) error
// startSampling sampling data from OnlineDSN to TestDSN
func (db *Connector) startSampling(onlineConn *sql.DB, database, table string, where string) error {
samplingQuery := fmt.Sprintf("select * from `%s`.`%s` %s", StringEscape(database), StringEscape(table), StringEscape(where))
samplingQuery := fmt.Sprintf("select * from `%s`.`%s` %s",
Escape(database, false),
Escape(table, false),
Escape(where, false))
common.Log.Debug("startSampling with Query: %s", samplingQuery)
res, err := onlineConn.Query(samplingQuery)
if err != nil {
......@@ -136,8 +139,11 @@ func (db *Connector) startSampling(onlineConn *sql.DB, database, table string, w
switch columnTypes[i].DatabaseTypeName() {
case "TIMESTAMP", "DATETIME":
t, err := time.Parse(time.RFC3339, string(val))
common.LogIfWarn(err, "")
values = append(values, fmt.Sprintf(`"%s"`, TimeString(t)))
if err != nil {
values = append(values, fmt.Sprintf(`"%s"`, string(val)))
} else {
values = append(values, fmt.Sprintf(`"%s"`, TimeString(t)))
}
default:
values = append(values, fmt.Sprintf(`unhex("%s")`, fmt.Sprintf("%x", val)))
}
......@@ -167,7 +173,10 @@ func (db *Connector) startSampling(onlineConn *sql.DB, database, table string, w
// 将泵取的数据转换成 insert 语句并在 testConn 数据库中执行
func (db *Connector) doSampling(table, colDef, values string) error {
// db.Database is hashed database name
query := fmt.Sprintf("insert into `%s`.`%s` (%s) values %s;", StringEscape(db.Database), StringEscape(table), StringEscape(colDef), values)
query := fmt.Sprintf("insert into `%s`.`%s` (%s) values %s;",
Escape(db.Database, false),
Escape(table, false),
Escape(colDef, false), values)
res, err := db.Query(query)
if res.Rows != nil {
res.Rows.Close()
......
......@@ -113,7 +113,7 @@ func (db *Connector) ShowTableStatus(tableName string) (*TableStatInfo, error) {
tbStatus := newTableStat(tableName)
// 执行 show table status
res, err := db.Query(fmt.Sprintf("show table status where name = '%s'", StringEscape(tbStatus.Name)))
res, err := db.Query(fmt.Sprintf("show table status where name = '%s'", Escape(tbStatus.Name, false)))
if err != nil {
return tbStatus, err
}
......@@ -208,7 +208,7 @@ func (db *Connector) ShowIndex(tableName string) (*TableIndexInfo, error) {
}
// 执行 show create table
res, err := db.Query(fmt.Sprintf("show index from `%s`.`%s`", StringEscape(db.Database), StringEscape(tableName)))
res, err := db.Query(fmt.Sprintf("show index from `%s`.`%s`", Escape(db.Database, false), Escape(tableName, false)))
if err != nil {
return nil, err
}
......@@ -348,7 +348,7 @@ func (db *Connector) ShowColumns(tableName string) (*TableDesc, error) {
tbDesc := NewTableDesc(tableName)
// 执行 show create table
res, err := db.Query(fmt.Sprintf("show full columns from `%s`.`%s`", StringEscape(db.Database), StringEscape(tableName)))
res, err := db.Query(fmt.Sprintf("show full columns from `%s`.`%s`", Escape(db.Database, false), Escape(tableName, false)))
if err != nil {
return nil, err
}
......@@ -408,7 +408,7 @@ func (db *Connector) showCreate(createType, name string) (string, error) {
// SHOW CREATE TABLE tbl_name
// SHOW CREATE TRIGGER trigger_name
// SHOW CREATE VIEW view_name
res, err := db.Query(fmt.Sprintf("SHOW CREATE %s `%s`", createType, StringEscape(name)))
res, err := db.Query(fmt.Sprintf("SHOW CREATE %s `%s`", createType, Escape(name, false)))
if err != nil {
return "", err
}
......@@ -500,16 +500,16 @@ func (db *Connector) FindColumn(name, dbName string, tables ...string) ([]*commo
var columns []*common.Column
sql := fmt.Sprintf("SELECT "+
"c.TABLE_NAME,c.TABLE_SCHEMA,c.COLUMN_TYPE,c.CHARACTER_SET_NAME, c.COLLATION_NAME "+
"FROM `INFORMATION_SCHEMA`.`COLUMNS` as c where c.COLUMN_NAME = '%s' ", StringEscape(name))
"FROM `INFORMATION_SCHEMA`.`COLUMNS` as c where c.COLUMN_NAME = '%s' ", Escape(name, false))
if dbName != "" {
sql += fmt.Sprintf(" and c.table_schema = '%s'", StringEscape(dbName))
sql += fmt.Sprintf(" and c.table_schema = '%s'", Escape(dbName, false))
}
if len(tables) > 0 {
var tmp []string
for _, table := range tables {
tmp = append(tmp, "'"+StringEscape(table)+"'")
tmp = append(tmp, "'"+Escape(table, false)+"'")
}
sql += fmt.Sprintf(" and c.table_name in (%s)", strings.Join(tmp, ","))
}
......@@ -538,7 +538,7 @@ func (db *Connector) FindColumn(name, dbName string, tables ...string) ([]*commo
// 由于 `INFORMATION_SCHEMA`.`TABLES` 表中未找到表的 character,所以从按照 MySQL 中 collation 的规则从中截取 character
sql = fmt.Sprintf("SELECT `t`.`TABLE_COLLATION` FROM `INFORMATION_SCHEMA`.`TABLES` AS `t` "+
"WHERE `t`.`TABLE_NAME`='%s' AND `t`.`TABLE_SCHEMA` = '%s'", StringEscape(col.Table), StringEscape(col.DB))
"WHERE `t`.`TABLE_NAME`='%s' AND `t`.`TABLE_SCHEMA` = '%s'", Escape(col.Table, false), Escape(col.DB, false))
common.Log.Debug("FindColumn, execute SQL: %s", sql)
var newRes QueryResult
......@@ -573,7 +573,7 @@ func (db *Connector) IsForeignKey(dbName, tbName, column string) bool {
"WHERE REFERENCED_TABLE_SCHEMA <> 'NULL' AND"+
" TABLE_NAME='%s' AND"+
" TABLE_SCHEMA='%s' AND"+
" COLUMN_NAME='%s'", StringEscape(tbName), StringEscape(dbName), StringEscape(column))
" COLUMN_NAME='%s'", Escape(tbName, false), Escape(dbName, false), Escape(column, false))
common.Log.Debug("IsForeignKey, execute SQL: %s", sql)
res, err := db.Query(sql)
......@@ -604,11 +604,11 @@ type ReferenceValue struct {
func (db *Connector) ShowReference(dbName string, tbName ...string) ([]ReferenceValue, error) {
var referenceValues []ReferenceValue
sql := `SELECT DISTINCT C.REFERENCED_TABLE_SCHEMA,C.REFERENCED_TABLE_NAME,C.TABLE_SCHEMA,C.TABLE_NAME,C.CONSTRAINT_NAME FROM INFORMATION_SCHEMA.KEY_COLUMN_USAGE C JOIN INFORMATION_SCHEMA. TABLES T ON T.TABLE_NAME = C.TABLE_NAME WHERE C.REFERENCED_TABLE_NAME IS NOT NULL`
sql = sql + fmt.Sprintf(` AND C.TABLE_SCHEMA = "%s"`, StringEscape(dbName))
sql = sql + fmt.Sprintf(` AND C.TABLE_SCHEMA = "%s"`, Escape(dbName, false))
var tables []string
for _, tb := range tbName {
tables = append(tables, "'"+StringEscape(tb)+"'")
tables = append(tables, "'"+Escape(tb, false)+"'")
}
if len(tbName) > 0 {
extra := fmt.Sprintf(` AND C.TABLE_NAME IN ("%s")`, strings.Join(tables, ","))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册