“74691789e9e5ee782adb003642f66699603b20e2”上不存在“paddle/fluid/framework/unused_var_check.h”
提交 59094bf3 编写于 作者: martianzhang's avatar martianzhang

string escape no_backslash_escapes

上级 84a67026
...@@ -17,9 +17,11 @@ ...@@ -17,9 +17,11 @@
package database package database
import ( import (
"bytes"
"database/sql" "database/sql"
"errors" "errors"
"fmt" "fmt"
"io"
"regexp" "regexp"
"strconv" "strconv"
"strings" "strings"
...@@ -39,7 +41,6 @@ type Connector struct { ...@@ -39,7 +41,6 @@ type Connector struct {
Pass string Pass string
Database string Database string
Charset string Charset string
Net string
Conn *sql.DB Conn *sql.DB
} }
...@@ -202,7 +203,11 @@ func (db *Connector) ColumnCardinality(tb, col string) float64 { ...@@ -202,7 +203,11 @@ func (db *Connector) ColumnCardinality(tb, col string) float64 {
} }
// 计算该列散粒度 // 计算该列散粒度
res, err := db.Query(fmt.Sprintf("select count(distinct `%s`) from `%s`.`%s`", StringEscape(col), StringEscape(db.Database), StringEscape(tb))) db.Conn.Stats()
res, err := db.Query(fmt.Sprintf("select count(distinct `%s`) from `%s`.`%s`",
Escape(col, false),
Escape(db.Database, false),
Escape(tb, false)))
if err != nil { if err != nil {
common.Log.Warn("(db *Connector) ColumnCardinality() Query Error: %v", err) common.Log.Warn("(db *Connector) ColumnCardinality() Query Error: %v", err)
return 0 return 0
...@@ -313,19 +318,37 @@ func NullString(buf []byte) string { ...@@ -313,19 +318,37 @@ func NullString(buf []byte) string {
return string(buf) return string(buf)
} }
// StringEscape like C API mysql_escape_string() // quoteEscape sql_mode=no_backslash_escapes
func quoteEscape(source string) string {
var buf bytes.Buffer
last := 0
for ii, bb := range source {
if bb == '\'' {
_, err := io.WriteString(&buf, source[last:ii])
common.LogIfWarn(err, "")
_, err = io.WriteString(&buf, `''`)
common.LogIfWarn(err, "")
last = ii + 1
}
}
_, err := io.WriteString(&buf, source[last:])
common.LogIfWarn(err, "")
return buf.String()
}
// stringEscape mysql_escape_string
// https://github.com/liule/golang_escape // https://github.com/liule/golang_escape
func StringEscape(source string) string { func stringEscape(source string) string {
var j int var j int
if source == "" { if source == "" {
return source return source
} }
tempStr := source[:] tempStr := source[:]
desc := make([]byte, len(tempStr)*2) desc := make([]byte, len(tempStr)*2)
for i := 0; i < len(tempStr); i++ { for i, b := range tempStr {
flag := false flag := false
var escape byte var escape byte
switch tempStr[i] { switch b {
case '\000': case '\000':
flag = true flag = true
escape = '\000' escape = '\000'
...@@ -360,3 +383,13 @@ func StringEscape(source string) string { ...@@ -360,3 +383,13 @@ func StringEscape(source string) string {
} }
return string(desc[0:j]) return string(desc[0:j])
} }
// Escape like C API mysql_escape_string()
func Escape(source string, NoBackslashEscapes bool) string {
// NoBackslashEscapes https://dev.mysql.com/doc/refman/8.0/en/sql-mode.html#sqlmode_no_backslash_escapes
// TODO: NoBackslashEscapes always false
if NoBackslashEscapes {
return quoteEscape(source)
}
return stringEscape(source)
}
...@@ -199,7 +199,7 @@ func TestNullString(t *testing.T) { ...@@ -199,7 +199,7 @@ func TestNullString(t *testing.T) {
common.Log.Debug("Exiting function: %s", common.GetFunctionName()) common.Log.Debug("Exiting function: %s", common.GetFunctionName())
} }
func TestStringEscaple(t *testing.T) { func TestEscape(t *testing.T) {
common.Log.Debug("Entering function: %s", common.GetFunctionName()) common.Log.Debug("Entering function: %s", common.GetFunctionName())
cases := []string{ cases := []string{
"", "",
...@@ -214,7 +214,8 @@ func TestStringEscaple(t *testing.T) { ...@@ -214,7 +214,8 @@ func TestStringEscaple(t *testing.T) {
} }
err := common.GoldenDiff(func() { err := common.GoldenDiff(func() {
for _, str := range cases { for _, str := range cases {
fmt.Println(StringEscape(str)) fmt.Println(Escape(str, false))
fmt.Println(Escape(str, true))
} }
}, t.Name(), update) }, t.Name(), update)
if err != nil { if err != nil {
......
...@@ -96,7 +96,10 @@ func (db *Connector) SamplingData(onlineConn *Connector, tables ...string) error ...@@ -96,7 +96,10 @@ func (db *Connector) SamplingData(onlineConn *Connector, tables ...string) error
// startSampling sampling data from OnlineDSN to TestDSN // startSampling sampling data from OnlineDSN to TestDSN
func (db *Connector) startSampling(onlineConn *sql.DB, database, table string, where string) error { func (db *Connector) startSampling(onlineConn *sql.DB, database, table string, where string) error {
samplingQuery := fmt.Sprintf("select * from `%s`.`%s` %s", StringEscape(database), StringEscape(table), StringEscape(where)) samplingQuery := fmt.Sprintf("select * from `%s`.`%s` %s",
Escape(database, false),
Escape(table, false),
Escape(where, false))
common.Log.Debug("startSampling with Query: %s", samplingQuery) common.Log.Debug("startSampling with Query: %s", samplingQuery)
res, err := onlineConn.Query(samplingQuery) res, err := onlineConn.Query(samplingQuery)
if err != nil { if err != nil {
...@@ -136,8 +139,11 @@ func (db *Connector) startSampling(onlineConn *sql.DB, database, table string, w ...@@ -136,8 +139,11 @@ func (db *Connector) startSampling(onlineConn *sql.DB, database, table string, w
switch columnTypes[i].DatabaseTypeName() { switch columnTypes[i].DatabaseTypeName() {
case "TIMESTAMP", "DATETIME": case "TIMESTAMP", "DATETIME":
t, err := time.Parse(time.RFC3339, string(val)) t, err := time.Parse(time.RFC3339, string(val))
common.LogIfWarn(err, "") if err != nil {
values = append(values, fmt.Sprintf(`"%s"`, string(val)))
} else {
values = append(values, fmt.Sprintf(`"%s"`, TimeString(t))) values = append(values, fmt.Sprintf(`"%s"`, TimeString(t)))
}
default: default:
values = append(values, fmt.Sprintf(`unhex("%s")`, fmt.Sprintf("%x", val))) values = append(values, fmt.Sprintf(`unhex("%s")`, fmt.Sprintf("%x", val)))
} }
...@@ -167,7 +173,10 @@ func (db *Connector) startSampling(onlineConn *sql.DB, database, table string, w ...@@ -167,7 +173,10 @@ func (db *Connector) startSampling(onlineConn *sql.DB, database, table string, w
// 将泵取的数据转换成 insert 语句并在 testConn 数据库中执行 // 将泵取的数据转换成 insert 语句并在 testConn 数据库中执行
func (db *Connector) doSampling(table, colDef, values string) error { func (db *Connector) doSampling(table, colDef, values string) error {
// db.Database is hashed database name // db.Database is hashed database name
query := fmt.Sprintf("insert into `%s`.`%s` (%s) values %s;", StringEscape(db.Database), StringEscape(table), StringEscape(colDef), values) query := fmt.Sprintf("insert into `%s`.`%s` (%s) values %s;",
Escape(db.Database, false),
Escape(table, false),
Escape(colDef, false), values)
res, err := db.Query(query) res, err := db.Query(query)
if res.Rows != nil { if res.Rows != nil {
res.Rows.Close() res.Rows.Close()
......
...@@ -113,7 +113,7 @@ func (db *Connector) ShowTableStatus(tableName string) (*TableStatInfo, error) { ...@@ -113,7 +113,7 @@ func (db *Connector) ShowTableStatus(tableName string) (*TableStatInfo, error) {
tbStatus := newTableStat(tableName) tbStatus := newTableStat(tableName)
// 执行 show table status // 执行 show table status
res, err := db.Query(fmt.Sprintf("show table status where name = '%s'", StringEscape(tbStatus.Name))) res, err := db.Query(fmt.Sprintf("show table status where name = '%s'", Escape(tbStatus.Name, false)))
if err != nil { if err != nil {
return tbStatus, err return tbStatus, err
} }
...@@ -208,7 +208,7 @@ func (db *Connector) ShowIndex(tableName string) (*TableIndexInfo, error) { ...@@ -208,7 +208,7 @@ func (db *Connector) ShowIndex(tableName string) (*TableIndexInfo, error) {
} }
// 执行 show create table // 执行 show create table
res, err := db.Query(fmt.Sprintf("show index from `%s`.`%s`", StringEscape(db.Database), StringEscape(tableName))) res, err := db.Query(fmt.Sprintf("show index from `%s`.`%s`", Escape(db.Database, false), Escape(tableName, false)))
if err != nil { if err != nil {
return nil, err return nil, err
} }
...@@ -348,7 +348,7 @@ func (db *Connector) ShowColumns(tableName string) (*TableDesc, error) { ...@@ -348,7 +348,7 @@ func (db *Connector) ShowColumns(tableName string) (*TableDesc, error) {
tbDesc := NewTableDesc(tableName) tbDesc := NewTableDesc(tableName)
// 执行 show create table // 执行 show create table
res, err := db.Query(fmt.Sprintf("show full columns from `%s`.`%s`", StringEscape(db.Database), StringEscape(tableName))) res, err := db.Query(fmt.Sprintf("show full columns from `%s`.`%s`", Escape(db.Database, false), Escape(tableName, false)))
if err != nil { if err != nil {
return nil, err return nil, err
} }
...@@ -408,7 +408,7 @@ func (db *Connector) showCreate(createType, name string) (string, error) { ...@@ -408,7 +408,7 @@ func (db *Connector) showCreate(createType, name string) (string, error) {
// SHOW CREATE TABLE tbl_name // SHOW CREATE TABLE tbl_name
// SHOW CREATE TRIGGER trigger_name // SHOW CREATE TRIGGER trigger_name
// SHOW CREATE VIEW view_name // SHOW CREATE VIEW view_name
res, err := db.Query(fmt.Sprintf("SHOW CREATE %s `%s`", createType, StringEscape(name))) res, err := db.Query(fmt.Sprintf("SHOW CREATE %s `%s`", createType, Escape(name, false)))
if err != nil { if err != nil {
return "", err return "", err
} }
...@@ -500,16 +500,16 @@ func (db *Connector) FindColumn(name, dbName string, tables ...string) ([]*commo ...@@ -500,16 +500,16 @@ func (db *Connector) FindColumn(name, dbName string, tables ...string) ([]*commo
var columns []*common.Column var columns []*common.Column
sql := fmt.Sprintf("SELECT "+ sql := fmt.Sprintf("SELECT "+
"c.TABLE_NAME,c.TABLE_SCHEMA,c.COLUMN_TYPE,c.CHARACTER_SET_NAME, c.COLLATION_NAME "+ "c.TABLE_NAME,c.TABLE_SCHEMA,c.COLUMN_TYPE,c.CHARACTER_SET_NAME, c.COLLATION_NAME "+
"FROM `INFORMATION_SCHEMA`.`COLUMNS` as c where c.COLUMN_NAME = '%s' ", StringEscape(name)) "FROM `INFORMATION_SCHEMA`.`COLUMNS` as c where c.COLUMN_NAME = '%s' ", Escape(name, false))
if dbName != "" { if dbName != "" {
sql += fmt.Sprintf(" and c.table_schema = '%s'", StringEscape(dbName)) sql += fmt.Sprintf(" and c.table_schema = '%s'", Escape(dbName, false))
} }
if len(tables) > 0 { if len(tables) > 0 {
var tmp []string var tmp []string
for _, table := range tables { for _, table := range tables {
tmp = append(tmp, "'"+StringEscape(table)+"'") tmp = append(tmp, "'"+Escape(table, false)+"'")
} }
sql += fmt.Sprintf(" and c.table_name in (%s)", strings.Join(tmp, ",")) sql += fmt.Sprintf(" and c.table_name in (%s)", strings.Join(tmp, ","))
} }
...@@ -538,7 +538,7 @@ func (db *Connector) FindColumn(name, dbName string, tables ...string) ([]*commo ...@@ -538,7 +538,7 @@ func (db *Connector) FindColumn(name, dbName string, tables ...string) ([]*commo
// 由于 `INFORMATION_SCHEMA`.`TABLES` 表中未找到表的 character,所以从按照 MySQL 中 collation 的规则从中截取 character // 由于 `INFORMATION_SCHEMA`.`TABLES` 表中未找到表的 character,所以从按照 MySQL 中 collation 的规则从中截取 character
sql = fmt.Sprintf("SELECT `t`.`TABLE_COLLATION` FROM `INFORMATION_SCHEMA`.`TABLES` AS `t` "+ sql = fmt.Sprintf("SELECT `t`.`TABLE_COLLATION` FROM `INFORMATION_SCHEMA`.`TABLES` AS `t` "+
"WHERE `t`.`TABLE_NAME`='%s' AND `t`.`TABLE_SCHEMA` = '%s'", StringEscape(col.Table), StringEscape(col.DB)) "WHERE `t`.`TABLE_NAME`='%s' AND `t`.`TABLE_SCHEMA` = '%s'", Escape(col.Table, false), Escape(col.DB, false))
common.Log.Debug("FindColumn, execute SQL: %s", sql) common.Log.Debug("FindColumn, execute SQL: %s", sql)
var newRes QueryResult var newRes QueryResult
...@@ -573,7 +573,7 @@ func (db *Connector) IsForeignKey(dbName, tbName, column string) bool { ...@@ -573,7 +573,7 @@ func (db *Connector) IsForeignKey(dbName, tbName, column string) bool {
"WHERE REFERENCED_TABLE_SCHEMA <> 'NULL' AND"+ "WHERE REFERENCED_TABLE_SCHEMA <> 'NULL' AND"+
" TABLE_NAME='%s' AND"+ " TABLE_NAME='%s' AND"+
" TABLE_SCHEMA='%s' AND"+ " TABLE_SCHEMA='%s' AND"+
" COLUMN_NAME='%s'", StringEscape(tbName), StringEscape(dbName), StringEscape(column)) " COLUMN_NAME='%s'", Escape(tbName, false), Escape(dbName, false), Escape(column, false))
common.Log.Debug("IsForeignKey, execute SQL: %s", sql) common.Log.Debug("IsForeignKey, execute SQL: %s", sql)
res, err := db.Query(sql) res, err := db.Query(sql)
...@@ -604,11 +604,11 @@ type ReferenceValue struct { ...@@ -604,11 +604,11 @@ type ReferenceValue struct {
func (db *Connector) ShowReference(dbName string, tbName ...string) ([]ReferenceValue, error) { func (db *Connector) ShowReference(dbName string, tbName ...string) ([]ReferenceValue, error) {
var referenceValues []ReferenceValue var referenceValues []ReferenceValue
sql := `SELECT DISTINCT C.REFERENCED_TABLE_SCHEMA,C.REFERENCED_TABLE_NAME,C.TABLE_SCHEMA,C.TABLE_NAME,C.CONSTRAINT_NAME FROM INFORMATION_SCHEMA.KEY_COLUMN_USAGE C JOIN INFORMATION_SCHEMA. TABLES T ON T.TABLE_NAME = C.TABLE_NAME WHERE C.REFERENCED_TABLE_NAME IS NOT NULL` sql := `SELECT DISTINCT C.REFERENCED_TABLE_SCHEMA,C.REFERENCED_TABLE_NAME,C.TABLE_SCHEMA,C.TABLE_NAME,C.CONSTRAINT_NAME FROM INFORMATION_SCHEMA.KEY_COLUMN_USAGE C JOIN INFORMATION_SCHEMA. TABLES T ON T.TABLE_NAME = C.TABLE_NAME WHERE C.REFERENCED_TABLE_NAME IS NOT NULL`
sql = sql + fmt.Sprintf(` AND C.TABLE_SCHEMA = "%s"`, StringEscape(dbName)) sql = sql + fmt.Sprintf(` AND C.TABLE_SCHEMA = "%s"`, Escape(dbName, false))
var tables []string var tables []string
for _, tb := range tbName { for _, tb := range tbName {
tables = append(tables, "'"+StringEscape(tb)+"'") tables = append(tables, "'"+Escape(tb, false)+"'")
} }
if len(tbName) > 0 { if len(tbName) > 0 {
extra := fmt.Sprintf(` AND C.TABLE_NAME IN ("%s")`, strings.Join(tables, ",")) extra := fmt.Sprintf(` AND C.TABLE_NAME IN ("%s")`, strings.Join(tables, ","))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册