diff --git a/Makefile b/Makefile
index 885f2350505e40f5162c681d1be6bb9158e22cd7..5e6ce8fd650706eaba9cfaec3524b7d3db250ff3 100644
--- a/Makefile
+++ b/Makefile
@@ -181,7 +181,7 @@ docker:
.PHONY: connect
connect:
- mysql -h 127.0.0.1 -u root -p1tIsB1g3rt -c
+ mysql -h 127.0.0.1 -u root -p1tIsB1g3rt sakila -c
.PHONY: main_test
main_test: install
diff --git a/advisor/heuristic.go b/advisor/heuristic.go
index 1ccab69d1bdcd092f87bbfe19e15bdd8134bee1a..c8feba434285e91c5fadcc80b02fef1298d91861 100644
--- a/advisor/heuristic.go
+++ b/advisor/heuristic.go
@@ -1990,7 +1990,7 @@ func (idxAdv *IndexAdvisor) RuleUpdatePrimaryKey() Rule {
if idxMeta == nil {
return rule
}
- for _, idx := range idxMeta.IdxRows {
+ for _, idx := range idxMeta.Rows {
if idx.KeyName == "PRIMARY" {
if col.Name == idx.ColumnName {
rule = HeuristicRules["CLA.016"]
diff --git a/advisor/index.go b/advisor/index.go
index 0c00818eddc5ef96b2e0dc6018f028a4fc75a129..9a2ffc4fedcffe2d1c1c4c1efbc3ee8899876516 100644
--- a/advisor/index.go
+++ b/advisor/index.go
@@ -914,7 +914,7 @@ func (idxAdv *IndexAdvisor) calcCardinality(cols []*common.Column) []*common.Col
// 检查对应列是否为主键或单列唯一索引,如果满足直接返回1,不再重复计算,提高效率
// 多列复合唯一索引不能跳过计算,单列普通索引不能跳过计算
- for _, index := range idxAdv.IndexMeta[realDB][col.Table].IdxRows {
+ for _, index := range idxAdv.IndexMeta[realDB][col.Table].Rows {
// 根据索引的名称判断该索引包含的列数,列数大于1即为复合索引
columnCount := len(idxAdv.IndexMeta[realDB][col.Table].FindIndex(database.IndexKeyName, index.KeyName))
if col.Name == index.ColumnName {
@@ -1079,7 +1079,7 @@ func DuplicateKeyChecker(conn *database.Connector, databases ...string) map[stri
}
// 枚举所有的索引信息,提取用到的列
- for _, idx := range idxInfo.IdxRows {
+ for _, idx := range idxInfo.Rows {
if _, ok := idxMap[idx.KeyName]; !ok {
idxMap[idx.KeyName] = make([]*common.Column, 0)
for _, col := range idxInfo.FindIndex(database.IndexKeyName, idx.KeyName) {
diff --git a/advisor/rules.go b/advisor/rules.go
index dc553b424ccb1ede4a00ee0022fb255cbdbce362..404e84518a6a4e4ad3aa30f8ca884f66ac88c3df 100644
--- a/advisor/rules.go
+++ b/advisor/rules.go
@@ -102,7 +102,7 @@ type Rule struct {
* SEC Security
* STA Standard
* SUB Subquery
-* TBL Table
+* TBL TableName
* TRA Trace, 由trace模块给
*/
diff --git a/common/cases.go b/common/cases.go
index 25763f7d831b9aa29ec9adc107fa0eb3e1fbdfc9..b49e9b76cf64983b8a69002a45989c68643a8e09 100644
--- a/common/cases.go
+++ b/common/cases.go
@@ -67,7 +67,7 @@ func init() {
"SELECT * FROM customer WHERE address_id in (224,510) ORDER BY last_name;", // INDEX(address_id)
"SELECT * FROM film WHERE release_year = 2016 AND length != 1 ORDER BY title;", // INDEX(`release_year`, `length`, `title`)
- // "Covering" IdxRows
+ // "Covering" Rows
"SELECT title FROM film WHERE release_year = 1995;", // INDEX(release_year, title)",
"SELECT title, replacement_cost FROM film WHERE language_id = 5 AND length = 70;", // INDEX(language_id, length, title, replacement_cos film ), title, replacement_cost顺序无关,language_id, length顺序视散粒度情况.
"SELECT title FROM film WHERE language_id > 5 AND length > 70;", // INDEX(language_id, length, title) language_id or length first (that's as far as the Algorithm goes), then the other two fields afterwards.
diff --git a/common/config.go b/common/config.go
index 4f84116c407b4870c96943b1559912f4354ad480..cdfacb2c7c4e5d29a69c93e8cd8bac70e27eb1fb 100644
--- a/common/config.go
+++ b/common/config.go
@@ -59,8 +59,6 @@ type Configuration struct {
Profiling bool `yaml:"profiling"` // 在开启数据采样的情况下,在测试环境执行进行profile
Trace bool `yaml:"trace"` // 在开启数据采样的情况下,在测试环境执行进行Trace
Explain bool `yaml:"explain"` // Explain开关
- ConnTimeOut int `yaml:"conn-time-out"` // 数据库连接超时时间,单位秒
- QueryTimeOut int `yaml:"query-time-out"` // 数据库SQL执行超时时间,单位秒
Delimiter string `yaml:"delimiter"` // SQL分隔符
// +++++++++++++++日志相关+++++++++++++++++
@@ -97,8 +95,8 @@ type Configuration struct {
MaxInCount int `yaml:"max-in-count"` // IN()最大数量
MaxIdxBytesPerColumn int `yaml:"max-index-bytes-percolumn"` // 索引中单列最大字节数,默认767
MaxIdxBytes int `yaml:"max-index-bytes"` // 索引总长度限制,默认3072
- TableAllowCharsets []string `yaml:"table-allow-charsets"` // Table 允许使用的 DEFAULT CHARSET
- TableAllowEngines []string `yaml:"table-allow-engines"` // Table 允许使用的 Engine
+ TableAllowCharsets []string `yaml:"table-allow-charsets"` // TableName 允许使用的 DEFAULT CHARSET
+ TableAllowEngines []string `yaml:"table-allow-engines"` // TableName 允许使用的 Engine
MaxIdxCount int `yaml:"max-index-count"` // 单张表允许最多索引数
MaxColCount int `yaml:"max-column-count"` // 单张表允许最大列数
MaxValueCount int `yaml:"max-value-count"` // INSERT/REPLACE 单次允许批量写入的行数
@@ -135,12 +133,14 @@ type Configuration struct {
// Config 默认设置
var Config = &Configuration{
OnlineDSN: &dsn{
+ Net: "tcp",
Schema: "information_schema",
Charset: "utf8mb4",
Disable: true,
Version: 99999,
},
TestDSN: &dsn{
+ Net: "tcp",
Schema: "information_schema",
Charset: "utf8mb4",
Disable: true,
@@ -156,8 +156,6 @@ var Config = &Configuration{
Profiling: false,
Trace: false,
Explain: true,
- ConnTimeOut: 3,
- QueryTimeOut: 30,
Delimiter: ";",
MaxJoinTableCount: 5,
@@ -227,6 +225,7 @@ var Config = &Configuration{
}
type dsn struct {
+ Net string `yaml:"net"`
Addr string `yaml:"addr"`
Schema string `yaml:"schema"`
@@ -236,6 +235,10 @@ type dsn struct {
Charset string `yaml:"charset"`
Disable bool `yaml:"disable"`
+ Timeout int `yaml:"timeout"`
+ ReadTimeout int `yaml:"read-timeout"`
+ WriteTimeout int `yaml:"write-timeout"`
+
Version int `yaml:"-"` // 版本自动检查,不可配置
}
@@ -502,8 +505,6 @@ func readCmdFlags() error {
explain := flag.Bool("explain", Config.Explain, "Explain, 是否开启Explain执行计划分析")
sampling := flag.Bool("sampling", Config.Sampling, "Sampling, 数据采样开关")
samplingStatisticTarget := flag.Int("sampling-statistic-target", Config.SamplingStatisticTarget, "SamplingStatisticTarget, 数据采样因子,对应 PostgreSQL 的 default_statistics_target")
- connTimeOut := flag.Int("conn-time-out", Config.ConnTimeOut, "ConnTimeOut, 数据库连接超时时间,单位秒")
- queryTimeOut := flag.Int("query-time-out", Config.QueryTimeOut, "QueryTimeOut, 数据库SQL执行超时时间,单位秒")
delimiter := flag.String("delimiter", Config.Delimiter, "Delimiter, SQL分隔符")
// +++++++++++++++日志相关+++++++++++++++++
logLevel := flag.Int("log-level", Config.LogLevel, "LogLevel, 日志级别, [0:Emergency, 1:Alert, 2:Critical, 3:Error, 4:Warning, 5:Notice, 6:Informational, 7:Debug]")
@@ -583,8 +584,6 @@ func readCmdFlags() error {
Config.Explain = *explain
Config.Sampling = *sampling
Config.SamplingStatisticTarget = *samplingStatisticTarget
- Config.ConnTimeOut = *connTimeOut
- Config.QueryTimeOut = *queryTimeOut
Config.LogLevel = *logLevel
if strings.HasPrefix(*logOutput, "/") {
diff --git a/common/meta.go b/common/meta.go
index 433704eff8c0059df662a48542a877a2165707e1..e95bb37d91351fa100bb67fe66c4c50ed96fe043 100644
--- a/common/meta.go
+++ b/common/meta.go
@@ -27,7 +27,7 @@ type Meta map[string]*DB
// DB 数据库相关的结构体
type DB struct {
Name string
- Table map[string]*Table // ['table_name']*Table
+ Table map[string]*Table // ['table_name']*TableName
}
// NewDB 用于初始化*DB
@@ -38,14 +38,14 @@ func NewDB(db string) *DB {
}
}
-// Table 含有表的属性
+// TableName 含有表的属性
type Table struct {
TableName string
TableAliases []string
Column map[string]*Column
}
-// NewTable 初始化*Table
+// NewTable 初始化*TableName
func NewTable(tb string) *Table {
return &Table{
TableName: tb,
diff --git a/common/testdata/TestMarkdown2Html.golden b/common/testdata/TestMarkdown2Html.golden
index 8e3b056be373d08649288e305d11e07e528e68b4..6c7f231029daa9daf28080fc16a4e4ec6a025385 100644
--- a/common/testdata/TestMarkdown2Html.golden
+++ b/common/testdata/TestMarkdown2Html.golden
@@ -176,9 +176,9 @@ $$
Typora support YAML Front Matter now. Input ---
at the top of the article and then press Enter
will introduce one. Or insert one metadata block from the menu.
-Table of Contents (TOC)
+TableName of Contents (TOC)
-Input [toc]
then press Return
key will create a section for “Table of Contents” extracting all headers from one’s writing, its contents will be updated automatically.
+Input [toc]
then press Return
key will create a section for “TableName of Contents” extracting all headers from one’s writing, its contents will be updated automatically.
Diagrams (Sequence, Flowchart and Mermaid)
diff --git a/common/testdata/TestMarkdown2Html.md b/common/testdata/TestMarkdown2Html.md
index 654b7f3f760d63f436ac81561ec7a7d24ad95a7e..1ab82ea6e490f6f052a1346374eead46c111f2be 100644
--- a/common/testdata/TestMarkdown2Html.md
+++ b/common/testdata/TestMarkdown2Html.md
@@ -186,9 +186,9 @@ Input `***` or `---` on a blank line and press `return` will draw a horizontal l
Typora support [YAML Front Matter](http://jekyllrb.com/docs/frontmatter/) now. Input `---` at the top of the article and then press `Enter` will introduce one. Or insert one metadata block from the menu.
-### Table of Contents (TOC)
+### TableName of Contents (TOC)
-Input `[toc]` then press `Return` key will create a section for “Table of Contents” extracting all headers from one’s writing, its contents will be updated automatically.
+Input `[toc]` then press `Return` key will create a section for “TableName of Contents” extracting all headers from one’s writing, its contents will be updated automatically.
### Diagrams (Sequence, Flowchart and Mermaid)
diff --git a/database/explain.go b/database/explain.go
index 0e17e6479ca31dab4331487e06a6d9834392ed50..1a0d8173c979635465a1f2fd19060ff964a9a124 100644
--- a/database/explain.go
+++ b/database/explain.go
@@ -267,6 +267,7 @@ var ExplainKeyWords = []string{
"using_temporary_table",
}
+/*
// ExplainColumnIndent EXPLAIN表头
var ExplainColumnIndent = map[string]string{
"id": "id为SELECT的标识符. 它是在SELECT查询中的顺序编号. 如果这一行表示其他行的union结果, 这个值可以为空. 在这种情况下, table列会显示为形如, 表示它是id为M和N的查询行的联合结果.",
@@ -281,6 +282,7 @@ var ExplainColumnIndent = map[string]string{
"filtered": "表示返回结果的行占需要读到的行(rows列的值)的百分比.",
"Extra": "该列显示MySQL在查询过程中的一些详细信息, MySQL查询优化器执行查询的过程中对查询计划的重要补充信息.",
}
+*/
// ExplainSelectType EXPLAIN中SELECT TYPE会出现的类型
var ExplainSelectType = map[string]string{
@@ -555,14 +557,16 @@ func (db *Connector) explainAbleSQL(sql string) (string, error) {
}
// 执行explain请求,返回mysql.Result执行结果
-func (db *Connector) executeExplain(sql string, explainType int, formatType int) (*QueryResult, error) {
+func (db *Connector) executeExplain(sql string, explainType int, formatType int) (QueryResult, error) {
+ var res QueryResult
var err error
+ var explainQuery string
sql, err = db.explainAbleSQL(sql)
if sql == "" {
- return nil, err
+ return res, err
}
- // 5.6以上支持FORMAT=JSON
+ // 5.6以上支持 FORMAT=JSON
explainFormat := ""
switch formatType {
case JSONFormatExplain:
@@ -570,22 +574,23 @@ func (db *Connector) executeExplain(sql string, explainType int, formatType int)
explainFormat = "FORMAT=JSON"
}
}
- // 执行explain
- var res *QueryResult
+
+ // 执行 explain
switch explainType {
case ExtendedExplainType:
// 5.6以上extended关键字已经不推荐使用,8.0废弃了这个关键字
if common.Config.TestDSN.Version >= 50600 {
- res, err = db.Query("explain %s", sql)
+ explainQuery = fmt.Sprintf("explain %s", sql)
} else {
- res, err = db.Query("explain extended %s", sql)
+ explainQuery = fmt.Sprintf("explain extended %s", sql)
}
case PartitionsExplainType:
- res, err = db.Query("explain partitions %s", sql)
+ explainQuery = fmt.Sprintf("explain partitions %s", sql)
default:
- res, err = db.Query("explain %s %s", explainFormat, sql)
+ explainQuery = fmt.Sprintf("explain %s %s", explainFormat, sql)
}
+ res, err = db.Query(explainQuery)
return res, err
}
@@ -928,76 +933,75 @@ func parseJSONExplainText(content string) (*ExplainJSON, error) {
}
// ParseExplainResult 分析 mysql 执行 explain 的结果,返回 ExplainInfo 结构化数据
-func ParseExplainResult(res *QueryResult, formatType int) (exp *ExplainInfo, err error) {
+func ParseExplainResult(res QueryResult, formatType int) (exp *ExplainInfo, err error) {
exp = &ExplainInfo{
ExplainFormat: formatType,
}
// JSON 格式直接调用文本方式解析
if formatType == JSONFormatExplain {
- exp.ExplainJSON, err = parseJSONExplainText(res.Rows[0].Str(0))
+ if res.Rows.Next() {
+ var explainString string
+ err = res.Rows.Scan(&explainString)
+ exp.ExplainJSON, err = parseJSONExplainText(explainString)
+ }
return exp, err
}
- // 生成表头
- colIdx := make(map[int]string)
- for i, f := range res.Result.Fields() {
- colIdx[i] = strings.ToLower(f.Name)
+ // Different MySQL version has different columns define
+ var possibleKeys string
+ expRow := &ExplainRow{}
+ explainFields := make([]interface{}, 0)
+ fields := map[string]interface{}{
+ "id": expRow.ID,
+ "select_type": expRow.SelectType,
+ "table": expRow.TableName,
+ "partitions": expRow.Partitions,
+ "type": expRow.AccessType,
+ "possible_keys": &possibleKeys,
+ "key": expRow.Key,
+ "key_len": expRow.KeyLen,
+ "ref": expRow.Ref,
+ "rows": expRow.Rows,
+ "filtered": expRow.Filtered,
+ "extra": expRow.Extra,
}
+ cols, err := res.Rows.Columns()
+ common.LogIfError(err, "")
+ for _, col := range cols {
+ explainFields = append(explainFields, fields[col])
+ }
+
// 补全 ExplainRows
- var explainrows []*ExplainRow
- for _, row := range res.Rows {
- expRow := &ExplainRow{Partitions: "NULL", Filtered: 0.00}
- // list 到 map 的转换
- for i := range row {
- switch colIdx[i] {
- case "id":
- expRow.ID = row.ForceInt(i)
- case "select_type":
- expRow.SelectType = row.Str(i)
- case "table":
- expRow.TableName = row.Str(i)
- if expRow.TableName == "" {
- expRow.TableName = "NULL"
- }
- case "type":
- expRow.AccessType = row.Str(i)
- if expRow.AccessType == "" {
- expRow.AccessType = "NULL"
- }
- expRow.Scalability = ExplainScalability[expRow.AccessType]
- case "possible_keys":
- expRow.PossibleKeys = strings.Split(row.Str(i), ",")
- case "key":
- expRow.Key = row.Str(i)
- if expRow.Key == "" {
- expRow.Key = "NULL"
- }
- case "key_len":
- expRow.KeyLen = row.Str(i)
- case "ref":
- expRow.Ref = strings.Split(row.Str(i), ",")
- case "rows":
- expRow.Rows = row.ForceInt(i)
- case "extra":
- expRow.Extra = row.Str(i)
- if expRow.Extra == "" {
- expRow.Extra = "NULL"
- }
- case "filtered":
- expRow.Filtered = row.ForceFloat(i)
- // MySQL bug: https://bugs.mysql.com/bug.php?id=34124
- if expRow.Filtered > 100.00 {
- expRow.Filtered = 100.00
- }
- }
+ var explainRows []*ExplainRow
+
+ for res.Rows.Next() {
+ res.Rows.Scan(explainFields...)
+ expRow.PossibleKeys = strings.Split(possibleKeys, ",")
+
+ // MySQL bug: https://bugs.mysql.com/bug.php?id=34124
+ if expRow.Filtered > 100.00 {
+ expRow.Filtered = 100.00
}
- explainrows = append(explainrows, expRow)
+
+ expRow.Scalability = ExplainScalability[expRow.AccessType]
+ explainRows = append(explainRows, expRow)
}
- exp.ExplainRows = explainrows
- for _, w := range res.Warning {
- // 'EXTENDED' is deprecated and will be removed in a future release.
- if w.Int(1) != 1681 {
- exp.Warnings = append(exp.Warnings, &ExplainWarning{Level: w.Str(0), Code: w.Int(1), Message: w.Str(2)})
+ exp.ExplainRows = explainRows
+
+ // check explain warning info
+ if common.Config.ShowWarnings {
+ for res.Warning.Next() {
+ var expWarning *ExplainWarning
+ res.Warning.Scan(
+ expWarning.Level,
+ expWarning.Code,
+ expWarning.Message,
+ )
+
+ // 'EXTENDED' is deprecated and will be removed in a future release.
+ if expWarning.Code != 1681 {
+ exp.Warnings = append(exp.Warnings, expWarning)
+ }
}
}
@@ -1009,7 +1013,6 @@ func ParseExplainResult(res *QueryResult, formatType int) (exp *ExplainInfo, err
// Explain 获取 SQL 的 explain 信息
func (db *Connector) Explain(sql string, explainType int, formatType int) (exp *ExplainInfo, err error) {
- exp = &ExplainInfo{SQL: sql}
if explainType != TraditionalExplainType {
formatType = TraditionalFormatExplain
}
@@ -1025,12 +1028,16 @@ func (db *Connector) Explain(sql string, explainType int, formatType int) (exp *
// 执行EXPLAIN请求
res, err := db.executeExplain(sql, explainType, formatType)
- if err != nil || res == nil {
+ if err != nil {
return exp, err
}
+ if res.Error != nil {
+ return exp, res.Error
+ }
// 解析mysql结果,输出ExplainInfo
exp, err = ParseExplainResult(res, formatType)
+ exp.SQL = sql
return exp, err
}
diff --git a/database/explain_test.go b/database/explain_test.go
index 01ef50f923322d55c2fb48bdf1d12207e39537b8..9366a18ac3afa1cb864100d9b308207aca021e28 100644
--- a/database/explain_test.go
+++ b/database/explain_test.go
@@ -17,8 +17,6 @@
package database
import (
- "fmt"
- "os"
"testing"
"github.com/XiaoMi/soar/common"
@@ -26,23 +24,6 @@ import (
"github.com/kr/pretty"
)
-var connTest *Connector
-
-func init() {
- common.BaseDir = common.DevPath
- common.ParseConfig("")
- connTest = &Connector{
- Addr: common.Config.OnlineDSN.Addr,
- User: common.Config.OnlineDSN.User,
- Pass: common.Config.OnlineDSN.Password,
- Database: common.Config.OnlineDSN.Schema,
- }
- if _, err := connTest.Version(); err != nil {
- common.Log.Critical("Test env Error: %v", err)
- os.Exit(0)
- }
-}
-
var sqls = []string{
`select * from city where country_id = 44;`,
`select * from address where address2 is not null;`,
@@ -54,10 +35,10 @@ var sqls = []string{
`select * from city where country_id > 31 and city = 'Aden';`,
`select * from address where address_id > 8 and city_id < 400 and district = 'Nantou';`,
`select * from address where address_id > 8 and city_id < 400;`,
- `select * from actor where last_update='2006-02-15 04:34:33' and last_name='CHASE' group by first_name;`,
- `select * from address where last_update >='2014-09-25 22:33:47' group by district;`,
- `select * from address group by address,district;`,
- `select * from address where last_update='2014-09-25 22:30:27' group by district,(address_id+city_id);`,
+ `select first_name from actor where last_update='2006-02-15 04:34:33' and last_name='CHASE' group by first_name;`,
+ `select district from address where last_update >='2014-09-25 22:33:47' group by district;`,
+ `select address from address group by address,district;`,
+ `select district from address where last_update='2014-09-25 22:30:27' group by district,(address_id+city_id);`,
`select * from customer where active=1 order by last_name limit 10;`,
`select * from customer order by last_name limit 10;`,
`select * from customer where address_id > 224 order by address_id limit 10;`,
@@ -2351,16 +2332,23 @@ possible_keys: idx_fk_country_id,idx_country_id_city,idx_all,idx_other
}
func TestExplain(t *testing.T) {
- for _, sql := range sqls {
+ // TraditionalFormatExplain
+ for idx, sql := range sqls {
exp, err := connTest.Explain(sql, TraditionalExplainType, TraditionalFormatExplain)
- //exp, err := conn.Explain(sql, TraditionalExplainType, JSONFormatExplain)
- fmt.Println("Old: ", sql)
- fmt.Println("New: ", exp.SQL)
if err != nil {
- fmt.Println(err)
+ t.Error(err)
}
+ pretty.Println("No.:", idx, "\nOld: ", sql, "\nNew: ", exp.SQL)
+ pretty.Println(exp)
+ }
+ // JSONFormatExplain
+ for idx, sql := range sqls {
+ exp, err := connTest.Explain(sql, TraditionalExplainType, JSONFormatExplain)
+ if err != nil {
+ t.Error(err)
+ }
+ pretty.Println("No.:", idx, "\nOld: ", sql, "\nNew: ", exp.SQL)
pretty.Println(exp)
- fmt.Println()
}
}
@@ -2400,6 +2388,7 @@ func TestPrintMarkdownExplainTable(t *testing.T) {
if err != nil {
t.Error(err)
}
+
err = common.GoldenDiff(func() {
PrintMarkdownExplainTable(expInfo)
}, t.Name(), update)
diff --git a/database/mysql.go b/database/mysql.go
index fa6927d46d17252829fc5e594f7b35d3beb0c843..68aa971e80f121b6168a46bb2438f0feafd3b88b 100644
--- a/database/mysql.go
+++ b/database/mysql.go
@@ -17,21 +17,17 @@
package database
import (
+ "database/sql"
"errors"
"fmt"
- "io/ioutil"
- "os"
"regexp"
"strconv"
"strings"
- "time"
- "github.com/XiaoMi/soar/ast"
"github.com/XiaoMi/soar/common"
- "github.com/ziutek/mymysql/mysql"
- // mymysql driver
- _ "github.com/ziutek/mymysql/native"
+ // for database/sql
+ _ "github.com/go-sql-driver/mysql"
"vitess.io/vitess/go/vt/sqlparser"
)
@@ -42,81 +38,62 @@ type Connector struct {
Pass string
Database string
Charset string
+ Net string
}
// QueryResult 数据库查询返回值
type QueryResult struct {
- Rows []mysql.Row
- Result mysql.Result
+ Rows *sql.Rows
Error error
- Warning []mysql.Row
+ Warning *sql.Rows
QueryCost float64
}
// NewConnection 创建新连接
-func (db *Connector) NewConnection() mysql.Conn {
- return mysql.New("tcp", "", db.Addr, db.User, db.Pass, db.Database)
+func (db *Connector) NewConnection() (*sql.DB, error) {
+ dsn := fmt.Sprintf("%s:%s@%s(%s)/%s?parseTime=true&charset=%s", db.User, db.Pass, db.Net, db.Addr, db.Database, db.Charset)
+ return sql.Open("mysql", dsn)
}
// Query 执行SQL
-func (db *Connector) Query(sql string, params ...interface{}) (*QueryResult, error) {
+func (db *Connector) Query(sql string, params ...interface{}) (QueryResult, error) {
+ var res QueryResult
// 测试环境如果检查是关闭的,则SQL不会被执行
if common.Config.TestDSN.Disable {
- return nil, errors.New("Dsn Disable")
+ return res, errors.New("dsn is disable")
}
// 数据库安全性检查:如果 Connector 的 IP 端口与 TEST 环境不一致,则启用SQL白名单
// 不在白名单中的SQL不允许执行
// 执行环境与test环境不相同
if db.Addr != common.Config.TestDSN.Addr && db.dangerousQuery(sql) {
- return nil, fmt.Errorf("query execution deny: execute SQL with DSN(%s/%s) '%s'",
+ return res, fmt.Errorf("query execution deny: execute SQL with DSN(%s/%s) '%s'",
db.Addr, db.Database, fmt.Sprintf(sql, params...))
}
common.Log.Debug("Execute SQL with DSN(%s/%s) : %s", db.Addr, db.Database, fmt.Sprintf(sql, params...))
- conn := db.NewConnection()
-
- // 设置SQL连接超时时间
- conn.SetTimeout(time.Duration(common.Config.ConnTimeOut) * time.Second)
+ conn, err := db.NewConnection()
defer conn.Close()
- err := conn.Connect()
if err != nil {
- return nil, err
+ return res, err
}
+ res.Rows, res.Error = conn.Query(sql, params...)
- // 添加SQL执行超时限制
- ch := make(chan QueryResult, 1)
- go func() {
- res := QueryResult{}
- res.Rows, res.Result, res.Error = conn.Query(sql, params...)
-
- if common.Config.ShowWarnings {
- warning, _, err := conn.Query("SHOW WARNINGS")
- if err == nil {
- res.Warning = warning
- }
- }
+ if common.Config.ShowWarnings {
+ res.Warning, err = conn.Query("SHOW WARNINGS")
+ }
- // SHOW WARNINGS 并不会影响 last_query_cost
- if common.Config.ShowLastQueryCost {
- cost, _, err := conn.Query("SHOW SESSION STATUS LIKE 'last_query_cost'")
- if err == nil {
- if len(cost) > 0 {
- res.QueryCost = cost[0].Float(1)
- }
+ // SHOW WARNINGS 并不会影响 last_query_cost
+ if common.Config.ShowLastQueryCost {
+ cost, err := conn.Query("SHOW SESSION STATUS LIKE 'last_query_cost'")
+ if err == nil {
+ if cost.Next() {
+ err = cost.Scan(res.QueryCost)
}
}
-
- ch <- res
- }()
-
- select {
- case res := <-ch:
- return &res, res.Error
- case <-time.After(time.Duration(common.Config.QueryTimeOut) * time.Second):
- return nil, errors.New("query execution timeout")
}
+ return res, err
}
// Version 获取MySQL数据库版本
@@ -124,77 +101,43 @@ func (db *Connector) Version() (int, error) {
version := 99999
// 从数据库中获取版本信息
res, err := db.Query("select @@version")
- if err != nil {
- common.Log.Warn("(db *Connector) Version() Error: %v", err)
+ if err != nil || res.Error != nil {
+ common.Log.Warn("(db *Connector) Version() Error: %v, MySQL Error: %v", err, res.Error)
return version, err
}
// MariaDB https://mariadb.com/kb/en/library/comment-syntax/
// MySQL https://dev.mysql.com/doc/refman/8.0/en/comments.html
- versionStr := strings.Split(res.Rows[0].Str(0), "-")[0]
- versionSeg := strings.Split(versionStr, ".")
- if len(versionSeg) == 3 {
- versionStr = fmt.Sprintf("%s%02s%02s", versionSeg[0], versionSeg[1], versionSeg[2])
- version, err = strconv.Atoi(versionStr)
- }
- return version, err
-}
-
-// Source execute sql from file
-func (db *Connector) Source(file string) ([]*QueryResult, error) {
- var sqlCounter int // SQL 计数器
- var result []*QueryResult
-
- fd, err := os.Open(file)
- defer func() {
- err = fd.Close()
- if err != nil {
- common.Log.Error("(db *Connector) Source(%s) fd.Close failed: %s", file, err.Error())
+ var versionStr string
+ var versionSeg []string
+ for res.Rows.Next() {
+ err = res.Rows.Scan(&versionStr)
+ versionStr = strings.Split(versionStr, "-")[0]
+ versionSeg = strings.Split(versionStr, ".")
+ if len(versionSeg) == 3 {
+ versionStr = fmt.Sprintf("%s%02s%02s", versionSeg[0], versionSeg[1], versionSeg[2])
+ version, err = strconv.Atoi(versionStr)
}
- }()
- if err != nil {
- common.Log.Warning("(db *Connector) Source(%s) os.Open failed: %s", file, err.Error())
- return nil, err
- }
- data, err := ioutil.ReadAll(fd)
- if err != nil {
- common.Log.Critical("ioutil.ReadAll Error: %s", err.Error())
- return nil, err
+ break
}
- sql := strings.TrimSpace(string(data))
- buf := strings.TrimSpace(sql)
- for ; ; sqlCounter++ {
- if buf == "" {
- break
- }
-
- // 查询请求切分
- _, sql, bufBytes := ast.SplitStatement([]byte(buf), []byte(common.Config.Delimiter))
- buf = string(bufBytes)
- sql = strings.TrimSpace(sql)
- common.Log.Debug("Source Query SQL: %s", sql)
-
- res, e := db.Query(sql)
- if e != nil {
- common.Log.Error("(db *Connector) Source Filename: %s, SQLCounter.: %d", file, sqlCounter)
- return result, e
- }
- result = append(result, res)
- }
- return result, nil
+ return version, err
}
// SingleIntValue 获取某个int型变量的值
func (db *Connector) SingleIntValue(option string) (int, error) {
// 从数据库中获取信息
- res, err := db.Query("select @@%s", option)
+ res, err := db.Query("select @@" + option)
if err != nil {
common.Log.Warn("(db *Connector) SingleIntValue() Error: %v", err)
return -1, err
}
- return res.Rows[0].Int(0), err
+ var intVal int
+ if res.Rows.Next() {
+ err = res.Rows.Scan(&intVal)
+ }
+ return intVal, err
}
// ColumnCardinality 粒度计算
@@ -228,13 +171,20 @@ func (db *Connector) ColumnCardinality(tb, col string) float64 {
}
// 计算该列散粒度
- res, err := db.Query("select count(distinct `%s`) from `%s`.`%s`", col, db.Database, tb)
+ res, err := db.Query(fmt.Sprintf("select count(distinct `%s`) from `%s`.`%s`", col, db.Database, tb))
if err != nil {
common.Log.Warn("(db *Connector) ColumnCardinality() Query Error: %v", err)
return 0
}
- colNum := res.Rows[0].Float(0)
+ var colNum float64
+ if res.Rows.Next() {
+ err = res.Rows.Scan(&colNum)
+ if err != nil {
+ common.Log.Warn("(db *Connector) ColumnCardinality() Query Error: %v", err)
+ return 0
+ }
+ }
// 散粒度区间:[0,1]
return colNum / float64(rowTotal)
@@ -249,13 +199,12 @@ func (db *Connector) IsView(tbName string) bool {
}
if len(tbStatus.Rows) > 0 {
- if tbStatus.Rows[0].Comment == "VIEW" {
+ if string(tbStatus.Rows[0].Comment) == "VIEW" {
return true
}
}
return false
-
}
// RemoveSQLComments 去除SQL中的注释
@@ -281,7 +230,7 @@ func (db *Connector) dangerousQuery(query string) bool {
return true
}
- for _, sql := range queries {
+ for _, query := range queries {
dangerous := true
whiteList := []string{
"select",
@@ -291,7 +240,7 @@ func (db *Connector) dangerousQuery(query string) bool {
}
for _, prefix := range whiteList {
- if strings.HasPrefix(sql, prefix) {
+ if strings.HasPrefix(query, prefix) {
dangerous = false
break
}
diff --git a/database/mysql_test.go b/database/mysql_test.go
index 79ea7bdd2a84c6254086dbdeb665d08bd528391e..f3239a5e5bbe7de0cd5e6f13fc2bf78071752b79 100644
--- a/database/mysql_test.go
+++ b/database/mysql_test.go
@@ -17,26 +17,70 @@
package database
import (
+ "flag"
"fmt"
+ "os"
"testing"
"github.com/XiaoMi/soar/common"
+
"github.com/kr/pretty"
)
+var connTest *Connector
+
+var update = flag.Bool("update", false, "update .golden files")
+
+func init() {
+ common.BaseDir = common.DevPath
+ common.ParseConfig("")
+ connTest = &Connector{
+ Addr: common.Config.OnlineDSN.Addr,
+ User: common.Config.OnlineDSN.User,
+ Pass: common.Config.OnlineDSN.Password,
+ Database: common.Config.OnlineDSN.Schema,
+ Charset: common.Config.OnlineDSN.Charset,
+ }
+ if _, err := connTest.Version(); err != nil {
+ common.Log.Critical("Test env Error: %v", err)
+ os.Exit(0)
+ }
+}
+
+func TestNewConnection(t *testing.T) {
+ _, err := connTest.NewConnection()
+ if err != nil {
+ t.Errorf("TestNewConnection, Error: %s", err.Error())
+ }
+}
+
// TODO: go test -race不通过待解决
func TestQuery(t *testing.T) {
- common.Config.QueryTimeOut = 1
- _, err := connTest.Query("select sleep(2)")
- if err == nil {
- t.Error("connTest.Query not timeout")
+ res, err := connTest.Query("select 0")
+ if err != nil {
+ t.Error(err.Error())
+ }
+ for res.Rows.Next() {
+ var val int
+ err = res.Rows.Scan(&val)
+ if err != nil {
+ t.Error(err.Error())
+ }
+ if val != 0 {
+ t.Error("should return 0")
+ }
}
+ // TODO: timeout test
}
-func TestColumnCardinality(_ *testing.T) {
- connTest.Database = "information_schema"
- a := connTest.ColumnCardinality("TABLES", "TABLE_SCHEMA")
- fmt.Println("TABLES.TABLE_SCHEMA:", a)
+func TestColumnCardinality(t *testing.T) {
+ orgDatabase := connTest.Database
+ connTest.Database = "sakila"
+ a := connTest.ColumnCardinality("actor", "first_name")
+ if a >= 1 || a <= 0 {
+ t.Error("sakila.actor.first_name cardinality should in (0, 1), now it's", a)
+ }
+ connTest.Database = orgDatabase
}
func TestDangerousSQL(t *testing.T) {
@@ -63,11 +107,15 @@ func TestWarningsAndQueryCost(t *testing.T) {
if err != nil {
t.Error("Query Error: ", err)
} else {
- for _, w := range res.Warning {
- pretty.Println(w.Str(2))
+ for res.Warning.Next() {
+ var str string
+ err = res.Warning.Scan(str)
+ if err != nil {
+ t.Error(err.Error())
+ }
+ pretty.Println(str)
}
- fmt.Println(res.QueryCost)
- pretty.Println(err)
+ fmt.Println(res.QueryCost, err)
}
}
@@ -79,16 +127,6 @@ func TestVersion(t *testing.T) {
fmt.Println(version)
}
-func TestSource(t *testing.T) {
- res, err := connTest.Source("testdata/" + t.Name() + ".sql")
- if err != nil {
- t.Error("Query Error: ", err)
- }
- if res[0].Rows[0].Int(0) != 1 || res[1].Rows[0].Int(0) != 1 {
- t.Error("Source result not match, expect 1, 1")
- }
-}
-
func TestRemoveSQLComments(t *testing.T) {
SQLs := []string{
`-- comment`,
@@ -109,3 +147,22 @@ comment*/`,
t.Error(err)
}
}
+
+func TestSingleIntValue(t *testing.T) {
+ val, err := connTest.SingleIntValue("read_only")
+ if err != nil {
+ t.Error(err)
+ }
+ if val < 0 {
+ t.Error("SingleIntValue, return should large than zero")
+ }
+}
+
+func TestIsView(t *testing.T) {
+ originalDatabase := connTest.Database
+ connTest.Database = "sakila"
+ if !connTest.IsView("actor_info") {
+ t.Error("actor_info should be a VIEW")
+ }
+ connTest.Database = originalDatabase
+}
diff --git a/database/privilege.go b/database/privilege.go
index 762de2898d046ae29aca169a4c92ee971fd74f00..dca833e3b9d11636675dd87ad18220f449f5dfe5 100644
--- a/database/privilege.go
+++ b/database/privilege.go
@@ -18,6 +18,7 @@ package database
import (
"errors"
+ "fmt"
"strings"
"github.com/XiaoMi/soar/common"
@@ -29,8 +30,13 @@ func (db *Connector) CurrentUser() (string, string, error) {
if err != nil {
return "", "", err
}
- if len(res.Rows) > 0 {
- cols := strings.Split(res.Rows[0].Str(0), "@")
+ if res.Rows.Next() {
+ var currentUser string
+ err = res.Rows.Scan(¤tUser)
+ if err != nil {
+ return "", "", err
+ }
+ cols := strings.Split(currentUser, "@")
if len(cols) == 2 {
user := strings.Trim(cols[0], "'")
host := strings.Trim(cols[1], "'")
@@ -51,14 +57,20 @@ func (db *Connector) HasSelectPrivilege() bool {
common.Log.Error("User: %s, HasSelectPrivilege: %s", db.User, err.Error())
return false
}
- res, err := db.Query("select Select_priv from mysql.user where user='%s' and host='%s'", user, host)
+ res, err := db.Query(fmt.Sprintf("select Select_priv from mysql.user where user='%s' and host='%s'", user, host))
if err != nil {
common.Log.Error("HasSelectPrivilege, DSN: %s, Error: %s", db.Addr, err.Error())
return false
}
// Select_priv
- if len(res.Rows) > 0 {
- if res.Rows[0].Str(0) == "Y" {
+ if res.Rows.Next() {
+ var selectPrivilege string
+ err = res.Rows.Scan(&selectPrivilege)
+ if err != nil {
+ common.Log.Error("HasSelectPrivilege, Scan Error: %s", err.Error())
+ return false
+ }
+ if selectPrivilege == "Y" {
return true
}
}
@@ -79,24 +91,31 @@ func (db *Connector) HasAllPrivilege() bool {
common.Log.Error("HasAllPrivilege, DSN: %s, Error: %s", db.Addr, err.Error())
return false
}
+
var priv string
- if len(res.Rows) > 0 {
- priv = res.Rows[0].Str(0)
- } else {
- common.Log.Error("HasAllPrivilege, DSN: %s, get privilege string error", db.Addr)
- return false
+ if res.Rows.Next() {
+ err = res.Rows.Scan(&priv)
+ if err != nil {
+ common.Log.Error("HasAllPrivilege, DSN: %s, Scan error", db.Addr)
+ return false
+ }
}
// get all privilege status
- res, err = db.Query("select concat("+priv+") from mysql.user where user='%s' and host='%s'", user, host)
+ res, err = db.Query(fmt.Sprintf("select concat("+priv+") from mysql.user where user='%s' and host='%s'", user, host))
if err != nil {
common.Log.Error("HasAllPrivilege, DSN: %s, Error: %s", db.Addr, err.Error())
return false
}
// %_priv
- if len(res.Rows) > 0 {
- if strings.Replace(res.Rows[0].Str(0), "Y", "", -1) == "" {
+ if res.Rows.Next() {
+ err = res.Rows.Scan(&priv)
+ if err != nil {
+ common.Log.Error("HasAllPrivilege, DSN: %s, Scan error", db.Addr)
+ return false
+ }
+ if strings.Replace(priv, "Y", "", -1) == "" {
return true
}
}
diff --git a/database/privilege_test.go b/database/privilege_test.go
index 4ed0fbc609d70550cca119c59b6e7583975e0fdc..3690fd3dc5c99ac6fc51051f535e63fd4c8489db 100644
--- a/database/privilege_test.go
+++ b/database/privilege_test.go
@@ -18,6 +18,16 @@ package database
import "testing"
+func TestCurrentUser(t *testing.T) {
+ user, host, err := connTest.CurrentUser()
+ if err != nil {
+ t.Error(err.Error())
+ }
+ if user != "root" || host != "%" {
+ t.Errorf("Want user: root, host: %%. Get user: %s, host: %s", user, host)
+ }
+}
+
func TestHasSelectPrivilege(t *testing.T) {
if !connTest.HasSelectPrivilege() {
t.Errorf("DSN: %s, User: %s, should has select privilege", connTest.Addr, connTest.User)
diff --git a/database/profiling.go b/database/profiling.go
index a6072313e09e554ecb04ac2fd89aa91bac3a9be1..f85a6b0231bfd259660db702f13b144e4a7193b6 100644
--- a/database/profiling.go
+++ b/database/profiling.go
@@ -19,9 +19,7 @@ package database
import (
"errors"
"fmt"
- "io"
"strings"
- "time"
"github.com/XiaoMi/soar/common"
@@ -40,95 +38,79 @@ type ProfilingRow struct {
// TODO: 支持show profile all,不过目前看all的信息过多有点眼花缭乱
}
-// Profiling 执行SQL,并对其Profiling
-func (db *Connector) Profiling(sql string, params ...interface{}) (*QueryResult, error) {
+// Profiling 执行SQL,并对其 Profile
+func (db *Connector) Profiling(sql string, params ...interface{}) ([]ProfilingRow, error) {
+ var rows []ProfilingRow
// 过滤不需要 profiling 的 SQL
switch sqlparser.Preview(sql) {
case sqlparser.StmtSelect, sqlparser.StmtUpdate, sqlparser.StmtDelete:
default:
- return nil, errors.New("no need profiling")
+ return rows, errors.New("no need profiling")
}
// 测试环境如果检查是关闭的,则SQL不会被执行
if common.Config.TestDSN.Disable {
- return nil, errors.New("Dsn Disable")
+ return rows, errors.New("dsn is disable")
}
// 数据库安全性检查:如果 Connector 的 IP 端口与 TEST 环境不一致,则启用 SQL 白名单
- // 不在白名单中的SQL不允许执行
+ // 不在白名单中的 SQL 不允许执行
// 执行环境与test环境不相同
if db.Addr != common.Config.TestDSN.Addr && db.dangerousQuery(sql) {
- return nil, fmt.Errorf("query execution deny: Execute SQL with DSN(%s/%s) '%s'",
+ return rows, fmt.Errorf("query execution deny: Execute SQL with DSN(%s/%s) '%s'",
db.Addr, db.Database, fmt.Sprintf(sql, params...))
}
common.Log.Debug("Execute SQL with DSN(%s/%s) : %s", db.Addr, db.Database, sql)
- conn := db.NewConnection()
-
- // 设置SQL连接超时时间
- conn.SetTimeout(time.Duration(common.Config.ConnTimeOut) * time.Second)
+ conn, err := db.NewConnection()
+ if err != nil {
+ return rows, err
+ }
defer conn.Close()
- err := conn.Connect()
+
+ // Keep connection
+ // https://github.com/go-sql-driver/mysql/issues/208
+ trx, err := conn.Begin()
if err != nil {
- return nil, err
+ return rows, err
}
+ defer trx.Rollback()
+
+ // 开启 Profiling
+ _, err = trx.Query("set @@profiling=1")
+ common.LogIfError(err, "")
- // 添加SQL执行超时限制
- ch := make(chan QueryResult, 1)
- go func() {
- // 开启Profiling
- _, _, err = conn.Query("set @@profiling=1")
- common.LogIfError(err, "")
+ // 执行 SQL,抛弃返回结果
+ tmpRes, err := trx.Query(sql, params...)
+ if err != nil {
+ return rows, err
+ }
+ for tmpRes.Next() {
+ continue
+ }
- // 执行SQL,抛弃返回结果
- result, err := conn.Start(sql, params...)
+ // 返回 Profiling 结果
+ res, err := trx.Query("show profile")
+ for res.Next() {
+ var profileRow ProfilingRow
+ err := res.Scan(&profileRow.Status, &profileRow.Duration)
if err != nil {
- ch <- QueryResult{
- Error: err,
- }
- return
+ common.LogIfError(err, "")
}
- row := result.MakeRow()
- for {
- err = result.ScanRow(row)
- if err == io.EOF {
- break
- }
- }
-
- // 返回Profiling结果
- res := QueryResult{}
- res.Rows, res.Result, res.Error = conn.Query("show profile")
- _, _, err = conn.Query("set @@profiling=0")
- common.LogIfError(err, "")
- ch <- res
- }()
-
- select {
- case res := <-ch:
- return &res, res.Error
- case <-time.After(time.Duration(common.Config.QueryTimeOut) * time.Second):
- return nil, errors.New("query execution timeout")
+ rows = append(rows, profileRow)
}
-}
-func getProfiling(res *QueryResult) Profiling {
- var rows []ProfilingRow
- for _, row := range res.Rows {
- rows = append(rows, ProfilingRow{
- Status: row.Str(0),
- Duration: row.Float(1),
- })
- }
- return Profiling{Rows: rows}
+ // 关闭 Profiling
+ _, err = trx.Query("set @@profiling=0")
+ common.LogIfError(err, "")
+ return rows, err
}
// FormatProfiling 格式化输出Profiling信息
-func FormatProfiling(res *QueryResult) string {
- profiling := getProfiling(res)
+func FormatProfiling(rows []ProfilingRow) string {
str := []string{"| Status | Duration |"}
str = append(str, "| --- | --- |")
- for _, row := range profiling.Rows {
+ for _, row := range rows {
str = append(str, fmt.Sprintf("| %s | %f |", row.Status, row.Duration))
}
return strings.Join(str, "\n")
diff --git a/database/profiling_test.go b/database/profiling_test.go
index 4b226ba90e4db8f482f3f55fc6f16c92a247ed8a..78d28fef46876fb1aca1cb2155cf74bab0886bae 100644
--- a/database/profiling_test.go
+++ b/database/profiling_test.go
@@ -19,35 +19,21 @@ package database
import (
"testing"
- "github.com/XiaoMi/soar/common"
-
"github.com/kr/pretty"
)
func TestProfiling(t *testing.T) {
- common.Config.QueryTimeOut = 1
- res, err := connTest.Profiling("select 1")
- if err == nil {
- pretty.Println(res)
- } else {
+ rows, err := connTest.Profiling("select 1")
+ if err != nil {
t.Error(err)
}
+ pretty.Println(rows)
}
func TestFormatProfiling(t *testing.T) {
res, err := connTest.Profiling("select 1")
- if err == nil {
- pretty.Println(FormatProfiling(res))
- } else {
- t.Error(err)
- }
-}
-
-func TestGetProfiling(t *testing.T) {
- res, err := connTest.Profiling("select 1")
- if err == nil {
- pretty.Println(getProfiling(res))
- } else {
+ if err != nil {
t.Error(err)
}
+ pretty.Println(FormatProfiling(res))
}
diff --git a/database/sampling.go b/database/sampling.go
index f8565b4cd610fc6259ccc089729418a6832c6af7..c9c3e60f4db366daeb022d6e94ba1b030c71c500 100644
--- a/database/sampling.go
+++ b/database/sampling.go
@@ -17,14 +17,9 @@
package database
import (
+ "database/sql"
"fmt"
- "io"
- "strconv"
- "strings"
- "time"
-
"github.com/XiaoMi/soar/common"
- "github.com/ziutek/mymysql/mysql"
)
/*--------------------
@@ -57,21 +52,17 @@ func (db *Connector) SamplingData(remote Connector, tables ...string) error {
maxValCount := 200
// 获取数据库连接对象
- conn := remote.NewConnection()
- localConn := db.NewConnection()
-
- // 连接数据库
- err := conn.Connect()
- defer conn.Close()
+ conn, err := remote.NewConnection()
if err != nil {
return err
}
+ defer conn.Close()
- err = localConn.Connect()
- defer localConn.Close()
+ localConn, err := db.NewConnection()
if err != nil {
return err
}
+ defer localConn.Close()
for _, table := range tables {
// 表类型检查
@@ -109,119 +100,128 @@ func (db *Connector) SamplingData(remote Connector, tables ...string) error {
// 开始从环境中泵取数据
// 因为涉及到的数据量问题,所以泵取与插入时同时进行的
// TODO 加 ref link
-func startSampling(conn, localConn mysql.Conn, database, table string, factor float64, wants, maxValCount int) error {
- // 从线上数据库获取所需dump的表中所有列的数据类型,备用
- // 由于测试库中的库表为刚建立的,所以在information_schema中很可能没有这个表的信息
- var dataTypes []string
- q := fmt.Sprintf("select DATA_TYPE from information_schema.COLUMNS where TABLE_SCHEMA='%s' and TABLE_NAME = '%s'",
- database, table)
- common.Log.Debug("Sampling data execute: %s", q)
- rs, _, err := localConn.Query(q)
- if err != nil {
- common.Log.Debug("Sampling data got data type Err: %v", err)
- } else {
- for _, r := range rs {
- dataTypes = append(dataTypes, r.Str(0))
+func startSampling(conn, localConn *sql.DB, database, table string, factor float64, wants, maxValCount int) error {
+ return nil
+ // TODO:
+ /*
+ // 从线上数据库获取所需dump的表中所有列的数据类型,备用
+ // 由于测试库中的库表为刚建立的,所以在information_schema中很可能没有这个表的信息
+ var dataTypes []string
+ q := fmt.Sprintf("select DATA_TYPE from information_schema.COLUMNS where TABLE_SCHEMA='%s' and TABLE_NAME = '%s'",
+ database, table)
+ common.Log.Debug("Sampling data execute: %s", q)
+ rs, err := localConn.Query(q)
+ if err != nil {
+ common.Log.Debug("Sampling data got data type Err: %v", err)
+ } else {
+ for rs.Next() {
+ var dataType string
+ err = rs.Scan(&dataType)
+ if err != nil {
+ return err
+ }
+ dataTypes = append(dataTypes, dataType)
+ }
}
- }
-
- // 生成where条件
- where := fmt.Sprintf("where RAND()<=%f", factor)
- if factor >= 1 {
- where = ""
- }
-
- sql := fmt.Sprintf("select * from `%s` %s limit %d;", table, where, wants)
- res, err := conn.Start(sql)
- if err != nil {
- return err
- }
- // GetRow method allocates a new chunk of memory for every received row.
- row := res.MakeRow()
- rowCount := 0
- valCount := 0
-
- // 获取所有的列名
- columns := make([]string, len(res.Fields()))
- for i, filed := range res.Fields() {
- columns[i] = filed.Name
- }
- colDef := strings.Join(columns, ",")
-
- // 开始填充数据
- var valList []string
- for {
- err := res.ScanRow(row)
- if err == io.EOF {
- // 扫描结束
- if len(valList) > 0 {
- // 如果缓存中还存在未插入的数据,则把缓存中的数据刷新到DB中
- doSampling(localConn, database, table, colDef, strings.Join(valList, ","))
- }
- break
+ // 生成where条件
+ where := fmt.Sprintf("where RAND()<=%f", factor)
+ if factor >= 1 {
+ where = ""
}
+ sql := fmt.Sprintf("select * from `%s` %s limit %d;", table, where, wants)
+ res, err := conn.Query(sql)
if err != nil {
return err
}
- values := make([]string, len(columns))
- for i := range row {
- // TODO 不支持坐标类型的导出
- switch data := row[i].(type) {
- case nil:
- // str = ""
- case []byte:
- // 先尝试转成数字,如果报错则转换成string
- if v, err := row.Int64Err(i); err != nil {
- values[i] = string(data)
- } else {
- values[i] = strconv.FormatInt(v, 10)
+ // GetRow method allocates a new chunk of memory for every received row.
+ row := res.MakeRow()
+ rowCount := 0
+ valCount := 0
+
+ // 获取所有的列名
+ columns := make([]string, len(res.Fields()))
+ for i, filed := range res.Fields() {
+ columns[i] = filed.Name
+ }
+ colDef := strings.Join(columns, ",")
+
+ // 开始填充数据
+ var valList []string
+ for {
+ err := res.ScanRow(row)
+ if err == io.EOF {
+ // 扫描结束
+ if len(valList) > 0 {
+ // 如果缓存中还存在未插入的数据,则把缓存中的数据刷新到DB中
+ doSampling(localConn, database, table, colDef, strings.Join(valList, ","))
}
- case time.Time:
- values[i] = mysql.TimeString(data)
- case time.Duration:
- values[i] = mysql.DurationString(data)
- default:
- values[i] = fmt.Sprint(data)
+ break
}
- // 非text/varchar类的数据类型,如果dump出的数据为空,则说明该值为null值
- // 应转换其 value 为 null,如果用空('')进行替代,会导致出现语法错误。
- if len(dataTypes) == len(res.Fields()) && values[i] == "" &&
- (!strings.Contains(dataTypes[i], "char") ||
- !strings.Contains(dataTypes[i], "text")) {
- values[i] = "null"
- } else {
- values[i] = "'" + values[i] + "'"
+ if err != nil {
+ return err
}
- }
- valuesStr := fmt.Sprintf(`(%s)`, strings.Join(values, `,`))
- valList = append(valList, valuesStr)
+ values := make([]string, len(columns))
+ for i := range row {
+ // TODO 不支持坐标类型的导出
+ switch data := row[i].(type) {
+ case nil:
+ // str = ""
+ case []byte:
+ // 先尝试转成数字,如果报错则转换成string
+ if v, err := row.Int64Err(i); err != nil {
+ values[i] = string(data)
+ } else {
+ values[i] = strconv.FormatInt(v, 10)
+ }
+ case time.Time:
+ values[i] = mysql.TimeString(data)
+ case time.Duration:
+ values[i] = mysql.DurationString(data)
+ default:
+ values[i] = fmt.Sprint(data)
+ }
+
+ // 非text/varchar类的数据类型,如果dump出的数据为空,则说明该值为null值
+ // 应转换其 value 为 null,如果用空('')进行替代,会导致出现语法错误。
+ if len(dataTypes) == len(res.Fields()) && values[i] == "" &&
+ (!strings.Contains(dataTypes[i], "char") ||
+ !strings.Contains(dataTypes[i], "text")) {
+ values[i] = "null"
+ } else {
+ values[i] = "'" + values[i] + "'"
+ }
+ }
- rowCount++
- valCount++
+ valuesStr := fmt.Sprintf(`(%s)`, strings.Join(values, `,`))
+ valList = append(valList, valuesStr)
- if rowCount%maxValCount == 0 {
- doSampling(localConn, database, table, colDef, strings.Join(valList, ","))
- valCount = 0
- valList = make([]string, 0)
+ rowCount++
+ valCount++
+ if rowCount%maxValCount == 0 {
+ doSampling(localConn, database, table, colDef, strings.Join(valList, ","))
+ valCount = 0
+ valList = make([]string, 0)
+
+ }
}
- }
- common.Log.Debug("%d rows sampling out", rowCount)
- return nil
+ common.Log.Debug("%d rows sampling out", rowCount)
+ return nil
+ */
}
// 将泵取的数据转换成Insert语句并在数据库中执行
-func doSampling(conn mysql.Conn, dbName, table, colDef, values string) {
- sql := fmt.Sprintf("Insert into `%s`.`%s`(%s) values%s;", dbName, table,
+func doSampling(conn *sql.DB, dbName, table, colDef, values string) {
+ query := fmt.Sprintf("Insert into `%s`.`%s`(%s) values%s;", dbName, table,
colDef, values)
- _, _, err := conn.Query(sql)
+ _, err := conn.Query(query)
if err != nil {
common.Log.Error("doSampling Error from %s.%s: %v", dbName, table, err)
diff --git a/database/sampling_test.go b/database/sampling_test.go
index 082d5e9fc3bd03202da2fb7a3897f318e782b7ac..d164b0b936e692cbdb976bc1f417c24733ffbfda 100644
--- a/database/sampling_test.go
+++ b/database/sampling_test.go
@@ -32,6 +32,8 @@ func TestSamplingData(t *testing.T) {
User: common.Config.OnlineDSN.User,
Pass: common.Config.OnlineDSN.Password,
Database: common.Config.OnlineDSN.Schema,
+ Charset: common.Config.OnlineDSN.Charset,
+ Net: common.Config.OnlineDSN.Net,
}
offline := &Connector{
@@ -39,6 +41,8 @@ func TestSamplingData(t *testing.T) {
User: common.Config.TestDSN.User,
Pass: common.Config.TestDSN.Password,
Database: common.Config.TestDSN.Schema,
+ Charset: common.Config.TestDSN.Charset,
+ Net: common.Config.TestDSN.Net,
}
offline.Database = "test"
diff --git a/database/show.go b/database/show.go
index a14a5fbe90e555bb8a3f513e95333657eab7e70d..7628bea9f910b6e58b9a0861dc2a749a029505c1 100644
--- a/database/show.go
+++ b/database/show.go
@@ -18,12 +18,10 @@ package database
import (
"fmt"
+ "github.com/XiaoMi/soar/common"
"regexp"
"strconv"
"strings"
- "time"
-
- "github.com/XiaoMi/soar/common"
)
// SHOW TABLE STATUS Syntax
@@ -36,11 +34,12 @@ type TableStatInfo struct {
}
// tableStatusRow 用于 show table status value
+// use []byte instead of string, because []byte allow to be null, string not
type tableStatusRow struct {
Name string // 表名
- Engine string // 该表使用的存储引擎
- Version int // 该表的 .frm 文件版本号
- RowFormat string // 该表使用的行存储格式
+ Engine []byte // 该表使用的存储引擎
+ Version []byte // 该表的 .frm 文件版本号
+ RowFormat []byte // 该表使用的行存储格式
Rows int64 // 表行数, InnoDB 引擎中为预估值,甚至可能会有40%~50%的数值偏差
AvgRowLength int // 平均行长度
@@ -59,15 +58,15 @@ type tableStatusRow struct {
// 其他不同的存储引擎中该值的意义可能不尽相同
IndexLength int
- DataFree int // 已分配但未使用的字节数
- AutoIncrement int // 下一个自增值
- CreateTime time.Time // 创建时间
- UpdateTime time.Time // 最近一次更新时间,该值不准确
- CheckTime time.Time // 上次检查时间
- Collation string // 字符集及排序规则信息
- Checksum string // 校验和
- CreateOptions string // 创建表的时候的时候一切其他属性
- Comment string // 注释
+ DataFree int // 已分配但未使用的字节数
+ AutoIncrement []byte // 下一个自增值
+ CreateTime []byte // 创建时间
+ UpdateTime []byte // 最近一次更新时间,该值不准确
+ CheckTime []byte // 上次检查时间
+ Collation []byte // 字符集及排序规则信息
+ Checksum []byte // 校验和
+ CreateOptions []byte // 创建表的时候的时候一切其他属性
+ Comment []byte // 注释
}
// newTableStat 构造 table Stat 对象
@@ -83,7 +82,7 @@ func (db *Connector) ShowTables() ([]string, error) {
defer func() {
err := recover()
if err != nil {
- common.Log.Error("recover ShowTableStatus()", err)
+ common.Log.Error("recover ShowTables()", err)
}
}()
@@ -92,79 +91,70 @@ func (db *Connector) ShowTables() ([]string, error) {
if err != nil {
return []string{}, err
}
+ if res.Error != nil {
+ return []string{}, res.Error
+ }
// 获取值
var tables []string
- for _, row := range res.Rows {
- tables = append(tables, row.Str(0))
+ for res.Rows.Next() {
+ var table string
+ err = res.Rows.Scan(&table)
+ if err != nil {
+ return []string{}, err
+ }
+ tables = append(tables, table)
}
-
return tables, err
}
// ShowTableStatus 执行 show table status
func (db *Connector) ShowTableStatus(tableName string) (*TableStatInfo, error) {
- defer func() {
- err := recover()
- if err != nil {
- common.Log.Error("recover ShowTableStatus()", err)
- }
- }()
-
// 初始化struct
- ts := newTableStat(tableName)
+ tbStatus := newTableStat(tableName)
// 执行 show table status
- res, err := db.Query("show table status where name = '%s'", ts.Name)
+ res, err := db.Query(fmt.Sprintf("show table status where name = '%s'", tbStatus.Name))
if err != nil {
- return ts, err
- }
-
- rs := res.Result.Map("Rows")
- name := res.Result.Map("Name")
- df := res.Result.Map("Data_free")
- sum := res.Result.Map("Checksum")
- engine := res.Result.Map("Engine")
- version := res.Result.Map("Version")
- comment := res.Result.Map("Comment")
- ai := res.Result.Map("Auto_increment")
- collation := res.Result.Map("Collation")
- rowFormat := res.Result.Map("Row_format")
- checkTime := res.Result.Map("Check_time")
- dataLength := res.Result.Map("Data_length")
- idxLength := res.Result.Map("Index_length")
- createTime := res.Result.Map("Create_time")
- updateTime := res.Result.Map("Update_time")
- options := res.Result.Map("Create_options")
- avgRowLength := res.Result.Map("Avg_row_length")
- maxDataLength := res.Result.Map("Max_data_length")
+ return tbStatus, err
+ }
+ if res.Error != nil {
+ return tbStatus, res.Error
+ }
+ ts := tableStatusRow{}
+ statusFields := make([]interface{}, 0)
+ fields := map[string]interface{}{
+ "Name": &ts.Name,
+ "Engine": &ts.Engine,
+ "Version": &ts.Version,
+ "Row_format": &ts.RowFormat,
+ "Rows": &ts.Rows,
+ "Avg_row_length": &ts.AvgRowLength,
+ "Data_length": &ts.DataLength,
+ "Max_data_length": &ts.MaxDataLength,
+ "Index_length": &ts.IndexLength,
+ "Data_free": &ts.DataFree,
+ "Auto_increment": &ts.AutoIncrement,
+ "Create_time": &ts.CreateTime,
+ "Update_time": &ts.UpdateTime,
+ "Check_time": &ts.CheckTime,
+ "Collation": &ts.Collation,
+ "Checksum": &ts.Checksum,
+ "Create_options": &ts.CreateOptions,
+ "Comment": &ts.Comment,
+ }
+ cols, err := res.Rows.Columns()
+ common.LogIfError(err, "")
+ for _, col := range cols {
+ statusFields = append(statusFields, fields[col])
+ }
// 获取值
- for _, row := range res.Rows {
- value := tableStatusRow{
- Name: row.Str(name),
- Engine: row.Str(engine),
- Version: row.Int(version),
- Rows: row.Int64(rs),
- RowFormat: row.Str(rowFormat),
- AvgRowLength: row.Int(avgRowLength),
- DataLength: row.Int(dataLength),
- MaxDataLength: row.Int(maxDataLength),
- IndexLength: row.Int(idxLength),
- DataFree: row.Int(df),
- AutoIncrement: row.Int(ai),
- CreateTime: row.Time(createTime, time.Local),
- UpdateTime: row.Time(updateTime, time.Local),
- CheckTime: row.Time(checkTime, time.Local),
- Collation: row.Str(collation),
- Checksum: row.Str(sum),
- CreateOptions: row.Str(options),
- Comment: row.Str(comment),
- }
- ts.Rows = append(ts.Rows, value)
+ for res.Rows.Next() {
+ res.Rows.Scan(statusFields...)
+ tbStatus.Rows = append(tbStatus.Rows, ts)
}
-
- return ts, err
+ return tbStatus, err
}
// https://dev.mysql.com/doc/refman/5.7/en/show-index.html
@@ -172,7 +162,7 @@ func (db *Connector) ShowTableStatus(tableName string) (*TableStatInfo, error) {
// TableIndexInfo 用以保存 show index 之后获取的 index 信息
type TableIndexInfo struct {
TableName string
- IdxRows []TableIndexRow
+ Rows []TableIndexRow
}
// TableIndexRow 用以存放show index之后获取的每一条index信息
@@ -190,13 +180,14 @@ type TableIndexRow struct {
IndexType string // BTREE, FULLTEXT, HASH, RTREE
Comment string
IndexComment string
+ Visible string
}
// NewTableIndexInfo 构造 TableIndexInfo
func NewTableIndexInfo(tableName string) *TableIndexInfo {
return &TableIndexInfo{
TableName: tableName,
- IdxRows: make([]TableIndexRow, 0),
+ Rows: make([]TableIndexRow, 0),
}
}
@@ -205,43 +196,32 @@ func (db *Connector) ShowIndex(tableName string) (*TableIndexInfo, error) {
tbIndex := NewTableIndexInfo(tableName)
// 执行 show create table
- res, err := db.Query("show index from `%s`.`%s`", db.Database, tableName)
+ res, err := db.Query(fmt.Sprintf("show index from `%s`.`%s`", db.Database, tableName))
if err != nil {
return nil, err
}
-
- table := res.Result.Map("Table")
- unique := res.Result.Map("Non_unique")
- keyName := res.Result.Map("Key_name")
- seq := res.Result.Map("Seq_in_index")
- cName := res.Result.Map("Column_name")
- collation := res.Result.Map("Collation")
- cardinality := res.Result.Map("Cardinality")
- subPart := res.Result.Map("Sub_part")
- packed := res.Result.Map("Packed")
- null := res.Result.Map("Null")
- idxType := res.Result.Map("Index_type")
- comment := res.Result.Map("Comment")
- idxComment := res.Result.Map("Index_comment")
+ if res.Error != nil {
+ return nil, res.Error
+ }
// 获取值
- for _, row := range res.Rows {
- value := TableIndexRow{
- Table: row.Str(table),
- NonUnique: row.Int(unique),
- KeyName: row.Str(keyName),
- SeqInIndex: row.Int(seq),
- ColumnName: row.Str(cName),
- Collation: row.Str(collation),
- Cardinality: row.Int(cardinality),
- SubPart: row.Int(subPart),
- Packed: row.Int(packed),
- Null: row.Str(null),
- IndexType: row.Str(idxType),
- Comment: row.Str(comment),
- IndexComment: row.Str(idxComment),
- }
- tbIndex.IdxRows = append(tbIndex.IdxRows, value)
+ for res.Rows.Next() {
+ var ti TableIndexRow
+ res.Rows.Scan(&ti.Table,
+ &ti.NonUnique,
+ &ti.KeyName,
+ &ti.SeqInIndex,
+ &ti.ColumnName,
+ &ti.Collation,
+ &ti.Cardinality,
+ &ti.SubPart,
+ &ti.Packed,
+ &ti.Null,
+ &ti.IndexType,
+ &ti.Comment,
+ &ti.IndexComment,
+ &ti.Visible)
+ tbIndex.Rows = append(tbIndex.Rows, ti)
}
return tbIndex, err
}
@@ -257,7 +237,7 @@ const (
IndexNonUnique = IndexSelectKey("NonUnique") // 唯一索引
)
-// FindIndex 获取TableIndexInfo中需要的索引
+// FindIndex 获取 TableIndexInfo 中需要的索引
func (tbIndex *TableIndexInfo) FindIndex(arg IndexSelectKey, value string) []TableIndexRow {
var result []TableIndexRow
if tbIndex == nil {
@@ -268,28 +248,28 @@ func (tbIndex *TableIndexInfo) FindIndex(arg IndexSelectKey, value string) []Tab
switch arg {
case IndexKeyName:
- for _, index := range tbIndex.IdxRows {
+ for _, index := range tbIndex.Rows {
if strings.ToLower(index.KeyName) == value {
result = append(result, index)
}
}
case IndexColumnName:
- for _, index := range tbIndex.IdxRows {
+ for _, index := range tbIndex.Rows {
if strings.ToLower(index.ColumnName) == value {
result = append(result, index)
}
}
case IndexIndexType:
- for _, index := range tbIndex.IdxRows {
+ for _, index := range tbIndex.Rows {
if strings.ToLower(index.IndexType) == value {
result = append(result, index)
}
}
case IndexNonUnique:
- for _, index := range tbIndex.IdxRows {
+ for _, index := range tbIndex.Rows {
unique := strconv.Itoa(index.NonUnique)
if unique == value {
result = append(result, index)
@@ -316,12 +296,12 @@ type TableDesc struct {
type TableDescValue struct {
Field string // 列名
Type string // 数据类型
+ Collation []byte // 字符集
Null string // 是否有NULL(NO、YES)
- Collation string // 字符集
- Privileges string // 权限s
Key string // 键类型
- Default string // 默认值
+ Default []byte // 默认值
Extra string // 其他
+ Privileges string // 权限
Comment string // 备注
}
@@ -338,35 +318,27 @@ func (db *Connector) ShowColumns(tableName string) (*TableDesc, error) {
tbDesc := NewTableDesc(tableName)
// 执行 show create table
- res, err := db.Query("show full columns from `%s`.`%s`", db.Database, tableName)
+ res, err := db.Query(fmt.Sprintf("show full columns from `%s`.`%s`", db.Database, tableName))
if err != nil {
return nil, err
}
-
- field := res.Result.Map("Field")
- tp := res.Result.Map("Type")
- null := res.Result.Map("Null")
- key := res.Result.Map("Key")
- def := res.Result.Map("Default")
- extra := res.Result.Map("Extra")
- collation := res.Result.Map("Collation")
- privileges := res.Result.Map("Privileges")
- comm := res.Result.Map("Comment")
+ if res.Error != nil {
+ return nil, res.Error
+ }
// 获取值
- for _, row := range res.Rows {
- value := TableDescValue{
- Field: row.Str(field),
- Type: row.Str(tp),
- Null: row.Str(null),
- Key: row.Str(key),
- Default: row.Str(def),
- Extra: row.Str(extra),
- Privileges: row.Str(privileges),
- Collation: row.Str(collation),
- Comment: row.Str(comm),
- }
- tbDesc.DescValues = append(tbDesc.DescValues, value)
+ for res.Rows.Next() {
+ var tc TableDescValue
+ res.Rows.Scan(&tc.Field,
+ &tc.Type,
+ &tc.Collation,
+ &tc.Null,
+ &tc.Key,
+ &tc.Default,
+ &tc.Extra,
+ &tc.Privileges,
+ &tc.Comment)
+ tbDesc.DescValues = append(tbDesc.DescValues, tc)
}
return tbDesc, err
}
@@ -383,18 +355,21 @@ func (td TableDesc) Columns() []string {
// showCreate show create
func (db *Connector) showCreate(createType, name string) (string, error) {
// 执行 show create table
- res, err := db.Query("show create %s `%s`", createType, name)
+ res, err := db.Query(fmt.Sprintf("show create %s `%s`", createType, name))
if err != nil {
return "", err
}
+ if res.Error != nil {
+ return "", res.Error
+ }
- // 获取ddl
- var ddl string
- for _, row := range res.Rows {
- ddl = row.Str(1)
+ // 获取 CREATE TABLE 语句
+ var tableName, createTable string
+ for res.Rows.Next() {
+ res.Rows.Scan(&tableName, &createTable)
}
- return ddl, err
+ return createTable, err
}
// ShowCreateDatabase show create database
@@ -451,6 +426,10 @@ func (db *Connector) FindColumn(name, dbName string, tables ...string) ([]*commo
"c.TABLE_NAME,c.TABLE_SCHEMA,c.COLUMN_TYPE,c.CHARACTER_SET_NAME, c.COLLATION_NAME "+
"FROM `INFORMATION_SCHEMA`.`COLUMNS` as c where c.COLUMN_NAME = '%s' ", name)
+ if dbName != "" {
+ sql += fmt.Sprintf(" and c.table_schema = '%s'", dbName)
+ }
+
if len(tables) > 0 {
var tmp []string
for _, table := range tables {
@@ -459,32 +438,24 @@ func (db *Connector) FindColumn(name, dbName string, tables ...string) ([]*commo
sql += fmt.Sprintf(" and c.table_name in (%s)", strings.Join(tmp, ","))
}
- if dbName != "" {
- sql += fmt.Sprintf(" and c.table_schema = '%s'", dbName)
- }
-
+ common.Log.Debug("FindColumn, execute SQL: %s", sql)
res, err := db.Query(sql)
if err != nil {
common.Log.Error("(db *Connector) FindColumn Error : ", err)
return columns, err
}
+ if res.Error != nil {
+ common.Log.Error("(db *Connector) FindColumn Error : ", res.Error)
+ return columns, res.Error
+ }
- tbName := res.Result.Map("TABLE_NAME")
- schema := res.Result.Map("TABLE_SCHEMA")
- colTyp := res.Result.Map("COLUMN_TYPE")
- colCharset := res.Result.Map("CHARACTER_SET_NAME")
- collation := res.Result.Map("COLLATION_NAME")
-
- // 获取ddl
- for _, row := range res.Rows {
- col := &common.Column{
- Name: name,
- Table: row.Str(tbName),
- DB: row.Str(schema),
- DataType: row.Str(colTyp),
- Character: row.Str(colCharset),
- Collation: row.Str(collation),
- }
+ var col common.Column
+ for res.Rows.Next() {
+ res.Rows.Scan(&col.Table,
+ &col.DB,
+ &col.DataType,
+ &col.Character,
+ &col.Collation)
// 填充字符集和排序规则
if col.Character == "" {
@@ -494,40 +465,56 @@ func (db *Connector) FindColumn(name, dbName string, tables ...string) ([]*commo
sql = fmt.Sprintf("SELECT `t`.`TABLE_COLLATION` FROM `INFORMATION_SCHEMA`.`TABLES` AS `t` "+
"WHERE `t`.`TABLE_NAME`='%s' AND `t`.`TABLE_SCHEMA` = '%s'", col.Table, col.DB)
- var newRes *QueryResult
+
+ common.Log.Debug("FindColumn, execute SQL: %s", sql)
+ var newRes QueryResult
newRes, err = db.Query(sql)
if err != nil {
common.Log.Error("(db *Connector) FindColumn Error : ", err)
return columns, err
}
+ if res.Error != nil {
+ common.Log.Error("(db *Connector) FindColumn Error : ", res.Error)
+ return columns, res.Error
+ }
- tbCollation := newRes.Rows[0].Str(0)
+ var tbCollation string
+ if newRes.Rows.Next() {
+ newRes.Rows.Scan(&tbCollation)
+ }
if tbCollation != "" {
col.Character = strings.Split(tbCollation, "_")[0]
col.Collation = tbCollation
}
}
-
- columns = append(columns, col)
+ columns = append(columns, &col)
}
-
return columns, err
}
-// IsFKey 判断列是否是外键
-func (db *Connector) IsFKey(dbName, tbName, column string) bool {
+// IsForeignKey 判断列是否是外键
+func (db *Connector) IsForeignKey(dbName, tbName, column string) bool {
sql := fmt.Sprintf("SELECT REFERENCED_COLUMN_NAME FROM INFORMATION_SCHEMA.KEY_COLUMN_USAGE C "+
"WHERE REFERENCED_TABLE_SCHEMA <> 'NULL' AND"+
" TABLE_NAME='%s' AND"+
" TABLE_SCHEMA='%s' AND"+
" COLUMN_NAME='%s'", tbName, dbName, column)
+ common.Log.Debug("IsForeignKey, execute SQL: %s", sql)
res, err := db.Query(sql)
- if err == nil && len(res.Rows) == 0 {
+ if err != nil {
+ common.Log.Error("IsForeignKey, Error: %s", err.Error())
+ return false
+ }
+ if res.Error != nil {
+ common.Log.Error("IsForeignKey, Error: %s", res.Error.Error())
return false
}
+ if res.Rows.Next() {
+ return true
+ }
- return true
+ return false
}
// Reference 用于存储关系
@@ -535,11 +522,11 @@ type Reference map[string][]ReferenceValue
// ReferenceValue 用于处理表之间的关系
type ReferenceValue struct {
- RefDBName string // 夫表所属数据库
- RefTable string // 父表
- DBName string // 子表所属数据库
- Table string // 子表
- ConstraintName string // 关系名称
+ ReferencedTableSchema string // 夫表所属数据库
+ ReferencedTableName string // 父表
+ TableSchema string // 子表所属数据库
+ TableName string // 子表
+ ConstraintName string // 关系名称
}
// ShowReference 查找所有的外键信息
@@ -555,30 +542,26 @@ WHERE C.REFERENCED_TABLE_NAME IS NOT NULL`
sql = sql + extra
}
+ common.Log.Debug("ShowReference, execute SQL: %s", sql)
// 执行SQL查找外键关联关系
res, err := db.Query(sql)
if err != nil {
return referenceValues, err
}
-
- refDb := res.Result.Map("REFERENCED_TABLE_SCHEMA")
- refTb := res.Result.Map("REFERENCED_TABLE_NAME")
- schema := res.Result.Map("TABLE_SCHEMA")
- tb := res.Result.Map("TABLE_NAME")
- cName := res.Result.Map("CONSTRAINT_NAME")
+ if res.Error != nil {
+ return referenceValues, res.Error
+ }
// 获取值
- for _, row := range res.Rows {
- value := ReferenceValue{
- RefDBName: row.Str(refDb),
- RefTable: row.Str(refTb),
- DBName: row.Str(schema),
- Table: row.Str(tb),
- ConstraintName: row.Str(cName),
- }
- referenceValues = append(referenceValues, value)
+ for res.Rows.Next() {
+ var rv ReferenceValue
+ res.Rows.Scan(&rv.ReferencedTableSchema,
+ &rv.ReferencedTableName,
+ &rv.TableSchema,
+ &rv.TableName,
+ &rv.ConstraintName)
+ referenceValues = append(referenceValues, rv)
}
return referenceValues, err
-
}
diff --git a/database/show_test.go b/database/show_test.go
index 66f98d4714b43f5dae50f0611891569105c1ebfe..68ae97d0290f06223e4de54ab93cf4f5678363da 100644
--- a/database/show_test.go
+++ b/database/show_test.go
@@ -20,75 +20,144 @@ import (
"fmt"
"testing"
+ "github.com/XiaoMi/soar/common"
+
"github.com/kr/pretty"
"vitess.io/vitess/go/vt/sqlparser"
)
func TestShowTableStatus(t *testing.T) {
- connTest.Database = "information_schema"
- ts, err := connTest.ShowTableStatus("TABLES")
+ orgDatabase := connTest.Database
+ connTest.Database = "sakila"
+ ts, err := connTest.ShowTableStatus("film")
if err != nil {
t.Error("ShowTableStatus Error: ", err)
}
+ if string(ts.Rows[0].Engine) != "InnoDB" {
+ t.Error("film table should be InnoDB engine")
+ }
pretty.Println(ts)
+
+ connTest.Database = "sakila"
+ ts, err = connTest.ShowTableStatus("actor_info")
+ if err != nil {
+ t.Error("ShowTableStatus Error: ", err)
+ }
+ if string(ts.Rows[0].Comment) != "VIEW" {
+ t.Error("actor_info should be VIEW")
+ }
+ pretty.Println(ts)
+ connTest.Database = orgDatabase
}
func TestShowTables(t *testing.T) {
- connTest.Database = "information_schema"
+ orgDatabase := connTest.Database
+ connTest.Database = "sakila"
ts, err := connTest.ShowTables()
if err != nil {
t.Error("ShowTableStatus Error: ", err)
}
- pretty.Println(ts)
+
+ err = common.GoldenDiff(func() {
+ for _, table := range ts {
+ fmt.Println(table)
+ }
+ }, t.Name(), update)
+ if err != nil {
+ t.Error(err)
+ }
+ connTest.Database = orgDatabase
}
func TestShowCreateTable(t *testing.T) {
- connTest.Database = "information_schema"
- ts, err := connTest.ShowCreateTable("TABLES")
+ orgDatabase := connTest.Database
+ connTest.Database = "sakila"
+ ts, err := connTest.ShowCreateTable("film")
if err != nil {
t.Error("ShowCreateTable Error: ", err)
}
- fmt.Println(ts)
- stmt, err := sqlparser.Parse(ts)
- pretty.Println(stmt, err)
+
+ err = common.GoldenDiff(func() {
+ fmt.Println(ts)
+ stmt, err := sqlparser.Parse(ts)
+ if err != nil {
+ t.Error(err.Error())
+ }
+ pretty.Println(stmt, err)
+ }, t.Name(), update)
+ if err != nil {
+ t.Error(err)
+ }
+
+ connTest.Database = orgDatabase
}
func TestShowIndex(t *testing.T) {
- connTest.Database = "information_schema"
- ti, err := connTest.ShowIndex("TABLES")
+ orgDatabase := connTest.Database
+ connTest.Database = "sakila"
+ ti, err := connTest.ShowIndex("film")
if err != nil {
t.Error("ShowIndex Error: ", err)
}
- pretty.Println(ti.FindIndex(IndexKeyName, "idx_store_id_film_id"))
+
+ err = common.GoldenDiff(func() {
+ pretty.Println(ti)
+ pretty.Println(ti.FindIndex(IndexKeyName, "idx_title"))
+ }, t.Name(), update)
+ if err != nil {
+ t.Error(err)
+ }
+
+ connTest.Database = orgDatabase
}
func TestShowColumns(t *testing.T) {
- connTest.Database = "information_schema"
- ti, err := connTest.ShowColumns("TABLES")
+ orgDatabase := connTest.Database
+ connTest.Database = "sakila"
+ ti, err := connTest.ShowColumns("film")
if err != nil {
t.Error("ShowColumns Error: ", err)
}
- pretty.Println(ti)
+
+ err = common.GoldenDiff(func() {
+ pretty.Println(ti)
+ }, t.Name(), update)
+ if err != nil {
+ t.Error(err)
+ }
+
+ connTest.Database = orgDatabase
}
func TestFindColumn(t *testing.T) {
- ti, err := connTest.FindColumn("id", "")
+ ti, err := connTest.FindColumn("film_id", "sakila", "film")
if err != nil {
t.Error("FindColumn Error: ", err)
}
- pretty.Println(ti)
+ err = common.GoldenDiff(func() {
+ pretty.Println(ti)
+ }, t.Name(), update)
+ if err != nil {
+ t.Error(err)
+ }
+}
+
+func TestIsFKey(t *testing.T) {
+ if !connTest.IsForeignKey("sakila", "film", "language_id") {
+ t.Error("want True. got false")
+ }
}
func TestShowReference(t *testing.T) {
- rv, err := connTest.ShowReference("test2", "homeImg")
+ rv, err := connTest.ShowReference("sakila", "film")
if err != nil {
t.Error("ShowReference Error: ", err)
}
- pretty.Println(rv)
-}
-func TestIsFKey(t *testing.T) {
- if !connTest.IsFKey("sakila", "film", "language_id") {
- t.Error("want True. got false")
+ err = common.GoldenDiff(func() {
+ pretty.Println(rv)
+ }, t.Name(), update)
+ if err != nil {
+ t.Error(err)
}
}
diff --git a/database/testdata/TestFindColumn.golden b/database/testdata/TestFindColumn.golden
new file mode 100644
index 0000000000000000000000000000000000000000..c7deffee353e83627e818875dd3bf283304ce80d
--- /dev/null
+++ b/database/testdata/TestFindColumn.golden
@@ -0,0 +1,18 @@
+[]*common.Column{
+ &common.Column{
+ Name: "",
+ Alias: nil,
+ Table: "film",
+ DB: "sakila",
+ DataType: "smallint(5) unsigned",
+ Character: "utf8",
+ Collation: "utf8_general_ci",
+ Cardinality: 0,
+ Null: "",
+ Key: "",
+ Default: "",
+ Extra: "",
+ Comment: "",
+ Privileges: "",
+ },
+}
diff --git a/database/testdata/TestFormatTrace.golden b/database/testdata/TestFormatTrace.golden
new file mode 100644
index 0000000000000000000000000000000000000000..b216399779df18fb839bc7e5bcd44a4fc7982b90
--- /dev/null
+++ b/database/testdata/TestFormatTrace.golden
@@ -0,0 +1,36 @@
+
+```sql
+select 1
+```
+
+```json
+{
+ "steps": [
+ {
+ "join_preparation": {
+ "select#": 1,
+ "steps": [
+ {
+ "expanded_query": "/* select#1 */ select 1 AS `1`"
+ }
+ ]
+ }
+ },
+ {
+ "join_optimization": {
+ "select#": 1,
+ "steps": [
+ ]
+ }
+ },
+ {
+ "join_explain": {
+ "select#": 1,
+ "steps": [
+ ]
+ }
+ }
+ ]
+}
+```
+
diff --git a/database/testdata/TestShowColumns.golden b/database/testdata/TestShowColumns.golden
new file mode 100644
index 0000000000000000000000000000000000000000..782a412be2134426a1fa811deeade89911ec6452
--- /dev/null
+++ b/database/testdata/TestShowColumns.golden
@@ -0,0 +1,148 @@
+&database.TableDesc{
+ Name: "film",
+ DescValues: {
+ {
+ Field: "film_id",
+ Type: "smallint(5) unsigned",
+ Collation: nil,
+ Null: "NO",
+ Key: "PRI",
+ Default: nil,
+ Extra: "auto_increment",
+ Privileges: "select,insert,update,references",
+ Comment: "",
+ },
+ {
+ Field: "title",
+ Type: "varchar(255)",
+ Collation: {0x75, 0x74, 0x66, 0x38, 0x5f, 0x67, 0x65, 0x6e, 0x65, 0x72, 0x61, 0x6c, 0x5f, 0x63, 0x69},
+ Null: "NO",
+ Key: "MUL",
+ Default: nil,
+ Extra: "",
+ Privileges: "select,insert,update,references",
+ Comment: "",
+ },
+ {
+ Field: "description",
+ Type: "text",
+ Collation: {0x75, 0x74, 0x66, 0x38, 0x5f, 0x67, 0x65, 0x6e, 0x65, 0x72, 0x61, 0x6c, 0x5f, 0x63, 0x69},
+ Null: "YES",
+ Key: "",
+ Default: nil,
+ Extra: "",
+ Privileges: "select,insert,update,references",
+ Comment: "",
+ },
+ {
+ Field: "release_year",
+ Type: "year(4)",
+ Collation: nil,
+ Null: "YES",
+ Key: "",
+ Default: nil,
+ Extra: "",
+ Privileges: "select,insert,update,references",
+ Comment: "",
+ },
+ {
+ Field: "language_id",
+ Type: "tinyint(3) unsigned",
+ Collation: nil,
+ Null: "NO",
+ Key: "MUL",
+ Default: nil,
+ Extra: "",
+ Privileges: "select,insert,update,references",
+ Comment: "",
+ },
+ {
+ Field: "original_language_id",
+ Type: "tinyint(3) unsigned",
+ Collation: nil,
+ Null: "YES",
+ Key: "MUL",
+ Default: nil,
+ Extra: "",
+ Privileges: "select,insert,update,references",
+ Comment: "",
+ },
+ {
+ Field: "rental_duration",
+ Type: "tinyint(3) unsigned",
+ Collation: nil,
+ Null: "NO",
+ Key: "",
+ Default: {0x33},
+ Extra: "",
+ Privileges: "select,insert,update,references",
+ Comment: "",
+ },
+ {
+ Field: "rental_rate",
+ Type: "decimal(4,2)",
+ Collation: nil,
+ Null: "NO",
+ Key: "",
+ Default: {0x34, 0x2e, 0x39, 0x39},
+ Extra: "",
+ Privileges: "select,insert,update,references",
+ Comment: "",
+ },
+ {
+ Field: "length",
+ Type: "smallint(5) unsigned",
+ Collation: nil,
+ Null: "YES",
+ Key: "",
+ Default: nil,
+ Extra: "",
+ Privileges: "select,insert,update,references",
+ Comment: "",
+ },
+ {
+ Field: "replacement_cost",
+ Type: "decimal(5,2)",
+ Collation: nil,
+ Null: "NO",
+ Key: "",
+ Default: {0x31, 0x39, 0x2e, 0x39, 0x39},
+ Extra: "",
+ Privileges: "select,insert,update,references",
+ Comment: "",
+ },
+ {
+ Field: "rating",
+ Type: "enum('G','PG','PG-13','R','NC-17')",
+ Collation: {0x75, 0x74, 0x66, 0x38, 0x5f, 0x67, 0x65, 0x6e, 0x65, 0x72, 0x61, 0x6c, 0x5f, 0x63, 0x69},
+ Null: "YES",
+ Key: "",
+ Default: {0x47},
+ Extra: "",
+ Privileges: "select,insert,update,references",
+ Comment: "",
+ },
+ {
+ Field: "special_features",
+ Type: "set('Trailers','Commentaries','Deleted Scenes','Behind the Scenes')",
+ Collation: {0x75, 0x74, 0x66, 0x38, 0x5f, 0x67, 0x65, 0x6e, 0x65, 0x72, 0x61, 0x6c, 0x5f, 0x63, 0x69},
+ Null: "YES",
+ Key: "",
+ Default: nil,
+ Extra: "",
+ Privileges: "select,insert,update,references",
+ Comment: "",
+ },
+ {
+ Field: "last_update",
+ Type: "timestamp",
+ Collation: nil,
+ Null: "NO",
+ Key: "",
+ Default: {0x43, 0x55, 0x52, 0x52, 0x45, 0x4e, 0x54, 0x5f, 0x54, 0x49, 0x4d, 0x45, 0x53, 0x54, 0x41, 0x4d, 0x50},
+ Extra: "on update CURRENT_TIMESTAMP",
+ Privileges: "select,insert,update,references",
+ Comment: "",
+ },
+ },
+}
diff --git a/database/testdata/TestShowCreateTable.golden b/database/testdata/TestShowCreateTable.golden
new file mode 100644
index 0000000000000000000000000000000000000000..377ef8d32fc8da8685f4b7e4268ba613371b489a
--- /dev/null
+++ b/database/testdata/TestShowCreateTable.golden
@@ -0,0 +1,34 @@
+CREATE TABLE `film` (
+ `film_id` smallint(5) unsigned NOT NULL AUTO_INCREMENT,
+ `title` varchar(255) NOT NULL,
+ `description` text,
+ `release_year` year(4) DEFAULT NULL,
+ `language_id` tinyint(3) unsigned NOT NULL,
+ `original_language_id` tinyint(3) unsigned DEFAULT NULL,
+ `rental_duration` tinyint(3) unsigned NOT NULL DEFAULT '3',
+ `rental_rate` decimal(4,2) NOT NULL DEFAULT '4.99',
+ `length` smallint(5) unsigned DEFAULT NULL,
+ `replacement_cost` decimal(5,2) NOT NULL DEFAULT '19.99',
+ `rating` enum('G','PG','PG-13','R','NC-17') DEFAULT 'G',
+ `special_features` set('Trailers','Commentaries','Deleted Scenes','Behind the Scenes') DEFAULT NULL,
+ `last_update` timestamp NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
+ PRIMARY KEY (`film_id`),
+ KEY `idx_title` (`title`),
+ KEY `idx_fk_language_id` (`language_id`),
+ KEY `idx_fk_original_language_id` (`original_language_id`)
+) ENGINE=InnoDB AUTO_INCREMENT=1001 DEFAULT CHARSET=utf8
+&sqlparser.DDL{
+ Action: "create",
+ FromTables: nil,
+ ToTables: nil,
+ Table: sqlparser.TableName{
+ Name: sqlparser.TableIdent{v:"film"},
+ Qualifier: sqlparser.TableIdent{},
+ },
+ IfExists: false,
+ TableSpec: (*sqlparser.TableSpec)(nil),
+ OptLike: (*sqlparser.OptLike)(nil),
+ PartitionSpec: (*sqlparser.PartitionSpec)(nil),
+ VindexSpec: (*sqlparser.VindexSpec)(nil),
+ VindexCols: nil,
+} nil
diff --git a/database/testdata/TestShowIndex.golden b/database/testdata/TestShowIndex.golden
new file mode 100644
index 0000000000000000000000000000000000000000..48a9f8cb5e40db045185f7f9cf87a088dd21eb8b
--- /dev/null
+++ b/database/testdata/TestShowIndex.golden
@@ -0,0 +1,12 @@
+&database.TableIndexInfo{
+ TableName: "film",
+ Rows: {
+ {Table:"film", NonUnique:0, KeyName:"PRIMARY", SeqInIndex:1, ColumnName:"film_id", Collation:"A", Cardinality:1000, SubPart:0, Packed:0, Null:"", IndexType:"", Comment:"", IndexComment:"", Visible:""},
+ {Table:"film", NonUnique:1, KeyName:"idx_title", SeqInIndex:1, ColumnName:"title", Collation:"A", Cardinality:1000, SubPart:0, Packed:0, Null:"", IndexType:"", Comment:"", IndexComment:"", Visible:""},
+ {Table:"film", NonUnique:1, KeyName:"idx_fk_language_id", SeqInIndex:1, ColumnName:"language_id", Collation:"A", Cardinality:1, SubPart:0, Packed:0, Null:"", IndexType:"", Comment:"", IndexComment:"", Visible:""},
+ {Table:"film", NonUnique:1, KeyName:"idx_fk_original_language_id", SeqInIndex:1, ColumnName:"original_language_id", Collation:"A", Cardinality:1, SubPart:0, Packed:0, Null:"", IndexType:"", Comment:"", IndexComment:"", Visible:""},
+ },
+}
+[]database.TableIndexRow{
+ {Table:"film", NonUnique:1, KeyName:"idx_title", SeqInIndex:1, ColumnName:"title", Collation:"A", Cardinality:1000, SubPart:0, Packed:0, Null:"", IndexType:"", Comment:"", IndexComment:"", Visible:""},
+}
diff --git a/database/testdata/TestShowReference.golden b/database/testdata/TestShowReference.golden
new file mode 100644
index 0000000000000000000000000000000000000000..b9c79ef096e00ae871530048b68b082fb09515ef
--- /dev/null
+++ b/database/testdata/TestShowReference.golden
@@ -0,0 +1,4 @@
+[]database.ReferenceValue{
+ {ReferencedTableSchema:"sakila", ReferencedTableName:"language", TableSchema:"sakila", TableName:"film", ConstraintName:"fk_film_language"},
+ {ReferencedTableSchema:"sakila", ReferencedTableName:"language", TableSchema:"sakila", TableName:"film", ConstraintName:"fk_film_language_original"},
+}
diff --git a/database/testdata/TestShowTables.golden b/database/testdata/TestShowTables.golden
new file mode 100644
index 0000000000000000000000000000000000000000..27a6ad5d1f84e6e101e9dfec5f3baf6fccfe182a
--- /dev/null
+++ b/database/testdata/TestShowTables.golden
@@ -0,0 +1,23 @@
+actor
+actor_info
+address
+category
+city
+country
+customer
+customer_list
+film
+film_actor
+film_category
+film_list
+film_text
+inventory
+language
+nicer_but_slower_film_list
+payment
+rental
+sales_by_film_category
+sales_by_store
+staff
+staff_list
+store
diff --git a/database/testdata/TestSource.sql b/database/testdata/TestSource.sql
deleted file mode 100644
index 2bba4d129a6a838eb13a07583d2ca0eb013843ce..0000000000000000000000000000000000000000
--- a/database/testdata/TestSource.sql
+++ /dev/null
@@ -1,2 +0,0 @@
-select 1;
-select 1;
diff --git a/database/testdata/TestTrace.golden b/database/testdata/TestTrace.golden
new file mode 100644
index 0000000000000000000000000000000000000000..b76661057d72c5ae3ff5a60adb1d0342fcde21ce
--- /dev/null
+++ b/database/testdata/TestTrace.golden
@@ -0,0 +1,3 @@
+[]database.TraceRow{
+ {Query:"explain select 1", Trace:"{\n \"steps\": [\n {\n \"join_preparation\": {\n \"select#\": 1,\n \"steps\": [\n {\n \"expanded_query\": \"/* select#1 */ select 1 AS `1`\"\n }\n ]\n }\n },\n {\n \"join_optimization\": {\n \"select#\": 1,\n \"steps\": [\n ]\n }\n },\n {\n \"join_explain\": {\n \"select#\": 1,\n \"steps\": [\n ]\n }\n }\n ]\n}", MissingBytesBeyondMaxMemSize:0, InsufficientPrivileges:0},
+}
diff --git a/database/trace.go b/database/trace.go
index 0ba611aba68fc5f62bc41aa2518f7011b763da27..7c2d2b5a1fe58a642d134848a1309856fbdfef3f 100644
--- a/database/trace.go
+++ b/database/trace.go
@@ -19,12 +19,9 @@ package database
import (
"errors"
"fmt"
- "io"
+ "github.com/XiaoMi/soar/common"
"regexp"
"strings"
- "time"
-
- "github.com/XiaoMi/soar/common"
"vitess.io/vitess/go/vt/sqlparser"
)
@@ -43,10 +40,11 @@ type TraceRow struct {
}
// Trace 执行SQL,并对其Trace
-func (db *Connector) Trace(sql string, params ...interface{}) (*QueryResult, error) {
+func (db *Connector) Trace(sql string, params ...interface{}) ([]TraceRow, error) {
common.Log.Debug("Trace SQL: %s", sql)
+ var rows []TraceRow
if common.Config.TestDSN.Version < 50600 {
- return nil, errors.New("version < 5.6, not support trace")
+ return rows, errors.New("version < 5.6, not support trace")
}
// 过滤不需要 Trace 的 SQL
@@ -55,98 +53,71 @@ func (db *Connector) Trace(sql string, params ...interface{}) (*QueryResult, err
sql = "explain " + sql
case sqlparser.EXPLAIN:
default:
- return nil, errors.New("no need trace")
+ return rows, errors.New("no need trace")
}
// 测试环境如果检查是关闭的,则SQL不会被执行
if common.Config.TestDSN.Disable {
- return nil, errors.New("Dsn Disable")
+ return rows, errors.New("dsn is disable")
}
// 数据库安全性检查:如果 Connector 的 IP 端口与 TEST 环境不一致,则启用SQL白名单
// 不在白名单中的SQL不允许执行
// 执行环境与test环境不相同
if db.Addr != common.Config.TestDSN.Addr && db.dangerousQuery(sql) {
- return nil, fmt.Errorf("query Execution Deny: Execute SQL with DSN(%s/%s) '%s'",
+ return rows, fmt.Errorf("query Execution Deny: Execute SQL with DSN(%s/%s) '%s'",
db.Addr, db.Database, fmt.Sprintf(sql, params...))
}
common.Log.Debug("Execute SQL with DSN(%s/%s) : %s", db.Addr, db.Database, sql)
- conn := db.NewConnection()
-
- // 设置SQL连接超时时间
- conn.SetTimeout(time.Duration(common.Config.ConnTimeOut) * time.Second)
- defer conn.Close()
- err := conn.Connect()
+ conn, err := db.NewConnection()
if err != nil {
- return nil, err
+ return rows, err
}
+ defer conn.Close()
- // 添加SQL执行超时限制
- ch := make(chan QueryResult, 1)
- go func() {
- // 开启Trace
- common.Log.Debug("SET SESSION OPTIMIZER_TRACE='enabled=on'")
- _, _, err = conn.Query("SET SESSION OPTIMIZER_TRACE='enabled=on'")
- common.LogIfError(err, "")
-
- // 执行SQL,抛弃返回结果
- result, err := conn.Start(sql, params...)
- if err != nil {
- ch <- QueryResult{
- Error: err,
- }
- return
- }
- row := result.MakeRow()
- for {
- err = result.ScanRow(row)
- if err == io.EOF {
- break
- }
- }
+ // 开启Trace
+ common.Log.Debug("SET SESSION OPTIMIZER_TRACE='enabled=on'")
+ trx, err := conn.Begin()
+ if err != nil {
+ return rows, err
+ }
+ defer trx.Rollback()
+ _, err = trx.Query("SET SESSION OPTIMIZER_TRACE='enabled=on'")
+ common.LogIfError(err, "")
- // 返回Trace结果
- res := QueryResult{}
- res.Rows, res.Result, res.Error = conn.Query("SELECT * FROM information_schema.OPTIMIZER_TRACE")
+ // 执行SQL,抛弃返回结果
+ tmpRes, err := trx.Query(sql, params...)
+ if err != nil {
+ return rows, err
+ }
+ for tmpRes.Next() {
+ continue
+ }
- // 关闭Trace
- common.Log.Debug("SET SESSION OPTIMIZER_TRACE='enabled=off'")
- _, _, err = conn.Query("SET SESSION OPTIMIZER_TRACE='enabled=off'")
+ // 返回Trace结果
+ res, err := trx.Query("SELECT * FROM information_schema.OPTIMIZER_TRACE")
+ for res.Next() {
+ var traceRow TraceRow
+ err = res.Scan(&traceRow.Query, &traceRow.Trace, &traceRow.MissingBytesBeyondMaxMemSize, &traceRow.InsufficientPrivileges)
if err != nil {
- fmt.Println(err.Error())
+ common.LogIfError(err, "")
}
- ch <- res
- }()
-
- select {
- case res := <-ch:
- return &res, res.Error
- case <-time.After(time.Duration(common.Config.QueryTimeOut) * time.Second):
- return nil, errors.New("query execution timeout")
+ rows = append(rows, traceRow)
}
-}
-// getTrace 获取trace信息
-func getTrace(res *QueryResult) Trace {
- var rows []TraceRow
- for _, row := range res.Rows {
- rows = append(rows, TraceRow{
- Query: row.Str(0),
- Trace: row.Str(1),
- MissingBytesBeyondMaxMemSize: row.Int(2),
- InsufficientPrivileges: row.Int(3),
- })
- }
- return Trace{Rows: rows}
+ // 关闭Trace
+ common.Log.Debug("SET SESSION OPTIMIZER_TRACE='enabled=off'")
+ _, err = trx.Query("SET SESSION OPTIMIZER_TRACE='enabled=off'")
+ common.LogIfError(err, "")
+ return rows, err
}
// FormatTrace 格式化输出Trace信息
-func FormatTrace(res *QueryResult) string {
+func FormatTrace(rows []TraceRow) string {
explainReg := regexp.MustCompile(`(?i)^explain\s+`)
- trace := getTrace(res)
str := []string{""}
- for _, row := range trace.Rows {
+ for _, row := range rows {
str = append(str, "```sql")
sql := explainReg.ReplaceAllString(row.Query, "")
str = append(str, sql)
diff --git a/database/trace_test.go b/database/trace_test.go
index 8dea2d7ffa8270facee7c4003566625e5ac04221..535d3f58cf948fbb54c1a8407a63bb3317225753 100644
--- a/database/trace_test.go
+++ b/database/trace_test.go
@@ -17,7 +17,6 @@
package database
import (
- "flag"
"testing"
"github.com/XiaoMi/soar/common"
@@ -25,34 +24,30 @@ import (
"github.com/kr/pretty"
)
-var update = flag.Bool("update", false, "update .golden files")
-
func TestTrace(t *testing.T) {
- common.Config.QueryTimeOut = 1
res, err := connTest.Trace("select 1")
- if err == nil {
- common.GoldenDiff(func() {
- pretty.Println(res)
- }, t.Name(), update)
- } else {
+ if err != nil {
+ t.Error(err)
+ }
+
+ err = common.GoldenDiff(func() {
+ pretty.Println(res)
+ }, t.Name(), update)
+ if err != nil {
t.Error(err)
}
}
func TestFormatTrace(t *testing.T) {
res, err := connTest.Trace("select 1")
- if err == nil {
- pretty.Println(FormatTrace(res))
- } else {
+ if err != nil {
t.Error(err)
}
-}
-func TestGetTrace(t *testing.T) {
- res, err := connTest.Trace("select 1")
- if err == nil {
- pretty.Println(getTrace(res))
- } else {
+ err = common.GoldenDiff(func() {
+ pretty.Println(FormatTrace(res))
+ }, t.Name(), update)
+ if err != nil {
t.Error(err)
}
}
diff --git a/env/env.go b/env/env.go
index 9b3f6ea0910b441df8a4b121d1a3ba7f0ed8e2b2..cc21a03ccae7bc946621569ebed1b6c9ee33a7a3 100644
--- a/env/env.go
+++ b/env/env.go
@@ -161,8 +161,9 @@ func (ve *VirtualEnv) CleanupTestDatabase() {
// TODO: 1 hour should be config-able
minHour := 1
- for _, row := range dbs.Rows {
- testDatabase := row.Str(0)
+ for dbs.Rows.Next() {
+ var testDatabase string
+ dbs.Rows.Scan(&testDatabase)
// test temporary database format `optimizer_YYMMDDHHmmss_randomString(16)`
if len(testDatabase) != 39 {
common.Log.Debug("CleanupTestDatabase by pass %s", testDatabase)
@@ -218,7 +219,7 @@ func (ve *VirtualEnv) BuildVirtualEnv(rEnv *database.Connector, SQLs ...string)
meta := make(map[string]*common.DB)
for _, sql := range SQLs {
- common.Log.Debug("BuildVirtualEnv Database&Table Mapping, SQL: %s", sql)
+ common.Log.Debug("BuildVirtualEnv Database&TableName Mapping, SQL: %s", sql)
stmt, err = sqlparser.Parse(sql)
if err != nil {
@@ -320,7 +321,7 @@ func (ve *VirtualEnv) BuildVirtualEnv(rEnv *database.Connector, SQLs ...string)
}
// 如果是视图,解析语句
- if len(tbStatus.Rows) > 0 && tbStatus.Rows[0].Comment == "VIEW" {
+ if len(tbStatus.Rows) > 0 && string(tbStatus.Rows[0].Comment) == "VIEW" {
tmpEnv.Database = db
var viewDDL string
viewDDL, err = tmpEnv.ShowCreateTable(tb.TableName)
@@ -418,7 +419,7 @@ func (ve VirtualEnv) createTable(rEnv database.Connector, dbName, tbName string)
return nil
}
- common.Log.Debug("createTable, Database: %s, Table: %s", dbName, tbName)
+ common.Log.Debug("createTable, Database: %s, TableName: %s", dbName, tbName)
// TODO:查看是否有外键关联(done),对外键的支持 (未解决循环依赖的问题)
@@ -506,9 +507,9 @@ func (ve *VirtualEnv) GenTableColumns(meta common.Meta) common.TableColumns {
DB: dbName,
Table: tb.TableName,
DataType: colInfo.Type,
- Character: colInfo.Collation,
+ Character: string(colInfo.Collation),
Key: colInfo.Key,
- Default: colInfo.Default,
+ Default: string(colInfo.Default),
Extra: colInfo.Extra,
Comment: colInfo.Comment,
Privileges: colInfo.Privileges,
@@ -525,9 +526,9 @@ func (ve *VirtualEnv) GenTableColumns(meta common.Meta) common.TableColumns {
col.DB = dbName
col.Table = tb.TableName
col.DataType = colInfo.Type
- col.Character = colInfo.Collation
+ col.Character = string(colInfo.Collation)
col.Key = colInfo.Key
- col.Default = colInfo.Default
+ col.Default = string(colInfo.Default)
col.Extra = colInfo.Extra
col.Comment = colInfo.Comment
col.Privileges = colInfo.Privileges