diff --git a/advisor/index_test.go b/advisor/index_test.go index fc9ecd206490eb9a02a1073602dfff979be7d9b6..0a355bb636fe28e835294a67623032f9e0f1402d 100644 --- a/advisor/index_test.go +++ b/advisor/index_test.go @@ -17,25 +17,30 @@ package advisor import ( + "flag" "fmt" "os" "strings" "testing" "github.com/XiaoMi/soar/common" + "github.com/XiaoMi/soar/database" "github.com/XiaoMi/soar/env" "github.com/kr/pretty" "vitess.io/vitess/go/vt/sqlparser" ) +var update = flag.Bool("update", false, "update .golden files") +var vEnv *env.VirtualEnv +var rEnv *database.Connector + func init() { common.BaseDir = common.DevPath err := common.ParseConfig("") - if err != nil { - fmt.Println(err.Error()) - } - vEnv, rEnv := env.BuildEnv() + common.LogIfError(err, "init ParseConfig") + common.Log.Debug("advisor_test init") + vEnv, rEnv = env.BuildEnv() if _, err = vEnv.Version(); err != nil { fmt.Println(err.Error(), ", By pass all advisor test cases") os.Exit(0) @@ -45,6 +50,7 @@ func init() { fmt.Println(err.Error(), ", By pass all advisor test cases") os.Exit(0) } + defer vEnv.CleanUp() } // ARG.003 @@ -52,8 +58,6 @@ func TestRuleImplicitConversion(t *testing.T) { common.Log.Debug("Entering function: %s", common.GetFunctionName()) dsn := common.Config.OnlineDSN common.Config.OnlineDSN = common.Config.TestDSN - vEnv, rEnv := env.BuildEnv() - defer vEnv.CleanUp() initSQLs := []string{ `CREATE TABLE t1 (id int, title varchar(255) CHARSET utf8 COLLATE utf8_general_ci);`, @@ -104,9 +108,6 @@ func TestRuleImpossibleOuterJoin(t *testing.T) { `select city_id, city, country from city left outer join country on city.country_id=country.country_id WHERE city.city_id IS NULL`, } - vEnv, rEnv := env.BuildEnv() - defer vEnv.CleanUp() - for _, sql := range sqls { stmt, syntaxErr := sqlparser.Parse(sql) if syntaxErr != nil { @@ -146,9 +147,6 @@ func TestIndexAdvisorRuleGroupByConst(t *testing.T) { }, } - vEnv, rEnv := env.BuildEnv() - defer vEnv.CleanUp() - for _, sql := range sqls[0] { stmt, syntaxErr := sqlparser.Parse(sql) if syntaxErr != nil { @@ -211,9 +209,6 @@ func TestIndexAdvisorRuleOrderByConst(t *testing.T) { }, } - vEnv, rEnv := env.BuildEnv() - defer vEnv.CleanUp() - for _, sql := range sqls[0] { stmt, syntaxErr := sqlparser.Parse(sql) if syntaxErr != nil { @@ -275,9 +270,6 @@ func TestRuleUpdatePrimaryKey(t *testing.T) { }, } - vEnv, rEnv := env.BuildEnv() - defer vEnv.CleanUp() - for _, sql := range sqls[0] { stmt, syntaxErr := sqlparser.Parse(sql) if syntaxErr != nil { @@ -328,8 +320,6 @@ func TestRuleUpdatePrimaryKey(t *testing.T) { func TestIndexAdvise(t *testing.T) { common.Log.Debug("Entering function: %s", common.GetFunctionName()) - vEnv, rEnv := env.BuildEnv() - defer vEnv.CleanUp() for _, sql := range common.TestSQLs { stmt, syntaxErr := sqlparser.Parse(sql) @@ -360,8 +350,6 @@ func TestIndexAdviseNoEnv(t *testing.T) { common.Log.Debug("Entering function: %s", common.GetFunctionName()) orgOnlineDSNStatus := common.Config.OnlineDSN.Disable common.Config.OnlineDSN.Disable = true - vEnv, rEnv := env.BuildEnv() - defer vEnv.CleanUp() for _, sql := range common.TestSQLs { stmt, syntaxErr := sqlparser.Parse(sql) @@ -391,7 +379,6 @@ func TestIndexAdviseNoEnv(t *testing.T) { func TestDuplicateKeyChecker(t *testing.T) { common.Log.Debug("Entering function: %s", common.GetFunctionName()) - _, rEnv := env.BuildEnv() rule := DuplicateKeyChecker(rEnv, "sakila") if len(rule) != 0 { t.Errorf("got rules: %s", pretty.Sprint(rule)) @@ -426,9 +413,6 @@ func TestIdxColsTypeCheck(t *testing.T) { `select city_id, city, country from city left outer join country using(country_id) WHERE city.city_id=59 and country.country='Algeria'`, } - vEnv, rEnv := env.BuildEnv() - defer vEnv.CleanUp() - for _, sql := range sqls { stmt, syntaxErr := sqlparser.Parse(sql) if syntaxErr != nil { diff --git a/advisor/rules_test.go b/advisor/rules_test.go index 938a54601813c974ca9607aadd04ba141153e436..b97e912ce9a8065a43a6501f1e660c257e50695f 100644 --- a/advisor/rules_test.go +++ b/advisor/rules_test.go @@ -17,14 +17,11 @@ package advisor import ( - "flag" "testing" "github.com/XiaoMi/soar/common" ) -var update = flag.Bool("update", false, "update .golden files") - func TestListTestSQLs(t *testing.T) { err := common.GoldenDiff(func() { ListTestSQLs() }, t.Name(), update) if nil != err { diff --git a/cmd/soar/soar.go b/cmd/soar/soar.go index 38ad820d8866af7ff206b15f029f7fd46438ebeb..dc8b7ea33d1a4e446951a5bde310c9a6a31dfc28 100644 --- a/cmd/soar/soar.go +++ b/cmd/soar/soar.go @@ -68,7 +68,7 @@ func main() { // 当程序卡死的时候,或者由于某些原因程序没有退出,可以通过捕获信号量的形式让程序优雅退出并且清理测试环境 common.HandleSignal(func() { - shutdown(vEnv) + shutdown(vEnv, rEnv) }) // 对指定的库表进行索引重复检查 diff --git a/cmd/soar/tool.go b/cmd/soar/tool.go index dd977d69a50df46f076e6f2681c8f18ed11ea9cf..bc124b931395a38e5d929934779729f0273b12af 100644 --- a/cmd/soar/tool.go +++ b/cmd/soar/tool.go @@ -67,52 +67,48 @@ func initConfig() { // if error found return non-zero, no error return zero func checkConfig() int { // TestDSN connection check - testConn := &database.Connector{ - Addr: common.Config.TestDSN.Addr, - User: common.Config.TestDSN.User, - Pass: common.Config.TestDSN.Password, - Database: common.Config.TestDSN.Schema, - Charset: common.Config.TestDSN.Charset, - } - testVersion, err := testConn.Version() + connTest, err := database.NewConnector(common.Config.TestDSN) + if err != nil { + fmt.Println("test-dsn:", common.Config.TestDSN.Addr, err.Error()) + return 1 + } + testVersion, err := connTest.Version() if err != nil && !common.Config.TestDSN.Disable { - fmt.Println("test-dsn:", testConn, err.Error()) + fmt.Println("test-dsn:", connTest, err.Error()) return 1 } if common.Config.Verbose { if err == nil { - fmt.Println("test-dsn", testConn, "Version:", testVersion) + fmt.Println("test-dsn", connTest, "Version:", testVersion) } else { fmt.Println("test-dsn", common.Config.TestDSN) } } - if !testConn.HasAllPrivilege() { + if !connTest.HasAllPrivilege() { fmt.Printf("test-dsn: %s, need all privileges", common.FormatDSN(common.Config.TestDSN)) return 1 } // OnlineDSN connection check - onlineConn := &database.Connector{ - Addr: common.Config.OnlineDSN.Addr, - User: common.Config.OnlineDSN.User, - Pass: common.Config.OnlineDSN.Password, - Database: common.Config.OnlineDSN.Schema, - Charset: common.Config.OnlineDSN.Charset, - } - onlineVersion, err := onlineConn.Version() + connOnline, err := database.NewConnector(common.Config.OnlineDSN) + if err != nil { + fmt.Println("test-dsn:", common.Config.OnlineDSN.Addr, err.Error()) + return 1 + } + onlineVersion, err := connOnline.Version() if err != nil && !common.Config.OnlineDSN.Disable { - fmt.Println("online-dsn:", onlineConn, err.Error()) + fmt.Println("online-dsn:", connOnline, err.Error()) return 1 } if common.Config.Verbose { if err == nil { - fmt.Println("online-dsn", onlineConn, "Version:", onlineVersion) + fmt.Println("online-dsn", connOnline, "Version:", onlineVersion) } else { fmt.Println("online-dsn", common.Config.OnlineDSN) } } - if !onlineConn.HasSelectPrivilege() { + if !connOnline.HasSelectPrivilege() { fmt.Printf("online-dsn: %s, need all privileges", common.FormatDSN(common.Config.OnlineDSN)) return 1 } @@ -229,9 +225,11 @@ func initQuery(query string) string { return query } -func shutdown(vEnv *env.VirtualEnv) { +func shutdown(vEnv *env.VirtualEnv, rEnv *database.Connector) { if common.Config.DropTestTemporary { vEnv.CleanUp() } + vEnv.Conn.Close() + rEnv.Conn.Close() os.Exit(0) } diff --git a/common/config.go b/common/config.go index cdfacb2c7c4e5d29a69c93e8cd8bac70e27eb1fb..b74f9728668df5a7d013f655cb8949addcaa53ba 100644 --- a/common/config.go +++ b/common/config.go @@ -48,8 +48,8 @@ var ( // Configuration 配置文件定义结构体 type Configuration struct { // +++++++++++++++测试环境+++++++++++++++++ - OnlineDSN *dsn `yaml:"online-dsn"` // 线上环境数据库配置 - TestDSN *dsn `yaml:"test-dsn"` // 测试环境数据库配置 + OnlineDSN *Dsn `yaml:"online-dsn"` // 线上环境数据库配置 + TestDSN *Dsn `yaml:"test-dsn"` // 测试环境数据库配置 AllowOnlineAsTest bool `yaml:"allow-online-as-test"` // 允许 Online 环境也可以当作 Test 环境 DropTestTemporary bool `yaml:"drop-test-temporary"` // 是否清理Test环境产生的临时库表 CleanupTestDatabase bool `yaml:"cleanup-test-database"` // 清理残余的测试数据库(程序异常退出或未开启drop-test-temporary) issue #48 @@ -132,14 +132,14 @@ type Configuration struct { // Config 默认设置 var Config = &Configuration{ - OnlineDSN: &dsn{ + OnlineDSN: &Dsn{ Net: "tcp", Schema: "information_schema", Charset: "utf8mb4", Disable: true, Version: 99999, }, - TestDSN: &dsn{ + TestDSN: &Dsn{ Net: "tcp", Schema: "information_schema", Charset: "utf8mb4", @@ -224,7 +224,8 @@ var Config = &Configuration{ MaxPrettySQLLength: 1024, } -type dsn struct { +// Dsn Data source name +type Dsn struct { Net string `yaml:"net"` Addr string `yaml:"addr"` Schema string `yaml:"schema"` @@ -243,7 +244,7 @@ type dsn struct { } // 解析命令行DSN输入 -func parseDSN(odbc string, d *dsn) *dsn { +func parseDSN(odbc string, d *Dsn) *Dsn { var addr, user, password, schema, charset string if odbc == FormatDSN(d) { return d @@ -260,7 +261,7 @@ func parseDSN(odbc string, d *dsn) *dsn { // 设置为空表示禁用环境 odbc = strings.TrimSpace(odbc) if odbc == "" { - return &dsn{Disable: true} + return &Dsn{Disable: true} } // username:password@ip:port/database @@ -354,7 +355,7 @@ func parseDSN(odbc string, d *dsn) *dsn { charset = "utf8mb4" } - dsn := &dsn{ + dsn := &Dsn{ Addr: addr, User: user, Password: password, @@ -367,7 +368,7 @@ func parseDSN(odbc string, d *dsn) *dsn { } // FormatDSN 格式化打印DSN -func FormatDSN(env *dsn) string { +func FormatDSN(env *Dsn) string { if env == nil || env.Disable { return "" } diff --git a/common/example_test.go b/common/example_test.go index 8baeb7100c19a78bc3e58731a5fe58824b8bcef5..a0d0877035e4927b094d34de7f2513aa8966b4b4 100644 --- a/common/example_test.go +++ b/common/example_test.go @@ -19,7 +19,7 @@ package common import "fmt" func ExampleFormatDSN() { - dsxExp := &dsn{ + dsxExp := &Dsn{ Addr: "127.0.0.1:3306", Schema: "mysql", User: "root", diff --git a/common/meta.go b/common/meta.go index e95bb37d91351fa100bb67fe66c4c50ed96fe043..1b2882991c79ac696ded2f3f826be58f2a1acc1a 100644 --- a/common/meta.go +++ b/common/meta.go @@ -38,7 +38,7 @@ func NewDB(db string) *DB { } } -// TableName 含有表的属性 +// Table 含有表的属性 type Table struct { TableName string TableAliases []string diff --git a/database/explain.go b/database/explain.go index 07e4ce35a5fa51cac83b29961b604d2b443aec58..fce96bd0842d03eade3770363301e64e4c832fa8 100644 --- a/database/explain.go +++ b/database/explain.go @@ -556,14 +556,12 @@ func (db *Connector) explainAbleSQL(sql string) (string, error) { return "", nil } -// 执行explain请求,返回mysql.Result执行结果 -func (db *Connector) executeExplain(sql string, explainType int, formatType int) (QueryResult, error) { - var res QueryResult +// explainQuery 生成可执行的 explain 查询请求 +func (db *Connector) explainQuery(sql string, explainType int, formatType int) string { var err error - var explainQuery string sql, err = db.explainAbleSQL(sql) - if sql == "" { - return res, err + if sql == "" || err != nil { + return sql } // 5.6以上支持 FORMAT=JSON @@ -580,18 +578,17 @@ func (db *Connector) executeExplain(sql string, explainType int, formatType int) case ExtendedExplainType: // 5.6以上extended关键字已经不推荐使用,8.0废弃了这个关键字 if common.Config.TestDSN.Version >= 50600 { - explainQuery = fmt.Sprintf("explain %s", sql) + sql = fmt.Sprintf("explain %s", sql) } else { - explainQuery = fmt.Sprintf("explain extended %s", sql) + sql = fmt.Sprintf("explain extended %s", sql) } case PartitionsExplainType: - explainQuery = fmt.Sprintf("explain partitions %s", sql) + sql = fmt.Sprintf("explain partitions %s", sql) default: - explainQuery = fmt.Sprintf("explain %s %s", explainFormat, sql) + sql = fmt.Sprintf("explain %s %s", explainFormat, sql) } - res, err = db.Query(explainQuery) - return res, err + return sql } // MySQLExplainWarnings WARNINGS信息中包含的优化器信息 @@ -944,6 +941,7 @@ func ParseExplainResult(res QueryResult, formatType int) (exp *ExplainInfo, err err = res.Rows.Scan(&explainString) exp.ExplainJSON, err = parseJSONExplainText(explainString) } + res.Rows.Close() return exp, err } @@ -992,23 +990,24 @@ func ParseExplainResult(res QueryResult, formatType int) (exp *ExplainInfo, err expRow.Scalability = ExplainScalability[expRow.AccessType] explainRows = append(explainRows, expRow) } + res.Rows.Close() exp.ExplainRows = explainRows // check explain warning info if common.Config.ShowWarnings { for res.Warning.Next() { var expWarning *ExplainWarning - res.Warning.Scan( - expWarning.Level, - expWarning.Code, - expWarning.Message, - ) + err = res.Warning.Scan(expWarning.Level, expWarning.Code, expWarning.Message) + if err != nil { + break + } // 'EXTENDED' is deprecated and will be removed in a future release. if expWarning.Code != 1681 { exp.Warnings = append(exp.Warnings, expWarning) } } + res.Warning.Close() } // 添加 last_query_cost @@ -1022,6 +1021,7 @@ func (db *Connector) Explain(sql string, explainType int, formatType int) (exp * if explainType != TraditionalExplainType { formatType = TraditionalFormatExplain } + defer func() { if e := recover(); e != nil { const size = 4096 @@ -1033,10 +1033,8 @@ func (db *Connector) Explain(sql string, explainType int, formatType int) (exp * }() // 执行EXPLAIN请求 - res, err := db.executeExplain(sql, explainType, formatType) - if err != nil { - return exp, err - } + sql = db.explainQuery(sql, explainType, formatType) + res, err := db.Query(sql) // 解析mysql结果,输出ExplainInfo exp, err = ParseExplainResult(res, formatType) diff --git a/database/mysql.go b/database/mysql.go index 51baa9f701b634d3d6939f7de98a0e1d3fd752a8..0db911e0b7cbf710f65508e090182f23f8452dca 100644 --- a/database/mysql.go +++ b/database/mysql.go @@ -39,6 +39,7 @@ type Connector struct { Database string Charset string Net string + Conn *sql.DB } // QueryResult 数据库查询返回值 @@ -49,20 +50,38 @@ type QueryResult struct { QueryCost float64 } -// NewConnection 创建新连接 -func (db *Connector) NewConnection() (*sql.DB, error) { - dsn := fmt.Sprintf("%s:%s@%s(%s)/%s?parseTime=true&charset=%s", db.User, db.Pass, db.Net, db.Addr, db.Database, db.Charset) - return sql.Open("mysql", dsn) +// NewConnector 创建新连接 +func NewConnector(dsn *common.Dsn) (*Connector, error) { + conn, err := sql.Open("mysql", fmt.Sprintf("%s:%s@%s(%s)/%s?parseTime=true&charset=%s", + dsn.User, + dsn.Password, + dsn.Net, + dsn.Addr, + dsn.Schema, + dsn.Charset, + )) + if err != nil { + return nil, err + } + connector := &Connector{ + Addr: dsn.Addr, + User: dsn.User, + Pass: dsn.Password, + Database: dsn.Schema, + Charset: dsn.Charset, + Conn: conn, + } + return connector, err } // Query 执行SQL func (db *Connector) Query(sql string, params ...interface{}) (QueryResult, error) { var res QueryResult + var err error // 测试环境如果检查是关闭的,则SQL不会被执行 if common.Config.TestDSN.Disable { return res, errors.New("dsn is disable") } - // 数据库安全性检查:如果 Connector 的 IP 端口与 TEST 环境不一致,则启用SQL白名单 // 不在白名单中的SQL不允许执行 // 执行环境与test环境不相同 @@ -72,23 +91,25 @@ func (db *Connector) Query(sql string, params ...interface{}) (QueryResult, erro } common.Log.Debug("Execute SQL with DSN(%s/%s) : %s", db.Addr, db.Database, fmt.Sprintf(sql, params...)) - conn, err := db.NewConnection() - defer conn.Close() - if err != nil { - return res, err - } - res.Rows, res.Error = conn.Query(sql, params...) + _, err = db.Conn.Exec("USE " + db.Database) + common.LogIfError(err, "") + res.Rows, res.Error = db.Conn.Query(sql, params...) if common.Config.ShowWarnings { - res.Warning, err = conn.Query("SHOW WARNINGS") + res.Warning, err = db.Conn.Query("SHOW WARNINGS") + common.LogIfError(err, "") } // SHOW WARNINGS 并不会影响 last_query_cost if common.Config.ShowLastQueryCost { - cost, err := conn.Query("SHOW SESSION STATUS LIKE 'last_query_cost'") + cost, err := db.Conn.Query("SHOW SESSION STATUS LIKE 'last_query_cost'") if err == nil { if cost.Next() { err = cost.Scan(res.QueryCost) + common.LogIfError(err, "") + } + if err := cost.Close(); err != nil { + common.Log.Error(err.Error()) } } } @@ -115,6 +136,9 @@ func (db *Connector) Version() (int, error) { var versionSeg []string for res.Rows.Next() { err = res.Rows.Scan(&versionStr) + if err != nil { + break + } versionStr = strings.Split(versionStr, "-")[0] versionSeg = strings.Split(versionStr, ".") if len(versionSeg) == 3 { @@ -123,7 +147,9 @@ func (db *Connector) Version() (int, error) { } break } - + if err := res.Rows.Close(); err != nil { + common.Log.Error(err.Error()) + } return version, err } @@ -140,6 +166,9 @@ func (db *Connector) SingleIntValue(option string) (int, error) { if res.Rows.Next() { err = res.Rows.Scan(&intVal) } + if err := res.Rows.Close(); err != nil { + common.Log.Error(err.Error()) + } return intVal, err } @@ -188,6 +217,7 @@ func (db *Connector) ColumnCardinality(tb, col string) float64 { return 0 } } + res.Rows.Close() // 当table status元数据不准确时 rowTotal 可能远小于count(*),导致散粒度大于1 if colNum > float64(rowTotal) { diff --git a/database/mysql_test.go b/database/mysql_test.go index 1688ba12cefe9794b164130d478a745578de0c11..9f4827e77dbbbfdbb2011f4ddca4ff237914aaf9 100644 --- a/database/mysql_test.go +++ b/database/mysql_test.go @@ -28,31 +28,23 @@ import ( ) var connTest *Connector - var update = flag.Bool("update", false, "update .golden files") func init() { common.BaseDir = common.DevPath - common.ParseConfig("") - connTest = &Connector{ - Addr: common.Config.OnlineDSN.Addr, - User: common.Config.OnlineDSN.User, - Pass: common.Config.OnlineDSN.Password, - Database: common.Config.OnlineDSN.Schema, - Charset: common.Config.OnlineDSN.Charset, - } - if _, err := connTest.Version(); err != nil { + err := common.ParseConfig("") + common.LogIfError(err, "init ParseConfig") + common.Log.Debug("mysql_test init") + connTest, err = NewConnector(common.Config.TestDSN) + if err != nil { common.Log.Critical("Test env Error: %v", err) os.Exit(0) } -} -func TestNewConnection(t *testing.T) { - conn, err := connTest.NewConnection() - if err != nil { - t.Errorf("TestNewConnection, Error: %s", err.Error()) + if _, err := connTest.Version(); err != nil { + common.Log.Critical("Test env Error: %v", err) + os.Exit(0) } - defer conn.Close() } func TestQuery(t *testing.T) { @@ -70,6 +62,7 @@ func TestQuery(t *testing.T) { t.Error("should return 0") } } + res.Rows.Close() // TODO: timeout test } @@ -115,6 +108,7 @@ func TestWarningsAndQueryCost(t *testing.T) { } pretty.Println(str) } + res.Warning.Close() fmt.Println(res.QueryCost, err) } } diff --git a/database/privilege.go b/database/privilege.go index dca833e3b9d11636675dd87ad18220f449f5dfe5..3ff22fb5e33a6ac577c2469d93886b2c9fad940c 100644 --- a/database/privilege.go +++ b/database/privilege.go @@ -26,28 +26,31 @@ import ( // CurrentUser get current user with current_user() function func (db *Connector) CurrentUser() (string, string, error) { + var user, host string res, err := db.Query("select current_user()") if err != nil { - return "", "", err + return user, host, err } if res.Rows.Next() { var currentUser string err = res.Rows.Scan(¤tUser) if err != nil { - return "", "", err + return user, host, err } + res.Rows.Close() + cols := strings.Split(currentUser, "@") if len(cols) == 2 { - user := strings.Trim(cols[0], "'") - host := strings.Trim(cols[1], "'") + user = strings.Trim(cols[0], "'") + host = strings.Trim(cols[1], "'") if strings.Contains(user, "'") || strings.Contains(host, "'") { return "", "", errors.New("user or host contains irregular character") } return user, host, nil } - return "", "", errors.New("user or host contains irregular character") + return user, host, errors.New("user or host contains irregular character") } - return "", "", errors.New("no privilege info") + return user, host, errors.New("no privilege info") } // HasSelectPrivilege if user has select privilege @@ -70,6 +73,8 @@ func (db *Connector) HasSelectPrivilege() bool { common.Log.Error("HasSelectPrivilege, Scan Error: %s", err.Error()) return false } + res.Rows.Close() + if selectPrivilege == "Y" { return true } @@ -99,6 +104,7 @@ func (db *Connector) HasAllPrivilege() bool { common.Log.Error("HasAllPrivilege, DSN: %s, Scan error", db.Addr) return false } + res.Rows.Close() } // get all privilege status @@ -115,6 +121,7 @@ func (db *Connector) HasAllPrivilege() bool { common.Log.Error("HasAllPrivilege, DSN: %s, Scan error", db.Addr) return false } + res.Rows.Close() if strings.Replace(priv, "Y", "", -1) == "" { return true } diff --git a/database/profiling.go b/database/profiling.go index f85a6b0231bfd259660db702f13b144e4a7193b6..db566973228dad4f529fc621ccbc3d896050dbf1 100644 --- a/database/profiling.go +++ b/database/profiling.go @@ -62,15 +62,9 @@ func (db *Connector) Profiling(sql string, params ...interface{}) ([]ProfilingRo } common.Log.Debug("Execute SQL with DSN(%s/%s) : %s", db.Addr, db.Database, sql) - conn, err := db.NewConnection() - if err != nil { - return rows, err - } - defer conn.Close() - // Keep connection // https://github.com/go-sql-driver/mysql/issues/208 - trx, err := conn.Begin() + trx, err := db.Conn.Begin() if err != nil { return rows, err } @@ -91,14 +85,20 @@ func (db *Connector) Profiling(sql string, params ...interface{}) ([]ProfilingRo // 返回 Profiling 结果 res, err := trx.Query("show profile") + if err != nil { + trx.Rollback() + return rows, err + } + var profileRow ProfilingRow for res.Next() { - var profileRow ProfilingRow - err := res.Scan(&profileRow.Status, &profileRow.Duration) + err = res.Scan(&profileRow.Status, &profileRow.Duration) if err != nil { common.LogIfError(err, "") + break } rows = append(rows, profileRow) } + res.Close() // 关闭 Profiling _, err = trx.Query("set @@profiling=0") diff --git a/database/sampling.go b/database/sampling.go index c9c3e60f4db366daeb022d6e94ba1b030c71c500..3e280db78a44eeb7e484d35338fc54ca8b6ba4d0 100644 --- a/database/sampling.go +++ b/database/sampling.go @@ -19,6 +19,7 @@ package database import ( "database/sql" "fmt" + "github.com/XiaoMi/soar/common" ) @@ -43,7 +44,7 @@ import ( */ // SamplingData 将数据从Remote拉取到 db 中 -func (db *Connector) SamplingData(remote Connector, tables ...string) error { +func (db *Connector) SamplingData(remote *Connector, tables ...string) error { // 计算需要泵取的数据量 wantRowsCount := 300 * common.Config.SamplingStatisticTarget @@ -51,19 +52,6 @@ func (db *Connector) SamplingData(remote Connector, tables ...string) error { // 该数值越大,在内存中缓存的data就越多,但相对的,插入时速度就越快 maxValCount := 200 - // 获取数据库连接对象 - conn, err := remote.NewConnection() - if err != nil { - return err - } - defer conn.Close() - - localConn, err := db.NewConnection() - if err != nil { - return err - } - defer localConn.Close() - for _, table := range tables { // 表类型检查 if remote.IsView(table) { @@ -89,7 +77,7 @@ func (db *Connector) SamplingData(remote Connector, tables ...string) error { factor := float64(wantRowsCount) / float64(tableRows) common.Log.Debug("SamplingData, tableRows: %d, wantRowsCount: %d, factor: %f", tableRows, wantRowsCount, factor) - err = startSampling(conn, localConn, db.Database, table, factor, wantRowsCount, maxValCount) + err = startSampling(remote.Conn, db.Conn, db.Database, table, factor, wantRowsCount, maxValCount) if err != nil { common.Log.Error("(db *Connector) SamplingData Error : %v", err) } diff --git a/database/sampling_test.go b/database/sampling_test.go index d164b0b936e692cbdb976bc1f417c24733ffbfda..f8525dbcc5bc58ee67544cc23046c443024bec44 100644 --- a/database/sampling_test.go +++ b/database/sampling_test.go @@ -27,27 +27,12 @@ func init() { } func TestSamplingData(t *testing.T) { - online := &Connector{ - Addr: common.Config.OnlineDSN.Addr, - User: common.Config.OnlineDSN.User, - Pass: common.Config.OnlineDSN.Password, - Database: common.Config.OnlineDSN.Schema, - Charset: common.Config.OnlineDSN.Charset, - Net: common.Config.OnlineDSN.Net, - } - - offline := &Connector{ - Addr: common.Config.TestDSN.Addr, - User: common.Config.TestDSN.User, - Pass: common.Config.TestDSN.Password, - Database: common.Config.TestDSN.Schema, - Charset: common.Config.TestDSN.Charset, - Net: common.Config.TestDSN.Net, + connOnline, err := NewConnector(common.Config.OnlineDSN) + if err != nil { + t.Error(err) } - offline.Database = "test" - - err := connTest.SamplingData(*online, "film") + err = connTest.SamplingData(connOnline, "film") if err != nil { t.Error(err) } diff --git a/database/show.go b/database/show.go index 2cf3d8c88e175ac0aaf7f0569cb79563071fec5e..75ca837c461b8845d49c0b64cbae910374e11fee 100644 --- a/database/show.go +++ b/database/show.go @@ -99,10 +99,11 @@ func (db *Connector) ShowTables() ([]string, error) { var table string err = res.Rows.Scan(&table) if err != nil { - return []string{}, err + break } tables = append(tables, table) } + res.Rows.Close() return tables, err } @@ -156,6 +157,7 @@ func (db *Connector) ShowTableStatus(tableName string) (*TableStatInfo, error) { res.Rows.Scan(statusFields...) tbStatus.Rows = append(tbStatus.Rows, ts) } + res.Rows.Close() return tbStatus, err } @@ -244,6 +246,7 @@ func (db *Connector) ShowIndex(tableName string) (*TableIndexInfo, error) { res.Rows.Scan(indexFields...) tbIndex.Rows = append(tbIndex.Rows, ti) } + res.Rows.Close() return tbIndex, err } @@ -374,6 +377,7 @@ func (db *Connector) ShowColumns(tableName string) (*TableDesc, error) { res.Rows.Scan(columnFields...) tbDesc.DescValues = append(tbDesc.DescValues, tc) } + res.Rows.Close() return tbDesc, err } @@ -428,7 +432,7 @@ func (db *Connector) showCreate(createType, name string) (string, error) { for res.Rows.Next() { res.Rows.Scan(createFields...) } - + res.Rows.Close() return create, err } @@ -508,13 +512,10 @@ func (db *Connector) FindColumn(name, dbName string, tables ...string) ([]*commo var col common.Column for res.Rows.Next() { var character, collation []byte - res.Rows.Scan( - &col.Table, - &col.DB, - &col.DataType, - &character, - &collation, - ) + err = res.Rows.Scan(&col.Table, &col.DB, &col.DataType, &character, &collation) + if err != nil { + break + } col.Name = name col.Character = string(character) col.Collation = string(collation) @@ -537,8 +538,12 @@ func (db *Connector) FindColumn(name, dbName string, tables ...string) ([]*commo var tbCollation []byte if newRes.Rows.Next() { - newRes.Rows.Scan(&tbCollation) + err = newRes.Rows.Scan(&tbCollation) + if err != nil { + break + } } + newRes.Rows.Close() if string(tbCollation) != "" { col.Character = strings.Split(string(tbCollation), "_")[0] col.Collation = string(tbCollation) @@ -546,6 +551,7 @@ func (db *Connector) FindColumn(name, dbName string, tables ...string) ([]*commo } columns = append(columns, &col) } + res.Rows.Close() return columns, err } @@ -603,13 +609,12 @@ func (db *Connector) ShowReference(dbName string, tbName ...string) ([]Reference // 获取值 for res.Rows.Next() { var rv ReferenceValue - res.Rows.Scan(&rv.ReferencedTableSchema, - &rv.ReferencedTableName, - &rv.TableSchema, - &rv.TableName, - &rv.ConstraintName) + err = res.Rows.Scan(&rv.ReferencedTableSchema, &rv.ReferencedTableName, &rv.TableSchema, &rv.TableName, &rv.ConstraintName) + if err != nil { + break + } referenceValues = append(referenceValues, rv) } - + res.Rows.Close() return referenceValues, err } diff --git a/database/show_test.go b/database/show_test.go index 49b2441674fb3c8babab042ffc9bf365446ced82..ceea0abffc4883dd4d4a679603c7561c86ebead7 100644 --- a/database/show_test.go +++ b/database/show_test.go @@ -68,6 +68,15 @@ func TestShowTables(t *testing.T) { connTest.Database = orgDatabase } +func TestShowCreateDatabase(t *testing.T) { + err := common.GoldenDiff(func() { + fmt.Println(connTest.ShowCreateDatabase("sakila")) + }, t.Name(), update) + if err != nil { + t.Error(err) + } +} + func TestShowCreateTable(t *testing.T) { orgDatabase := connTest.Database connTest.Database = "sakila" @@ -87,7 +96,6 @@ func TestShowCreateTable(t *testing.T) { if err != nil { t.Error(err) } - connTest.Database = orgDatabase } diff --git a/database/testdata/TestShowCreateDatabase.golden b/database/testdata/TestShowCreateDatabase.golden new file mode 100644 index 0000000000000000000000000000000000000000..85dc9c85fce4f33025f7263a1f242dee642743fe --- /dev/null +++ b/database/testdata/TestShowCreateDatabase.golden @@ -0,0 +1 @@ +CREATE DATABASE `sakila` /*!40100 DEFAULT CHARACTER SET utf8mb4 COLLATE utf8mb4_0900_ai_ci */ diff --git a/database/trace.go b/database/trace.go index 7c2d2b5a1fe58a642d134848a1309856fbdfef3f..bbf3160a4a6ef5c20ef17ff0e944d21ff486920e 100644 --- a/database/trace.go +++ b/database/trace.go @@ -19,10 +19,11 @@ package database import ( "errors" "fmt" - "github.com/XiaoMi/soar/common" "regexp" "strings" + "github.com/XiaoMi/soar/common" + "vitess.io/vitess/go/vt/sqlparser" ) @@ -70,15 +71,9 @@ func (db *Connector) Trace(sql string, params ...interface{}) ([]TraceRow, error } common.Log.Debug("Execute SQL with DSN(%s/%s) : %s", db.Addr, db.Database, sql) - conn, err := db.NewConnection() - if err != nil { - return rows, err - } - defer conn.Close() - // 开启Trace common.Log.Debug("SET SESSION OPTIMIZER_TRACE='enabled=on'") - trx, err := conn.Begin() + trx, err := db.Conn.Begin() if err != nil { return rows, err } @@ -97,14 +92,20 @@ func (db *Connector) Trace(sql string, params ...interface{}) ([]TraceRow, error // 返回Trace结果 res, err := trx.Query("SELECT * FROM information_schema.OPTIMIZER_TRACE") + if err != nil { + trx.Rollback() + return rows, err + } for res.Next() { var traceRow TraceRow err = res.Scan(&traceRow.Query, &traceRow.Trace, &traceRow.MissingBytesBeyondMaxMemSize, &traceRow.InsufficientPrivileges) if err != nil { common.LogIfError(err, "") + break } rows = append(rows, traceRow) } + res.Close() // 关闭Trace common.Log.Debug("SET SESSION OPTIMIZER_TRACE='enabled=off'") diff --git a/env/env.go b/env/env.go index bc7491bd042aae3edcb4802fcca3501539996abf..204b62be3225ae002aeaa3262b8f9a768c79efef 100644 --- a/env/env.go +++ b/env/env.go @@ -57,14 +57,10 @@ func NewVirtualEnv(vEnv *database.Connector) *VirtualEnv { // @output *VirtualEnv 测试环境 // @output *database.Connector 线上环境连接句柄 func BuildEnv() (*VirtualEnv, *database.Connector) { + connTest, err := database.NewConnector(common.Config.TestDSN) + common.LogIfError(err, "") // 生成测试环境 - vEnv := NewVirtualEnv(&database.Connector{ - Addr: common.Config.TestDSN.Addr, - User: common.Config.TestDSN.User, - Pass: common.Config.TestDSN.Password, - Database: common.Config.TestDSN.Schema, - Charset: common.Config.TestDSN.Charset, - }) + vEnv := NewVirtualEnv(connTest) // 检查测试环境可用性,并记录数据库版本 vEnvVersion, err := vEnv.Version() @@ -82,13 +78,8 @@ func BuildEnv() (*VirtualEnv, *database.Connector) { vEnv.User, vEnv.Addr, vEnv.Database) common.Config.OnlineDSN = common.Config.TestDSN } - conn := &database.Connector{ - Addr: common.Config.OnlineDSN.Addr, - User: common.Config.OnlineDSN.User, - Pass: common.Config.OnlineDSN.Password, - Database: common.Config.OnlineDSN.Schema, - Charset: common.Config.OnlineDSN.Charset, - } + connOnline, err := database.NewConnector(common.Config.OnlineDSN) + common.LogIfError(err, "") // 检查线上环境可用性版本 rEnvVersion, err := vEnv.Version() @@ -114,7 +105,7 @@ func BuildEnv() (*VirtualEnv, *database.Connector) { common.Config.TestDSN.Disable = true } - return vEnv, conn + return vEnv, connOnline } // RealDB 从测试环境中获取通过hash后的DB @@ -163,7 +154,10 @@ func (ve *VirtualEnv) CleanupTestDatabase() { minHour := 1 for dbs.Rows.Next() { var testDatabase string - dbs.Rows.Scan(&testDatabase) + err = dbs.Rows.Scan(&testDatabase) + if err != nil { + break + } // test temporary database format `optimizer_YYMMDDHHmmss_randomString(16)` if len(testDatabase) != 39 { common.Log.Debug("CleanupTestDatabase by pass %s", testDatabase) @@ -187,7 +181,8 @@ func (ve *VirtualEnv) CleanupTestDatabase() { } common.Log.Debug("CleanupTestDatabase by pass database %s, %.2f less than %d hours", testDatabase, subHour, minHour) } - + err = dbs.Rows.Close() + common.LogIfError(err, "") common.Log.Debug("CleanupTestDatabase done") } @@ -200,7 +195,7 @@ func (ve *VirtualEnv) BuildVirtualEnv(rEnv *database.Connector, SQLs ...string) // 置空错误信息 ve.Error = nil // 检测是否已经创建初始数据库,如果未创建则创建一个名称hash过的映射数据库 - err = ve.createDatabase(*rEnv, rEnv.Database) + err = ve.createDatabase(rEnv, rEnv.Database) common.LogIfWarn(err, "") // 测试环境检测 @@ -237,7 +232,7 @@ func (ve *VirtualEnv) BuildVirtualEnv(rEnv *database.Connector, SQLs ...string) rEnv.Database = stmt.DBName.String() // use DB 后检查 DB是否已经创建,如果没有创建则创建DB - err = ve.createDatabase(*rEnv, rEnv.Database) + err = ve.createDatabase(rEnv, rEnv.Database) common.LogIfWarn(err, "") } return true @@ -271,7 +266,7 @@ func (ve *VirtualEnv) BuildVirtualEnv(rEnv *database.Connector, SQLs ...string) // 拉取表结构 table := stmt.Table.Name.String() if table != "" { - err = ve.createTable(*rEnv, rEnv.Database, table) + err = ve.createTable(rEnv, rEnv.Database, table) // 这里如果报错可能有两种可能: // 1. SQL 是 Create 语句,线上环境并没有相关的库表结构 // 2. 在测试环境中执行 SQL 报错 @@ -303,7 +298,7 @@ func (ve *VirtualEnv) BuildVirtualEnv(rEnv *database.Connector, SQLs ...string) if db == "" { db = rEnv.Database } - tmpEnv := *rEnv + tmpEnv := rEnv tmpEnv.Database = db // 创建数据库环境 @@ -336,7 +331,7 @@ func (ve *VirtualEnv) BuildVirtualEnv(rEnv *database.Connector, SQLs ...string) return false } viewDDL = viewDDL[startIdx+2:] - if !ve.BuildVirtualEnv(&tmpEnv, viewDDL) { + if !ve.BuildVirtualEnv(tmpEnv, viewDDL) { return false } } @@ -352,7 +347,7 @@ func (ve *VirtualEnv) BuildVirtualEnv(rEnv *database.Connector, SQLs ...string) return true } -func (ve VirtualEnv) createDatabase(rEnv database.Connector, dbName string) error { +func (ve VirtualEnv) createDatabase(rEnv *database.Connector, dbName string) error { // 生成映射关系 if _, ok := ve.DBRef[dbName]; ok { common.Log.Debug("createDatabase, Database `%s` created", dbName) @@ -369,6 +364,9 @@ func (ve VirtualEnv) createDatabase(rEnv database.Connector, dbName string) erro } ddl = strings.Replace(ddl, dbName, dbHash, -1) + if ddl == "" { + return fmt.Errorf("dbName: '%s' get create info error", dbName) + } _, err = ve.Query(ddl) if err != nil { common.Log.Warning("createDatabase, Error : %v", err) @@ -401,7 +399,7 @@ func (ve VirtualEnv) createDatabase(rEnv database.Connector, dbName string) erro soar 能够做出判断并进行 session 级别的修改,但是这一阶段可用性保证应该是由用户提供两个完全相同(或测试环境兼容线上环境) 的数据库环境来实现的。 */ -func (ve VirtualEnv) createTable(rEnv database.Connector, dbName, tbName string) error { +func (ve VirtualEnv) createTable(rEnv *database.Connector, dbName, tbName string) error { if dbName == "" { dbName = rEnv.Database @@ -470,7 +468,7 @@ func (ve VirtualEnv) createTable(rEnv database.Connector, dbName, tbName string) return nil } -// GenTableColumns 为Rewrite提供的结构体初始化 +// GenTableColumns 为 Rewrite 提供的结构体初始化 func (ve *VirtualEnv) GenTableColumns(meta common.Meta) common.TableColumns { tableColumns := make(common.TableColumns) for dbName, db := range meta { diff --git a/env/env_test.go b/env/env_test.go index 5e3e43dd24f9a76715258ecad50253596bf5a041..6f48a6310b5223c721fdcc07679879e8bc89c76c 100644 --- a/env/env_test.go +++ b/env/env_test.go @@ -18,10 +18,12 @@ package env import ( "flag" + "os" "testing" "github.com/XiaoMi/soar/common" "github.com/XiaoMi/soar/database" + "github.com/go-sql-driver/mysql" "github.com/kr/pretty" ) @@ -33,12 +35,16 @@ func init() { common.BaseDir = common.DevPath err := common.ParseConfig("") common.LogIfError(err, "init ParseConfig") - connTest = &database.Connector{ - Addr: common.Config.TestDSN.Addr, - User: common.Config.TestDSN.User, - Pass: common.Config.TestDSN.Password, - Database: common.Config.TestDSN.Schema, - Charset: common.Config.TestDSN.Charset, + common.Log.Debug("env_test init") + connTest, err = database.NewConnector(common.Config.TestDSN) + if err != nil { + common.Log.Critical("Test env Error: %v", err) + os.Exit(0) + } + + if _, err := connTest.Version(); err != nil { + common.Log.Critical("Test env Error: %v", err) + os.Exit(0) } }