提交 3991dde4 编写于 作者: martianzhang's avatar martianzhang

change mymysql into go-sql-driver/mysql

上级 a60a145f
...@@ -181,7 +181,7 @@ docker: ...@@ -181,7 +181,7 @@ docker:
.PHONY: connect .PHONY: connect
connect: connect:
mysql -h 127.0.0.1 -u root -p1tIsB1g3rt -c mysql -h 127.0.0.1 -u root -p1tIsB1g3rt sakila -c
.PHONY: main_test .PHONY: main_test
main_test: install main_test: install
......
...@@ -1990,7 +1990,7 @@ func (idxAdv *IndexAdvisor) RuleUpdatePrimaryKey() Rule { ...@@ -1990,7 +1990,7 @@ func (idxAdv *IndexAdvisor) RuleUpdatePrimaryKey() Rule {
if idxMeta == nil { if idxMeta == nil {
return rule return rule
} }
for _, idx := range idxMeta.IdxRows { for _, idx := range idxMeta.Rows {
if idx.KeyName == "PRIMARY" { if idx.KeyName == "PRIMARY" {
if col.Name == idx.ColumnName { if col.Name == idx.ColumnName {
rule = HeuristicRules["CLA.016"] rule = HeuristicRules["CLA.016"]
......
...@@ -914,7 +914,7 @@ func (idxAdv *IndexAdvisor) calcCardinality(cols []*common.Column) []*common.Col ...@@ -914,7 +914,7 @@ func (idxAdv *IndexAdvisor) calcCardinality(cols []*common.Column) []*common.Col
// 检查对应列是否为主键或单列唯一索引,如果满足直接返回1,不再重复计算,提高效率 // 检查对应列是否为主键或单列唯一索引,如果满足直接返回1,不再重复计算,提高效率
// 多列复合唯一索引不能跳过计算,单列普通索引不能跳过计算 // 多列复合唯一索引不能跳过计算,单列普通索引不能跳过计算
for _, index := range idxAdv.IndexMeta[realDB][col.Table].IdxRows { for _, index := range idxAdv.IndexMeta[realDB][col.Table].Rows {
// 根据索引的名称判断该索引包含的列数,列数大于1即为复合索引 // 根据索引的名称判断该索引包含的列数,列数大于1即为复合索引
columnCount := len(idxAdv.IndexMeta[realDB][col.Table].FindIndex(database.IndexKeyName, index.KeyName)) columnCount := len(idxAdv.IndexMeta[realDB][col.Table].FindIndex(database.IndexKeyName, index.KeyName))
if col.Name == index.ColumnName { if col.Name == index.ColumnName {
...@@ -1079,7 +1079,7 @@ func DuplicateKeyChecker(conn *database.Connector, databases ...string) map[stri ...@@ -1079,7 +1079,7 @@ func DuplicateKeyChecker(conn *database.Connector, databases ...string) map[stri
} }
// 枚举所有的索引信息,提取用到的列 // 枚举所有的索引信息,提取用到的列
for _, idx := range idxInfo.IdxRows { for _, idx := range idxInfo.Rows {
if _, ok := idxMap[idx.KeyName]; !ok { if _, ok := idxMap[idx.KeyName]; !ok {
idxMap[idx.KeyName] = make([]*common.Column, 0) idxMap[idx.KeyName] = make([]*common.Column, 0)
for _, col := range idxInfo.FindIndex(database.IndexKeyName, idx.KeyName) { for _, col := range idxInfo.FindIndex(database.IndexKeyName, idx.KeyName) {
......
...@@ -102,7 +102,7 @@ type Rule struct { ...@@ -102,7 +102,7 @@ type Rule struct {
* SEC Security * SEC Security
* STA Standard * STA Standard
* SUB Subquery * SUB Subquery
* TBL Table * TBL TableName
* TRA Trace, 由trace模块给 * TRA Trace, 由trace模块给
*/ */
......
...@@ -67,7 +67,7 @@ func init() { ...@@ -67,7 +67,7 @@ func init() {
"SELECT * FROM customer WHERE address_id in (224,510) ORDER BY last_name;", // INDEX(address_id) "SELECT * FROM customer WHERE address_id in (224,510) ORDER BY last_name;", // INDEX(address_id)
"SELECT * FROM film WHERE release_year = 2016 AND length != 1 ORDER BY title;", // INDEX(`release_year`, `length`, `title`) "SELECT * FROM film WHERE release_year = 2016 AND length != 1 ORDER BY title;", // INDEX(`release_year`, `length`, `title`)
// "Covering" IdxRows // "Covering" Rows
"SELECT title FROM film WHERE release_year = 1995;", // INDEX(release_year, title)", "SELECT title FROM film WHERE release_year = 1995;", // INDEX(release_year, title)",
"SELECT title, replacement_cost FROM film WHERE language_id = 5 AND length = 70;", // INDEX(language_id, length, title, replacement_cos film ), title, replacement_cost顺序无关,language_id, length顺序视散粒度情况. "SELECT title, replacement_cost FROM film WHERE language_id = 5 AND length = 70;", // INDEX(language_id, length, title, replacement_cos film ), title, replacement_cost顺序无关,language_id, length顺序视散粒度情况.
"SELECT title FROM film WHERE language_id > 5 AND length > 70;", // INDEX(language_id, length, title) language_id or length first (that's as far as the Algorithm goes), then the other two fields afterwards. "SELECT title FROM film WHERE language_id > 5 AND length > 70;", // INDEX(language_id, length, title) language_id or length first (that's as far as the Algorithm goes), then the other two fields afterwards.
......
...@@ -59,8 +59,6 @@ type Configuration struct { ...@@ -59,8 +59,6 @@ type Configuration struct {
Profiling bool `yaml:"profiling"` // 在开启数据采样的情况下,在测试环境执行进行profile Profiling bool `yaml:"profiling"` // 在开启数据采样的情况下,在测试环境执行进行profile
Trace bool `yaml:"trace"` // 在开启数据采样的情况下,在测试环境执行进行Trace Trace bool `yaml:"trace"` // 在开启数据采样的情况下,在测试环境执行进行Trace
Explain bool `yaml:"explain"` // Explain开关 Explain bool `yaml:"explain"` // Explain开关
ConnTimeOut int `yaml:"conn-time-out"` // 数据库连接超时时间,单位秒
QueryTimeOut int `yaml:"query-time-out"` // 数据库SQL执行超时时间,单位秒
Delimiter string `yaml:"delimiter"` // SQL分隔符 Delimiter string `yaml:"delimiter"` // SQL分隔符
// +++++++++++++++日志相关+++++++++++++++++ // +++++++++++++++日志相关+++++++++++++++++
...@@ -97,8 +95,8 @@ type Configuration struct { ...@@ -97,8 +95,8 @@ type Configuration struct {
MaxInCount int `yaml:"max-in-count"` // IN()最大数量 MaxInCount int `yaml:"max-in-count"` // IN()最大数量
MaxIdxBytesPerColumn int `yaml:"max-index-bytes-percolumn"` // 索引中单列最大字节数,默认767 MaxIdxBytesPerColumn int `yaml:"max-index-bytes-percolumn"` // 索引中单列最大字节数,默认767
MaxIdxBytes int `yaml:"max-index-bytes"` // 索引总长度限制,默认3072 MaxIdxBytes int `yaml:"max-index-bytes"` // 索引总长度限制,默认3072
TableAllowCharsets []string `yaml:"table-allow-charsets"` // Table 允许使用的 DEFAULT CHARSET TableAllowCharsets []string `yaml:"table-allow-charsets"` // TableName 允许使用的 DEFAULT CHARSET
TableAllowEngines []string `yaml:"table-allow-engines"` // Table 允许使用的 Engine TableAllowEngines []string `yaml:"table-allow-engines"` // TableName 允许使用的 Engine
MaxIdxCount int `yaml:"max-index-count"` // 单张表允许最多索引数 MaxIdxCount int `yaml:"max-index-count"` // 单张表允许最多索引数
MaxColCount int `yaml:"max-column-count"` // 单张表允许最大列数 MaxColCount int `yaml:"max-column-count"` // 单张表允许最大列数
MaxValueCount int `yaml:"max-value-count"` // INSERT/REPLACE 单次允许批量写入的行数 MaxValueCount int `yaml:"max-value-count"` // INSERT/REPLACE 单次允许批量写入的行数
...@@ -135,12 +133,14 @@ type Configuration struct { ...@@ -135,12 +133,14 @@ type Configuration struct {
// Config 默认设置 // Config 默认设置
var Config = &Configuration{ var Config = &Configuration{
OnlineDSN: &dsn{ OnlineDSN: &dsn{
Net: "tcp",
Schema: "information_schema", Schema: "information_schema",
Charset: "utf8mb4", Charset: "utf8mb4",
Disable: true, Disable: true,
Version: 99999, Version: 99999,
}, },
TestDSN: &dsn{ TestDSN: &dsn{
Net: "tcp",
Schema: "information_schema", Schema: "information_schema",
Charset: "utf8mb4", Charset: "utf8mb4",
Disable: true, Disable: true,
...@@ -156,8 +156,6 @@ var Config = &Configuration{ ...@@ -156,8 +156,6 @@ var Config = &Configuration{
Profiling: false, Profiling: false,
Trace: false, Trace: false,
Explain: true, Explain: true,
ConnTimeOut: 3,
QueryTimeOut: 30,
Delimiter: ";", Delimiter: ";",
MaxJoinTableCount: 5, MaxJoinTableCount: 5,
...@@ -227,6 +225,7 @@ var Config = &Configuration{ ...@@ -227,6 +225,7 @@ var Config = &Configuration{
} }
type dsn struct { type dsn struct {
Net string `yaml:"net"`
Addr string `yaml:"addr"` Addr string `yaml:"addr"`
Schema string `yaml:"schema"` Schema string `yaml:"schema"`
...@@ -236,6 +235,10 @@ type dsn struct { ...@@ -236,6 +235,10 @@ type dsn struct {
Charset string `yaml:"charset"` Charset string `yaml:"charset"`
Disable bool `yaml:"disable"` Disable bool `yaml:"disable"`
Timeout int `yaml:"timeout"`
ReadTimeout int `yaml:"read-timeout"`
WriteTimeout int `yaml:"write-timeout"`
Version int `yaml:"-"` // 版本自动检查,不可配置 Version int `yaml:"-"` // 版本自动检查,不可配置
} }
...@@ -502,8 +505,6 @@ func readCmdFlags() error { ...@@ -502,8 +505,6 @@ func readCmdFlags() error {
explain := flag.Bool("explain", Config.Explain, "Explain, 是否开启Explain执行计划分析") explain := flag.Bool("explain", Config.Explain, "Explain, 是否开启Explain执行计划分析")
sampling := flag.Bool("sampling", Config.Sampling, "Sampling, 数据采样开关") sampling := flag.Bool("sampling", Config.Sampling, "Sampling, 数据采样开关")
samplingStatisticTarget := flag.Int("sampling-statistic-target", Config.SamplingStatisticTarget, "SamplingStatisticTarget, 数据采样因子,对应 PostgreSQL 的 default_statistics_target") samplingStatisticTarget := flag.Int("sampling-statistic-target", Config.SamplingStatisticTarget, "SamplingStatisticTarget, 数据采样因子,对应 PostgreSQL 的 default_statistics_target")
connTimeOut := flag.Int("conn-time-out", Config.ConnTimeOut, "ConnTimeOut, 数据库连接超时时间,单位秒")
queryTimeOut := flag.Int("query-time-out", Config.QueryTimeOut, "QueryTimeOut, 数据库SQL执行超时时间,单位秒")
delimiter := flag.String("delimiter", Config.Delimiter, "Delimiter, SQL分隔符") delimiter := flag.String("delimiter", Config.Delimiter, "Delimiter, SQL分隔符")
// +++++++++++++++日志相关+++++++++++++++++ // +++++++++++++++日志相关+++++++++++++++++
logLevel := flag.Int("log-level", Config.LogLevel, "LogLevel, 日志级别, [0:Emergency, 1:Alert, 2:Critical, 3:Error, 4:Warning, 5:Notice, 6:Informational, 7:Debug]") logLevel := flag.Int("log-level", Config.LogLevel, "LogLevel, 日志级别, [0:Emergency, 1:Alert, 2:Critical, 3:Error, 4:Warning, 5:Notice, 6:Informational, 7:Debug]")
...@@ -583,8 +584,6 @@ func readCmdFlags() error { ...@@ -583,8 +584,6 @@ func readCmdFlags() error {
Config.Explain = *explain Config.Explain = *explain
Config.Sampling = *sampling Config.Sampling = *sampling
Config.SamplingStatisticTarget = *samplingStatisticTarget Config.SamplingStatisticTarget = *samplingStatisticTarget
Config.ConnTimeOut = *connTimeOut
Config.QueryTimeOut = *queryTimeOut
Config.LogLevel = *logLevel Config.LogLevel = *logLevel
if strings.HasPrefix(*logOutput, "/") { if strings.HasPrefix(*logOutput, "/") {
......
...@@ -27,7 +27,7 @@ type Meta map[string]*DB ...@@ -27,7 +27,7 @@ type Meta map[string]*DB
// DB 数据库相关的结构体 // DB 数据库相关的结构体
type DB struct { type DB struct {
Name string Name string
Table map[string]*Table // ['table_name']*Table Table map[string]*Table // ['table_name']*TableName
} }
// NewDB 用于初始化*DB // NewDB 用于初始化*DB
...@@ -38,14 +38,14 @@ func NewDB(db string) *DB { ...@@ -38,14 +38,14 @@ func NewDB(db string) *DB {
} }
} }
// Table 含有表的属性 // TableName 含有表的属性
type Table struct { type Table struct {
TableName string TableName string
TableAliases []string TableAliases []string
Column map[string]*Column Column map[string]*Column
} }
// NewTable 初始化*Table // NewTable 初始化*TableName
func NewTable(tb string) *Table { func NewTable(tb string) *Table {
return &Table{ return &Table{
TableName: tb, TableName: tb,
......
...@@ -176,9 +176,9 @@ $$ ...@@ -176,9 +176,9 @@ $$
<p>Typora support <a href="http://jekyllrb.com/docs/frontmatter/">YAML Front Matter</a> now. Input <code>---</code> at the top of the article and then press <code>Enter</code> will introduce one. Or insert one metadata block from the menu.</p> <p>Typora support <a href="http://jekyllrb.com/docs/frontmatter/">YAML Front Matter</a> now. Input <code>---</code> at the top of the article and then press <code>Enter</code> will introduce one. Or insert one metadata block from the menu.</p>
<h3>Table of Contents (TOC)</h3> <h3>TableName of Contents (TOC)</h3>
<p>Input <code>[toc]</code> then press <code>Return</code> key will create a section for “Table of Contents” extracting all headers from one’s writing, its contents will be updated automatically.</p> <p>Input <code>[toc]</code> then press <code>Return</code> key will create a section for “TableName of Contents” extracting all headers from one’s writing, its contents will be updated automatically.</p>
<h3>Diagrams (Sequence, Flowchart and Mermaid)</h3> <h3>Diagrams (Sequence, Flowchart and Mermaid)</h3>
......
...@@ -186,9 +186,9 @@ Input `***` or `---` on a blank line and press `return` will draw a horizontal l ...@@ -186,9 +186,9 @@ Input `***` or `---` on a blank line and press `return` will draw a horizontal l
Typora support [YAML Front Matter](http://jekyllrb.com/docs/frontmatter/) now. Input `---` at the top of the article and then press `Enter` will introduce one. Or insert one metadata block from the menu. Typora support [YAML Front Matter](http://jekyllrb.com/docs/frontmatter/) now. Input `---` at the top of the article and then press `Enter` will introduce one. Or insert one metadata block from the menu.
### Table of Contents (TOC) ### TableName of Contents (TOC)
Input `[toc]` then press `Return` key will create a section for “Table of Contents” extracting all headers from one’s writing, its contents will be updated automatically. Input `[toc]` then press `Return` key will create a section for “TableName of Contents” extracting all headers from one’s writing, its contents will be updated automatically.
### Diagrams (Sequence, Flowchart and Mermaid) ### Diagrams (Sequence, Flowchart and Mermaid)
......
...@@ -267,6 +267,7 @@ var ExplainKeyWords = []string{ ...@@ -267,6 +267,7 @@ var ExplainKeyWords = []string{
"using_temporary_table", "using_temporary_table",
} }
/*
// ExplainColumnIndent EXPLAIN表头 // ExplainColumnIndent EXPLAIN表头
var ExplainColumnIndent = map[string]string{ var ExplainColumnIndent = map[string]string{
"id": "id为SELECT的标识符. 它是在SELECT查询中的顺序编号. 如果这一行表示其他行的union结果, 这个值可以为空. 在这种情况下, table列会显示为形如<union M,N>, 表示它是id为M和N的查询行的联合结果.", "id": "id为SELECT的标识符. 它是在SELECT查询中的顺序编号. 如果这一行表示其他行的union结果, 这个值可以为空. 在这种情况下, table列会显示为形如<union M,N>, 表示它是id为M和N的查询行的联合结果.",
...@@ -281,6 +282,7 @@ var ExplainColumnIndent = map[string]string{ ...@@ -281,6 +282,7 @@ var ExplainColumnIndent = map[string]string{
"filtered": "表示返回结果的行占需要读到的行(rows列的值)的百分比.", "filtered": "表示返回结果的行占需要读到的行(rows列的值)的百分比.",
"Extra": "该列显示MySQL在查询过程中的一些详细信息, MySQL查询优化器执行查询的过程中对查询计划的重要补充信息.", "Extra": "该列显示MySQL在查询过程中的一些详细信息, MySQL查询优化器执行查询的过程中对查询计划的重要补充信息.",
} }
*/
// ExplainSelectType EXPLAIN中SELECT TYPE会出现的类型 // ExplainSelectType EXPLAIN中SELECT TYPE会出现的类型
var ExplainSelectType = map[string]string{ var ExplainSelectType = map[string]string{
...@@ -555,14 +557,16 @@ func (db *Connector) explainAbleSQL(sql string) (string, error) { ...@@ -555,14 +557,16 @@ func (db *Connector) explainAbleSQL(sql string) (string, error) {
} }
// 执行explain请求,返回mysql.Result执行结果 // 执行explain请求,返回mysql.Result执行结果
func (db *Connector) executeExplain(sql string, explainType int, formatType int) (*QueryResult, error) { func (db *Connector) executeExplain(sql string, explainType int, formatType int) (QueryResult, error) {
var res QueryResult
var err error var err error
var explainQuery string
sql, err = db.explainAbleSQL(sql) sql, err = db.explainAbleSQL(sql)
if sql == "" { if sql == "" {
return nil, err return res, err
} }
// 5.6以上支持FORMAT=JSON // 5.6以上支持 FORMAT=JSON
explainFormat := "" explainFormat := ""
switch formatType { switch formatType {
case JSONFormatExplain: case JSONFormatExplain:
...@@ -570,22 +574,23 @@ func (db *Connector) executeExplain(sql string, explainType int, formatType int) ...@@ -570,22 +574,23 @@ func (db *Connector) executeExplain(sql string, explainType int, formatType int)
explainFormat = "FORMAT=JSON" explainFormat = "FORMAT=JSON"
} }
} }
// 执行explain
var res *QueryResult // 执行 explain
switch explainType { switch explainType {
case ExtendedExplainType: case ExtendedExplainType:
// 5.6以上extended关键字已经不推荐使用,8.0废弃了这个关键字 // 5.6以上extended关键字已经不推荐使用,8.0废弃了这个关键字
if common.Config.TestDSN.Version >= 50600 { if common.Config.TestDSN.Version >= 50600 {
res, err = db.Query("explain %s", sql) explainQuery = fmt.Sprintf("explain %s", sql)
} else { } else {
res, err = db.Query("explain extended %s", sql) explainQuery = fmt.Sprintf("explain extended %s", sql)
} }
case PartitionsExplainType: case PartitionsExplainType:
res, err = db.Query("explain partitions %s", sql) explainQuery = fmt.Sprintf("explain partitions %s", sql)
default: default:
res, err = db.Query("explain %s %s", explainFormat, sql) explainQuery = fmt.Sprintf("explain %s %s", explainFormat, sql)
} }
res, err = db.Query(explainQuery)
return res, err return res, err
} }
...@@ -928,76 +933,75 @@ func parseJSONExplainText(content string) (*ExplainJSON, error) { ...@@ -928,76 +933,75 @@ func parseJSONExplainText(content string) (*ExplainJSON, error) {
} }
// ParseExplainResult 分析 mysql 执行 explain 的结果,返回 ExplainInfo 结构化数据 // ParseExplainResult 分析 mysql 执行 explain 的结果,返回 ExplainInfo 结构化数据
func ParseExplainResult(res *QueryResult, formatType int) (exp *ExplainInfo, err error) { func ParseExplainResult(res QueryResult, formatType int) (exp *ExplainInfo, err error) {
exp = &ExplainInfo{ exp = &ExplainInfo{
ExplainFormat: formatType, ExplainFormat: formatType,
} }
// JSON 格式直接调用文本方式解析 // JSON 格式直接调用文本方式解析
if formatType == JSONFormatExplain { if formatType == JSONFormatExplain {
exp.ExplainJSON, err = parseJSONExplainText(res.Rows[0].Str(0)) if res.Rows.Next() {
var explainString string
err = res.Rows.Scan(&explainString)
exp.ExplainJSON, err = parseJSONExplainText(explainString)
}
return exp, err return exp, err
} }
// 生成表头 // Different MySQL version has different columns define
colIdx := make(map[int]string) var possibleKeys string
for i, f := range res.Result.Fields() { expRow := &ExplainRow{}
colIdx[i] = strings.ToLower(f.Name) explainFields := make([]interface{}, 0)
fields := map[string]interface{}{
"id": expRow.ID,
"select_type": expRow.SelectType,
"table": expRow.TableName,
"partitions": expRow.Partitions,
"type": expRow.AccessType,
"possible_keys": &possibleKeys,
"key": expRow.Key,
"key_len": expRow.KeyLen,
"ref": expRow.Ref,
"rows": expRow.Rows,
"filtered": expRow.Filtered,
"extra": expRow.Extra,
} }
cols, err := res.Rows.Columns()
common.LogIfError(err, "")
for _, col := range cols {
explainFields = append(explainFields, fields[col])
}
// 补全 ExplainRows // 补全 ExplainRows
var explainrows []*ExplainRow var explainRows []*ExplainRow
for _, row := range res.Rows {
expRow := &ExplainRow{Partitions: "NULL", Filtered: 0.00} for res.Rows.Next() {
// list 到 map 的转换 res.Rows.Scan(explainFields...)
for i := range row { expRow.PossibleKeys = strings.Split(possibleKeys, ",")
switch colIdx[i] {
case "id": // MySQL bug: https://bugs.mysql.com/bug.php?id=34124
expRow.ID = row.ForceInt(i) if expRow.Filtered > 100.00 {
case "select_type": expRow.Filtered = 100.00
expRow.SelectType = row.Str(i)
case "table":
expRow.TableName = row.Str(i)
if expRow.TableName == "" {
expRow.TableName = "NULL"
}
case "type":
expRow.AccessType = row.Str(i)
if expRow.AccessType == "" {
expRow.AccessType = "NULL"
}
expRow.Scalability = ExplainScalability[expRow.AccessType]
case "possible_keys":
expRow.PossibleKeys = strings.Split(row.Str(i), ",")
case "key":
expRow.Key = row.Str(i)
if expRow.Key == "" {
expRow.Key = "NULL"
}
case "key_len":
expRow.KeyLen = row.Str(i)
case "ref":
expRow.Ref = strings.Split(row.Str(i), ",")
case "rows":
expRow.Rows = row.ForceInt(i)
case "extra":
expRow.Extra = row.Str(i)
if expRow.Extra == "" {
expRow.Extra = "NULL"
}
case "filtered":
expRow.Filtered = row.ForceFloat(i)
// MySQL bug: https://bugs.mysql.com/bug.php?id=34124
if expRow.Filtered > 100.00 {
expRow.Filtered = 100.00
}
}
} }
explainrows = append(explainrows, expRow)
expRow.Scalability = ExplainScalability[expRow.AccessType]
explainRows = append(explainRows, expRow)
} }
exp.ExplainRows = explainrows exp.ExplainRows = explainRows
for _, w := range res.Warning {
// 'EXTENDED' is deprecated and will be removed in a future release. // check explain warning info
if w.Int(1) != 1681 { if common.Config.ShowWarnings {
exp.Warnings = append(exp.Warnings, &ExplainWarning{Level: w.Str(0), Code: w.Int(1), Message: w.Str(2)}) for res.Warning.Next() {
var expWarning *ExplainWarning
res.Warning.Scan(
expWarning.Level,
expWarning.Code,
expWarning.Message,
)
// 'EXTENDED' is deprecated and will be removed in a future release.
if expWarning.Code != 1681 {
exp.Warnings = append(exp.Warnings, expWarning)
}
} }
} }
...@@ -1009,7 +1013,6 @@ func ParseExplainResult(res *QueryResult, formatType int) (exp *ExplainInfo, err ...@@ -1009,7 +1013,6 @@ func ParseExplainResult(res *QueryResult, formatType int) (exp *ExplainInfo, err
// Explain 获取 SQL 的 explain 信息 // Explain 获取 SQL 的 explain 信息
func (db *Connector) Explain(sql string, explainType int, formatType int) (exp *ExplainInfo, err error) { func (db *Connector) Explain(sql string, explainType int, formatType int) (exp *ExplainInfo, err error) {
exp = &ExplainInfo{SQL: sql}
if explainType != TraditionalExplainType { if explainType != TraditionalExplainType {
formatType = TraditionalFormatExplain formatType = TraditionalFormatExplain
} }
...@@ -1025,12 +1028,16 @@ func (db *Connector) Explain(sql string, explainType int, formatType int) (exp * ...@@ -1025,12 +1028,16 @@ func (db *Connector) Explain(sql string, explainType int, formatType int) (exp *
// 执行EXPLAIN请求 // 执行EXPLAIN请求
res, err := db.executeExplain(sql, explainType, formatType) res, err := db.executeExplain(sql, explainType, formatType)
if err != nil || res == nil { if err != nil {
return exp, err return exp, err
} }
if res.Error != nil {
return exp, res.Error
}
// 解析mysql结果,输出ExplainInfo // 解析mysql结果,输出ExplainInfo
exp, err = ParseExplainResult(res, formatType) exp, err = ParseExplainResult(res, formatType)
exp.SQL = sql
return exp, err return exp, err
} }
......
...@@ -17,8 +17,6 @@ ...@@ -17,8 +17,6 @@
package database package database
import ( import (
"fmt"
"os"
"testing" "testing"
"github.com/XiaoMi/soar/common" "github.com/XiaoMi/soar/common"
...@@ -26,23 +24,6 @@ import ( ...@@ -26,23 +24,6 @@ import (
"github.com/kr/pretty" "github.com/kr/pretty"
) )
var connTest *Connector
func init() {
common.BaseDir = common.DevPath
common.ParseConfig("")
connTest = &Connector{
Addr: common.Config.OnlineDSN.Addr,
User: common.Config.OnlineDSN.User,
Pass: common.Config.OnlineDSN.Password,
Database: common.Config.OnlineDSN.Schema,
}
if _, err := connTest.Version(); err != nil {
common.Log.Critical("Test env Error: %v", err)
os.Exit(0)
}
}
var sqls = []string{ var sqls = []string{
`select * from city where country_id = 44;`, `select * from city where country_id = 44;`,
`select * from address where address2 is not null;`, `select * from address where address2 is not null;`,
...@@ -54,10 +35,10 @@ var sqls = []string{ ...@@ -54,10 +35,10 @@ var sqls = []string{
`select * from city where country_id > 31 and city = 'Aden';`, `select * from city where country_id > 31 and city = 'Aden';`,
`select * from address where address_id > 8 and city_id < 400 and district = 'Nantou';`, `select * from address where address_id > 8 and city_id < 400 and district = 'Nantou';`,
`select * from address where address_id > 8 and city_id < 400;`, `select * from address where address_id > 8 and city_id < 400;`,
`select * from actor where last_update='2006-02-15 04:34:33' and last_name='CHASE' group by first_name;`, `select first_name from actor where last_update='2006-02-15 04:34:33' and last_name='CHASE' group by first_name;`,
`select * from address where last_update >='2014-09-25 22:33:47' group by district;`, `select district from address where last_update >='2014-09-25 22:33:47' group by district;`,
`select * from address group by address,district;`, `select address from address group by address,district;`,
`select * from address where last_update='2014-09-25 22:30:27' group by district,(address_id+city_id);`, `select district from address where last_update='2014-09-25 22:30:27' group by district,(address_id+city_id);`,
`select * from customer where active=1 order by last_name limit 10;`, `select * from customer where active=1 order by last_name limit 10;`,
`select * from customer order by last_name limit 10;`, `select * from customer order by last_name limit 10;`,
`select * from customer where address_id > 224 order by address_id limit 10;`, `select * from customer where address_id > 224 order by address_id limit 10;`,
...@@ -2351,16 +2332,23 @@ possible_keys: idx_fk_country_id,idx_country_id_city,idx_all,idx_other ...@@ -2351,16 +2332,23 @@ possible_keys: idx_fk_country_id,idx_country_id_city,idx_all,idx_other
} }
func TestExplain(t *testing.T) { func TestExplain(t *testing.T) {
for _, sql := range sqls { // TraditionalFormatExplain
for idx, sql := range sqls {
exp, err := connTest.Explain(sql, TraditionalExplainType, TraditionalFormatExplain) exp, err := connTest.Explain(sql, TraditionalExplainType, TraditionalFormatExplain)
//exp, err := conn.Explain(sql, TraditionalExplainType, JSONFormatExplain)
fmt.Println("Old: ", sql)
fmt.Println("New: ", exp.SQL)
if err != nil { if err != nil {
fmt.Println(err) t.Error(err)
} }
pretty.Println("No.:", idx, "\nOld: ", sql, "\nNew: ", exp.SQL)
pretty.Println(exp)
}
// JSONFormatExplain
for idx, sql := range sqls {
exp, err := connTest.Explain(sql, TraditionalExplainType, JSONFormatExplain)
if err != nil {
t.Error(err)
}
pretty.Println("No.:", idx, "\nOld: ", sql, "\nNew: ", exp.SQL)
pretty.Println(exp) pretty.Println(exp)
fmt.Println()
} }
} }
...@@ -2400,6 +2388,7 @@ func TestPrintMarkdownExplainTable(t *testing.T) { ...@@ -2400,6 +2388,7 @@ func TestPrintMarkdownExplainTable(t *testing.T) {
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} }
err = common.GoldenDiff(func() { err = common.GoldenDiff(func() {
PrintMarkdownExplainTable(expInfo) PrintMarkdownExplainTable(expInfo)
}, t.Name(), update) }, t.Name(), update)
......
...@@ -17,21 +17,17 @@ ...@@ -17,21 +17,17 @@
package database package database
import ( import (
"database/sql"
"errors" "errors"
"fmt" "fmt"
"io/ioutil"
"os"
"regexp" "regexp"
"strconv" "strconv"
"strings" "strings"
"time"
"github.com/XiaoMi/soar/ast"
"github.com/XiaoMi/soar/common" "github.com/XiaoMi/soar/common"
"github.com/ziutek/mymysql/mysql" // for database/sql
// mymysql driver _ "github.com/go-sql-driver/mysql"
_ "github.com/ziutek/mymysql/native"
"vitess.io/vitess/go/vt/sqlparser" "vitess.io/vitess/go/vt/sqlparser"
) )
...@@ -42,81 +38,62 @@ type Connector struct { ...@@ -42,81 +38,62 @@ type Connector struct {
Pass string Pass string
Database string Database string
Charset string Charset string
Net string
} }
// QueryResult 数据库查询返回值 // QueryResult 数据库查询返回值
type QueryResult struct { type QueryResult struct {
Rows []mysql.Row Rows *sql.Rows
Result mysql.Result
Error error Error error
Warning []mysql.Row Warning *sql.Rows
QueryCost float64 QueryCost float64
} }
// NewConnection 创建新连接 // NewConnection 创建新连接
func (db *Connector) NewConnection() mysql.Conn { func (db *Connector) NewConnection() (*sql.DB, error) {
return mysql.New("tcp", "", db.Addr, db.User, db.Pass, db.Database) dsn := fmt.Sprintf("%s:%s@%s(%s)/%s?parseTime=true&charset=%s", db.User, db.Pass, db.Net, db.Addr, db.Database, db.Charset)
return sql.Open("mysql", dsn)
} }
// Query 执行SQL // Query 执行SQL
func (db *Connector) Query(sql string, params ...interface{}) (*QueryResult, error) { func (db *Connector) Query(sql string, params ...interface{}) (QueryResult, error) {
var res QueryResult
// 测试环境如果检查是关闭的,则SQL不会被执行 // 测试环境如果检查是关闭的,则SQL不会被执行
if common.Config.TestDSN.Disable { if common.Config.TestDSN.Disable {
return nil, errors.New("Dsn Disable") return res, errors.New("dsn is disable")
} }
// 数据库安全性检查:如果 Connector 的 IP 端口与 TEST 环境不一致,则启用SQL白名单 // 数据库安全性检查:如果 Connector 的 IP 端口与 TEST 环境不一致,则启用SQL白名单
// 不在白名单中的SQL不允许执行 // 不在白名单中的SQL不允许执行
// 执行环境与test环境不相同 // 执行环境与test环境不相同
if db.Addr != common.Config.TestDSN.Addr && db.dangerousQuery(sql) { if db.Addr != common.Config.TestDSN.Addr && db.dangerousQuery(sql) {
return nil, fmt.Errorf("query execution deny: execute SQL with DSN(%s/%s) '%s'", return res, fmt.Errorf("query execution deny: execute SQL with DSN(%s/%s) '%s'",
db.Addr, db.Database, fmt.Sprintf(sql, params...)) db.Addr, db.Database, fmt.Sprintf(sql, params...))
} }
common.Log.Debug("Execute SQL with DSN(%s/%s) : %s", db.Addr, db.Database, fmt.Sprintf(sql, params...)) common.Log.Debug("Execute SQL with DSN(%s/%s) : %s", db.Addr, db.Database, fmt.Sprintf(sql, params...))
conn := db.NewConnection() conn, err := db.NewConnection()
// 设置SQL连接超时时间
conn.SetTimeout(time.Duration(common.Config.ConnTimeOut) * time.Second)
defer conn.Close() defer conn.Close()
err := conn.Connect()
if err != nil { if err != nil {
return nil, err return res, err
} }
res.Rows, res.Error = conn.Query(sql, params...)
// 添加SQL执行超时限制 if common.Config.ShowWarnings {
ch := make(chan QueryResult, 1) res.Warning, err = conn.Query("SHOW WARNINGS")
go func() { }
res := QueryResult{}
res.Rows, res.Result, res.Error = conn.Query(sql, params...)
if common.Config.ShowWarnings {
warning, _, err := conn.Query("SHOW WARNINGS")
if err == nil {
res.Warning = warning
}
}
// SHOW WARNINGS 并不会影响 last_query_cost // SHOW WARNINGS 并不会影响 last_query_cost
if common.Config.ShowLastQueryCost { if common.Config.ShowLastQueryCost {
cost, _, err := conn.Query("SHOW SESSION STATUS LIKE 'last_query_cost'") cost, err := conn.Query("SHOW SESSION STATUS LIKE 'last_query_cost'")
if err == nil { if err == nil {
if len(cost) > 0 { if cost.Next() {
res.QueryCost = cost[0].Float(1) err = cost.Scan(res.QueryCost)
}
} }
} }
ch <- res
}()
select {
case res := <-ch:
return &res, res.Error
case <-time.After(time.Duration(common.Config.QueryTimeOut) * time.Second):
return nil, errors.New("query execution timeout")
} }
return res, err
} }
// Version 获取MySQL数据库版本 // Version 获取MySQL数据库版本
...@@ -124,77 +101,43 @@ func (db *Connector) Version() (int, error) { ...@@ -124,77 +101,43 @@ func (db *Connector) Version() (int, error) {
version := 99999 version := 99999
// 从数据库中获取版本信息 // 从数据库中获取版本信息
res, err := db.Query("select @@version") res, err := db.Query("select @@version")
if err != nil { if err != nil || res.Error != nil {
common.Log.Warn("(db *Connector) Version() Error: %v", err) common.Log.Warn("(db *Connector) Version() Error: %v, MySQL Error: %v", err, res.Error)
return version, err return version, err
} }
// MariaDB https://mariadb.com/kb/en/library/comment-syntax/ // MariaDB https://mariadb.com/kb/en/library/comment-syntax/
// MySQL https://dev.mysql.com/doc/refman/8.0/en/comments.html // MySQL https://dev.mysql.com/doc/refman/8.0/en/comments.html
versionStr := strings.Split(res.Rows[0].Str(0), "-")[0] var versionStr string
versionSeg := strings.Split(versionStr, ".") var versionSeg []string
if len(versionSeg) == 3 { for res.Rows.Next() {
versionStr = fmt.Sprintf("%s%02s%02s", versionSeg[0], versionSeg[1], versionSeg[2]) err = res.Rows.Scan(&versionStr)
version, err = strconv.Atoi(versionStr) versionStr = strings.Split(versionStr, "-")[0]
} versionSeg = strings.Split(versionStr, ".")
return version, err if len(versionSeg) == 3 {
} versionStr = fmt.Sprintf("%s%02s%02s", versionSeg[0], versionSeg[1], versionSeg[2])
version, err = strconv.Atoi(versionStr)
// Source execute sql from file
func (db *Connector) Source(file string) ([]*QueryResult, error) {
var sqlCounter int // SQL 计数器
var result []*QueryResult
fd, err := os.Open(file)
defer func() {
err = fd.Close()
if err != nil {
common.Log.Error("(db *Connector) Source(%s) fd.Close failed: %s", file, err.Error())
} }
}() break
if err != nil {
common.Log.Warning("(db *Connector) Source(%s) os.Open failed: %s", file, err.Error())
return nil, err
}
data, err := ioutil.ReadAll(fd)
if err != nil {
common.Log.Critical("ioutil.ReadAll Error: %s", err.Error())
return nil, err
} }
sql := strings.TrimSpace(string(data)) return version, err
buf := strings.TrimSpace(sql)
for ; ; sqlCounter++ {
if buf == "" {
break
}
// 查询请求切分
_, sql, bufBytes := ast.SplitStatement([]byte(buf), []byte(common.Config.Delimiter))
buf = string(bufBytes)
sql = strings.TrimSpace(sql)
common.Log.Debug("Source Query SQL: %s", sql)
res, e := db.Query(sql)
if e != nil {
common.Log.Error("(db *Connector) Source Filename: %s, SQLCounter.: %d", file, sqlCounter)
return result, e
}
result = append(result, res)
}
return result, nil
} }
// SingleIntValue 获取某个int型变量的值 // SingleIntValue 获取某个int型变量的值
func (db *Connector) SingleIntValue(option string) (int, error) { func (db *Connector) SingleIntValue(option string) (int, error) {
// 从数据库中获取信息 // 从数据库中获取信息
res, err := db.Query("select @@%s", option) res, err := db.Query("select @@" + option)
if err != nil { if err != nil {
common.Log.Warn("(db *Connector) SingleIntValue() Error: %v", err) common.Log.Warn("(db *Connector) SingleIntValue() Error: %v", err)
return -1, err return -1, err
} }
return res.Rows[0].Int(0), err var intVal int
if res.Rows.Next() {
err = res.Rows.Scan(&intVal)
}
return intVal, err
} }
// ColumnCardinality 粒度计算 // ColumnCardinality 粒度计算
...@@ -228,13 +171,20 @@ func (db *Connector) ColumnCardinality(tb, col string) float64 { ...@@ -228,13 +171,20 @@ func (db *Connector) ColumnCardinality(tb, col string) float64 {
} }
// 计算该列散粒度 // 计算该列散粒度
res, err := db.Query("select count(distinct `%s`) from `%s`.`%s`", col, db.Database, tb) res, err := db.Query(fmt.Sprintf("select count(distinct `%s`) from `%s`.`%s`", col, db.Database, tb))
if err != nil { if err != nil {
common.Log.Warn("(db *Connector) ColumnCardinality() Query Error: %v", err) common.Log.Warn("(db *Connector) ColumnCardinality() Query Error: %v", err)
return 0 return 0
} }
colNum := res.Rows[0].Float(0) var colNum float64
if res.Rows.Next() {
err = res.Rows.Scan(&colNum)
if err != nil {
common.Log.Warn("(db *Connector) ColumnCardinality() Query Error: %v", err)
return 0
}
}
// 散粒度区间:[0,1] // 散粒度区间:[0,1]
return colNum / float64(rowTotal) return colNum / float64(rowTotal)
...@@ -249,13 +199,12 @@ func (db *Connector) IsView(tbName string) bool { ...@@ -249,13 +199,12 @@ func (db *Connector) IsView(tbName string) bool {
} }
if len(tbStatus.Rows) > 0 { if len(tbStatus.Rows) > 0 {
if tbStatus.Rows[0].Comment == "VIEW" { if string(tbStatus.Rows[0].Comment) == "VIEW" {
return true return true
} }
} }
return false return false
} }
// RemoveSQLComments 去除SQL中的注释 // RemoveSQLComments 去除SQL中的注释
...@@ -281,7 +230,7 @@ func (db *Connector) dangerousQuery(query string) bool { ...@@ -281,7 +230,7 @@ func (db *Connector) dangerousQuery(query string) bool {
return true return true
} }
for _, sql := range queries { for _, query := range queries {
dangerous := true dangerous := true
whiteList := []string{ whiteList := []string{
"select", "select",
...@@ -291,7 +240,7 @@ func (db *Connector) dangerousQuery(query string) bool { ...@@ -291,7 +240,7 @@ func (db *Connector) dangerousQuery(query string) bool {
} }
for _, prefix := range whiteList { for _, prefix := range whiteList {
if strings.HasPrefix(sql, prefix) { if strings.HasPrefix(query, prefix) {
dangerous = false dangerous = false
break break
} }
......
...@@ -17,26 +17,70 @@ ...@@ -17,26 +17,70 @@
package database package database
import ( import (
"flag"
"fmt" "fmt"
"os"
"testing" "testing"
"github.com/XiaoMi/soar/common" "github.com/XiaoMi/soar/common"
"github.com/kr/pretty" "github.com/kr/pretty"
) )
var connTest *Connector
var update = flag.Bool("update", false, "update .golden files")
func init() {
common.BaseDir = common.DevPath
common.ParseConfig("")
connTest = &Connector{
Addr: common.Config.OnlineDSN.Addr,
User: common.Config.OnlineDSN.User,
Pass: common.Config.OnlineDSN.Password,
Database: common.Config.OnlineDSN.Schema,
Charset: common.Config.OnlineDSN.Charset,
}
if _, err := connTest.Version(); err != nil {
common.Log.Critical("Test env Error: %v", err)
os.Exit(0)
}
}
func TestNewConnection(t *testing.T) {
_, err := connTest.NewConnection()
if err != nil {
t.Errorf("TestNewConnection, Error: %s", err.Error())
}
}
// TODO: go test -race不通过待解决 // TODO: go test -race不通过待解决
func TestQuery(t *testing.T) { func TestQuery(t *testing.T) {
common.Config.QueryTimeOut = 1 res, err := connTest.Query("select 0")
_, err := connTest.Query("select sleep(2)") if err != nil {
if err == nil { t.Error(err.Error())
t.Error("connTest.Query not timeout") }
for res.Rows.Next() {
var val int
err = res.Rows.Scan(&val)
if err != nil {
t.Error(err.Error())
}
if val != 0 {
t.Error("should return 0")
}
} }
// TODO: timeout test
} }
func TestColumnCardinality(_ *testing.T) { func TestColumnCardinality(t *testing.T) {
connTest.Database = "information_schema" orgDatabase := connTest.Database
a := connTest.ColumnCardinality("TABLES", "TABLE_SCHEMA") connTest.Database = "sakila"
fmt.Println("TABLES.TABLE_SCHEMA:", a) a := connTest.ColumnCardinality("actor", "first_name")
if a >= 1 || a <= 0 {
t.Error("sakila.actor.first_name cardinality should in (0, 1), now it's", a)
}
connTest.Database = orgDatabase
} }
func TestDangerousSQL(t *testing.T) { func TestDangerousSQL(t *testing.T) {
...@@ -63,11 +107,15 @@ func TestWarningsAndQueryCost(t *testing.T) { ...@@ -63,11 +107,15 @@ func TestWarningsAndQueryCost(t *testing.T) {
if err != nil { if err != nil {
t.Error("Query Error: ", err) t.Error("Query Error: ", err)
} else { } else {
for _, w := range res.Warning { for res.Warning.Next() {
pretty.Println(w.Str(2)) var str string
err = res.Warning.Scan(str)
if err != nil {
t.Error(err.Error())
}
pretty.Println(str)
} }
fmt.Println(res.QueryCost) fmt.Println(res.QueryCost, err)
pretty.Println(err)
} }
} }
...@@ -79,16 +127,6 @@ func TestVersion(t *testing.T) { ...@@ -79,16 +127,6 @@ func TestVersion(t *testing.T) {
fmt.Println(version) fmt.Println(version)
} }
func TestSource(t *testing.T) {
res, err := connTest.Source("testdata/" + t.Name() + ".sql")
if err != nil {
t.Error("Query Error: ", err)
}
if res[0].Rows[0].Int(0) != 1 || res[1].Rows[0].Int(0) != 1 {
t.Error("Source result not match, expect 1, 1")
}
}
func TestRemoveSQLComments(t *testing.T) { func TestRemoveSQLComments(t *testing.T) {
SQLs := []string{ SQLs := []string{
`-- comment`, `-- comment`,
...@@ -109,3 +147,22 @@ comment*/`, ...@@ -109,3 +147,22 @@ comment*/`,
t.Error(err) t.Error(err)
} }
} }
func TestSingleIntValue(t *testing.T) {
val, err := connTest.SingleIntValue("read_only")
if err != nil {
t.Error(err)
}
if val < 0 {
t.Error("SingleIntValue, return should large than zero")
}
}
func TestIsView(t *testing.T) {
originalDatabase := connTest.Database
connTest.Database = "sakila"
if !connTest.IsView("actor_info") {
t.Error("actor_info should be a VIEW")
}
connTest.Database = originalDatabase
}
...@@ -18,6 +18,7 @@ package database ...@@ -18,6 +18,7 @@ package database
import ( import (
"errors" "errors"
"fmt"
"strings" "strings"
"github.com/XiaoMi/soar/common" "github.com/XiaoMi/soar/common"
...@@ -29,8 +30,13 @@ func (db *Connector) CurrentUser() (string, string, error) { ...@@ -29,8 +30,13 @@ func (db *Connector) CurrentUser() (string, string, error) {
if err != nil { if err != nil {
return "", "", err return "", "", err
} }
if len(res.Rows) > 0 { if res.Rows.Next() {
cols := strings.Split(res.Rows[0].Str(0), "@") var currentUser string
err = res.Rows.Scan(&currentUser)
if err != nil {
return "", "", err
}
cols := strings.Split(currentUser, "@")
if len(cols) == 2 { if len(cols) == 2 {
user := strings.Trim(cols[0], "'") user := strings.Trim(cols[0], "'")
host := strings.Trim(cols[1], "'") host := strings.Trim(cols[1], "'")
...@@ -51,14 +57,20 @@ func (db *Connector) HasSelectPrivilege() bool { ...@@ -51,14 +57,20 @@ func (db *Connector) HasSelectPrivilege() bool {
common.Log.Error("User: %s, HasSelectPrivilege: %s", db.User, err.Error()) common.Log.Error("User: %s, HasSelectPrivilege: %s", db.User, err.Error())
return false return false
} }
res, err := db.Query("select Select_priv from mysql.user where user='%s' and host='%s'", user, host) res, err := db.Query(fmt.Sprintf("select Select_priv from mysql.user where user='%s' and host='%s'", user, host))
if err != nil { if err != nil {
common.Log.Error("HasSelectPrivilege, DSN: %s, Error: %s", db.Addr, err.Error()) common.Log.Error("HasSelectPrivilege, DSN: %s, Error: %s", db.Addr, err.Error())
return false return false
} }
// Select_priv // Select_priv
if len(res.Rows) > 0 { if res.Rows.Next() {
if res.Rows[0].Str(0) == "Y" { var selectPrivilege string
err = res.Rows.Scan(&selectPrivilege)
if err != nil {
common.Log.Error("HasSelectPrivilege, Scan Error: %s", err.Error())
return false
}
if selectPrivilege == "Y" {
return true return true
} }
} }
...@@ -79,24 +91,31 @@ func (db *Connector) HasAllPrivilege() bool { ...@@ -79,24 +91,31 @@ func (db *Connector) HasAllPrivilege() bool {
common.Log.Error("HasAllPrivilege, DSN: %s, Error: %s", db.Addr, err.Error()) common.Log.Error("HasAllPrivilege, DSN: %s, Error: %s", db.Addr, err.Error())
return false return false
} }
var priv string var priv string
if len(res.Rows) > 0 { if res.Rows.Next() {
priv = res.Rows[0].Str(0) err = res.Rows.Scan(&priv)
} else { if err != nil {
common.Log.Error("HasAllPrivilege, DSN: %s, get privilege string error", db.Addr) common.Log.Error("HasAllPrivilege, DSN: %s, Scan error", db.Addr)
return false return false
}
} }
// get all privilege status // get all privilege status
res, err = db.Query("select concat("+priv+") from mysql.user where user='%s' and host='%s'", user, host) res, err = db.Query(fmt.Sprintf("select concat("+priv+") from mysql.user where user='%s' and host='%s'", user, host))
if err != nil { if err != nil {
common.Log.Error("HasAllPrivilege, DSN: %s, Error: %s", db.Addr, err.Error()) common.Log.Error("HasAllPrivilege, DSN: %s, Error: %s", db.Addr, err.Error())
return false return false
} }
// %_priv // %_priv
if len(res.Rows) > 0 { if res.Rows.Next() {
if strings.Replace(res.Rows[0].Str(0), "Y", "", -1) == "" { err = res.Rows.Scan(&priv)
if err != nil {
common.Log.Error("HasAllPrivilege, DSN: %s, Scan error", db.Addr)
return false
}
if strings.Replace(priv, "Y", "", -1) == "" {
return true return true
} }
} }
......
...@@ -18,6 +18,16 @@ package database ...@@ -18,6 +18,16 @@ package database
import "testing" import "testing"
func TestCurrentUser(t *testing.T) {
user, host, err := connTest.CurrentUser()
if err != nil {
t.Error(err.Error())
}
if user != "root" || host != "%" {
t.Errorf("Want user: root, host: %%. Get user: %s, host: %s", user, host)
}
}
func TestHasSelectPrivilege(t *testing.T) { func TestHasSelectPrivilege(t *testing.T) {
if !connTest.HasSelectPrivilege() { if !connTest.HasSelectPrivilege() {
t.Errorf("DSN: %s, User: %s, should has select privilege", connTest.Addr, connTest.User) t.Errorf("DSN: %s, User: %s, should has select privilege", connTest.Addr, connTest.User)
......
...@@ -19,9 +19,7 @@ package database ...@@ -19,9 +19,7 @@ package database
import ( import (
"errors" "errors"
"fmt" "fmt"
"io"
"strings" "strings"
"time"
"github.com/XiaoMi/soar/common" "github.com/XiaoMi/soar/common"
...@@ -40,95 +38,79 @@ type ProfilingRow struct { ...@@ -40,95 +38,79 @@ type ProfilingRow struct {
// TODO: 支持show profile all,不过目前看all的信息过多有点眼花缭乱 // TODO: 支持show profile all,不过目前看all的信息过多有点眼花缭乱
} }
// Profiling 执行SQL,并对其Profiling // Profiling 执行SQL,并对其 Profile
func (db *Connector) Profiling(sql string, params ...interface{}) (*QueryResult, error) { func (db *Connector) Profiling(sql string, params ...interface{}) ([]ProfilingRow, error) {
var rows []ProfilingRow
// 过滤不需要 profiling 的 SQL // 过滤不需要 profiling 的 SQL
switch sqlparser.Preview(sql) { switch sqlparser.Preview(sql) {
case sqlparser.StmtSelect, sqlparser.StmtUpdate, sqlparser.StmtDelete: case sqlparser.StmtSelect, sqlparser.StmtUpdate, sqlparser.StmtDelete:
default: default:
return nil, errors.New("no need profiling") return rows, errors.New("no need profiling")
} }
// 测试环境如果检查是关闭的,则SQL不会被执行 // 测试环境如果检查是关闭的,则SQL不会被执行
if common.Config.TestDSN.Disable { if common.Config.TestDSN.Disable {
return nil, errors.New("Dsn Disable") return rows, errors.New("dsn is disable")
} }
// 数据库安全性检查:如果 Connector 的 IP 端口与 TEST 环境不一致,则启用 SQL 白名单 // 数据库安全性检查:如果 Connector 的 IP 端口与 TEST 环境不一致,则启用 SQL 白名单
// 不在白名单中的SQL不允许执行 // 不在白名单中的 SQL 不允许执行
// 执行环境与test环境不相同 // 执行环境与test环境不相同
if db.Addr != common.Config.TestDSN.Addr && db.dangerousQuery(sql) { if db.Addr != common.Config.TestDSN.Addr && db.dangerousQuery(sql) {
return nil, fmt.Errorf("query execution deny: Execute SQL with DSN(%s/%s) '%s'", return rows, fmt.Errorf("query execution deny: Execute SQL with DSN(%s/%s) '%s'",
db.Addr, db.Database, fmt.Sprintf(sql, params...)) db.Addr, db.Database, fmt.Sprintf(sql, params...))
} }
common.Log.Debug("Execute SQL with DSN(%s/%s) : %s", db.Addr, db.Database, sql) common.Log.Debug("Execute SQL with DSN(%s/%s) : %s", db.Addr, db.Database, sql)
conn := db.NewConnection() conn, err := db.NewConnection()
if err != nil {
// 设置SQL连接超时时间 return rows, err
conn.SetTimeout(time.Duration(common.Config.ConnTimeOut) * time.Second) }
defer conn.Close() defer conn.Close()
err := conn.Connect()
// Keep connection
// https://github.com/go-sql-driver/mysql/issues/208
trx, err := conn.Begin()
if err != nil { if err != nil {
return nil, err return rows, err
} }
defer trx.Rollback()
// 开启 Profiling
_, err = trx.Query("set @@profiling=1")
common.LogIfError(err, "")
// 添加SQL执行超时限制 // 执行 SQL,抛弃返回结果
ch := make(chan QueryResult, 1) tmpRes, err := trx.Query(sql, params...)
go func() { if err != nil {
// 开启Profiling return rows, err
_, _, err = conn.Query("set @@profiling=1") }
common.LogIfError(err, "") for tmpRes.Next() {
continue
}
// 执行SQL,抛弃返回结果 // 返回 Profiling 结果
result, err := conn.Start(sql, params...) res, err := trx.Query("show profile")
for res.Next() {
var profileRow ProfilingRow
err := res.Scan(&profileRow.Status, &profileRow.Duration)
if err != nil { if err != nil {
ch <- QueryResult{ common.LogIfError(err, "")
Error: err,
}
return
} }
row := result.MakeRow() rows = append(rows, profileRow)
for {
err = result.ScanRow(row)
if err == io.EOF {
break
}
}
// 返回Profiling结果
res := QueryResult{}
res.Rows, res.Result, res.Error = conn.Query("show profile")
_, _, err = conn.Query("set @@profiling=0")
common.LogIfError(err, "")
ch <- res
}()
select {
case res := <-ch:
return &res, res.Error
case <-time.After(time.Duration(common.Config.QueryTimeOut) * time.Second):
return nil, errors.New("query execution timeout")
} }
}
func getProfiling(res *QueryResult) Profiling { // 关闭 Profiling
var rows []ProfilingRow _, err = trx.Query("set @@profiling=0")
for _, row := range res.Rows { common.LogIfError(err, "")
rows = append(rows, ProfilingRow{ return rows, err
Status: row.Str(0),
Duration: row.Float(1),
})
}
return Profiling{Rows: rows}
} }
// FormatProfiling 格式化输出Profiling信息 // FormatProfiling 格式化输出Profiling信息
func FormatProfiling(res *QueryResult) string { func FormatProfiling(rows []ProfilingRow) string {
profiling := getProfiling(res)
str := []string{"| Status | Duration |"} str := []string{"| Status | Duration |"}
str = append(str, "| --- | --- |") str = append(str, "| --- | --- |")
for _, row := range profiling.Rows { for _, row := range rows {
str = append(str, fmt.Sprintf("| %s | %f |", row.Status, row.Duration)) str = append(str, fmt.Sprintf("| %s | %f |", row.Status, row.Duration))
} }
return strings.Join(str, "\n") return strings.Join(str, "\n")
......
...@@ -19,35 +19,21 @@ package database ...@@ -19,35 +19,21 @@ package database
import ( import (
"testing" "testing"
"github.com/XiaoMi/soar/common"
"github.com/kr/pretty" "github.com/kr/pretty"
) )
func TestProfiling(t *testing.T) { func TestProfiling(t *testing.T) {
common.Config.QueryTimeOut = 1 rows, err := connTest.Profiling("select 1")
res, err := connTest.Profiling("select 1") if err != nil {
if err == nil {
pretty.Println(res)
} else {
t.Error(err) t.Error(err)
} }
pretty.Println(rows)
} }
func TestFormatProfiling(t *testing.T) { func TestFormatProfiling(t *testing.T) {
res, err := connTest.Profiling("select 1") res, err := connTest.Profiling("select 1")
if err == nil { if err != nil {
pretty.Println(FormatProfiling(res))
} else {
t.Error(err)
}
}
func TestGetProfiling(t *testing.T) {
res, err := connTest.Profiling("select 1")
if err == nil {
pretty.Println(getProfiling(res))
} else {
t.Error(err) t.Error(err)
} }
pretty.Println(FormatProfiling(res))
} }
...@@ -17,14 +17,9 @@ ...@@ -17,14 +17,9 @@
package database package database
import ( import (
"database/sql"
"fmt" "fmt"
"io"
"strconv"
"strings"
"time"
"github.com/XiaoMi/soar/common" "github.com/XiaoMi/soar/common"
"github.com/ziutek/mymysql/mysql"
) )
/*-------------------- /*--------------------
...@@ -57,21 +52,17 @@ func (db *Connector) SamplingData(remote Connector, tables ...string) error { ...@@ -57,21 +52,17 @@ func (db *Connector) SamplingData(remote Connector, tables ...string) error {
maxValCount := 200 maxValCount := 200
// 获取数据库连接对象 // 获取数据库连接对象
conn := remote.NewConnection() conn, err := remote.NewConnection()
localConn := db.NewConnection()
// 连接数据库
err := conn.Connect()
defer conn.Close()
if err != nil { if err != nil {
return err return err
} }
defer conn.Close()
err = localConn.Connect() localConn, err := db.NewConnection()
defer localConn.Close()
if err != nil { if err != nil {
return err return err
} }
defer localConn.Close()
for _, table := range tables { for _, table := range tables {
// 表类型检查 // 表类型检查
...@@ -109,119 +100,128 @@ func (db *Connector) SamplingData(remote Connector, tables ...string) error { ...@@ -109,119 +100,128 @@ func (db *Connector) SamplingData(remote Connector, tables ...string) error {
// 开始从环境中泵取数据 // 开始从环境中泵取数据
// 因为涉及到的数据量问题,所以泵取与插入时同时进行的 // 因为涉及到的数据量问题,所以泵取与插入时同时进行的
// TODO 加 ref link // TODO 加 ref link
func startSampling(conn, localConn mysql.Conn, database, table string, factor float64, wants, maxValCount int) error { func startSampling(conn, localConn *sql.DB, database, table string, factor float64, wants, maxValCount int) error {
// 从线上数据库获取所需dump的表中所有列的数据类型,备用 return nil
// 由于测试库中的库表为刚建立的,所以在information_schema中很可能没有这个表的信息 // TODO:
var dataTypes []string /*
q := fmt.Sprintf("select DATA_TYPE from information_schema.COLUMNS where TABLE_SCHEMA='%s' and TABLE_NAME = '%s'", // 从线上数据库获取所需dump的表中所有列的数据类型,备用
database, table) // 由于测试库中的库表为刚建立的,所以在information_schema中很可能没有这个表的信息
common.Log.Debug("Sampling data execute: %s", q) var dataTypes []string
rs, _, err := localConn.Query(q) q := fmt.Sprintf("select DATA_TYPE from information_schema.COLUMNS where TABLE_SCHEMA='%s' and TABLE_NAME = '%s'",
if err != nil { database, table)
common.Log.Debug("Sampling data got data type Err: %v", err) common.Log.Debug("Sampling data execute: %s", q)
} else { rs, err := localConn.Query(q)
for _, r := range rs { if err != nil {
dataTypes = append(dataTypes, r.Str(0)) common.Log.Debug("Sampling data got data type Err: %v", err)
} else {
for rs.Next() {
var dataType string
err = rs.Scan(&dataType)
if err != nil {
return err
}
dataTypes = append(dataTypes, dataType)
}
} }
}
// 生成where条件
where := fmt.Sprintf("where RAND()<=%f", factor)
if factor >= 1 {
where = ""
}
sql := fmt.Sprintf("select * from `%s` %s limit %d;", table, where, wants)
res, err := conn.Start(sql)
if err != nil {
return err
}
// GetRow method allocates a new chunk of memory for every received row. // 生成where条件
row := res.MakeRow() where := fmt.Sprintf("where RAND()<=%f", factor)
rowCount := 0 if factor >= 1 {
valCount := 0 where = ""
// 获取所有的列名
columns := make([]string, len(res.Fields()))
for i, filed := range res.Fields() {
columns[i] = filed.Name
}
colDef := strings.Join(columns, ",")
// 开始填充数据
var valList []string
for {
err := res.ScanRow(row)
if err == io.EOF {
// 扫描结束
if len(valList) > 0 {
// 如果缓存中还存在未插入的数据,则把缓存中的数据刷新到DB中
doSampling(localConn, database, table, colDef, strings.Join(valList, ","))
}
break
} }
sql := fmt.Sprintf("select * from `%s` %s limit %d;", table, where, wants)
res, err := conn.Query(sql)
if err != nil { if err != nil {
return err return err
} }
values := make([]string, len(columns)) // GetRow method allocates a new chunk of memory for every received row.
for i := range row { row := res.MakeRow()
// TODO 不支持坐标类型的导出 rowCount := 0
switch data := row[i].(type) { valCount := 0
case nil:
// str = "" // 获取所有的列名
case []byte: columns := make([]string, len(res.Fields()))
// 先尝试转成数字,如果报错则转换成string for i, filed := range res.Fields() {
if v, err := row.Int64Err(i); err != nil { columns[i] = filed.Name
values[i] = string(data) }
} else { colDef := strings.Join(columns, ",")
values[i] = strconv.FormatInt(v, 10)
// 开始填充数据
var valList []string
for {
err := res.ScanRow(row)
if err == io.EOF {
// 扫描结束
if len(valList) > 0 {
// 如果缓存中还存在未插入的数据,则把缓存中的数据刷新到DB中
doSampling(localConn, database, table, colDef, strings.Join(valList, ","))
} }
case time.Time: break
values[i] = mysql.TimeString(data)
case time.Duration:
values[i] = mysql.DurationString(data)
default:
values[i] = fmt.Sprint(data)
} }
// 非text/varchar类的数据类型,如果dump出的数据为空,则说明该值为null值 if err != nil {
// 应转换其 value 为 null,如果用空('')进行替代,会导致出现语法错误。 return err
if len(dataTypes) == len(res.Fields()) && values[i] == "" &&
(!strings.Contains(dataTypes[i], "char") ||
!strings.Contains(dataTypes[i], "text")) {
values[i] = "null"
} else {
values[i] = "'" + values[i] + "'"
} }
}
valuesStr := fmt.Sprintf(`(%s)`, strings.Join(values, `,`)) values := make([]string, len(columns))
valList = append(valList, valuesStr) for i := range row {
// TODO 不支持坐标类型的导出
switch data := row[i].(type) {
case nil:
// str = ""
case []byte:
// 先尝试转成数字,如果报错则转换成string
if v, err := row.Int64Err(i); err != nil {
values[i] = string(data)
} else {
values[i] = strconv.FormatInt(v, 10)
}
case time.Time:
values[i] = mysql.TimeString(data)
case time.Duration:
values[i] = mysql.DurationString(data)
default:
values[i] = fmt.Sprint(data)
}
// 非text/varchar类的数据类型,如果dump出的数据为空,则说明该值为null值
// 应转换其 value 为 null,如果用空('')进行替代,会导致出现语法错误。
if len(dataTypes) == len(res.Fields()) && values[i] == "" &&
(!strings.Contains(dataTypes[i], "char") ||
!strings.Contains(dataTypes[i], "text")) {
values[i] = "null"
} else {
values[i] = "'" + values[i] + "'"
}
}
rowCount++ valuesStr := fmt.Sprintf(`(%s)`, strings.Join(values, `,`))
valCount++ valList = append(valList, valuesStr)
if rowCount%maxValCount == 0 { rowCount++
doSampling(localConn, database, table, colDef, strings.Join(valList, ",")) valCount++
valCount = 0
valList = make([]string, 0)
if rowCount%maxValCount == 0 {
doSampling(localConn, database, table, colDef, strings.Join(valList, ","))
valCount = 0
valList = make([]string, 0)
}
} }
}
common.Log.Debug("%d rows sampling out", rowCount) common.Log.Debug("%d rows sampling out", rowCount)
return nil return nil
*/
} }
// 将泵取的数据转换成Insert语句并在数据库中执行 // 将泵取的数据转换成Insert语句并在数据库中执行
func doSampling(conn mysql.Conn, dbName, table, colDef, values string) { func doSampling(conn *sql.DB, dbName, table, colDef, values string) {
sql := fmt.Sprintf("Insert into `%s`.`%s`(%s) values%s;", dbName, table, query := fmt.Sprintf("Insert into `%s`.`%s`(%s) values%s;", dbName, table,
colDef, values) colDef, values)
_, _, err := conn.Query(sql) _, err := conn.Query(query)
if err != nil { if err != nil {
common.Log.Error("doSampling Error from %s.%s: %v", dbName, table, err) common.Log.Error("doSampling Error from %s.%s: %v", dbName, table, err)
......
...@@ -32,6 +32,8 @@ func TestSamplingData(t *testing.T) { ...@@ -32,6 +32,8 @@ func TestSamplingData(t *testing.T) {
User: common.Config.OnlineDSN.User, User: common.Config.OnlineDSN.User,
Pass: common.Config.OnlineDSN.Password, Pass: common.Config.OnlineDSN.Password,
Database: common.Config.OnlineDSN.Schema, Database: common.Config.OnlineDSN.Schema,
Charset: common.Config.OnlineDSN.Charset,
Net: common.Config.OnlineDSN.Net,
} }
offline := &Connector{ offline := &Connector{
...@@ -39,6 +41,8 @@ func TestSamplingData(t *testing.T) { ...@@ -39,6 +41,8 @@ func TestSamplingData(t *testing.T) {
User: common.Config.TestDSN.User, User: common.Config.TestDSN.User,
Pass: common.Config.TestDSN.Password, Pass: common.Config.TestDSN.Password,
Database: common.Config.TestDSN.Schema, Database: common.Config.TestDSN.Schema,
Charset: common.Config.TestDSN.Charset,
Net: common.Config.TestDSN.Net,
} }
offline.Database = "test" offline.Database = "test"
......
...@@ -18,12 +18,10 @@ package database ...@@ -18,12 +18,10 @@ package database
import ( import (
"fmt" "fmt"
"github.com/XiaoMi/soar/common"
"regexp" "regexp"
"strconv" "strconv"
"strings" "strings"
"time"
"github.com/XiaoMi/soar/common"
) )
// SHOW TABLE STATUS Syntax // SHOW TABLE STATUS Syntax
...@@ -36,11 +34,12 @@ type TableStatInfo struct { ...@@ -36,11 +34,12 @@ type TableStatInfo struct {
} }
// tableStatusRow 用于 show table status value // tableStatusRow 用于 show table status value
// use []byte instead of string, because []byte allow to be null, string not
type tableStatusRow struct { type tableStatusRow struct {
Name string // 表名 Name string // 表名
Engine string // 该表使用的存储引擎 Engine []byte // 该表使用的存储引擎
Version int // 该表的 .frm 文件版本号 Version []byte // 该表的 .frm 文件版本号
RowFormat string // 该表使用的行存储格式 RowFormat []byte // 该表使用的行存储格式
Rows int64 // 表行数, InnoDB 引擎中为预估值,甚至可能会有40%~50%的数值偏差 Rows int64 // 表行数, InnoDB 引擎中为预估值,甚至可能会有40%~50%的数值偏差
AvgRowLength int // 平均行长度 AvgRowLength int // 平均行长度
...@@ -59,15 +58,15 @@ type tableStatusRow struct { ...@@ -59,15 +58,15 @@ type tableStatusRow struct {
// 其他不同的存储引擎中该值的意义可能不尽相同 // 其他不同的存储引擎中该值的意义可能不尽相同
IndexLength int IndexLength int
DataFree int // 已分配但未使用的字节数 DataFree int // 已分配但未使用的字节数
AutoIncrement int // 下一个自增值 AutoIncrement []byte // 下一个自增值
CreateTime time.Time // 创建时间 CreateTime []byte // 创建时间
UpdateTime time.Time // 最近一次更新时间,该值不准确 UpdateTime []byte // 最近一次更新时间,该值不准确
CheckTime time.Time // 上次检查时间 CheckTime []byte // 上次检查时间
Collation string // 字符集及排序规则信息 Collation []byte // 字符集及排序规则信息
Checksum string // 校验和 Checksum []byte // 校验和
CreateOptions string // 创建表的时候的时候一切其他属性 CreateOptions []byte // 创建表的时候的时候一切其他属性
Comment string // 注释 Comment []byte // 注释
} }
// newTableStat 构造 table Stat 对象 // newTableStat 构造 table Stat 对象
...@@ -83,7 +82,7 @@ func (db *Connector) ShowTables() ([]string, error) { ...@@ -83,7 +82,7 @@ func (db *Connector) ShowTables() ([]string, error) {
defer func() { defer func() {
err := recover() err := recover()
if err != nil { if err != nil {
common.Log.Error("recover ShowTableStatus()", err) common.Log.Error("recover ShowTables()", err)
} }
}() }()
...@@ -92,79 +91,70 @@ func (db *Connector) ShowTables() ([]string, error) { ...@@ -92,79 +91,70 @@ func (db *Connector) ShowTables() ([]string, error) {
if err != nil { if err != nil {
return []string{}, err return []string{}, err
} }
if res.Error != nil {
return []string{}, res.Error
}
// 获取值 // 获取值
var tables []string var tables []string
for _, row := range res.Rows { for res.Rows.Next() {
tables = append(tables, row.Str(0)) var table string
err = res.Rows.Scan(&table)
if err != nil {
return []string{}, err
}
tables = append(tables, table)
} }
return tables, err return tables, err
} }
// ShowTableStatus 执行 show table status // ShowTableStatus 执行 show table status
func (db *Connector) ShowTableStatus(tableName string) (*TableStatInfo, error) { func (db *Connector) ShowTableStatus(tableName string) (*TableStatInfo, error) {
defer func() {
err := recover()
if err != nil {
common.Log.Error("recover ShowTableStatus()", err)
}
}()
// 初始化struct // 初始化struct
ts := newTableStat(tableName) tbStatus := newTableStat(tableName)
// 执行 show table status // 执行 show table status
res, err := db.Query("show table status where name = '%s'", ts.Name) res, err := db.Query(fmt.Sprintf("show table status where name = '%s'", tbStatus.Name))
if err != nil { if err != nil {
return ts, err return tbStatus, err
} }
if res.Error != nil {
rs := res.Result.Map("Rows") return tbStatus, res.Error
name := res.Result.Map("Name") }
df := res.Result.Map("Data_free")
sum := res.Result.Map("Checksum")
engine := res.Result.Map("Engine")
version := res.Result.Map("Version")
comment := res.Result.Map("Comment")
ai := res.Result.Map("Auto_increment")
collation := res.Result.Map("Collation")
rowFormat := res.Result.Map("Row_format")
checkTime := res.Result.Map("Check_time")
dataLength := res.Result.Map("Data_length")
idxLength := res.Result.Map("Index_length")
createTime := res.Result.Map("Create_time")
updateTime := res.Result.Map("Update_time")
options := res.Result.Map("Create_options")
avgRowLength := res.Result.Map("Avg_row_length")
maxDataLength := res.Result.Map("Max_data_length")
ts := tableStatusRow{}
statusFields := make([]interface{}, 0)
fields := map[string]interface{}{
"Name": &ts.Name,
"Engine": &ts.Engine,
"Version": &ts.Version,
"Row_format": &ts.RowFormat,
"Rows": &ts.Rows,
"Avg_row_length": &ts.AvgRowLength,
"Data_length": &ts.DataLength,
"Max_data_length": &ts.MaxDataLength,
"Index_length": &ts.IndexLength,
"Data_free": &ts.DataFree,
"Auto_increment": &ts.AutoIncrement,
"Create_time": &ts.CreateTime,
"Update_time": &ts.UpdateTime,
"Check_time": &ts.CheckTime,
"Collation": &ts.Collation,
"Checksum": &ts.Checksum,
"Create_options": &ts.CreateOptions,
"Comment": &ts.Comment,
}
cols, err := res.Rows.Columns()
common.LogIfError(err, "")
for _, col := range cols {
statusFields = append(statusFields, fields[col])
}
// 获取值 // 获取值
for _, row := range res.Rows { for res.Rows.Next() {
value := tableStatusRow{ res.Rows.Scan(statusFields...)
Name: row.Str(name), tbStatus.Rows = append(tbStatus.Rows, ts)
Engine: row.Str(engine),
Version: row.Int(version),
Rows: row.Int64(rs),
RowFormat: row.Str(rowFormat),
AvgRowLength: row.Int(avgRowLength),
DataLength: row.Int(dataLength),
MaxDataLength: row.Int(maxDataLength),
IndexLength: row.Int(idxLength),
DataFree: row.Int(df),
AutoIncrement: row.Int(ai),
CreateTime: row.Time(createTime, time.Local),
UpdateTime: row.Time(updateTime, time.Local),
CheckTime: row.Time(checkTime, time.Local),
Collation: row.Str(collation),
Checksum: row.Str(sum),
CreateOptions: row.Str(options),
Comment: row.Str(comment),
}
ts.Rows = append(ts.Rows, value)
} }
return tbStatus, err
return ts, err
} }
// https://dev.mysql.com/doc/refman/5.7/en/show-index.html // https://dev.mysql.com/doc/refman/5.7/en/show-index.html
...@@ -172,7 +162,7 @@ func (db *Connector) ShowTableStatus(tableName string) (*TableStatInfo, error) { ...@@ -172,7 +162,7 @@ func (db *Connector) ShowTableStatus(tableName string) (*TableStatInfo, error) {
// TableIndexInfo 用以保存 show index 之后获取的 index 信息 // TableIndexInfo 用以保存 show index 之后获取的 index 信息
type TableIndexInfo struct { type TableIndexInfo struct {
TableName string TableName string
IdxRows []TableIndexRow Rows []TableIndexRow
} }
// TableIndexRow 用以存放show index之后获取的每一条index信息 // TableIndexRow 用以存放show index之后获取的每一条index信息
...@@ -190,13 +180,14 @@ type TableIndexRow struct { ...@@ -190,13 +180,14 @@ type TableIndexRow struct {
IndexType string // BTREE, FULLTEXT, HASH, RTREE IndexType string // BTREE, FULLTEXT, HASH, RTREE
Comment string Comment string
IndexComment string IndexComment string
Visible string
} }
// NewTableIndexInfo 构造 TableIndexInfo // NewTableIndexInfo 构造 TableIndexInfo
func NewTableIndexInfo(tableName string) *TableIndexInfo { func NewTableIndexInfo(tableName string) *TableIndexInfo {
return &TableIndexInfo{ return &TableIndexInfo{
TableName: tableName, TableName: tableName,
IdxRows: make([]TableIndexRow, 0), Rows: make([]TableIndexRow, 0),
} }
} }
...@@ -205,43 +196,32 @@ func (db *Connector) ShowIndex(tableName string) (*TableIndexInfo, error) { ...@@ -205,43 +196,32 @@ func (db *Connector) ShowIndex(tableName string) (*TableIndexInfo, error) {
tbIndex := NewTableIndexInfo(tableName) tbIndex := NewTableIndexInfo(tableName)
// 执行 show create table // 执行 show create table
res, err := db.Query("show index from `%s`.`%s`", db.Database, tableName) res, err := db.Query(fmt.Sprintf("show index from `%s`.`%s`", db.Database, tableName))
if err != nil { if err != nil {
return nil, err return nil, err
} }
if res.Error != nil {
table := res.Result.Map("Table") return nil, res.Error
unique := res.Result.Map("Non_unique") }
keyName := res.Result.Map("Key_name")
seq := res.Result.Map("Seq_in_index")
cName := res.Result.Map("Column_name")
collation := res.Result.Map("Collation")
cardinality := res.Result.Map("Cardinality")
subPart := res.Result.Map("Sub_part")
packed := res.Result.Map("Packed")
null := res.Result.Map("Null")
idxType := res.Result.Map("Index_type")
comment := res.Result.Map("Comment")
idxComment := res.Result.Map("Index_comment")
// 获取值 // 获取值
for _, row := range res.Rows { for res.Rows.Next() {
value := TableIndexRow{ var ti TableIndexRow
Table: row.Str(table), res.Rows.Scan(&ti.Table,
NonUnique: row.Int(unique), &ti.NonUnique,
KeyName: row.Str(keyName), &ti.KeyName,
SeqInIndex: row.Int(seq), &ti.SeqInIndex,
ColumnName: row.Str(cName), &ti.ColumnName,
Collation: row.Str(collation), &ti.Collation,
Cardinality: row.Int(cardinality), &ti.Cardinality,
SubPart: row.Int(subPart), &ti.SubPart,
Packed: row.Int(packed), &ti.Packed,
Null: row.Str(null), &ti.Null,
IndexType: row.Str(idxType), &ti.IndexType,
Comment: row.Str(comment), &ti.Comment,
IndexComment: row.Str(idxComment), &ti.IndexComment,
} &ti.Visible)
tbIndex.IdxRows = append(tbIndex.IdxRows, value) tbIndex.Rows = append(tbIndex.Rows, ti)
} }
return tbIndex, err return tbIndex, err
} }
...@@ -257,7 +237,7 @@ const ( ...@@ -257,7 +237,7 @@ const (
IndexNonUnique = IndexSelectKey("NonUnique") // 唯一索引 IndexNonUnique = IndexSelectKey("NonUnique") // 唯一索引
) )
// FindIndex 获取TableIndexInfo中需要的索引 // FindIndex 获取 TableIndexInfo 中需要的索引
func (tbIndex *TableIndexInfo) FindIndex(arg IndexSelectKey, value string) []TableIndexRow { func (tbIndex *TableIndexInfo) FindIndex(arg IndexSelectKey, value string) []TableIndexRow {
var result []TableIndexRow var result []TableIndexRow
if tbIndex == nil { if tbIndex == nil {
...@@ -268,28 +248,28 @@ func (tbIndex *TableIndexInfo) FindIndex(arg IndexSelectKey, value string) []Tab ...@@ -268,28 +248,28 @@ func (tbIndex *TableIndexInfo) FindIndex(arg IndexSelectKey, value string) []Tab
switch arg { switch arg {
case IndexKeyName: case IndexKeyName:
for _, index := range tbIndex.IdxRows { for _, index := range tbIndex.Rows {
if strings.ToLower(index.KeyName) == value { if strings.ToLower(index.KeyName) == value {
result = append(result, index) result = append(result, index)
} }
} }
case IndexColumnName: case IndexColumnName:
for _, index := range tbIndex.IdxRows { for _, index := range tbIndex.Rows {
if strings.ToLower(index.ColumnName) == value { if strings.ToLower(index.ColumnName) == value {
result = append(result, index) result = append(result, index)
} }
} }
case IndexIndexType: case IndexIndexType:
for _, index := range tbIndex.IdxRows { for _, index := range tbIndex.Rows {
if strings.ToLower(index.IndexType) == value { if strings.ToLower(index.IndexType) == value {
result = append(result, index) result = append(result, index)
} }
} }
case IndexNonUnique: case IndexNonUnique:
for _, index := range tbIndex.IdxRows { for _, index := range tbIndex.Rows {
unique := strconv.Itoa(index.NonUnique) unique := strconv.Itoa(index.NonUnique)
if unique == value { if unique == value {
result = append(result, index) result = append(result, index)
...@@ -316,12 +296,12 @@ type TableDesc struct { ...@@ -316,12 +296,12 @@ type TableDesc struct {
type TableDescValue struct { type TableDescValue struct {
Field string // 列名 Field string // 列名
Type string // 数据类型 Type string // 数据类型
Collation []byte // 字符集
Null string // 是否有NULL(NO、YES) Null string // 是否有NULL(NO、YES)
Collation string // 字符集
Privileges string // 权限s
Key string // 键类型 Key string // 键类型
Default string // 默认值 Default []byte // 默认值
Extra string // 其他 Extra string // 其他
Privileges string // 权限
Comment string // 备注 Comment string // 备注
} }
...@@ -338,35 +318,27 @@ func (db *Connector) ShowColumns(tableName string) (*TableDesc, error) { ...@@ -338,35 +318,27 @@ func (db *Connector) ShowColumns(tableName string) (*TableDesc, error) {
tbDesc := NewTableDesc(tableName) tbDesc := NewTableDesc(tableName)
// 执行 show create table // 执行 show create table
res, err := db.Query("show full columns from `%s`.`%s`", db.Database, tableName) res, err := db.Query(fmt.Sprintf("show full columns from `%s`.`%s`", db.Database, tableName))
if err != nil { if err != nil {
return nil, err return nil, err
} }
if res.Error != nil {
field := res.Result.Map("Field") return nil, res.Error
tp := res.Result.Map("Type") }
null := res.Result.Map("Null")
key := res.Result.Map("Key")
def := res.Result.Map("Default")
extra := res.Result.Map("Extra")
collation := res.Result.Map("Collation")
privileges := res.Result.Map("Privileges")
comm := res.Result.Map("Comment")
// 获取值 // 获取值
for _, row := range res.Rows { for res.Rows.Next() {
value := TableDescValue{ var tc TableDescValue
Field: row.Str(field), res.Rows.Scan(&tc.Field,
Type: row.Str(tp), &tc.Type,
Null: row.Str(null), &tc.Collation,
Key: row.Str(key), &tc.Null,
Default: row.Str(def), &tc.Key,
Extra: row.Str(extra), &tc.Default,
Privileges: row.Str(privileges), &tc.Extra,
Collation: row.Str(collation), &tc.Privileges,
Comment: row.Str(comm), &tc.Comment)
} tbDesc.DescValues = append(tbDesc.DescValues, tc)
tbDesc.DescValues = append(tbDesc.DescValues, value)
} }
return tbDesc, err return tbDesc, err
} }
...@@ -383,18 +355,21 @@ func (td TableDesc) Columns() []string { ...@@ -383,18 +355,21 @@ func (td TableDesc) Columns() []string {
// showCreate show create // showCreate show create
func (db *Connector) showCreate(createType, name string) (string, error) { func (db *Connector) showCreate(createType, name string) (string, error) {
// 执行 show create table // 执行 show create table
res, err := db.Query("show create %s `%s`", createType, name) res, err := db.Query(fmt.Sprintf("show create %s `%s`", createType, name))
if err != nil { if err != nil {
return "", err return "", err
} }
if res.Error != nil {
return "", res.Error
}
// 获取ddl // 获取 CREATE TABLE 语句
var ddl string var tableName, createTable string
for _, row := range res.Rows { for res.Rows.Next() {
ddl = row.Str(1) res.Rows.Scan(&tableName, &createTable)
} }
return ddl, err return createTable, err
} }
// ShowCreateDatabase show create database // ShowCreateDatabase show create database
...@@ -451,6 +426,10 @@ func (db *Connector) FindColumn(name, dbName string, tables ...string) ([]*commo ...@@ -451,6 +426,10 @@ func (db *Connector) FindColumn(name, dbName string, tables ...string) ([]*commo
"c.TABLE_NAME,c.TABLE_SCHEMA,c.COLUMN_TYPE,c.CHARACTER_SET_NAME, c.COLLATION_NAME "+ "c.TABLE_NAME,c.TABLE_SCHEMA,c.COLUMN_TYPE,c.CHARACTER_SET_NAME, c.COLLATION_NAME "+
"FROM `INFORMATION_SCHEMA`.`COLUMNS` as c where c.COLUMN_NAME = '%s' ", name) "FROM `INFORMATION_SCHEMA`.`COLUMNS` as c where c.COLUMN_NAME = '%s' ", name)
if dbName != "" {
sql += fmt.Sprintf(" and c.table_schema = '%s'", dbName)
}
if len(tables) > 0 { if len(tables) > 0 {
var tmp []string var tmp []string
for _, table := range tables { for _, table := range tables {
...@@ -459,32 +438,24 @@ func (db *Connector) FindColumn(name, dbName string, tables ...string) ([]*commo ...@@ -459,32 +438,24 @@ func (db *Connector) FindColumn(name, dbName string, tables ...string) ([]*commo
sql += fmt.Sprintf(" and c.table_name in (%s)", strings.Join(tmp, ",")) sql += fmt.Sprintf(" and c.table_name in (%s)", strings.Join(tmp, ","))
} }
if dbName != "" { common.Log.Debug("FindColumn, execute SQL: %s", sql)
sql += fmt.Sprintf(" and c.table_schema = '%s'", dbName)
}
res, err := db.Query(sql) res, err := db.Query(sql)
if err != nil { if err != nil {
common.Log.Error("(db *Connector) FindColumn Error : ", err) common.Log.Error("(db *Connector) FindColumn Error : ", err)
return columns, err return columns, err
} }
if res.Error != nil {
common.Log.Error("(db *Connector) FindColumn Error : ", res.Error)
return columns, res.Error
}
tbName := res.Result.Map("TABLE_NAME") var col common.Column
schema := res.Result.Map("TABLE_SCHEMA") for res.Rows.Next() {
colTyp := res.Result.Map("COLUMN_TYPE") res.Rows.Scan(&col.Table,
colCharset := res.Result.Map("CHARACTER_SET_NAME") &col.DB,
collation := res.Result.Map("COLLATION_NAME") &col.DataType,
&col.Character,
// 获取ddl &col.Collation)
for _, row := range res.Rows {
col := &common.Column{
Name: name,
Table: row.Str(tbName),
DB: row.Str(schema),
DataType: row.Str(colTyp),
Character: row.Str(colCharset),
Collation: row.Str(collation),
}
// 填充字符集和排序规则 // 填充字符集和排序规则
if col.Character == "" { if col.Character == "" {
...@@ -494,40 +465,56 @@ func (db *Connector) FindColumn(name, dbName string, tables ...string) ([]*commo ...@@ -494,40 +465,56 @@ func (db *Connector) FindColumn(name, dbName string, tables ...string) ([]*commo
sql = fmt.Sprintf("SELECT `t`.`TABLE_COLLATION` FROM `INFORMATION_SCHEMA`.`TABLES` AS `t` "+ sql = fmt.Sprintf("SELECT `t`.`TABLE_COLLATION` FROM `INFORMATION_SCHEMA`.`TABLES` AS `t` "+
"WHERE `t`.`TABLE_NAME`='%s' AND `t`.`TABLE_SCHEMA` = '%s'", col.Table, col.DB) "WHERE `t`.`TABLE_NAME`='%s' AND `t`.`TABLE_SCHEMA` = '%s'", col.Table, col.DB)
var newRes *QueryResult
common.Log.Debug("FindColumn, execute SQL: %s", sql)
var newRes QueryResult
newRes, err = db.Query(sql) newRes, err = db.Query(sql)
if err != nil { if err != nil {
common.Log.Error("(db *Connector) FindColumn Error : ", err) common.Log.Error("(db *Connector) FindColumn Error : ", err)
return columns, err return columns, err
} }
if res.Error != nil {
common.Log.Error("(db *Connector) FindColumn Error : ", res.Error)
return columns, res.Error
}
tbCollation := newRes.Rows[0].Str(0) var tbCollation string
if newRes.Rows.Next() {
newRes.Rows.Scan(&tbCollation)
}
if tbCollation != "" { if tbCollation != "" {
col.Character = strings.Split(tbCollation, "_")[0] col.Character = strings.Split(tbCollation, "_")[0]
col.Collation = tbCollation col.Collation = tbCollation
} }
} }
columns = append(columns, &col)
columns = append(columns, col)
} }
return columns, err return columns, err
} }
// IsFKey 判断列是否是外键 // IsForeignKey 判断列是否是外键
func (db *Connector) IsFKey(dbName, tbName, column string) bool { func (db *Connector) IsForeignKey(dbName, tbName, column string) bool {
sql := fmt.Sprintf("SELECT REFERENCED_COLUMN_NAME FROM INFORMATION_SCHEMA.KEY_COLUMN_USAGE C "+ sql := fmt.Sprintf("SELECT REFERENCED_COLUMN_NAME FROM INFORMATION_SCHEMA.KEY_COLUMN_USAGE C "+
"WHERE REFERENCED_TABLE_SCHEMA <> 'NULL' AND"+ "WHERE REFERENCED_TABLE_SCHEMA <> 'NULL' AND"+
" TABLE_NAME='%s' AND"+ " TABLE_NAME='%s' AND"+
" TABLE_SCHEMA='%s' AND"+ " TABLE_SCHEMA='%s' AND"+
" COLUMN_NAME='%s'", tbName, dbName, column) " COLUMN_NAME='%s'", tbName, dbName, column)
common.Log.Debug("IsForeignKey, execute SQL: %s", sql)
res, err := db.Query(sql) res, err := db.Query(sql)
if err == nil && len(res.Rows) == 0 { if err != nil {
common.Log.Error("IsForeignKey, Error: %s", err.Error())
return false
}
if res.Error != nil {
common.Log.Error("IsForeignKey, Error: %s", res.Error.Error())
return false return false
} }
if res.Rows.Next() {
return true
}
return true return false
} }
// Reference 用于存储关系 // Reference 用于存储关系
...@@ -535,11 +522,11 @@ type Reference map[string][]ReferenceValue ...@@ -535,11 +522,11 @@ type Reference map[string][]ReferenceValue
// ReferenceValue 用于处理表之间的关系 // ReferenceValue 用于处理表之间的关系
type ReferenceValue struct { type ReferenceValue struct {
RefDBName string // 夫表所属数据库 ReferencedTableSchema string // 夫表所属数据库
RefTable string // 父表 ReferencedTableName string // 父表
DBName string // 子表所属数据库 TableSchema string // 子表所属数据库
Table string // 子表 TableName string // 子表
ConstraintName string // 关系名称 ConstraintName string // 关系名称
} }
// ShowReference 查找所有的外键信息 // ShowReference 查找所有的外键信息
...@@ -555,30 +542,26 @@ WHERE C.REFERENCED_TABLE_NAME IS NOT NULL` ...@@ -555,30 +542,26 @@ WHERE C.REFERENCED_TABLE_NAME IS NOT NULL`
sql = sql + extra sql = sql + extra
} }
common.Log.Debug("ShowReference, execute SQL: %s", sql)
// 执行SQL查找外键关联关系 // 执行SQL查找外键关联关系
res, err := db.Query(sql) res, err := db.Query(sql)
if err != nil { if err != nil {
return referenceValues, err return referenceValues, err
} }
if res.Error != nil {
refDb := res.Result.Map("REFERENCED_TABLE_SCHEMA") return referenceValues, res.Error
refTb := res.Result.Map("REFERENCED_TABLE_NAME") }
schema := res.Result.Map("TABLE_SCHEMA")
tb := res.Result.Map("TABLE_NAME")
cName := res.Result.Map("CONSTRAINT_NAME")
// 获取值 // 获取值
for _, row := range res.Rows { for res.Rows.Next() {
value := ReferenceValue{ var rv ReferenceValue
RefDBName: row.Str(refDb), res.Rows.Scan(&rv.ReferencedTableSchema,
RefTable: row.Str(refTb), &rv.ReferencedTableName,
DBName: row.Str(schema), &rv.TableSchema,
Table: row.Str(tb), &rv.TableName,
ConstraintName: row.Str(cName), &rv.ConstraintName)
} referenceValues = append(referenceValues, rv)
referenceValues = append(referenceValues, value)
} }
return referenceValues, err return referenceValues, err
} }
...@@ -20,75 +20,144 @@ import ( ...@@ -20,75 +20,144 @@ import (
"fmt" "fmt"
"testing" "testing"
"github.com/XiaoMi/soar/common"
"github.com/kr/pretty" "github.com/kr/pretty"
"vitess.io/vitess/go/vt/sqlparser" "vitess.io/vitess/go/vt/sqlparser"
) )
func TestShowTableStatus(t *testing.T) { func TestShowTableStatus(t *testing.T) {
connTest.Database = "information_schema" orgDatabase := connTest.Database
ts, err := connTest.ShowTableStatus("TABLES") connTest.Database = "sakila"
ts, err := connTest.ShowTableStatus("film")
if err != nil { if err != nil {
t.Error("ShowTableStatus Error: ", err) t.Error("ShowTableStatus Error: ", err)
} }
if string(ts.Rows[0].Engine) != "InnoDB" {
t.Error("film table should be InnoDB engine")
}
pretty.Println(ts) pretty.Println(ts)
connTest.Database = "sakila"
ts, err = connTest.ShowTableStatus("actor_info")
if err != nil {
t.Error("ShowTableStatus Error: ", err)
}
if string(ts.Rows[0].Comment) != "VIEW" {
t.Error("actor_info should be VIEW")
}
pretty.Println(ts)
connTest.Database = orgDatabase
} }
func TestShowTables(t *testing.T) { func TestShowTables(t *testing.T) {
connTest.Database = "information_schema" orgDatabase := connTest.Database
connTest.Database = "sakila"
ts, err := connTest.ShowTables() ts, err := connTest.ShowTables()
if err != nil { if err != nil {
t.Error("ShowTableStatus Error: ", err) t.Error("ShowTableStatus Error: ", err)
} }
pretty.Println(ts)
err = common.GoldenDiff(func() {
for _, table := range ts {
fmt.Println(table)
}
}, t.Name(), update)
if err != nil {
t.Error(err)
}
connTest.Database = orgDatabase
} }
func TestShowCreateTable(t *testing.T) { func TestShowCreateTable(t *testing.T) {
connTest.Database = "information_schema" orgDatabase := connTest.Database
ts, err := connTest.ShowCreateTable("TABLES") connTest.Database = "sakila"
ts, err := connTest.ShowCreateTable("film")
if err != nil { if err != nil {
t.Error("ShowCreateTable Error: ", err) t.Error("ShowCreateTable Error: ", err)
} }
fmt.Println(ts)
stmt, err := sqlparser.Parse(ts) err = common.GoldenDiff(func() {
pretty.Println(stmt, err) fmt.Println(ts)
stmt, err := sqlparser.Parse(ts)
if err != nil {
t.Error(err.Error())
}
pretty.Println(stmt, err)
}, t.Name(), update)
if err != nil {
t.Error(err)
}
connTest.Database = orgDatabase
} }
func TestShowIndex(t *testing.T) { func TestShowIndex(t *testing.T) {
connTest.Database = "information_schema" orgDatabase := connTest.Database
ti, err := connTest.ShowIndex("TABLES") connTest.Database = "sakila"
ti, err := connTest.ShowIndex("film")
if err != nil { if err != nil {
t.Error("ShowIndex Error: ", err) t.Error("ShowIndex Error: ", err)
} }
pretty.Println(ti.FindIndex(IndexKeyName, "idx_store_id_film_id"))
err = common.GoldenDiff(func() {
pretty.Println(ti)
pretty.Println(ti.FindIndex(IndexKeyName, "idx_title"))
}, t.Name(), update)
if err != nil {
t.Error(err)
}
connTest.Database = orgDatabase
} }
func TestShowColumns(t *testing.T) { func TestShowColumns(t *testing.T) {
connTest.Database = "information_schema" orgDatabase := connTest.Database
ti, err := connTest.ShowColumns("TABLES") connTest.Database = "sakila"
ti, err := connTest.ShowColumns("film")
if err != nil { if err != nil {
t.Error("ShowColumns Error: ", err) t.Error("ShowColumns Error: ", err)
} }
pretty.Println(ti)
err = common.GoldenDiff(func() {
pretty.Println(ti)
}, t.Name(), update)
if err != nil {
t.Error(err)
}
connTest.Database = orgDatabase
} }
func TestFindColumn(t *testing.T) { func TestFindColumn(t *testing.T) {
ti, err := connTest.FindColumn("id", "") ti, err := connTest.FindColumn("film_id", "sakila", "film")
if err != nil { if err != nil {
t.Error("FindColumn Error: ", err) t.Error("FindColumn Error: ", err)
} }
pretty.Println(ti) err = common.GoldenDiff(func() {
pretty.Println(ti)
}, t.Name(), update)
if err != nil {
t.Error(err)
}
}
func TestIsFKey(t *testing.T) {
if !connTest.IsForeignKey("sakila", "film", "language_id") {
t.Error("want True. got false")
}
} }
func TestShowReference(t *testing.T) { func TestShowReference(t *testing.T) {
rv, err := connTest.ShowReference("test2", "homeImg") rv, err := connTest.ShowReference("sakila", "film")
if err != nil { if err != nil {
t.Error("ShowReference Error: ", err) t.Error("ShowReference Error: ", err)
} }
pretty.Println(rv)
}
func TestIsFKey(t *testing.T) { err = common.GoldenDiff(func() {
if !connTest.IsFKey("sakila", "film", "language_id") { pretty.Println(rv)
t.Error("want True. got false") }, t.Name(), update)
if err != nil {
t.Error(err)
} }
} }
[]*common.Column{
&common.Column{
Name: "",
Alias: nil,
Table: "film",
DB: "sakila",
DataType: "smallint(5) unsigned",
Character: "utf8",
Collation: "utf8_general_ci",
Cardinality: 0,
Null: "",
Key: "",
Default: "",
Extra: "",
Comment: "",
Privileges: "",
},
}
```sql
select 1
```
```json
{
"steps": [
{
"join_preparation": {
"select#": 1,
"steps": [
{
"expanded_query": "/* select#1 */ select 1 AS `1`"
}
]
}
},
{
"join_optimization": {
"select#": 1,
"steps": [
]
}
},
{
"join_explain": {
"select#": 1,
"steps": [
]
}
}
]
}
```
&database.TableDesc{
Name: "film",
DescValues: {
{
Field: "film_id",
Type: "smallint(5) unsigned",
Collation: nil,
Null: "NO",
Key: "PRI",
Default: nil,
Extra: "auto_increment",
Privileges: "select,insert,update,references",
Comment: "",
},
{
Field: "title",
Type: "varchar(255)",
Collation: {0x75, 0x74, 0x66, 0x38, 0x5f, 0x67, 0x65, 0x6e, 0x65, 0x72, 0x61, 0x6c, 0x5f, 0x63, 0x69},
Null: "NO",
Key: "MUL",
Default: nil,
Extra: "",
Privileges: "select,insert,update,references",
Comment: "",
},
{
Field: "description",
Type: "text",
Collation: {0x75, 0x74, 0x66, 0x38, 0x5f, 0x67, 0x65, 0x6e, 0x65, 0x72, 0x61, 0x6c, 0x5f, 0x63, 0x69},
Null: "YES",
Key: "",
Default: nil,
Extra: "",
Privileges: "select,insert,update,references",
Comment: "",
},
{
Field: "release_year",
Type: "year(4)",
Collation: nil,
Null: "YES",
Key: "",
Default: nil,
Extra: "",
Privileges: "select,insert,update,references",
Comment: "",
},
{
Field: "language_id",
Type: "tinyint(3) unsigned",
Collation: nil,
Null: "NO",
Key: "MUL",
Default: nil,
Extra: "",
Privileges: "select,insert,update,references",
Comment: "",
},
{
Field: "original_language_id",
Type: "tinyint(3) unsigned",
Collation: nil,
Null: "YES",
Key: "MUL",
Default: nil,
Extra: "",
Privileges: "select,insert,update,references",
Comment: "",
},
{
Field: "rental_duration",
Type: "tinyint(3) unsigned",
Collation: nil,
Null: "NO",
Key: "",
Default: {0x33},
Extra: "",
Privileges: "select,insert,update,references",
Comment: "",
},
{
Field: "rental_rate",
Type: "decimal(4,2)",
Collation: nil,
Null: "NO",
Key: "",
Default: {0x34, 0x2e, 0x39, 0x39},
Extra: "",
Privileges: "select,insert,update,references",
Comment: "",
},
{
Field: "length",
Type: "smallint(5) unsigned",
Collation: nil,
Null: "YES",
Key: "",
Default: nil,
Extra: "",
Privileges: "select,insert,update,references",
Comment: "",
},
{
Field: "replacement_cost",
Type: "decimal(5,2)",
Collation: nil,
Null: "NO",
Key: "",
Default: {0x31, 0x39, 0x2e, 0x39, 0x39},
Extra: "",
Privileges: "select,insert,update,references",
Comment: "",
},
{
Field: "rating",
Type: "enum('G','PG','PG-13','R','NC-17')",
Collation: {0x75, 0x74, 0x66, 0x38, 0x5f, 0x67, 0x65, 0x6e, 0x65, 0x72, 0x61, 0x6c, 0x5f, 0x63, 0x69},
Null: "YES",
Key: "",
Default: {0x47},
Extra: "",
Privileges: "select,insert,update,references",
Comment: "",
},
{
Field: "special_features",
Type: "set('Trailers','Commentaries','Deleted Scenes','Behind the Scenes')",
Collation: {0x75, 0x74, 0x66, 0x38, 0x5f, 0x67, 0x65, 0x6e, 0x65, 0x72, 0x61, 0x6c, 0x5f, 0x63, 0x69},
Null: "YES",
Key: "",
Default: nil,
Extra: "",
Privileges: "select,insert,update,references",
Comment: "",
},
{
Field: "last_update",
Type: "timestamp",
Collation: nil,
Null: "NO",
Key: "",
Default: {0x43, 0x55, 0x52, 0x52, 0x45, 0x4e, 0x54, 0x5f, 0x54, 0x49, 0x4d, 0x45, 0x53, 0x54, 0x41, 0x4d, 0x50},
Extra: "on update CURRENT_TIMESTAMP",
Privileges: "select,insert,update,references",
Comment: "",
},
},
}
CREATE TABLE `film` (
`film_id` smallint(5) unsigned NOT NULL AUTO_INCREMENT,
`title` varchar(255) NOT NULL,
`description` text,
`release_year` year(4) DEFAULT NULL,
`language_id` tinyint(3) unsigned NOT NULL,
`original_language_id` tinyint(3) unsigned DEFAULT NULL,
`rental_duration` tinyint(3) unsigned NOT NULL DEFAULT '3',
`rental_rate` decimal(4,2) NOT NULL DEFAULT '4.99',
`length` smallint(5) unsigned DEFAULT NULL,
`replacement_cost` decimal(5,2) NOT NULL DEFAULT '19.99',
`rating` enum('G','PG','PG-13','R','NC-17') DEFAULT 'G',
`special_features` set('Trailers','Commentaries','Deleted Scenes','Behind the Scenes') DEFAULT NULL,
`last_update` timestamp NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
PRIMARY KEY (`film_id`),
KEY `idx_title` (`title`),
KEY `idx_fk_language_id` (`language_id`),
KEY `idx_fk_original_language_id` (`original_language_id`)
) ENGINE=InnoDB AUTO_INCREMENT=1001 DEFAULT CHARSET=utf8
&sqlparser.DDL{
Action: "create",
FromTables: nil,
ToTables: nil,
Table: sqlparser.TableName{
Name: sqlparser.TableIdent{v:"film"},
Qualifier: sqlparser.TableIdent{},
},
IfExists: false,
TableSpec: (*sqlparser.TableSpec)(nil),
OptLike: (*sqlparser.OptLike)(nil),
PartitionSpec: (*sqlparser.PartitionSpec)(nil),
VindexSpec: (*sqlparser.VindexSpec)(nil),
VindexCols: nil,
} nil
&database.TableIndexInfo{
TableName: "film",
Rows: {
{Table:"film", NonUnique:0, KeyName:"PRIMARY", SeqInIndex:1, ColumnName:"film_id", Collation:"A", Cardinality:1000, SubPart:0, Packed:0, Null:"", IndexType:"", Comment:"", IndexComment:"", Visible:""},
{Table:"film", NonUnique:1, KeyName:"idx_title", SeqInIndex:1, ColumnName:"title", Collation:"A", Cardinality:1000, SubPart:0, Packed:0, Null:"", IndexType:"", Comment:"", IndexComment:"", Visible:""},
{Table:"film", NonUnique:1, KeyName:"idx_fk_language_id", SeqInIndex:1, ColumnName:"language_id", Collation:"A", Cardinality:1, SubPart:0, Packed:0, Null:"", IndexType:"", Comment:"", IndexComment:"", Visible:""},
{Table:"film", NonUnique:1, KeyName:"idx_fk_original_language_id", SeqInIndex:1, ColumnName:"original_language_id", Collation:"A", Cardinality:1, SubPart:0, Packed:0, Null:"", IndexType:"", Comment:"", IndexComment:"", Visible:""},
},
}
[]database.TableIndexRow{
{Table:"film", NonUnique:1, KeyName:"idx_title", SeqInIndex:1, ColumnName:"title", Collation:"A", Cardinality:1000, SubPart:0, Packed:0, Null:"", IndexType:"", Comment:"", IndexComment:"", Visible:""},
}
[]database.ReferenceValue{
{ReferencedTableSchema:"sakila", ReferencedTableName:"language", TableSchema:"sakila", TableName:"film", ConstraintName:"fk_film_language"},
{ReferencedTableSchema:"sakila", ReferencedTableName:"language", TableSchema:"sakila", TableName:"film", ConstraintName:"fk_film_language_original"},
}
actor
actor_info
address
category
city
country
customer
customer_list
film
film_actor
film_category
film_list
film_text
inventory
language
nicer_but_slower_film_list
payment
rental
sales_by_film_category
sales_by_store
staff
staff_list
store
[]database.TraceRow{
{Query:"explain select 1", Trace:"{\n \"steps\": [\n {\n \"join_preparation\": {\n \"select#\": 1,\n \"steps\": [\n {\n \"expanded_query\": \"/* select#1 */ select 1 AS `1`\"\n }\n ]\n }\n },\n {\n \"join_optimization\": {\n \"select#\": 1,\n \"steps\": [\n ]\n }\n },\n {\n \"join_explain\": {\n \"select#\": 1,\n \"steps\": [\n ]\n }\n }\n ]\n}", MissingBytesBeyondMaxMemSize:0, InsufficientPrivileges:0},
}
...@@ -19,12 +19,9 @@ package database ...@@ -19,12 +19,9 @@ package database
import ( import (
"errors" "errors"
"fmt" "fmt"
"io" "github.com/XiaoMi/soar/common"
"regexp" "regexp"
"strings" "strings"
"time"
"github.com/XiaoMi/soar/common"
"vitess.io/vitess/go/vt/sqlparser" "vitess.io/vitess/go/vt/sqlparser"
) )
...@@ -43,10 +40,11 @@ type TraceRow struct { ...@@ -43,10 +40,11 @@ type TraceRow struct {
} }
// Trace 执行SQL,并对其Trace // Trace 执行SQL,并对其Trace
func (db *Connector) Trace(sql string, params ...interface{}) (*QueryResult, error) { func (db *Connector) Trace(sql string, params ...interface{}) ([]TraceRow, error) {
common.Log.Debug("Trace SQL: %s", sql) common.Log.Debug("Trace SQL: %s", sql)
var rows []TraceRow
if common.Config.TestDSN.Version < 50600 { if common.Config.TestDSN.Version < 50600 {
return nil, errors.New("version < 5.6, not support trace") return rows, errors.New("version < 5.6, not support trace")
} }
// 过滤不需要 Trace 的 SQL // 过滤不需要 Trace 的 SQL
...@@ -55,98 +53,71 @@ func (db *Connector) Trace(sql string, params ...interface{}) (*QueryResult, err ...@@ -55,98 +53,71 @@ func (db *Connector) Trace(sql string, params ...interface{}) (*QueryResult, err
sql = "explain " + sql sql = "explain " + sql
case sqlparser.EXPLAIN: case sqlparser.EXPLAIN:
default: default:
return nil, errors.New("no need trace") return rows, errors.New("no need trace")
} }
// 测试环境如果检查是关闭的,则SQL不会被执行 // 测试环境如果检查是关闭的,则SQL不会被执行
if common.Config.TestDSN.Disable { if common.Config.TestDSN.Disable {
return nil, errors.New("Dsn Disable") return rows, errors.New("dsn is disable")
} }
// 数据库安全性检查:如果 Connector 的 IP 端口与 TEST 环境不一致,则启用SQL白名单 // 数据库安全性检查:如果 Connector 的 IP 端口与 TEST 环境不一致,则启用SQL白名单
// 不在白名单中的SQL不允许执行 // 不在白名单中的SQL不允许执行
// 执行环境与test环境不相同 // 执行环境与test环境不相同
if db.Addr != common.Config.TestDSN.Addr && db.dangerousQuery(sql) { if db.Addr != common.Config.TestDSN.Addr && db.dangerousQuery(sql) {
return nil, fmt.Errorf("query Execution Deny: Execute SQL with DSN(%s/%s) '%s'", return rows, fmt.Errorf("query Execution Deny: Execute SQL with DSN(%s/%s) '%s'",
db.Addr, db.Database, fmt.Sprintf(sql, params...)) db.Addr, db.Database, fmt.Sprintf(sql, params...))
} }
common.Log.Debug("Execute SQL with DSN(%s/%s) : %s", db.Addr, db.Database, sql) common.Log.Debug("Execute SQL with DSN(%s/%s) : %s", db.Addr, db.Database, sql)
conn := db.NewConnection() conn, err := db.NewConnection()
// 设置SQL连接超时时间
conn.SetTimeout(time.Duration(common.Config.ConnTimeOut) * time.Second)
defer conn.Close()
err := conn.Connect()
if err != nil { if err != nil {
return nil, err return rows, err
} }
defer conn.Close()
// 添加SQL执行超时限制 // 开启Trace
ch := make(chan QueryResult, 1) common.Log.Debug("SET SESSION OPTIMIZER_TRACE='enabled=on'")
go func() { trx, err := conn.Begin()
// 开启Trace if err != nil {
common.Log.Debug("SET SESSION OPTIMIZER_TRACE='enabled=on'") return rows, err
_, _, err = conn.Query("SET SESSION OPTIMIZER_TRACE='enabled=on'") }
common.LogIfError(err, "") defer trx.Rollback()
_, err = trx.Query("SET SESSION OPTIMIZER_TRACE='enabled=on'")
// 执行SQL,抛弃返回结果 common.LogIfError(err, "")
result, err := conn.Start(sql, params...)
if err != nil {
ch <- QueryResult{
Error: err,
}
return
}
row := result.MakeRow()
for {
err = result.ScanRow(row)
if err == io.EOF {
break
}
}
// 返回Trace结果 // 执行SQL,抛弃返回结果
res := QueryResult{} tmpRes, err := trx.Query(sql, params...)
res.Rows, res.Result, res.Error = conn.Query("SELECT * FROM information_schema.OPTIMIZER_TRACE") if err != nil {
return rows, err
}
for tmpRes.Next() {
continue
}
// 关闭Trace // 返回Trace结果
common.Log.Debug("SET SESSION OPTIMIZER_TRACE='enabled=off'") res, err := trx.Query("SELECT * FROM information_schema.OPTIMIZER_TRACE")
_, _, err = conn.Query("SET SESSION OPTIMIZER_TRACE='enabled=off'") for res.Next() {
var traceRow TraceRow
err = res.Scan(&traceRow.Query, &traceRow.Trace, &traceRow.MissingBytesBeyondMaxMemSize, &traceRow.InsufficientPrivileges)
if err != nil { if err != nil {
fmt.Println(err.Error()) common.LogIfError(err, "")
} }
ch <- res rows = append(rows, traceRow)
}()
select {
case res := <-ch:
return &res, res.Error
case <-time.After(time.Duration(common.Config.QueryTimeOut) * time.Second):
return nil, errors.New("query execution timeout")
} }
}
// getTrace 获取trace信息 // 关闭Trace
func getTrace(res *QueryResult) Trace { common.Log.Debug("SET SESSION OPTIMIZER_TRACE='enabled=off'")
var rows []TraceRow _, err = trx.Query("SET SESSION OPTIMIZER_TRACE='enabled=off'")
for _, row := range res.Rows { common.LogIfError(err, "")
rows = append(rows, TraceRow{ return rows, err
Query: row.Str(0),
Trace: row.Str(1),
MissingBytesBeyondMaxMemSize: row.Int(2),
InsufficientPrivileges: row.Int(3),
})
}
return Trace{Rows: rows}
} }
// FormatTrace 格式化输出Trace信息 // FormatTrace 格式化输出Trace信息
func FormatTrace(res *QueryResult) string { func FormatTrace(rows []TraceRow) string {
explainReg := regexp.MustCompile(`(?i)^explain\s+`) explainReg := regexp.MustCompile(`(?i)^explain\s+`)
trace := getTrace(res)
str := []string{""} str := []string{""}
for _, row := range trace.Rows { for _, row := range rows {
str = append(str, "```sql") str = append(str, "```sql")
sql := explainReg.ReplaceAllString(row.Query, "") sql := explainReg.ReplaceAllString(row.Query, "")
str = append(str, sql) str = append(str, sql)
......
...@@ -17,7 +17,6 @@ ...@@ -17,7 +17,6 @@
package database package database
import ( import (
"flag"
"testing" "testing"
"github.com/XiaoMi/soar/common" "github.com/XiaoMi/soar/common"
...@@ -25,34 +24,30 @@ import ( ...@@ -25,34 +24,30 @@ import (
"github.com/kr/pretty" "github.com/kr/pretty"
) )
var update = flag.Bool("update", false, "update .golden files")
func TestTrace(t *testing.T) { func TestTrace(t *testing.T) {
common.Config.QueryTimeOut = 1
res, err := connTest.Trace("select 1") res, err := connTest.Trace("select 1")
if err == nil { if err != nil {
common.GoldenDiff(func() { t.Error(err)
pretty.Println(res) }
}, t.Name(), update)
} else { err = common.GoldenDiff(func() {
pretty.Println(res)
}, t.Name(), update)
if err != nil {
t.Error(err) t.Error(err)
} }
} }
func TestFormatTrace(t *testing.T) { func TestFormatTrace(t *testing.T) {
res, err := connTest.Trace("select 1") res, err := connTest.Trace("select 1")
if err == nil { if err != nil {
pretty.Println(FormatTrace(res))
} else {
t.Error(err) t.Error(err)
} }
}
func TestGetTrace(t *testing.T) { err = common.GoldenDiff(func() {
res, err := connTest.Trace("select 1") pretty.Println(FormatTrace(res))
if err == nil { }, t.Name(), update)
pretty.Println(getTrace(res)) if err != nil {
} else {
t.Error(err) t.Error(err)
} }
} }
...@@ -161,8 +161,9 @@ func (ve *VirtualEnv) CleanupTestDatabase() { ...@@ -161,8 +161,9 @@ func (ve *VirtualEnv) CleanupTestDatabase() {
// TODO: 1 hour should be config-able // TODO: 1 hour should be config-able
minHour := 1 minHour := 1
for _, row := range dbs.Rows { for dbs.Rows.Next() {
testDatabase := row.Str(0) var testDatabase string
dbs.Rows.Scan(&testDatabase)
// test temporary database format `optimizer_YYMMDDHHmmss_randomString(16)` // test temporary database format `optimizer_YYMMDDHHmmss_randomString(16)`
if len(testDatabase) != 39 { if len(testDatabase) != 39 {
common.Log.Debug("CleanupTestDatabase by pass %s", testDatabase) common.Log.Debug("CleanupTestDatabase by pass %s", testDatabase)
...@@ -218,7 +219,7 @@ func (ve *VirtualEnv) BuildVirtualEnv(rEnv *database.Connector, SQLs ...string) ...@@ -218,7 +219,7 @@ func (ve *VirtualEnv) BuildVirtualEnv(rEnv *database.Connector, SQLs ...string)
meta := make(map[string]*common.DB) meta := make(map[string]*common.DB)
for _, sql := range SQLs { for _, sql := range SQLs {
common.Log.Debug("BuildVirtualEnv Database&Table Mapping, SQL: %s", sql) common.Log.Debug("BuildVirtualEnv Database&TableName Mapping, SQL: %s", sql)
stmt, err = sqlparser.Parse(sql) stmt, err = sqlparser.Parse(sql)
if err != nil { if err != nil {
...@@ -320,7 +321,7 @@ func (ve *VirtualEnv) BuildVirtualEnv(rEnv *database.Connector, SQLs ...string) ...@@ -320,7 +321,7 @@ func (ve *VirtualEnv) BuildVirtualEnv(rEnv *database.Connector, SQLs ...string)
} }
// 如果是视图,解析语句 // 如果是视图,解析语句
if len(tbStatus.Rows) > 0 && tbStatus.Rows[0].Comment == "VIEW" { if len(tbStatus.Rows) > 0 && string(tbStatus.Rows[0].Comment) == "VIEW" {
tmpEnv.Database = db tmpEnv.Database = db
var viewDDL string var viewDDL string
viewDDL, err = tmpEnv.ShowCreateTable(tb.TableName) viewDDL, err = tmpEnv.ShowCreateTable(tb.TableName)
...@@ -418,7 +419,7 @@ func (ve VirtualEnv) createTable(rEnv database.Connector, dbName, tbName string) ...@@ -418,7 +419,7 @@ func (ve VirtualEnv) createTable(rEnv database.Connector, dbName, tbName string)
return nil return nil
} }
common.Log.Debug("createTable, Database: %s, Table: %s", dbName, tbName) common.Log.Debug("createTable, Database: %s, TableName: %s", dbName, tbName)
// TODO:查看是否有外键关联(done),对外键的支持 (未解决循环依赖的问题) // TODO:查看是否有外键关联(done),对外键的支持 (未解决循环依赖的问题)
...@@ -506,9 +507,9 @@ func (ve *VirtualEnv) GenTableColumns(meta common.Meta) common.TableColumns { ...@@ -506,9 +507,9 @@ func (ve *VirtualEnv) GenTableColumns(meta common.Meta) common.TableColumns {
DB: dbName, DB: dbName,
Table: tb.TableName, Table: tb.TableName,
DataType: colInfo.Type, DataType: colInfo.Type,
Character: colInfo.Collation, Character: string(colInfo.Collation),
Key: colInfo.Key, Key: colInfo.Key,
Default: colInfo.Default, Default: string(colInfo.Default),
Extra: colInfo.Extra, Extra: colInfo.Extra,
Comment: colInfo.Comment, Comment: colInfo.Comment,
Privileges: colInfo.Privileges, Privileges: colInfo.Privileges,
...@@ -525,9 +526,9 @@ func (ve *VirtualEnv) GenTableColumns(meta common.Meta) common.TableColumns { ...@@ -525,9 +526,9 @@ func (ve *VirtualEnv) GenTableColumns(meta common.Meta) common.TableColumns {
col.DB = dbName col.DB = dbName
col.Table = tb.TableName col.Table = tb.TableName
col.DataType = colInfo.Type col.DataType = colInfo.Type
col.Character = colInfo.Collation col.Character = string(colInfo.Collation)
col.Key = colInfo.Key col.Key = colInfo.Key
col.Default = colInfo.Default col.Default = string(colInfo.Default)
col.Extra = colInfo.Extra col.Extra = colInfo.Extra
col.Comment = colInfo.Comment col.Comment = colInfo.Comment
col.Privileges = colInfo.Privileges col.Privileges = colInfo.Privileges
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册