diff --git a/Makefile b/Makefile index 885f2350505e40f5162c681d1be6bb9158e22cd7..5e6ce8fd650706eaba9cfaec3524b7d3db250ff3 100644 --- a/Makefile +++ b/Makefile @@ -181,7 +181,7 @@ docker: .PHONY: connect connect: - mysql -h 127.0.0.1 -u root -p1tIsB1g3rt -c + mysql -h 127.0.0.1 -u root -p1tIsB1g3rt sakila -c .PHONY: main_test main_test: install diff --git a/advisor/heuristic.go b/advisor/heuristic.go index 1ccab69d1bdcd092f87bbfe19e15bdd8134bee1a..c8feba434285e91c5fadcc80b02fef1298d91861 100644 --- a/advisor/heuristic.go +++ b/advisor/heuristic.go @@ -1990,7 +1990,7 @@ func (idxAdv *IndexAdvisor) RuleUpdatePrimaryKey() Rule { if idxMeta == nil { return rule } - for _, idx := range idxMeta.IdxRows { + for _, idx := range idxMeta.Rows { if idx.KeyName == "PRIMARY" { if col.Name == idx.ColumnName { rule = HeuristicRules["CLA.016"] diff --git a/advisor/index.go b/advisor/index.go index 0c00818eddc5ef96b2e0dc6018f028a4fc75a129..9a2ffc4fedcffe2d1c1c4c1efbc3ee8899876516 100644 --- a/advisor/index.go +++ b/advisor/index.go @@ -914,7 +914,7 @@ func (idxAdv *IndexAdvisor) calcCardinality(cols []*common.Column) []*common.Col // 检查对应列是否为主键或单列唯一索引,如果满足直接返回1,不再重复计算,提高效率 // 多列复合唯一索引不能跳过计算,单列普通索引不能跳过计算 - for _, index := range idxAdv.IndexMeta[realDB][col.Table].IdxRows { + for _, index := range idxAdv.IndexMeta[realDB][col.Table].Rows { // 根据索引的名称判断该索引包含的列数,列数大于1即为复合索引 columnCount := len(idxAdv.IndexMeta[realDB][col.Table].FindIndex(database.IndexKeyName, index.KeyName)) if col.Name == index.ColumnName { @@ -1079,7 +1079,7 @@ func DuplicateKeyChecker(conn *database.Connector, databases ...string) map[stri } // 枚举所有的索引信息,提取用到的列 - for _, idx := range idxInfo.IdxRows { + for _, idx := range idxInfo.Rows { if _, ok := idxMap[idx.KeyName]; !ok { idxMap[idx.KeyName] = make([]*common.Column, 0) for _, col := range idxInfo.FindIndex(database.IndexKeyName, idx.KeyName) { diff --git a/advisor/rules.go b/advisor/rules.go index dc553b424ccb1ede4a00ee0022fb255cbdbce362..404e84518a6a4e4ad3aa30f8ca884f66ac88c3df 100644 --- a/advisor/rules.go +++ b/advisor/rules.go @@ -102,7 +102,7 @@ type Rule struct { * SEC Security * STA Standard * SUB Subquery -* TBL Table +* TBL TableName * TRA Trace, 由trace模块给 */ diff --git a/common/cases.go b/common/cases.go index 25763f7d831b9aa29ec9adc107fa0eb3e1fbdfc9..b49e9b76cf64983b8a69002a45989c68643a8e09 100644 --- a/common/cases.go +++ b/common/cases.go @@ -67,7 +67,7 @@ func init() { "SELECT * FROM customer WHERE address_id in (224,510) ORDER BY last_name;", // INDEX(address_id) "SELECT * FROM film WHERE release_year = 2016 AND length != 1 ORDER BY title;", // INDEX(`release_year`, `length`, `title`) - // "Covering" IdxRows + // "Covering" Rows "SELECT title FROM film WHERE release_year = 1995;", // INDEX(release_year, title)", "SELECT title, replacement_cost FROM film WHERE language_id = 5 AND length = 70;", // INDEX(language_id, length, title, replacement_cos film ), title, replacement_cost顺序无关,language_id, length顺序视散粒度情况. "SELECT title FROM film WHERE language_id > 5 AND length > 70;", // INDEX(language_id, length, title) language_id or length first (that's as far as the Algorithm goes), then the other two fields afterwards. diff --git a/common/config.go b/common/config.go index 4f84116c407b4870c96943b1559912f4354ad480..cdfacb2c7c4e5d29a69c93e8cd8bac70e27eb1fb 100644 --- a/common/config.go +++ b/common/config.go @@ -59,8 +59,6 @@ type Configuration struct { Profiling bool `yaml:"profiling"` // 在开启数据采样的情况下,在测试环境执行进行profile Trace bool `yaml:"trace"` // 在开启数据采样的情况下,在测试环境执行进行Trace Explain bool `yaml:"explain"` // Explain开关 - ConnTimeOut int `yaml:"conn-time-out"` // 数据库连接超时时间,单位秒 - QueryTimeOut int `yaml:"query-time-out"` // 数据库SQL执行超时时间,单位秒 Delimiter string `yaml:"delimiter"` // SQL分隔符 // +++++++++++++++日志相关+++++++++++++++++ @@ -97,8 +95,8 @@ type Configuration struct { MaxInCount int `yaml:"max-in-count"` // IN()最大数量 MaxIdxBytesPerColumn int `yaml:"max-index-bytes-percolumn"` // 索引中单列最大字节数,默认767 MaxIdxBytes int `yaml:"max-index-bytes"` // 索引总长度限制,默认3072 - TableAllowCharsets []string `yaml:"table-allow-charsets"` // Table 允许使用的 DEFAULT CHARSET - TableAllowEngines []string `yaml:"table-allow-engines"` // Table 允许使用的 Engine + TableAllowCharsets []string `yaml:"table-allow-charsets"` // TableName 允许使用的 DEFAULT CHARSET + TableAllowEngines []string `yaml:"table-allow-engines"` // TableName 允许使用的 Engine MaxIdxCount int `yaml:"max-index-count"` // 单张表允许最多索引数 MaxColCount int `yaml:"max-column-count"` // 单张表允许最大列数 MaxValueCount int `yaml:"max-value-count"` // INSERT/REPLACE 单次允许批量写入的行数 @@ -135,12 +133,14 @@ type Configuration struct { // Config 默认设置 var Config = &Configuration{ OnlineDSN: &dsn{ + Net: "tcp", Schema: "information_schema", Charset: "utf8mb4", Disable: true, Version: 99999, }, TestDSN: &dsn{ + Net: "tcp", Schema: "information_schema", Charset: "utf8mb4", Disable: true, @@ -156,8 +156,6 @@ var Config = &Configuration{ Profiling: false, Trace: false, Explain: true, - ConnTimeOut: 3, - QueryTimeOut: 30, Delimiter: ";", MaxJoinTableCount: 5, @@ -227,6 +225,7 @@ var Config = &Configuration{ } type dsn struct { + Net string `yaml:"net"` Addr string `yaml:"addr"` Schema string `yaml:"schema"` @@ -236,6 +235,10 @@ type dsn struct { Charset string `yaml:"charset"` Disable bool `yaml:"disable"` + Timeout int `yaml:"timeout"` + ReadTimeout int `yaml:"read-timeout"` + WriteTimeout int `yaml:"write-timeout"` + Version int `yaml:"-"` // 版本自动检查,不可配置 } @@ -502,8 +505,6 @@ func readCmdFlags() error { explain := flag.Bool("explain", Config.Explain, "Explain, 是否开启Explain执行计划分析") sampling := flag.Bool("sampling", Config.Sampling, "Sampling, 数据采样开关") samplingStatisticTarget := flag.Int("sampling-statistic-target", Config.SamplingStatisticTarget, "SamplingStatisticTarget, 数据采样因子,对应 PostgreSQL 的 default_statistics_target") - connTimeOut := flag.Int("conn-time-out", Config.ConnTimeOut, "ConnTimeOut, 数据库连接超时时间,单位秒") - queryTimeOut := flag.Int("query-time-out", Config.QueryTimeOut, "QueryTimeOut, 数据库SQL执行超时时间,单位秒") delimiter := flag.String("delimiter", Config.Delimiter, "Delimiter, SQL分隔符") // +++++++++++++++日志相关+++++++++++++++++ logLevel := flag.Int("log-level", Config.LogLevel, "LogLevel, 日志级别, [0:Emergency, 1:Alert, 2:Critical, 3:Error, 4:Warning, 5:Notice, 6:Informational, 7:Debug]") @@ -583,8 +584,6 @@ func readCmdFlags() error { Config.Explain = *explain Config.Sampling = *sampling Config.SamplingStatisticTarget = *samplingStatisticTarget - Config.ConnTimeOut = *connTimeOut - Config.QueryTimeOut = *queryTimeOut Config.LogLevel = *logLevel if strings.HasPrefix(*logOutput, "/") { diff --git a/common/meta.go b/common/meta.go index 433704eff8c0059df662a48542a877a2165707e1..e95bb37d91351fa100bb67fe66c4c50ed96fe043 100644 --- a/common/meta.go +++ b/common/meta.go @@ -27,7 +27,7 @@ type Meta map[string]*DB // DB 数据库相关的结构体 type DB struct { Name string - Table map[string]*Table // ['table_name']*Table + Table map[string]*Table // ['table_name']*TableName } // NewDB 用于初始化*DB @@ -38,14 +38,14 @@ func NewDB(db string) *DB { } } -// Table 含有表的属性 +// TableName 含有表的属性 type Table struct { TableName string TableAliases []string Column map[string]*Column } -// NewTable 初始化*Table +// NewTable 初始化*TableName func NewTable(tb string) *Table { return &Table{ TableName: tb, diff --git a/common/testdata/TestMarkdown2Html.golden b/common/testdata/TestMarkdown2Html.golden index 8e3b056be373d08649288e305d11e07e528e68b4..6c7f231029daa9daf28080fc16a4e4ec6a025385 100644 --- a/common/testdata/TestMarkdown2Html.golden +++ b/common/testdata/TestMarkdown2Html.golden @@ -176,9 +176,9 @@ $$

Typora support YAML Front Matter now. Input --- at the top of the article and then press Enter will introduce one. Or insert one metadata block from the menu.

-

Table of Contents (TOC)

+

TableName of Contents (TOC)

-

Input [toc] then press Return key will create a section for “Table of Contents” extracting all headers from one’s writing, its contents will be updated automatically.

+

Input [toc] then press Return key will create a section for “TableName of Contents” extracting all headers from one’s writing, its contents will be updated automatically.

Diagrams (Sequence, Flowchart and Mermaid)

diff --git a/common/testdata/TestMarkdown2Html.md b/common/testdata/TestMarkdown2Html.md index 654b7f3f760d63f436ac81561ec7a7d24ad95a7e..1ab82ea6e490f6f052a1346374eead46c111f2be 100644 --- a/common/testdata/TestMarkdown2Html.md +++ b/common/testdata/TestMarkdown2Html.md @@ -186,9 +186,9 @@ Input `***` or `---` on a blank line and press `return` will draw a horizontal l Typora support [YAML Front Matter](http://jekyllrb.com/docs/frontmatter/) now. Input `---` at the top of the article and then press `Enter` will introduce one. Or insert one metadata block from the menu. -### Table of Contents (TOC) +### TableName of Contents (TOC) -Input `[toc]` then press `Return` key will create a section for “Table of Contents” extracting all headers from one’s writing, its contents will be updated automatically. +Input `[toc]` then press `Return` key will create a section for “TableName of Contents” extracting all headers from one’s writing, its contents will be updated automatically. ### Diagrams (Sequence, Flowchart and Mermaid) diff --git a/database/explain.go b/database/explain.go index 0e17e6479ca31dab4331487e06a6d9834392ed50..1a0d8173c979635465a1f2fd19060ff964a9a124 100644 --- a/database/explain.go +++ b/database/explain.go @@ -267,6 +267,7 @@ var ExplainKeyWords = []string{ "using_temporary_table", } +/* // ExplainColumnIndent EXPLAIN表头 var ExplainColumnIndent = map[string]string{ "id": "id为SELECT的标识符. 它是在SELECT查询中的顺序编号. 如果这一行表示其他行的union结果, 这个值可以为空. 在这种情况下, table列会显示为形如, 表示它是id为M和N的查询行的联合结果.", @@ -281,6 +282,7 @@ var ExplainColumnIndent = map[string]string{ "filtered": "表示返回结果的行占需要读到的行(rows列的值)的百分比.", "Extra": "该列显示MySQL在查询过程中的一些详细信息, MySQL查询优化器执行查询的过程中对查询计划的重要补充信息.", } +*/ // ExplainSelectType EXPLAIN中SELECT TYPE会出现的类型 var ExplainSelectType = map[string]string{ @@ -555,14 +557,16 @@ func (db *Connector) explainAbleSQL(sql string) (string, error) { } // 执行explain请求,返回mysql.Result执行结果 -func (db *Connector) executeExplain(sql string, explainType int, formatType int) (*QueryResult, error) { +func (db *Connector) executeExplain(sql string, explainType int, formatType int) (QueryResult, error) { + var res QueryResult var err error + var explainQuery string sql, err = db.explainAbleSQL(sql) if sql == "" { - return nil, err + return res, err } - // 5.6以上支持FORMAT=JSON + // 5.6以上支持 FORMAT=JSON explainFormat := "" switch formatType { case JSONFormatExplain: @@ -570,22 +574,23 @@ func (db *Connector) executeExplain(sql string, explainType int, formatType int) explainFormat = "FORMAT=JSON" } } - // 执行explain - var res *QueryResult + + // 执行 explain switch explainType { case ExtendedExplainType: // 5.6以上extended关键字已经不推荐使用,8.0废弃了这个关键字 if common.Config.TestDSN.Version >= 50600 { - res, err = db.Query("explain %s", sql) + explainQuery = fmt.Sprintf("explain %s", sql) } else { - res, err = db.Query("explain extended %s", sql) + explainQuery = fmt.Sprintf("explain extended %s", sql) } case PartitionsExplainType: - res, err = db.Query("explain partitions %s", sql) + explainQuery = fmt.Sprintf("explain partitions %s", sql) default: - res, err = db.Query("explain %s %s", explainFormat, sql) + explainQuery = fmt.Sprintf("explain %s %s", explainFormat, sql) } + res, err = db.Query(explainQuery) return res, err } @@ -928,76 +933,75 @@ func parseJSONExplainText(content string) (*ExplainJSON, error) { } // ParseExplainResult 分析 mysql 执行 explain 的结果,返回 ExplainInfo 结构化数据 -func ParseExplainResult(res *QueryResult, formatType int) (exp *ExplainInfo, err error) { +func ParseExplainResult(res QueryResult, formatType int) (exp *ExplainInfo, err error) { exp = &ExplainInfo{ ExplainFormat: formatType, } // JSON 格式直接调用文本方式解析 if formatType == JSONFormatExplain { - exp.ExplainJSON, err = parseJSONExplainText(res.Rows[0].Str(0)) + if res.Rows.Next() { + var explainString string + err = res.Rows.Scan(&explainString) + exp.ExplainJSON, err = parseJSONExplainText(explainString) + } return exp, err } - // 生成表头 - colIdx := make(map[int]string) - for i, f := range res.Result.Fields() { - colIdx[i] = strings.ToLower(f.Name) + // Different MySQL version has different columns define + var possibleKeys string + expRow := &ExplainRow{} + explainFields := make([]interface{}, 0) + fields := map[string]interface{}{ + "id": expRow.ID, + "select_type": expRow.SelectType, + "table": expRow.TableName, + "partitions": expRow.Partitions, + "type": expRow.AccessType, + "possible_keys": &possibleKeys, + "key": expRow.Key, + "key_len": expRow.KeyLen, + "ref": expRow.Ref, + "rows": expRow.Rows, + "filtered": expRow.Filtered, + "extra": expRow.Extra, } + cols, err := res.Rows.Columns() + common.LogIfError(err, "") + for _, col := range cols { + explainFields = append(explainFields, fields[col]) + } + // 补全 ExplainRows - var explainrows []*ExplainRow - for _, row := range res.Rows { - expRow := &ExplainRow{Partitions: "NULL", Filtered: 0.00} - // list 到 map 的转换 - for i := range row { - switch colIdx[i] { - case "id": - expRow.ID = row.ForceInt(i) - case "select_type": - expRow.SelectType = row.Str(i) - case "table": - expRow.TableName = row.Str(i) - if expRow.TableName == "" { - expRow.TableName = "NULL" - } - case "type": - expRow.AccessType = row.Str(i) - if expRow.AccessType == "" { - expRow.AccessType = "NULL" - } - expRow.Scalability = ExplainScalability[expRow.AccessType] - case "possible_keys": - expRow.PossibleKeys = strings.Split(row.Str(i), ",") - case "key": - expRow.Key = row.Str(i) - if expRow.Key == "" { - expRow.Key = "NULL" - } - case "key_len": - expRow.KeyLen = row.Str(i) - case "ref": - expRow.Ref = strings.Split(row.Str(i), ",") - case "rows": - expRow.Rows = row.ForceInt(i) - case "extra": - expRow.Extra = row.Str(i) - if expRow.Extra == "" { - expRow.Extra = "NULL" - } - case "filtered": - expRow.Filtered = row.ForceFloat(i) - // MySQL bug: https://bugs.mysql.com/bug.php?id=34124 - if expRow.Filtered > 100.00 { - expRow.Filtered = 100.00 - } - } + var explainRows []*ExplainRow + + for res.Rows.Next() { + res.Rows.Scan(explainFields...) + expRow.PossibleKeys = strings.Split(possibleKeys, ",") + + // MySQL bug: https://bugs.mysql.com/bug.php?id=34124 + if expRow.Filtered > 100.00 { + expRow.Filtered = 100.00 } - explainrows = append(explainrows, expRow) + + expRow.Scalability = ExplainScalability[expRow.AccessType] + explainRows = append(explainRows, expRow) } - exp.ExplainRows = explainrows - for _, w := range res.Warning { - // 'EXTENDED' is deprecated and will be removed in a future release. - if w.Int(1) != 1681 { - exp.Warnings = append(exp.Warnings, &ExplainWarning{Level: w.Str(0), Code: w.Int(1), Message: w.Str(2)}) + exp.ExplainRows = explainRows + + // check explain warning info + if common.Config.ShowWarnings { + for res.Warning.Next() { + var expWarning *ExplainWarning + res.Warning.Scan( + expWarning.Level, + expWarning.Code, + expWarning.Message, + ) + + // 'EXTENDED' is deprecated and will be removed in a future release. + if expWarning.Code != 1681 { + exp.Warnings = append(exp.Warnings, expWarning) + } } } @@ -1009,7 +1013,6 @@ func ParseExplainResult(res *QueryResult, formatType int) (exp *ExplainInfo, err // Explain 获取 SQL 的 explain 信息 func (db *Connector) Explain(sql string, explainType int, formatType int) (exp *ExplainInfo, err error) { - exp = &ExplainInfo{SQL: sql} if explainType != TraditionalExplainType { formatType = TraditionalFormatExplain } @@ -1025,12 +1028,16 @@ func (db *Connector) Explain(sql string, explainType int, formatType int) (exp * // 执行EXPLAIN请求 res, err := db.executeExplain(sql, explainType, formatType) - if err != nil || res == nil { + if err != nil { return exp, err } + if res.Error != nil { + return exp, res.Error + } // 解析mysql结果,输出ExplainInfo exp, err = ParseExplainResult(res, formatType) + exp.SQL = sql return exp, err } diff --git a/database/explain_test.go b/database/explain_test.go index 01ef50f923322d55c2fb48bdf1d12207e39537b8..9366a18ac3afa1cb864100d9b308207aca021e28 100644 --- a/database/explain_test.go +++ b/database/explain_test.go @@ -17,8 +17,6 @@ package database import ( - "fmt" - "os" "testing" "github.com/XiaoMi/soar/common" @@ -26,23 +24,6 @@ import ( "github.com/kr/pretty" ) -var connTest *Connector - -func init() { - common.BaseDir = common.DevPath - common.ParseConfig("") - connTest = &Connector{ - Addr: common.Config.OnlineDSN.Addr, - User: common.Config.OnlineDSN.User, - Pass: common.Config.OnlineDSN.Password, - Database: common.Config.OnlineDSN.Schema, - } - if _, err := connTest.Version(); err != nil { - common.Log.Critical("Test env Error: %v", err) - os.Exit(0) - } -} - var sqls = []string{ `select * from city where country_id = 44;`, `select * from address where address2 is not null;`, @@ -54,10 +35,10 @@ var sqls = []string{ `select * from city where country_id > 31 and city = 'Aden';`, `select * from address where address_id > 8 and city_id < 400 and district = 'Nantou';`, `select * from address where address_id > 8 and city_id < 400;`, - `select * from actor where last_update='2006-02-15 04:34:33' and last_name='CHASE' group by first_name;`, - `select * from address where last_update >='2014-09-25 22:33:47' group by district;`, - `select * from address group by address,district;`, - `select * from address where last_update='2014-09-25 22:30:27' group by district,(address_id+city_id);`, + `select first_name from actor where last_update='2006-02-15 04:34:33' and last_name='CHASE' group by first_name;`, + `select district from address where last_update >='2014-09-25 22:33:47' group by district;`, + `select address from address group by address,district;`, + `select district from address where last_update='2014-09-25 22:30:27' group by district,(address_id+city_id);`, `select * from customer where active=1 order by last_name limit 10;`, `select * from customer order by last_name limit 10;`, `select * from customer where address_id > 224 order by address_id limit 10;`, @@ -2351,16 +2332,23 @@ possible_keys: idx_fk_country_id,idx_country_id_city,idx_all,idx_other } func TestExplain(t *testing.T) { - for _, sql := range sqls { + // TraditionalFormatExplain + for idx, sql := range sqls { exp, err := connTest.Explain(sql, TraditionalExplainType, TraditionalFormatExplain) - //exp, err := conn.Explain(sql, TraditionalExplainType, JSONFormatExplain) - fmt.Println("Old: ", sql) - fmt.Println("New: ", exp.SQL) if err != nil { - fmt.Println(err) + t.Error(err) } + pretty.Println("No.:", idx, "\nOld: ", sql, "\nNew: ", exp.SQL) + pretty.Println(exp) + } + // JSONFormatExplain + for idx, sql := range sqls { + exp, err := connTest.Explain(sql, TraditionalExplainType, JSONFormatExplain) + if err != nil { + t.Error(err) + } + pretty.Println("No.:", idx, "\nOld: ", sql, "\nNew: ", exp.SQL) pretty.Println(exp) - fmt.Println() } } @@ -2400,6 +2388,7 @@ func TestPrintMarkdownExplainTable(t *testing.T) { if err != nil { t.Error(err) } + err = common.GoldenDiff(func() { PrintMarkdownExplainTable(expInfo) }, t.Name(), update) diff --git a/database/mysql.go b/database/mysql.go index fa6927d46d17252829fc5e594f7b35d3beb0c843..68aa971e80f121b6168a46bb2438f0feafd3b88b 100644 --- a/database/mysql.go +++ b/database/mysql.go @@ -17,21 +17,17 @@ package database import ( + "database/sql" "errors" "fmt" - "io/ioutil" - "os" "regexp" "strconv" "strings" - "time" - "github.com/XiaoMi/soar/ast" "github.com/XiaoMi/soar/common" - "github.com/ziutek/mymysql/mysql" - // mymysql driver - _ "github.com/ziutek/mymysql/native" + // for database/sql + _ "github.com/go-sql-driver/mysql" "vitess.io/vitess/go/vt/sqlparser" ) @@ -42,81 +38,62 @@ type Connector struct { Pass string Database string Charset string + Net string } // QueryResult 数据库查询返回值 type QueryResult struct { - Rows []mysql.Row - Result mysql.Result + Rows *sql.Rows Error error - Warning []mysql.Row + Warning *sql.Rows QueryCost float64 } // NewConnection 创建新连接 -func (db *Connector) NewConnection() mysql.Conn { - return mysql.New("tcp", "", db.Addr, db.User, db.Pass, db.Database) +func (db *Connector) NewConnection() (*sql.DB, error) { + dsn := fmt.Sprintf("%s:%s@%s(%s)/%s?parseTime=true&charset=%s", db.User, db.Pass, db.Net, db.Addr, db.Database, db.Charset) + return sql.Open("mysql", dsn) } // Query 执行SQL -func (db *Connector) Query(sql string, params ...interface{}) (*QueryResult, error) { +func (db *Connector) Query(sql string, params ...interface{}) (QueryResult, error) { + var res QueryResult // 测试环境如果检查是关闭的,则SQL不会被执行 if common.Config.TestDSN.Disable { - return nil, errors.New("Dsn Disable") + return res, errors.New("dsn is disable") } // 数据库安全性检查:如果 Connector 的 IP 端口与 TEST 环境不一致,则启用SQL白名单 // 不在白名单中的SQL不允许执行 // 执行环境与test环境不相同 if db.Addr != common.Config.TestDSN.Addr && db.dangerousQuery(sql) { - return nil, fmt.Errorf("query execution deny: execute SQL with DSN(%s/%s) '%s'", + return res, fmt.Errorf("query execution deny: execute SQL with DSN(%s/%s) '%s'", db.Addr, db.Database, fmt.Sprintf(sql, params...)) } common.Log.Debug("Execute SQL with DSN(%s/%s) : %s", db.Addr, db.Database, fmt.Sprintf(sql, params...)) - conn := db.NewConnection() - - // 设置SQL连接超时时间 - conn.SetTimeout(time.Duration(common.Config.ConnTimeOut) * time.Second) + conn, err := db.NewConnection() defer conn.Close() - err := conn.Connect() if err != nil { - return nil, err + return res, err } + res.Rows, res.Error = conn.Query(sql, params...) - // 添加SQL执行超时限制 - ch := make(chan QueryResult, 1) - go func() { - res := QueryResult{} - res.Rows, res.Result, res.Error = conn.Query(sql, params...) - - if common.Config.ShowWarnings { - warning, _, err := conn.Query("SHOW WARNINGS") - if err == nil { - res.Warning = warning - } - } + if common.Config.ShowWarnings { + res.Warning, err = conn.Query("SHOW WARNINGS") + } - // SHOW WARNINGS 并不会影响 last_query_cost - if common.Config.ShowLastQueryCost { - cost, _, err := conn.Query("SHOW SESSION STATUS LIKE 'last_query_cost'") - if err == nil { - if len(cost) > 0 { - res.QueryCost = cost[0].Float(1) - } + // SHOW WARNINGS 并不会影响 last_query_cost + if common.Config.ShowLastQueryCost { + cost, err := conn.Query("SHOW SESSION STATUS LIKE 'last_query_cost'") + if err == nil { + if cost.Next() { + err = cost.Scan(res.QueryCost) } } - - ch <- res - }() - - select { - case res := <-ch: - return &res, res.Error - case <-time.After(time.Duration(common.Config.QueryTimeOut) * time.Second): - return nil, errors.New("query execution timeout") } + return res, err } // Version 获取MySQL数据库版本 @@ -124,77 +101,43 @@ func (db *Connector) Version() (int, error) { version := 99999 // 从数据库中获取版本信息 res, err := db.Query("select @@version") - if err != nil { - common.Log.Warn("(db *Connector) Version() Error: %v", err) + if err != nil || res.Error != nil { + common.Log.Warn("(db *Connector) Version() Error: %v, MySQL Error: %v", err, res.Error) return version, err } // MariaDB https://mariadb.com/kb/en/library/comment-syntax/ // MySQL https://dev.mysql.com/doc/refman/8.0/en/comments.html - versionStr := strings.Split(res.Rows[0].Str(0), "-")[0] - versionSeg := strings.Split(versionStr, ".") - if len(versionSeg) == 3 { - versionStr = fmt.Sprintf("%s%02s%02s", versionSeg[0], versionSeg[1], versionSeg[2]) - version, err = strconv.Atoi(versionStr) - } - return version, err -} - -// Source execute sql from file -func (db *Connector) Source(file string) ([]*QueryResult, error) { - var sqlCounter int // SQL 计数器 - var result []*QueryResult - - fd, err := os.Open(file) - defer func() { - err = fd.Close() - if err != nil { - common.Log.Error("(db *Connector) Source(%s) fd.Close failed: %s", file, err.Error()) + var versionStr string + var versionSeg []string + for res.Rows.Next() { + err = res.Rows.Scan(&versionStr) + versionStr = strings.Split(versionStr, "-")[0] + versionSeg = strings.Split(versionStr, ".") + if len(versionSeg) == 3 { + versionStr = fmt.Sprintf("%s%02s%02s", versionSeg[0], versionSeg[1], versionSeg[2]) + version, err = strconv.Atoi(versionStr) } - }() - if err != nil { - common.Log.Warning("(db *Connector) Source(%s) os.Open failed: %s", file, err.Error()) - return nil, err - } - data, err := ioutil.ReadAll(fd) - if err != nil { - common.Log.Critical("ioutil.ReadAll Error: %s", err.Error()) - return nil, err + break } - sql := strings.TrimSpace(string(data)) - buf := strings.TrimSpace(sql) - for ; ; sqlCounter++ { - if buf == "" { - break - } - - // 查询请求切分 - _, sql, bufBytes := ast.SplitStatement([]byte(buf), []byte(common.Config.Delimiter)) - buf = string(bufBytes) - sql = strings.TrimSpace(sql) - common.Log.Debug("Source Query SQL: %s", sql) - - res, e := db.Query(sql) - if e != nil { - common.Log.Error("(db *Connector) Source Filename: %s, SQLCounter.: %d", file, sqlCounter) - return result, e - } - result = append(result, res) - } - return result, nil + return version, err } // SingleIntValue 获取某个int型变量的值 func (db *Connector) SingleIntValue(option string) (int, error) { // 从数据库中获取信息 - res, err := db.Query("select @@%s", option) + res, err := db.Query("select @@" + option) if err != nil { common.Log.Warn("(db *Connector) SingleIntValue() Error: %v", err) return -1, err } - return res.Rows[0].Int(0), err + var intVal int + if res.Rows.Next() { + err = res.Rows.Scan(&intVal) + } + return intVal, err } // ColumnCardinality 粒度计算 @@ -228,13 +171,20 @@ func (db *Connector) ColumnCardinality(tb, col string) float64 { } // 计算该列散粒度 - res, err := db.Query("select count(distinct `%s`) from `%s`.`%s`", col, db.Database, tb) + res, err := db.Query(fmt.Sprintf("select count(distinct `%s`) from `%s`.`%s`", col, db.Database, tb)) if err != nil { common.Log.Warn("(db *Connector) ColumnCardinality() Query Error: %v", err) return 0 } - colNum := res.Rows[0].Float(0) + var colNum float64 + if res.Rows.Next() { + err = res.Rows.Scan(&colNum) + if err != nil { + common.Log.Warn("(db *Connector) ColumnCardinality() Query Error: %v", err) + return 0 + } + } // 散粒度区间:[0,1] return colNum / float64(rowTotal) @@ -249,13 +199,12 @@ func (db *Connector) IsView(tbName string) bool { } if len(tbStatus.Rows) > 0 { - if tbStatus.Rows[0].Comment == "VIEW" { + if string(tbStatus.Rows[0].Comment) == "VIEW" { return true } } return false - } // RemoveSQLComments 去除SQL中的注释 @@ -281,7 +230,7 @@ func (db *Connector) dangerousQuery(query string) bool { return true } - for _, sql := range queries { + for _, query := range queries { dangerous := true whiteList := []string{ "select", @@ -291,7 +240,7 @@ func (db *Connector) dangerousQuery(query string) bool { } for _, prefix := range whiteList { - if strings.HasPrefix(sql, prefix) { + if strings.HasPrefix(query, prefix) { dangerous = false break } diff --git a/database/mysql_test.go b/database/mysql_test.go index 79ea7bdd2a84c6254086dbdeb665d08bd528391e..f3239a5e5bbe7de0cd5e6f13fc2bf78071752b79 100644 --- a/database/mysql_test.go +++ b/database/mysql_test.go @@ -17,26 +17,70 @@ package database import ( + "flag" "fmt" + "os" "testing" "github.com/XiaoMi/soar/common" + "github.com/kr/pretty" ) +var connTest *Connector + +var update = flag.Bool("update", false, "update .golden files") + +func init() { + common.BaseDir = common.DevPath + common.ParseConfig("") + connTest = &Connector{ + Addr: common.Config.OnlineDSN.Addr, + User: common.Config.OnlineDSN.User, + Pass: common.Config.OnlineDSN.Password, + Database: common.Config.OnlineDSN.Schema, + Charset: common.Config.OnlineDSN.Charset, + } + if _, err := connTest.Version(); err != nil { + common.Log.Critical("Test env Error: %v", err) + os.Exit(0) + } +} + +func TestNewConnection(t *testing.T) { + _, err := connTest.NewConnection() + if err != nil { + t.Errorf("TestNewConnection, Error: %s", err.Error()) + } +} + // TODO: go test -race不通过待解决 func TestQuery(t *testing.T) { - common.Config.QueryTimeOut = 1 - _, err := connTest.Query("select sleep(2)") - if err == nil { - t.Error("connTest.Query not timeout") + res, err := connTest.Query("select 0") + if err != nil { + t.Error(err.Error()) + } + for res.Rows.Next() { + var val int + err = res.Rows.Scan(&val) + if err != nil { + t.Error(err.Error()) + } + if val != 0 { + t.Error("should return 0") + } } + // TODO: timeout test } -func TestColumnCardinality(_ *testing.T) { - connTest.Database = "information_schema" - a := connTest.ColumnCardinality("TABLES", "TABLE_SCHEMA") - fmt.Println("TABLES.TABLE_SCHEMA:", a) +func TestColumnCardinality(t *testing.T) { + orgDatabase := connTest.Database + connTest.Database = "sakila" + a := connTest.ColumnCardinality("actor", "first_name") + if a >= 1 || a <= 0 { + t.Error("sakila.actor.first_name cardinality should in (0, 1), now it's", a) + } + connTest.Database = orgDatabase } func TestDangerousSQL(t *testing.T) { @@ -63,11 +107,15 @@ func TestWarningsAndQueryCost(t *testing.T) { if err != nil { t.Error("Query Error: ", err) } else { - for _, w := range res.Warning { - pretty.Println(w.Str(2)) + for res.Warning.Next() { + var str string + err = res.Warning.Scan(str) + if err != nil { + t.Error(err.Error()) + } + pretty.Println(str) } - fmt.Println(res.QueryCost) - pretty.Println(err) + fmt.Println(res.QueryCost, err) } } @@ -79,16 +127,6 @@ func TestVersion(t *testing.T) { fmt.Println(version) } -func TestSource(t *testing.T) { - res, err := connTest.Source("testdata/" + t.Name() + ".sql") - if err != nil { - t.Error("Query Error: ", err) - } - if res[0].Rows[0].Int(0) != 1 || res[1].Rows[0].Int(0) != 1 { - t.Error("Source result not match, expect 1, 1") - } -} - func TestRemoveSQLComments(t *testing.T) { SQLs := []string{ `-- comment`, @@ -109,3 +147,22 @@ comment*/`, t.Error(err) } } + +func TestSingleIntValue(t *testing.T) { + val, err := connTest.SingleIntValue("read_only") + if err != nil { + t.Error(err) + } + if val < 0 { + t.Error("SingleIntValue, return should large than zero") + } +} + +func TestIsView(t *testing.T) { + originalDatabase := connTest.Database + connTest.Database = "sakila" + if !connTest.IsView("actor_info") { + t.Error("actor_info should be a VIEW") + } + connTest.Database = originalDatabase +} diff --git a/database/privilege.go b/database/privilege.go index 762de2898d046ae29aca169a4c92ee971fd74f00..dca833e3b9d11636675dd87ad18220f449f5dfe5 100644 --- a/database/privilege.go +++ b/database/privilege.go @@ -18,6 +18,7 @@ package database import ( "errors" + "fmt" "strings" "github.com/XiaoMi/soar/common" @@ -29,8 +30,13 @@ func (db *Connector) CurrentUser() (string, string, error) { if err != nil { return "", "", err } - if len(res.Rows) > 0 { - cols := strings.Split(res.Rows[0].Str(0), "@") + if res.Rows.Next() { + var currentUser string + err = res.Rows.Scan(¤tUser) + if err != nil { + return "", "", err + } + cols := strings.Split(currentUser, "@") if len(cols) == 2 { user := strings.Trim(cols[0], "'") host := strings.Trim(cols[1], "'") @@ -51,14 +57,20 @@ func (db *Connector) HasSelectPrivilege() bool { common.Log.Error("User: %s, HasSelectPrivilege: %s", db.User, err.Error()) return false } - res, err := db.Query("select Select_priv from mysql.user where user='%s' and host='%s'", user, host) + res, err := db.Query(fmt.Sprintf("select Select_priv from mysql.user where user='%s' and host='%s'", user, host)) if err != nil { common.Log.Error("HasSelectPrivilege, DSN: %s, Error: %s", db.Addr, err.Error()) return false } // Select_priv - if len(res.Rows) > 0 { - if res.Rows[0].Str(0) == "Y" { + if res.Rows.Next() { + var selectPrivilege string + err = res.Rows.Scan(&selectPrivilege) + if err != nil { + common.Log.Error("HasSelectPrivilege, Scan Error: %s", err.Error()) + return false + } + if selectPrivilege == "Y" { return true } } @@ -79,24 +91,31 @@ func (db *Connector) HasAllPrivilege() bool { common.Log.Error("HasAllPrivilege, DSN: %s, Error: %s", db.Addr, err.Error()) return false } + var priv string - if len(res.Rows) > 0 { - priv = res.Rows[0].Str(0) - } else { - common.Log.Error("HasAllPrivilege, DSN: %s, get privilege string error", db.Addr) - return false + if res.Rows.Next() { + err = res.Rows.Scan(&priv) + if err != nil { + common.Log.Error("HasAllPrivilege, DSN: %s, Scan error", db.Addr) + return false + } } // get all privilege status - res, err = db.Query("select concat("+priv+") from mysql.user where user='%s' and host='%s'", user, host) + res, err = db.Query(fmt.Sprintf("select concat("+priv+") from mysql.user where user='%s' and host='%s'", user, host)) if err != nil { common.Log.Error("HasAllPrivilege, DSN: %s, Error: %s", db.Addr, err.Error()) return false } // %_priv - if len(res.Rows) > 0 { - if strings.Replace(res.Rows[0].Str(0), "Y", "", -1) == "" { + if res.Rows.Next() { + err = res.Rows.Scan(&priv) + if err != nil { + common.Log.Error("HasAllPrivilege, DSN: %s, Scan error", db.Addr) + return false + } + if strings.Replace(priv, "Y", "", -1) == "" { return true } } diff --git a/database/privilege_test.go b/database/privilege_test.go index 4ed0fbc609d70550cca119c59b6e7583975e0fdc..3690fd3dc5c99ac6fc51051f535e63fd4c8489db 100644 --- a/database/privilege_test.go +++ b/database/privilege_test.go @@ -18,6 +18,16 @@ package database import "testing" +func TestCurrentUser(t *testing.T) { + user, host, err := connTest.CurrentUser() + if err != nil { + t.Error(err.Error()) + } + if user != "root" || host != "%" { + t.Errorf("Want user: root, host: %%. Get user: %s, host: %s", user, host) + } +} + func TestHasSelectPrivilege(t *testing.T) { if !connTest.HasSelectPrivilege() { t.Errorf("DSN: %s, User: %s, should has select privilege", connTest.Addr, connTest.User) diff --git a/database/profiling.go b/database/profiling.go index a6072313e09e554ecb04ac2fd89aa91bac3a9be1..f85a6b0231bfd259660db702f13b144e4a7193b6 100644 --- a/database/profiling.go +++ b/database/profiling.go @@ -19,9 +19,7 @@ package database import ( "errors" "fmt" - "io" "strings" - "time" "github.com/XiaoMi/soar/common" @@ -40,95 +38,79 @@ type ProfilingRow struct { // TODO: 支持show profile all,不过目前看all的信息过多有点眼花缭乱 } -// Profiling 执行SQL,并对其Profiling -func (db *Connector) Profiling(sql string, params ...interface{}) (*QueryResult, error) { +// Profiling 执行SQL,并对其 Profile +func (db *Connector) Profiling(sql string, params ...interface{}) ([]ProfilingRow, error) { + var rows []ProfilingRow // 过滤不需要 profiling 的 SQL switch sqlparser.Preview(sql) { case sqlparser.StmtSelect, sqlparser.StmtUpdate, sqlparser.StmtDelete: default: - return nil, errors.New("no need profiling") + return rows, errors.New("no need profiling") } // 测试环境如果检查是关闭的,则SQL不会被执行 if common.Config.TestDSN.Disable { - return nil, errors.New("Dsn Disable") + return rows, errors.New("dsn is disable") } // 数据库安全性检查:如果 Connector 的 IP 端口与 TEST 环境不一致,则启用 SQL 白名单 - // 不在白名单中的SQL不允许执行 + // 不在白名单中的 SQL 不允许执行 // 执行环境与test环境不相同 if db.Addr != common.Config.TestDSN.Addr && db.dangerousQuery(sql) { - return nil, fmt.Errorf("query execution deny: Execute SQL with DSN(%s/%s) '%s'", + return rows, fmt.Errorf("query execution deny: Execute SQL with DSN(%s/%s) '%s'", db.Addr, db.Database, fmt.Sprintf(sql, params...)) } common.Log.Debug("Execute SQL with DSN(%s/%s) : %s", db.Addr, db.Database, sql) - conn := db.NewConnection() - - // 设置SQL连接超时时间 - conn.SetTimeout(time.Duration(common.Config.ConnTimeOut) * time.Second) + conn, err := db.NewConnection() + if err != nil { + return rows, err + } defer conn.Close() - err := conn.Connect() + + // Keep connection + // https://github.com/go-sql-driver/mysql/issues/208 + trx, err := conn.Begin() if err != nil { - return nil, err + return rows, err } + defer trx.Rollback() + + // 开启 Profiling + _, err = trx.Query("set @@profiling=1") + common.LogIfError(err, "") - // 添加SQL执行超时限制 - ch := make(chan QueryResult, 1) - go func() { - // 开启Profiling - _, _, err = conn.Query("set @@profiling=1") - common.LogIfError(err, "") + // 执行 SQL,抛弃返回结果 + tmpRes, err := trx.Query(sql, params...) + if err != nil { + return rows, err + } + for tmpRes.Next() { + continue + } - // 执行SQL,抛弃返回结果 - result, err := conn.Start(sql, params...) + // 返回 Profiling 结果 + res, err := trx.Query("show profile") + for res.Next() { + var profileRow ProfilingRow + err := res.Scan(&profileRow.Status, &profileRow.Duration) if err != nil { - ch <- QueryResult{ - Error: err, - } - return + common.LogIfError(err, "") } - row := result.MakeRow() - for { - err = result.ScanRow(row) - if err == io.EOF { - break - } - } - - // 返回Profiling结果 - res := QueryResult{} - res.Rows, res.Result, res.Error = conn.Query("show profile") - _, _, err = conn.Query("set @@profiling=0") - common.LogIfError(err, "") - ch <- res - }() - - select { - case res := <-ch: - return &res, res.Error - case <-time.After(time.Duration(common.Config.QueryTimeOut) * time.Second): - return nil, errors.New("query execution timeout") + rows = append(rows, profileRow) } -} -func getProfiling(res *QueryResult) Profiling { - var rows []ProfilingRow - for _, row := range res.Rows { - rows = append(rows, ProfilingRow{ - Status: row.Str(0), - Duration: row.Float(1), - }) - } - return Profiling{Rows: rows} + // 关闭 Profiling + _, err = trx.Query("set @@profiling=0") + common.LogIfError(err, "") + return rows, err } // FormatProfiling 格式化输出Profiling信息 -func FormatProfiling(res *QueryResult) string { - profiling := getProfiling(res) +func FormatProfiling(rows []ProfilingRow) string { str := []string{"| Status | Duration |"} str = append(str, "| --- | --- |") - for _, row := range profiling.Rows { + for _, row := range rows { str = append(str, fmt.Sprintf("| %s | %f |", row.Status, row.Duration)) } return strings.Join(str, "\n") diff --git a/database/profiling_test.go b/database/profiling_test.go index 4b226ba90e4db8f482f3f55fc6f16c92a247ed8a..78d28fef46876fb1aca1cb2155cf74bab0886bae 100644 --- a/database/profiling_test.go +++ b/database/profiling_test.go @@ -19,35 +19,21 @@ package database import ( "testing" - "github.com/XiaoMi/soar/common" - "github.com/kr/pretty" ) func TestProfiling(t *testing.T) { - common.Config.QueryTimeOut = 1 - res, err := connTest.Profiling("select 1") - if err == nil { - pretty.Println(res) - } else { + rows, err := connTest.Profiling("select 1") + if err != nil { t.Error(err) } + pretty.Println(rows) } func TestFormatProfiling(t *testing.T) { res, err := connTest.Profiling("select 1") - if err == nil { - pretty.Println(FormatProfiling(res)) - } else { - t.Error(err) - } -} - -func TestGetProfiling(t *testing.T) { - res, err := connTest.Profiling("select 1") - if err == nil { - pretty.Println(getProfiling(res)) - } else { + if err != nil { t.Error(err) } + pretty.Println(FormatProfiling(res)) } diff --git a/database/sampling.go b/database/sampling.go index f8565b4cd610fc6259ccc089729418a6832c6af7..c9c3e60f4db366daeb022d6e94ba1b030c71c500 100644 --- a/database/sampling.go +++ b/database/sampling.go @@ -17,14 +17,9 @@ package database import ( + "database/sql" "fmt" - "io" - "strconv" - "strings" - "time" - "github.com/XiaoMi/soar/common" - "github.com/ziutek/mymysql/mysql" ) /*-------------------- @@ -57,21 +52,17 @@ func (db *Connector) SamplingData(remote Connector, tables ...string) error { maxValCount := 200 // 获取数据库连接对象 - conn := remote.NewConnection() - localConn := db.NewConnection() - - // 连接数据库 - err := conn.Connect() - defer conn.Close() + conn, err := remote.NewConnection() if err != nil { return err } + defer conn.Close() - err = localConn.Connect() - defer localConn.Close() + localConn, err := db.NewConnection() if err != nil { return err } + defer localConn.Close() for _, table := range tables { // 表类型检查 @@ -109,119 +100,128 @@ func (db *Connector) SamplingData(remote Connector, tables ...string) error { // 开始从环境中泵取数据 // 因为涉及到的数据量问题,所以泵取与插入时同时进行的 // TODO 加 ref link -func startSampling(conn, localConn mysql.Conn, database, table string, factor float64, wants, maxValCount int) error { - // 从线上数据库获取所需dump的表中所有列的数据类型,备用 - // 由于测试库中的库表为刚建立的,所以在information_schema中很可能没有这个表的信息 - var dataTypes []string - q := fmt.Sprintf("select DATA_TYPE from information_schema.COLUMNS where TABLE_SCHEMA='%s' and TABLE_NAME = '%s'", - database, table) - common.Log.Debug("Sampling data execute: %s", q) - rs, _, err := localConn.Query(q) - if err != nil { - common.Log.Debug("Sampling data got data type Err: %v", err) - } else { - for _, r := range rs { - dataTypes = append(dataTypes, r.Str(0)) +func startSampling(conn, localConn *sql.DB, database, table string, factor float64, wants, maxValCount int) error { + return nil + // TODO: + /* + // 从线上数据库获取所需dump的表中所有列的数据类型,备用 + // 由于测试库中的库表为刚建立的,所以在information_schema中很可能没有这个表的信息 + var dataTypes []string + q := fmt.Sprintf("select DATA_TYPE from information_schema.COLUMNS where TABLE_SCHEMA='%s' and TABLE_NAME = '%s'", + database, table) + common.Log.Debug("Sampling data execute: %s", q) + rs, err := localConn.Query(q) + if err != nil { + common.Log.Debug("Sampling data got data type Err: %v", err) + } else { + for rs.Next() { + var dataType string + err = rs.Scan(&dataType) + if err != nil { + return err + } + dataTypes = append(dataTypes, dataType) + } } - } - - // 生成where条件 - where := fmt.Sprintf("where RAND()<=%f", factor) - if factor >= 1 { - where = "" - } - - sql := fmt.Sprintf("select * from `%s` %s limit %d;", table, where, wants) - res, err := conn.Start(sql) - if err != nil { - return err - } - // GetRow method allocates a new chunk of memory for every received row. - row := res.MakeRow() - rowCount := 0 - valCount := 0 - - // 获取所有的列名 - columns := make([]string, len(res.Fields())) - for i, filed := range res.Fields() { - columns[i] = filed.Name - } - colDef := strings.Join(columns, ",") - - // 开始填充数据 - var valList []string - for { - err := res.ScanRow(row) - if err == io.EOF { - // 扫描结束 - if len(valList) > 0 { - // 如果缓存中还存在未插入的数据,则把缓存中的数据刷新到DB中 - doSampling(localConn, database, table, colDef, strings.Join(valList, ",")) - } - break + // 生成where条件 + where := fmt.Sprintf("where RAND()<=%f", factor) + if factor >= 1 { + where = "" } + sql := fmt.Sprintf("select * from `%s` %s limit %d;", table, where, wants) + res, err := conn.Query(sql) if err != nil { return err } - values := make([]string, len(columns)) - for i := range row { - // TODO 不支持坐标类型的导出 - switch data := row[i].(type) { - case nil: - // str = "" - case []byte: - // 先尝试转成数字,如果报错则转换成string - if v, err := row.Int64Err(i); err != nil { - values[i] = string(data) - } else { - values[i] = strconv.FormatInt(v, 10) + // GetRow method allocates a new chunk of memory for every received row. + row := res.MakeRow() + rowCount := 0 + valCount := 0 + + // 获取所有的列名 + columns := make([]string, len(res.Fields())) + for i, filed := range res.Fields() { + columns[i] = filed.Name + } + colDef := strings.Join(columns, ",") + + // 开始填充数据 + var valList []string + for { + err := res.ScanRow(row) + if err == io.EOF { + // 扫描结束 + if len(valList) > 0 { + // 如果缓存中还存在未插入的数据,则把缓存中的数据刷新到DB中 + doSampling(localConn, database, table, colDef, strings.Join(valList, ",")) } - case time.Time: - values[i] = mysql.TimeString(data) - case time.Duration: - values[i] = mysql.DurationString(data) - default: - values[i] = fmt.Sprint(data) + break } - // 非text/varchar类的数据类型,如果dump出的数据为空,则说明该值为null值 - // 应转换其 value 为 null,如果用空('')进行替代,会导致出现语法错误。 - if len(dataTypes) == len(res.Fields()) && values[i] == "" && - (!strings.Contains(dataTypes[i], "char") || - !strings.Contains(dataTypes[i], "text")) { - values[i] = "null" - } else { - values[i] = "'" + values[i] + "'" + if err != nil { + return err } - } - valuesStr := fmt.Sprintf(`(%s)`, strings.Join(values, `,`)) - valList = append(valList, valuesStr) + values := make([]string, len(columns)) + for i := range row { + // TODO 不支持坐标类型的导出 + switch data := row[i].(type) { + case nil: + // str = "" + case []byte: + // 先尝试转成数字,如果报错则转换成string + if v, err := row.Int64Err(i); err != nil { + values[i] = string(data) + } else { + values[i] = strconv.FormatInt(v, 10) + } + case time.Time: + values[i] = mysql.TimeString(data) + case time.Duration: + values[i] = mysql.DurationString(data) + default: + values[i] = fmt.Sprint(data) + } + + // 非text/varchar类的数据类型,如果dump出的数据为空,则说明该值为null值 + // 应转换其 value 为 null,如果用空('')进行替代,会导致出现语法错误。 + if len(dataTypes) == len(res.Fields()) && values[i] == "" && + (!strings.Contains(dataTypes[i], "char") || + !strings.Contains(dataTypes[i], "text")) { + values[i] = "null" + } else { + values[i] = "'" + values[i] + "'" + } + } - rowCount++ - valCount++ + valuesStr := fmt.Sprintf(`(%s)`, strings.Join(values, `,`)) + valList = append(valList, valuesStr) - if rowCount%maxValCount == 0 { - doSampling(localConn, database, table, colDef, strings.Join(valList, ",")) - valCount = 0 - valList = make([]string, 0) + rowCount++ + valCount++ + if rowCount%maxValCount == 0 { + doSampling(localConn, database, table, colDef, strings.Join(valList, ",")) + valCount = 0 + valList = make([]string, 0) + + } } - } - common.Log.Debug("%d rows sampling out", rowCount) - return nil + common.Log.Debug("%d rows sampling out", rowCount) + return nil + */ } // 将泵取的数据转换成Insert语句并在数据库中执行 -func doSampling(conn mysql.Conn, dbName, table, colDef, values string) { - sql := fmt.Sprintf("Insert into `%s`.`%s`(%s) values%s;", dbName, table, +func doSampling(conn *sql.DB, dbName, table, colDef, values string) { + query := fmt.Sprintf("Insert into `%s`.`%s`(%s) values%s;", dbName, table, colDef, values) - _, _, err := conn.Query(sql) + _, err := conn.Query(query) if err != nil { common.Log.Error("doSampling Error from %s.%s: %v", dbName, table, err) diff --git a/database/sampling_test.go b/database/sampling_test.go index 082d5e9fc3bd03202da2fb7a3897f318e782b7ac..d164b0b936e692cbdb976bc1f417c24733ffbfda 100644 --- a/database/sampling_test.go +++ b/database/sampling_test.go @@ -32,6 +32,8 @@ func TestSamplingData(t *testing.T) { User: common.Config.OnlineDSN.User, Pass: common.Config.OnlineDSN.Password, Database: common.Config.OnlineDSN.Schema, + Charset: common.Config.OnlineDSN.Charset, + Net: common.Config.OnlineDSN.Net, } offline := &Connector{ @@ -39,6 +41,8 @@ func TestSamplingData(t *testing.T) { User: common.Config.TestDSN.User, Pass: common.Config.TestDSN.Password, Database: common.Config.TestDSN.Schema, + Charset: common.Config.TestDSN.Charset, + Net: common.Config.TestDSN.Net, } offline.Database = "test" diff --git a/database/show.go b/database/show.go index a14a5fbe90e555bb8a3f513e95333657eab7e70d..7628bea9f910b6e58b9a0861dc2a749a029505c1 100644 --- a/database/show.go +++ b/database/show.go @@ -18,12 +18,10 @@ package database import ( "fmt" + "github.com/XiaoMi/soar/common" "regexp" "strconv" "strings" - "time" - - "github.com/XiaoMi/soar/common" ) // SHOW TABLE STATUS Syntax @@ -36,11 +34,12 @@ type TableStatInfo struct { } // tableStatusRow 用于 show table status value +// use []byte instead of string, because []byte allow to be null, string not type tableStatusRow struct { Name string // 表名 - Engine string // 该表使用的存储引擎 - Version int // 该表的 .frm 文件版本号 - RowFormat string // 该表使用的行存储格式 + Engine []byte // 该表使用的存储引擎 + Version []byte // 该表的 .frm 文件版本号 + RowFormat []byte // 该表使用的行存储格式 Rows int64 // 表行数, InnoDB 引擎中为预估值,甚至可能会有40%~50%的数值偏差 AvgRowLength int // 平均行长度 @@ -59,15 +58,15 @@ type tableStatusRow struct { // 其他不同的存储引擎中该值的意义可能不尽相同 IndexLength int - DataFree int // 已分配但未使用的字节数 - AutoIncrement int // 下一个自增值 - CreateTime time.Time // 创建时间 - UpdateTime time.Time // 最近一次更新时间,该值不准确 - CheckTime time.Time // 上次检查时间 - Collation string // 字符集及排序规则信息 - Checksum string // 校验和 - CreateOptions string // 创建表的时候的时候一切其他属性 - Comment string // 注释 + DataFree int // 已分配但未使用的字节数 + AutoIncrement []byte // 下一个自增值 + CreateTime []byte // 创建时间 + UpdateTime []byte // 最近一次更新时间,该值不准确 + CheckTime []byte // 上次检查时间 + Collation []byte // 字符集及排序规则信息 + Checksum []byte // 校验和 + CreateOptions []byte // 创建表的时候的时候一切其他属性 + Comment []byte // 注释 } // newTableStat 构造 table Stat 对象 @@ -83,7 +82,7 @@ func (db *Connector) ShowTables() ([]string, error) { defer func() { err := recover() if err != nil { - common.Log.Error("recover ShowTableStatus()", err) + common.Log.Error("recover ShowTables()", err) } }() @@ -92,79 +91,70 @@ func (db *Connector) ShowTables() ([]string, error) { if err != nil { return []string{}, err } + if res.Error != nil { + return []string{}, res.Error + } // 获取值 var tables []string - for _, row := range res.Rows { - tables = append(tables, row.Str(0)) + for res.Rows.Next() { + var table string + err = res.Rows.Scan(&table) + if err != nil { + return []string{}, err + } + tables = append(tables, table) } - return tables, err } // ShowTableStatus 执行 show table status func (db *Connector) ShowTableStatus(tableName string) (*TableStatInfo, error) { - defer func() { - err := recover() - if err != nil { - common.Log.Error("recover ShowTableStatus()", err) - } - }() - // 初始化struct - ts := newTableStat(tableName) + tbStatus := newTableStat(tableName) // 执行 show table status - res, err := db.Query("show table status where name = '%s'", ts.Name) + res, err := db.Query(fmt.Sprintf("show table status where name = '%s'", tbStatus.Name)) if err != nil { - return ts, err - } - - rs := res.Result.Map("Rows") - name := res.Result.Map("Name") - df := res.Result.Map("Data_free") - sum := res.Result.Map("Checksum") - engine := res.Result.Map("Engine") - version := res.Result.Map("Version") - comment := res.Result.Map("Comment") - ai := res.Result.Map("Auto_increment") - collation := res.Result.Map("Collation") - rowFormat := res.Result.Map("Row_format") - checkTime := res.Result.Map("Check_time") - dataLength := res.Result.Map("Data_length") - idxLength := res.Result.Map("Index_length") - createTime := res.Result.Map("Create_time") - updateTime := res.Result.Map("Update_time") - options := res.Result.Map("Create_options") - avgRowLength := res.Result.Map("Avg_row_length") - maxDataLength := res.Result.Map("Max_data_length") + return tbStatus, err + } + if res.Error != nil { + return tbStatus, res.Error + } + ts := tableStatusRow{} + statusFields := make([]interface{}, 0) + fields := map[string]interface{}{ + "Name": &ts.Name, + "Engine": &ts.Engine, + "Version": &ts.Version, + "Row_format": &ts.RowFormat, + "Rows": &ts.Rows, + "Avg_row_length": &ts.AvgRowLength, + "Data_length": &ts.DataLength, + "Max_data_length": &ts.MaxDataLength, + "Index_length": &ts.IndexLength, + "Data_free": &ts.DataFree, + "Auto_increment": &ts.AutoIncrement, + "Create_time": &ts.CreateTime, + "Update_time": &ts.UpdateTime, + "Check_time": &ts.CheckTime, + "Collation": &ts.Collation, + "Checksum": &ts.Checksum, + "Create_options": &ts.CreateOptions, + "Comment": &ts.Comment, + } + cols, err := res.Rows.Columns() + common.LogIfError(err, "") + for _, col := range cols { + statusFields = append(statusFields, fields[col]) + } // 获取值 - for _, row := range res.Rows { - value := tableStatusRow{ - Name: row.Str(name), - Engine: row.Str(engine), - Version: row.Int(version), - Rows: row.Int64(rs), - RowFormat: row.Str(rowFormat), - AvgRowLength: row.Int(avgRowLength), - DataLength: row.Int(dataLength), - MaxDataLength: row.Int(maxDataLength), - IndexLength: row.Int(idxLength), - DataFree: row.Int(df), - AutoIncrement: row.Int(ai), - CreateTime: row.Time(createTime, time.Local), - UpdateTime: row.Time(updateTime, time.Local), - CheckTime: row.Time(checkTime, time.Local), - Collation: row.Str(collation), - Checksum: row.Str(sum), - CreateOptions: row.Str(options), - Comment: row.Str(comment), - } - ts.Rows = append(ts.Rows, value) + for res.Rows.Next() { + res.Rows.Scan(statusFields...) + tbStatus.Rows = append(tbStatus.Rows, ts) } - - return ts, err + return tbStatus, err } // https://dev.mysql.com/doc/refman/5.7/en/show-index.html @@ -172,7 +162,7 @@ func (db *Connector) ShowTableStatus(tableName string) (*TableStatInfo, error) { // TableIndexInfo 用以保存 show index 之后获取的 index 信息 type TableIndexInfo struct { TableName string - IdxRows []TableIndexRow + Rows []TableIndexRow } // TableIndexRow 用以存放show index之后获取的每一条index信息 @@ -190,13 +180,14 @@ type TableIndexRow struct { IndexType string // BTREE, FULLTEXT, HASH, RTREE Comment string IndexComment string + Visible string } // NewTableIndexInfo 构造 TableIndexInfo func NewTableIndexInfo(tableName string) *TableIndexInfo { return &TableIndexInfo{ TableName: tableName, - IdxRows: make([]TableIndexRow, 0), + Rows: make([]TableIndexRow, 0), } } @@ -205,43 +196,32 @@ func (db *Connector) ShowIndex(tableName string) (*TableIndexInfo, error) { tbIndex := NewTableIndexInfo(tableName) // 执行 show create table - res, err := db.Query("show index from `%s`.`%s`", db.Database, tableName) + res, err := db.Query(fmt.Sprintf("show index from `%s`.`%s`", db.Database, tableName)) if err != nil { return nil, err } - - table := res.Result.Map("Table") - unique := res.Result.Map("Non_unique") - keyName := res.Result.Map("Key_name") - seq := res.Result.Map("Seq_in_index") - cName := res.Result.Map("Column_name") - collation := res.Result.Map("Collation") - cardinality := res.Result.Map("Cardinality") - subPart := res.Result.Map("Sub_part") - packed := res.Result.Map("Packed") - null := res.Result.Map("Null") - idxType := res.Result.Map("Index_type") - comment := res.Result.Map("Comment") - idxComment := res.Result.Map("Index_comment") + if res.Error != nil { + return nil, res.Error + } // 获取值 - for _, row := range res.Rows { - value := TableIndexRow{ - Table: row.Str(table), - NonUnique: row.Int(unique), - KeyName: row.Str(keyName), - SeqInIndex: row.Int(seq), - ColumnName: row.Str(cName), - Collation: row.Str(collation), - Cardinality: row.Int(cardinality), - SubPart: row.Int(subPart), - Packed: row.Int(packed), - Null: row.Str(null), - IndexType: row.Str(idxType), - Comment: row.Str(comment), - IndexComment: row.Str(idxComment), - } - tbIndex.IdxRows = append(tbIndex.IdxRows, value) + for res.Rows.Next() { + var ti TableIndexRow + res.Rows.Scan(&ti.Table, + &ti.NonUnique, + &ti.KeyName, + &ti.SeqInIndex, + &ti.ColumnName, + &ti.Collation, + &ti.Cardinality, + &ti.SubPart, + &ti.Packed, + &ti.Null, + &ti.IndexType, + &ti.Comment, + &ti.IndexComment, + &ti.Visible) + tbIndex.Rows = append(tbIndex.Rows, ti) } return tbIndex, err } @@ -257,7 +237,7 @@ const ( IndexNonUnique = IndexSelectKey("NonUnique") // 唯一索引 ) -// FindIndex 获取TableIndexInfo中需要的索引 +// FindIndex 获取 TableIndexInfo 中需要的索引 func (tbIndex *TableIndexInfo) FindIndex(arg IndexSelectKey, value string) []TableIndexRow { var result []TableIndexRow if tbIndex == nil { @@ -268,28 +248,28 @@ func (tbIndex *TableIndexInfo) FindIndex(arg IndexSelectKey, value string) []Tab switch arg { case IndexKeyName: - for _, index := range tbIndex.IdxRows { + for _, index := range tbIndex.Rows { if strings.ToLower(index.KeyName) == value { result = append(result, index) } } case IndexColumnName: - for _, index := range tbIndex.IdxRows { + for _, index := range tbIndex.Rows { if strings.ToLower(index.ColumnName) == value { result = append(result, index) } } case IndexIndexType: - for _, index := range tbIndex.IdxRows { + for _, index := range tbIndex.Rows { if strings.ToLower(index.IndexType) == value { result = append(result, index) } } case IndexNonUnique: - for _, index := range tbIndex.IdxRows { + for _, index := range tbIndex.Rows { unique := strconv.Itoa(index.NonUnique) if unique == value { result = append(result, index) @@ -316,12 +296,12 @@ type TableDesc struct { type TableDescValue struct { Field string // 列名 Type string // 数据类型 + Collation []byte // 字符集 Null string // 是否有NULL(NO、YES) - Collation string // 字符集 - Privileges string // 权限s Key string // 键类型 - Default string // 默认值 + Default []byte // 默认值 Extra string // 其他 + Privileges string // 权限 Comment string // 备注 } @@ -338,35 +318,27 @@ func (db *Connector) ShowColumns(tableName string) (*TableDesc, error) { tbDesc := NewTableDesc(tableName) // 执行 show create table - res, err := db.Query("show full columns from `%s`.`%s`", db.Database, tableName) + res, err := db.Query(fmt.Sprintf("show full columns from `%s`.`%s`", db.Database, tableName)) if err != nil { return nil, err } - - field := res.Result.Map("Field") - tp := res.Result.Map("Type") - null := res.Result.Map("Null") - key := res.Result.Map("Key") - def := res.Result.Map("Default") - extra := res.Result.Map("Extra") - collation := res.Result.Map("Collation") - privileges := res.Result.Map("Privileges") - comm := res.Result.Map("Comment") + if res.Error != nil { + return nil, res.Error + } // 获取值 - for _, row := range res.Rows { - value := TableDescValue{ - Field: row.Str(field), - Type: row.Str(tp), - Null: row.Str(null), - Key: row.Str(key), - Default: row.Str(def), - Extra: row.Str(extra), - Privileges: row.Str(privileges), - Collation: row.Str(collation), - Comment: row.Str(comm), - } - tbDesc.DescValues = append(tbDesc.DescValues, value) + for res.Rows.Next() { + var tc TableDescValue + res.Rows.Scan(&tc.Field, + &tc.Type, + &tc.Collation, + &tc.Null, + &tc.Key, + &tc.Default, + &tc.Extra, + &tc.Privileges, + &tc.Comment) + tbDesc.DescValues = append(tbDesc.DescValues, tc) } return tbDesc, err } @@ -383,18 +355,21 @@ func (td TableDesc) Columns() []string { // showCreate show create func (db *Connector) showCreate(createType, name string) (string, error) { // 执行 show create table - res, err := db.Query("show create %s `%s`", createType, name) + res, err := db.Query(fmt.Sprintf("show create %s `%s`", createType, name)) if err != nil { return "", err } + if res.Error != nil { + return "", res.Error + } - // 获取ddl - var ddl string - for _, row := range res.Rows { - ddl = row.Str(1) + // 获取 CREATE TABLE 语句 + var tableName, createTable string + for res.Rows.Next() { + res.Rows.Scan(&tableName, &createTable) } - return ddl, err + return createTable, err } // ShowCreateDatabase show create database @@ -451,6 +426,10 @@ func (db *Connector) FindColumn(name, dbName string, tables ...string) ([]*commo "c.TABLE_NAME,c.TABLE_SCHEMA,c.COLUMN_TYPE,c.CHARACTER_SET_NAME, c.COLLATION_NAME "+ "FROM `INFORMATION_SCHEMA`.`COLUMNS` as c where c.COLUMN_NAME = '%s' ", name) + if dbName != "" { + sql += fmt.Sprintf(" and c.table_schema = '%s'", dbName) + } + if len(tables) > 0 { var tmp []string for _, table := range tables { @@ -459,32 +438,24 @@ func (db *Connector) FindColumn(name, dbName string, tables ...string) ([]*commo sql += fmt.Sprintf(" and c.table_name in (%s)", strings.Join(tmp, ",")) } - if dbName != "" { - sql += fmt.Sprintf(" and c.table_schema = '%s'", dbName) - } - + common.Log.Debug("FindColumn, execute SQL: %s", sql) res, err := db.Query(sql) if err != nil { common.Log.Error("(db *Connector) FindColumn Error : ", err) return columns, err } + if res.Error != nil { + common.Log.Error("(db *Connector) FindColumn Error : ", res.Error) + return columns, res.Error + } - tbName := res.Result.Map("TABLE_NAME") - schema := res.Result.Map("TABLE_SCHEMA") - colTyp := res.Result.Map("COLUMN_TYPE") - colCharset := res.Result.Map("CHARACTER_SET_NAME") - collation := res.Result.Map("COLLATION_NAME") - - // 获取ddl - for _, row := range res.Rows { - col := &common.Column{ - Name: name, - Table: row.Str(tbName), - DB: row.Str(schema), - DataType: row.Str(colTyp), - Character: row.Str(colCharset), - Collation: row.Str(collation), - } + var col common.Column + for res.Rows.Next() { + res.Rows.Scan(&col.Table, + &col.DB, + &col.DataType, + &col.Character, + &col.Collation) // 填充字符集和排序规则 if col.Character == "" { @@ -494,40 +465,56 @@ func (db *Connector) FindColumn(name, dbName string, tables ...string) ([]*commo sql = fmt.Sprintf("SELECT `t`.`TABLE_COLLATION` FROM `INFORMATION_SCHEMA`.`TABLES` AS `t` "+ "WHERE `t`.`TABLE_NAME`='%s' AND `t`.`TABLE_SCHEMA` = '%s'", col.Table, col.DB) - var newRes *QueryResult + + common.Log.Debug("FindColumn, execute SQL: %s", sql) + var newRes QueryResult newRes, err = db.Query(sql) if err != nil { common.Log.Error("(db *Connector) FindColumn Error : ", err) return columns, err } + if res.Error != nil { + common.Log.Error("(db *Connector) FindColumn Error : ", res.Error) + return columns, res.Error + } - tbCollation := newRes.Rows[0].Str(0) + var tbCollation string + if newRes.Rows.Next() { + newRes.Rows.Scan(&tbCollation) + } if tbCollation != "" { col.Character = strings.Split(tbCollation, "_")[0] col.Collation = tbCollation } } - - columns = append(columns, col) + columns = append(columns, &col) } - return columns, err } -// IsFKey 判断列是否是外键 -func (db *Connector) IsFKey(dbName, tbName, column string) bool { +// IsForeignKey 判断列是否是外键 +func (db *Connector) IsForeignKey(dbName, tbName, column string) bool { sql := fmt.Sprintf("SELECT REFERENCED_COLUMN_NAME FROM INFORMATION_SCHEMA.KEY_COLUMN_USAGE C "+ "WHERE REFERENCED_TABLE_SCHEMA <> 'NULL' AND"+ " TABLE_NAME='%s' AND"+ " TABLE_SCHEMA='%s' AND"+ " COLUMN_NAME='%s'", tbName, dbName, column) + common.Log.Debug("IsForeignKey, execute SQL: %s", sql) res, err := db.Query(sql) - if err == nil && len(res.Rows) == 0 { + if err != nil { + common.Log.Error("IsForeignKey, Error: %s", err.Error()) + return false + } + if res.Error != nil { + common.Log.Error("IsForeignKey, Error: %s", res.Error.Error()) return false } + if res.Rows.Next() { + return true + } - return true + return false } // Reference 用于存储关系 @@ -535,11 +522,11 @@ type Reference map[string][]ReferenceValue // ReferenceValue 用于处理表之间的关系 type ReferenceValue struct { - RefDBName string // 夫表所属数据库 - RefTable string // 父表 - DBName string // 子表所属数据库 - Table string // 子表 - ConstraintName string // 关系名称 + ReferencedTableSchema string // 夫表所属数据库 + ReferencedTableName string // 父表 + TableSchema string // 子表所属数据库 + TableName string // 子表 + ConstraintName string // 关系名称 } // ShowReference 查找所有的外键信息 @@ -555,30 +542,26 @@ WHERE C.REFERENCED_TABLE_NAME IS NOT NULL` sql = sql + extra } + common.Log.Debug("ShowReference, execute SQL: %s", sql) // 执行SQL查找外键关联关系 res, err := db.Query(sql) if err != nil { return referenceValues, err } - - refDb := res.Result.Map("REFERENCED_TABLE_SCHEMA") - refTb := res.Result.Map("REFERENCED_TABLE_NAME") - schema := res.Result.Map("TABLE_SCHEMA") - tb := res.Result.Map("TABLE_NAME") - cName := res.Result.Map("CONSTRAINT_NAME") + if res.Error != nil { + return referenceValues, res.Error + } // 获取值 - for _, row := range res.Rows { - value := ReferenceValue{ - RefDBName: row.Str(refDb), - RefTable: row.Str(refTb), - DBName: row.Str(schema), - Table: row.Str(tb), - ConstraintName: row.Str(cName), - } - referenceValues = append(referenceValues, value) + for res.Rows.Next() { + var rv ReferenceValue + res.Rows.Scan(&rv.ReferencedTableSchema, + &rv.ReferencedTableName, + &rv.TableSchema, + &rv.TableName, + &rv.ConstraintName) + referenceValues = append(referenceValues, rv) } return referenceValues, err - } diff --git a/database/show_test.go b/database/show_test.go index 66f98d4714b43f5dae50f0611891569105c1ebfe..68ae97d0290f06223e4de54ab93cf4f5678363da 100644 --- a/database/show_test.go +++ b/database/show_test.go @@ -20,75 +20,144 @@ import ( "fmt" "testing" + "github.com/XiaoMi/soar/common" + "github.com/kr/pretty" "vitess.io/vitess/go/vt/sqlparser" ) func TestShowTableStatus(t *testing.T) { - connTest.Database = "information_schema" - ts, err := connTest.ShowTableStatus("TABLES") + orgDatabase := connTest.Database + connTest.Database = "sakila" + ts, err := connTest.ShowTableStatus("film") if err != nil { t.Error("ShowTableStatus Error: ", err) } + if string(ts.Rows[0].Engine) != "InnoDB" { + t.Error("film table should be InnoDB engine") + } pretty.Println(ts) + + connTest.Database = "sakila" + ts, err = connTest.ShowTableStatus("actor_info") + if err != nil { + t.Error("ShowTableStatus Error: ", err) + } + if string(ts.Rows[0].Comment) != "VIEW" { + t.Error("actor_info should be VIEW") + } + pretty.Println(ts) + connTest.Database = orgDatabase } func TestShowTables(t *testing.T) { - connTest.Database = "information_schema" + orgDatabase := connTest.Database + connTest.Database = "sakila" ts, err := connTest.ShowTables() if err != nil { t.Error("ShowTableStatus Error: ", err) } - pretty.Println(ts) + + err = common.GoldenDiff(func() { + for _, table := range ts { + fmt.Println(table) + } + }, t.Name(), update) + if err != nil { + t.Error(err) + } + connTest.Database = orgDatabase } func TestShowCreateTable(t *testing.T) { - connTest.Database = "information_schema" - ts, err := connTest.ShowCreateTable("TABLES") + orgDatabase := connTest.Database + connTest.Database = "sakila" + ts, err := connTest.ShowCreateTable("film") if err != nil { t.Error("ShowCreateTable Error: ", err) } - fmt.Println(ts) - stmt, err := sqlparser.Parse(ts) - pretty.Println(stmt, err) + + err = common.GoldenDiff(func() { + fmt.Println(ts) + stmt, err := sqlparser.Parse(ts) + if err != nil { + t.Error(err.Error()) + } + pretty.Println(stmt, err) + }, t.Name(), update) + if err != nil { + t.Error(err) + } + + connTest.Database = orgDatabase } func TestShowIndex(t *testing.T) { - connTest.Database = "information_schema" - ti, err := connTest.ShowIndex("TABLES") + orgDatabase := connTest.Database + connTest.Database = "sakila" + ti, err := connTest.ShowIndex("film") if err != nil { t.Error("ShowIndex Error: ", err) } - pretty.Println(ti.FindIndex(IndexKeyName, "idx_store_id_film_id")) + + err = common.GoldenDiff(func() { + pretty.Println(ti) + pretty.Println(ti.FindIndex(IndexKeyName, "idx_title")) + }, t.Name(), update) + if err != nil { + t.Error(err) + } + + connTest.Database = orgDatabase } func TestShowColumns(t *testing.T) { - connTest.Database = "information_schema" - ti, err := connTest.ShowColumns("TABLES") + orgDatabase := connTest.Database + connTest.Database = "sakila" + ti, err := connTest.ShowColumns("film") if err != nil { t.Error("ShowColumns Error: ", err) } - pretty.Println(ti) + + err = common.GoldenDiff(func() { + pretty.Println(ti) + }, t.Name(), update) + if err != nil { + t.Error(err) + } + + connTest.Database = orgDatabase } func TestFindColumn(t *testing.T) { - ti, err := connTest.FindColumn("id", "") + ti, err := connTest.FindColumn("film_id", "sakila", "film") if err != nil { t.Error("FindColumn Error: ", err) } - pretty.Println(ti) + err = common.GoldenDiff(func() { + pretty.Println(ti) + }, t.Name(), update) + if err != nil { + t.Error(err) + } +} + +func TestIsFKey(t *testing.T) { + if !connTest.IsForeignKey("sakila", "film", "language_id") { + t.Error("want True. got false") + } } func TestShowReference(t *testing.T) { - rv, err := connTest.ShowReference("test2", "homeImg") + rv, err := connTest.ShowReference("sakila", "film") if err != nil { t.Error("ShowReference Error: ", err) } - pretty.Println(rv) -} -func TestIsFKey(t *testing.T) { - if !connTest.IsFKey("sakila", "film", "language_id") { - t.Error("want True. got false") + err = common.GoldenDiff(func() { + pretty.Println(rv) + }, t.Name(), update) + if err != nil { + t.Error(err) } } diff --git a/database/testdata/TestFindColumn.golden b/database/testdata/TestFindColumn.golden new file mode 100644 index 0000000000000000000000000000000000000000..c7deffee353e83627e818875dd3bf283304ce80d --- /dev/null +++ b/database/testdata/TestFindColumn.golden @@ -0,0 +1,18 @@ +[]*common.Column{ + &common.Column{ + Name: "", + Alias: nil, + Table: "film", + DB: "sakila", + DataType: "smallint(5) unsigned", + Character: "utf8", + Collation: "utf8_general_ci", + Cardinality: 0, + Null: "", + Key: "", + Default: "", + Extra: "", + Comment: "", + Privileges: "", + }, +} diff --git a/database/testdata/TestFormatTrace.golden b/database/testdata/TestFormatTrace.golden new file mode 100644 index 0000000000000000000000000000000000000000..b216399779df18fb839bc7e5bcd44a4fc7982b90 --- /dev/null +++ b/database/testdata/TestFormatTrace.golden @@ -0,0 +1,36 @@ + +```sql +select 1 +``` + +```json +{ + "steps": [ + { + "join_preparation": { + "select#": 1, + "steps": [ + { + "expanded_query": "/* select#1 */ select 1 AS `1`" + } + ] + } + }, + { + "join_optimization": { + "select#": 1, + "steps": [ + ] + } + }, + { + "join_explain": { + "select#": 1, + "steps": [ + ] + } + } + ] +} +``` + diff --git a/database/testdata/TestShowColumns.golden b/database/testdata/TestShowColumns.golden new file mode 100644 index 0000000000000000000000000000000000000000..782a412be2134426a1fa811deeade89911ec6452 --- /dev/null +++ b/database/testdata/TestShowColumns.golden @@ -0,0 +1,148 @@ +&database.TableDesc{ + Name: "film", + DescValues: { + { + Field: "film_id", + Type: "smallint(5) unsigned", + Collation: nil, + Null: "NO", + Key: "PRI", + Default: nil, + Extra: "auto_increment", + Privileges: "select,insert,update,references", + Comment: "", + }, + { + Field: "title", + Type: "varchar(255)", + Collation: {0x75, 0x74, 0x66, 0x38, 0x5f, 0x67, 0x65, 0x6e, 0x65, 0x72, 0x61, 0x6c, 0x5f, 0x63, 0x69}, + Null: "NO", + Key: "MUL", + Default: nil, + Extra: "", + Privileges: "select,insert,update,references", + Comment: "", + }, + { + Field: "description", + Type: "text", + Collation: {0x75, 0x74, 0x66, 0x38, 0x5f, 0x67, 0x65, 0x6e, 0x65, 0x72, 0x61, 0x6c, 0x5f, 0x63, 0x69}, + Null: "YES", + Key: "", + Default: nil, + Extra: "", + Privileges: "select,insert,update,references", + Comment: "", + }, + { + Field: "release_year", + Type: "year(4)", + Collation: nil, + Null: "YES", + Key: "", + Default: nil, + Extra: "", + Privileges: "select,insert,update,references", + Comment: "", + }, + { + Field: "language_id", + Type: "tinyint(3) unsigned", + Collation: nil, + Null: "NO", + Key: "MUL", + Default: nil, + Extra: "", + Privileges: "select,insert,update,references", + Comment: "", + }, + { + Field: "original_language_id", + Type: "tinyint(3) unsigned", + Collation: nil, + Null: "YES", + Key: "MUL", + Default: nil, + Extra: "", + Privileges: "select,insert,update,references", + Comment: "", + }, + { + Field: "rental_duration", + Type: "tinyint(3) unsigned", + Collation: nil, + Null: "NO", + Key: "", + Default: {0x33}, + Extra: "", + Privileges: "select,insert,update,references", + Comment: "", + }, + { + Field: "rental_rate", + Type: "decimal(4,2)", + Collation: nil, + Null: "NO", + Key: "", + Default: {0x34, 0x2e, 0x39, 0x39}, + Extra: "", + Privileges: "select,insert,update,references", + Comment: "", + }, + { + Field: "length", + Type: "smallint(5) unsigned", + Collation: nil, + Null: "YES", + Key: "", + Default: nil, + Extra: "", + Privileges: "select,insert,update,references", + Comment: "", + }, + { + Field: "replacement_cost", + Type: "decimal(5,2)", + Collation: nil, + Null: "NO", + Key: "", + Default: {0x31, 0x39, 0x2e, 0x39, 0x39}, + Extra: "", + Privileges: "select,insert,update,references", + Comment: "", + }, + { + Field: "rating", + Type: "enum('G','PG','PG-13','R','NC-17')", + Collation: {0x75, 0x74, 0x66, 0x38, 0x5f, 0x67, 0x65, 0x6e, 0x65, 0x72, 0x61, 0x6c, 0x5f, 0x63, 0x69}, + Null: "YES", + Key: "", + Default: {0x47}, + Extra: "", + Privileges: "select,insert,update,references", + Comment: "", + }, + { + Field: "special_features", + Type: "set('Trailers','Commentaries','Deleted Scenes','Behind the Scenes')", + Collation: {0x75, 0x74, 0x66, 0x38, 0x5f, 0x67, 0x65, 0x6e, 0x65, 0x72, 0x61, 0x6c, 0x5f, 0x63, 0x69}, + Null: "YES", + Key: "", + Default: nil, + Extra: "", + Privileges: "select,insert,update,references", + Comment: "", + }, + { + Field: "last_update", + Type: "timestamp", + Collation: nil, + Null: "NO", + Key: "", + Default: {0x43, 0x55, 0x52, 0x52, 0x45, 0x4e, 0x54, 0x5f, 0x54, 0x49, 0x4d, 0x45, 0x53, 0x54, 0x41, 0x4d, 0x50}, + Extra: "on update CURRENT_TIMESTAMP", + Privileges: "select,insert,update,references", + Comment: "", + }, + }, +} diff --git a/database/testdata/TestShowCreateTable.golden b/database/testdata/TestShowCreateTable.golden new file mode 100644 index 0000000000000000000000000000000000000000..377ef8d32fc8da8685f4b7e4268ba613371b489a --- /dev/null +++ b/database/testdata/TestShowCreateTable.golden @@ -0,0 +1,34 @@ +CREATE TABLE `film` ( + `film_id` smallint(5) unsigned NOT NULL AUTO_INCREMENT, + `title` varchar(255) NOT NULL, + `description` text, + `release_year` year(4) DEFAULT NULL, + `language_id` tinyint(3) unsigned NOT NULL, + `original_language_id` tinyint(3) unsigned DEFAULT NULL, + `rental_duration` tinyint(3) unsigned NOT NULL DEFAULT '3', + `rental_rate` decimal(4,2) NOT NULL DEFAULT '4.99', + `length` smallint(5) unsigned DEFAULT NULL, + `replacement_cost` decimal(5,2) NOT NULL DEFAULT '19.99', + `rating` enum('G','PG','PG-13','R','NC-17') DEFAULT 'G', + `special_features` set('Trailers','Commentaries','Deleted Scenes','Behind the Scenes') DEFAULT NULL, + `last_update` timestamp NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, + PRIMARY KEY (`film_id`), + KEY `idx_title` (`title`), + KEY `idx_fk_language_id` (`language_id`), + KEY `idx_fk_original_language_id` (`original_language_id`) +) ENGINE=InnoDB AUTO_INCREMENT=1001 DEFAULT CHARSET=utf8 +&sqlparser.DDL{ + Action: "create", + FromTables: nil, + ToTables: nil, + Table: sqlparser.TableName{ + Name: sqlparser.TableIdent{v:"film"}, + Qualifier: sqlparser.TableIdent{}, + }, + IfExists: false, + TableSpec: (*sqlparser.TableSpec)(nil), + OptLike: (*sqlparser.OptLike)(nil), + PartitionSpec: (*sqlparser.PartitionSpec)(nil), + VindexSpec: (*sqlparser.VindexSpec)(nil), + VindexCols: nil, +} nil diff --git a/database/testdata/TestShowIndex.golden b/database/testdata/TestShowIndex.golden new file mode 100644 index 0000000000000000000000000000000000000000..48a9f8cb5e40db045185f7f9cf87a088dd21eb8b --- /dev/null +++ b/database/testdata/TestShowIndex.golden @@ -0,0 +1,12 @@ +&database.TableIndexInfo{ + TableName: "film", + Rows: { + {Table:"film", NonUnique:0, KeyName:"PRIMARY", SeqInIndex:1, ColumnName:"film_id", Collation:"A", Cardinality:1000, SubPart:0, Packed:0, Null:"", IndexType:"", Comment:"", IndexComment:"", Visible:""}, + {Table:"film", NonUnique:1, KeyName:"idx_title", SeqInIndex:1, ColumnName:"title", Collation:"A", Cardinality:1000, SubPart:0, Packed:0, Null:"", IndexType:"", Comment:"", IndexComment:"", Visible:""}, + {Table:"film", NonUnique:1, KeyName:"idx_fk_language_id", SeqInIndex:1, ColumnName:"language_id", Collation:"A", Cardinality:1, SubPart:0, Packed:0, Null:"", IndexType:"", Comment:"", IndexComment:"", Visible:""}, + {Table:"film", NonUnique:1, KeyName:"idx_fk_original_language_id", SeqInIndex:1, ColumnName:"original_language_id", Collation:"A", Cardinality:1, SubPart:0, Packed:0, Null:"", IndexType:"", Comment:"", IndexComment:"", Visible:""}, + }, +} +[]database.TableIndexRow{ + {Table:"film", NonUnique:1, KeyName:"idx_title", SeqInIndex:1, ColumnName:"title", Collation:"A", Cardinality:1000, SubPart:0, Packed:0, Null:"", IndexType:"", Comment:"", IndexComment:"", Visible:""}, +} diff --git a/database/testdata/TestShowReference.golden b/database/testdata/TestShowReference.golden new file mode 100644 index 0000000000000000000000000000000000000000..b9c79ef096e00ae871530048b68b082fb09515ef --- /dev/null +++ b/database/testdata/TestShowReference.golden @@ -0,0 +1,4 @@ +[]database.ReferenceValue{ + {ReferencedTableSchema:"sakila", ReferencedTableName:"language", TableSchema:"sakila", TableName:"film", ConstraintName:"fk_film_language"}, + {ReferencedTableSchema:"sakila", ReferencedTableName:"language", TableSchema:"sakila", TableName:"film", ConstraintName:"fk_film_language_original"}, +} diff --git a/database/testdata/TestShowTables.golden b/database/testdata/TestShowTables.golden new file mode 100644 index 0000000000000000000000000000000000000000..27a6ad5d1f84e6e101e9dfec5f3baf6fccfe182a --- /dev/null +++ b/database/testdata/TestShowTables.golden @@ -0,0 +1,23 @@ +actor +actor_info +address +category +city +country +customer +customer_list +film +film_actor +film_category +film_list +film_text +inventory +language +nicer_but_slower_film_list +payment +rental +sales_by_film_category +sales_by_store +staff +staff_list +store diff --git a/database/testdata/TestSource.sql b/database/testdata/TestSource.sql deleted file mode 100644 index 2bba4d129a6a838eb13a07583d2ca0eb013843ce..0000000000000000000000000000000000000000 --- a/database/testdata/TestSource.sql +++ /dev/null @@ -1,2 +0,0 @@ -select 1; -select 1; diff --git a/database/testdata/TestTrace.golden b/database/testdata/TestTrace.golden new file mode 100644 index 0000000000000000000000000000000000000000..b76661057d72c5ae3ff5a60adb1d0342fcde21ce --- /dev/null +++ b/database/testdata/TestTrace.golden @@ -0,0 +1,3 @@ +[]database.TraceRow{ + {Query:"explain select 1", Trace:"{\n \"steps\": [\n {\n \"join_preparation\": {\n \"select#\": 1,\n \"steps\": [\n {\n \"expanded_query\": \"/* select#1 */ select 1 AS `1`\"\n }\n ]\n }\n },\n {\n \"join_optimization\": {\n \"select#\": 1,\n \"steps\": [\n ]\n }\n },\n {\n \"join_explain\": {\n \"select#\": 1,\n \"steps\": [\n ]\n }\n }\n ]\n}", MissingBytesBeyondMaxMemSize:0, InsufficientPrivileges:0}, +} diff --git a/database/trace.go b/database/trace.go index 0ba611aba68fc5f62bc41aa2518f7011b763da27..7c2d2b5a1fe58a642d134848a1309856fbdfef3f 100644 --- a/database/trace.go +++ b/database/trace.go @@ -19,12 +19,9 @@ package database import ( "errors" "fmt" - "io" + "github.com/XiaoMi/soar/common" "regexp" "strings" - "time" - - "github.com/XiaoMi/soar/common" "vitess.io/vitess/go/vt/sqlparser" ) @@ -43,10 +40,11 @@ type TraceRow struct { } // Trace 执行SQL,并对其Trace -func (db *Connector) Trace(sql string, params ...interface{}) (*QueryResult, error) { +func (db *Connector) Trace(sql string, params ...interface{}) ([]TraceRow, error) { common.Log.Debug("Trace SQL: %s", sql) + var rows []TraceRow if common.Config.TestDSN.Version < 50600 { - return nil, errors.New("version < 5.6, not support trace") + return rows, errors.New("version < 5.6, not support trace") } // 过滤不需要 Trace 的 SQL @@ -55,98 +53,71 @@ func (db *Connector) Trace(sql string, params ...interface{}) (*QueryResult, err sql = "explain " + sql case sqlparser.EXPLAIN: default: - return nil, errors.New("no need trace") + return rows, errors.New("no need trace") } // 测试环境如果检查是关闭的,则SQL不会被执行 if common.Config.TestDSN.Disable { - return nil, errors.New("Dsn Disable") + return rows, errors.New("dsn is disable") } // 数据库安全性检查:如果 Connector 的 IP 端口与 TEST 环境不一致,则启用SQL白名单 // 不在白名单中的SQL不允许执行 // 执行环境与test环境不相同 if db.Addr != common.Config.TestDSN.Addr && db.dangerousQuery(sql) { - return nil, fmt.Errorf("query Execution Deny: Execute SQL with DSN(%s/%s) '%s'", + return rows, fmt.Errorf("query Execution Deny: Execute SQL with DSN(%s/%s) '%s'", db.Addr, db.Database, fmt.Sprintf(sql, params...)) } common.Log.Debug("Execute SQL with DSN(%s/%s) : %s", db.Addr, db.Database, sql) - conn := db.NewConnection() - - // 设置SQL连接超时时间 - conn.SetTimeout(time.Duration(common.Config.ConnTimeOut) * time.Second) - defer conn.Close() - err := conn.Connect() + conn, err := db.NewConnection() if err != nil { - return nil, err + return rows, err } + defer conn.Close() - // 添加SQL执行超时限制 - ch := make(chan QueryResult, 1) - go func() { - // 开启Trace - common.Log.Debug("SET SESSION OPTIMIZER_TRACE='enabled=on'") - _, _, err = conn.Query("SET SESSION OPTIMIZER_TRACE='enabled=on'") - common.LogIfError(err, "") - - // 执行SQL,抛弃返回结果 - result, err := conn.Start(sql, params...) - if err != nil { - ch <- QueryResult{ - Error: err, - } - return - } - row := result.MakeRow() - for { - err = result.ScanRow(row) - if err == io.EOF { - break - } - } + // 开启Trace + common.Log.Debug("SET SESSION OPTIMIZER_TRACE='enabled=on'") + trx, err := conn.Begin() + if err != nil { + return rows, err + } + defer trx.Rollback() + _, err = trx.Query("SET SESSION OPTIMIZER_TRACE='enabled=on'") + common.LogIfError(err, "") - // 返回Trace结果 - res := QueryResult{} - res.Rows, res.Result, res.Error = conn.Query("SELECT * FROM information_schema.OPTIMIZER_TRACE") + // 执行SQL,抛弃返回结果 + tmpRes, err := trx.Query(sql, params...) + if err != nil { + return rows, err + } + for tmpRes.Next() { + continue + } - // 关闭Trace - common.Log.Debug("SET SESSION OPTIMIZER_TRACE='enabled=off'") - _, _, err = conn.Query("SET SESSION OPTIMIZER_TRACE='enabled=off'") + // 返回Trace结果 + res, err := trx.Query("SELECT * FROM information_schema.OPTIMIZER_TRACE") + for res.Next() { + var traceRow TraceRow + err = res.Scan(&traceRow.Query, &traceRow.Trace, &traceRow.MissingBytesBeyondMaxMemSize, &traceRow.InsufficientPrivileges) if err != nil { - fmt.Println(err.Error()) + common.LogIfError(err, "") } - ch <- res - }() - - select { - case res := <-ch: - return &res, res.Error - case <-time.After(time.Duration(common.Config.QueryTimeOut) * time.Second): - return nil, errors.New("query execution timeout") + rows = append(rows, traceRow) } -} -// getTrace 获取trace信息 -func getTrace(res *QueryResult) Trace { - var rows []TraceRow - for _, row := range res.Rows { - rows = append(rows, TraceRow{ - Query: row.Str(0), - Trace: row.Str(1), - MissingBytesBeyondMaxMemSize: row.Int(2), - InsufficientPrivileges: row.Int(3), - }) - } - return Trace{Rows: rows} + // 关闭Trace + common.Log.Debug("SET SESSION OPTIMIZER_TRACE='enabled=off'") + _, err = trx.Query("SET SESSION OPTIMIZER_TRACE='enabled=off'") + common.LogIfError(err, "") + return rows, err } // FormatTrace 格式化输出Trace信息 -func FormatTrace(res *QueryResult) string { +func FormatTrace(rows []TraceRow) string { explainReg := regexp.MustCompile(`(?i)^explain\s+`) - trace := getTrace(res) str := []string{""} - for _, row := range trace.Rows { + for _, row := range rows { str = append(str, "```sql") sql := explainReg.ReplaceAllString(row.Query, "") str = append(str, sql) diff --git a/database/trace_test.go b/database/trace_test.go index 8dea2d7ffa8270facee7c4003566625e5ac04221..535d3f58cf948fbb54c1a8407a63bb3317225753 100644 --- a/database/trace_test.go +++ b/database/trace_test.go @@ -17,7 +17,6 @@ package database import ( - "flag" "testing" "github.com/XiaoMi/soar/common" @@ -25,34 +24,30 @@ import ( "github.com/kr/pretty" ) -var update = flag.Bool("update", false, "update .golden files") - func TestTrace(t *testing.T) { - common.Config.QueryTimeOut = 1 res, err := connTest.Trace("select 1") - if err == nil { - common.GoldenDiff(func() { - pretty.Println(res) - }, t.Name(), update) - } else { + if err != nil { + t.Error(err) + } + + err = common.GoldenDiff(func() { + pretty.Println(res) + }, t.Name(), update) + if err != nil { t.Error(err) } } func TestFormatTrace(t *testing.T) { res, err := connTest.Trace("select 1") - if err == nil { - pretty.Println(FormatTrace(res)) - } else { + if err != nil { t.Error(err) } -} -func TestGetTrace(t *testing.T) { - res, err := connTest.Trace("select 1") - if err == nil { - pretty.Println(getTrace(res)) - } else { + err = common.GoldenDiff(func() { + pretty.Println(FormatTrace(res)) + }, t.Name(), update) + if err != nil { t.Error(err) } } diff --git a/env/env.go b/env/env.go index 9b3f6ea0910b441df8a4b121d1a3ba7f0ed8e2b2..cc21a03ccae7bc946621569ebed1b6c9ee33a7a3 100644 --- a/env/env.go +++ b/env/env.go @@ -161,8 +161,9 @@ func (ve *VirtualEnv) CleanupTestDatabase() { // TODO: 1 hour should be config-able minHour := 1 - for _, row := range dbs.Rows { - testDatabase := row.Str(0) + for dbs.Rows.Next() { + var testDatabase string + dbs.Rows.Scan(&testDatabase) // test temporary database format `optimizer_YYMMDDHHmmss_randomString(16)` if len(testDatabase) != 39 { common.Log.Debug("CleanupTestDatabase by pass %s", testDatabase) @@ -218,7 +219,7 @@ func (ve *VirtualEnv) BuildVirtualEnv(rEnv *database.Connector, SQLs ...string) meta := make(map[string]*common.DB) for _, sql := range SQLs { - common.Log.Debug("BuildVirtualEnv Database&Table Mapping, SQL: %s", sql) + common.Log.Debug("BuildVirtualEnv Database&TableName Mapping, SQL: %s", sql) stmt, err = sqlparser.Parse(sql) if err != nil { @@ -320,7 +321,7 @@ func (ve *VirtualEnv) BuildVirtualEnv(rEnv *database.Connector, SQLs ...string) } // 如果是视图,解析语句 - if len(tbStatus.Rows) > 0 && tbStatus.Rows[0].Comment == "VIEW" { + if len(tbStatus.Rows) > 0 && string(tbStatus.Rows[0].Comment) == "VIEW" { tmpEnv.Database = db var viewDDL string viewDDL, err = tmpEnv.ShowCreateTable(tb.TableName) @@ -418,7 +419,7 @@ func (ve VirtualEnv) createTable(rEnv database.Connector, dbName, tbName string) return nil } - common.Log.Debug("createTable, Database: %s, Table: %s", dbName, tbName) + common.Log.Debug("createTable, Database: %s, TableName: %s", dbName, tbName) // TODO:查看是否有外键关联(done),对外键的支持 (未解决循环依赖的问题) @@ -506,9 +507,9 @@ func (ve *VirtualEnv) GenTableColumns(meta common.Meta) common.TableColumns { DB: dbName, Table: tb.TableName, DataType: colInfo.Type, - Character: colInfo.Collation, + Character: string(colInfo.Collation), Key: colInfo.Key, - Default: colInfo.Default, + Default: string(colInfo.Default), Extra: colInfo.Extra, Comment: colInfo.Comment, Privileges: colInfo.Privileges, @@ -525,9 +526,9 @@ func (ve *VirtualEnv) GenTableColumns(meta common.Meta) common.TableColumns { col.DB = dbName col.Table = tb.TableName col.DataType = colInfo.Type - col.Character = colInfo.Collation + col.Character = string(colInfo.Collation) col.Key = colInfo.Key - col.Default = colInfo.Default + col.Default = string(colInfo.Default) col.Extra = colInfo.Extra col.Comment = colInfo.Comment col.Privileges = colInfo.Privileges