提交 16df1e34 编写于 作者: martianzhang's avatar martianzhang

NewConnector

上级 78477138
...@@ -17,25 +17,30 @@ ...@@ -17,25 +17,30 @@
package advisor package advisor
import ( import (
"flag"
"fmt" "fmt"
"os" "os"
"strings" "strings"
"testing" "testing"
"github.com/XiaoMi/soar/common" "github.com/XiaoMi/soar/common"
"github.com/XiaoMi/soar/database"
"github.com/XiaoMi/soar/env" "github.com/XiaoMi/soar/env"
"github.com/kr/pretty" "github.com/kr/pretty"
"vitess.io/vitess/go/vt/sqlparser" "vitess.io/vitess/go/vt/sqlparser"
) )
var update = flag.Bool("update", false, "update .golden files")
var vEnv *env.VirtualEnv
var rEnv *database.Connector
func init() { func init() {
common.BaseDir = common.DevPath common.BaseDir = common.DevPath
err := common.ParseConfig("") err := common.ParseConfig("")
if err != nil { common.LogIfError(err, "init ParseConfig")
fmt.Println(err.Error()) common.Log.Debug("advisor_test init")
} vEnv, rEnv = env.BuildEnv()
vEnv, rEnv := env.BuildEnv()
if _, err = vEnv.Version(); err != nil { if _, err = vEnv.Version(); err != nil {
fmt.Println(err.Error(), ", By pass all advisor test cases") fmt.Println(err.Error(), ", By pass all advisor test cases")
os.Exit(0) os.Exit(0)
...@@ -45,6 +50,7 @@ func init() { ...@@ -45,6 +50,7 @@ func init() {
fmt.Println(err.Error(), ", By pass all advisor test cases") fmt.Println(err.Error(), ", By pass all advisor test cases")
os.Exit(0) os.Exit(0)
} }
defer vEnv.CleanUp()
} }
// ARG.003 // ARG.003
...@@ -52,8 +58,6 @@ func TestRuleImplicitConversion(t *testing.T) { ...@@ -52,8 +58,6 @@ func TestRuleImplicitConversion(t *testing.T) {
common.Log.Debug("Entering function: %s", common.GetFunctionName()) common.Log.Debug("Entering function: %s", common.GetFunctionName())
dsn := common.Config.OnlineDSN dsn := common.Config.OnlineDSN
common.Config.OnlineDSN = common.Config.TestDSN common.Config.OnlineDSN = common.Config.TestDSN
vEnv, rEnv := env.BuildEnv()
defer vEnv.CleanUp()
initSQLs := []string{ initSQLs := []string{
`CREATE TABLE t1 (id int, title varchar(255) CHARSET utf8 COLLATE utf8_general_ci);`, `CREATE TABLE t1 (id int, title varchar(255) CHARSET utf8 COLLATE utf8_general_ci);`,
...@@ -104,9 +108,6 @@ func TestRuleImpossibleOuterJoin(t *testing.T) { ...@@ -104,9 +108,6 @@ func TestRuleImpossibleOuterJoin(t *testing.T) {
`select city_id, city, country from city left outer join country on city.country_id=country.country_id WHERE city.city_id IS NULL`, `select city_id, city, country from city left outer join country on city.country_id=country.country_id WHERE city.city_id IS NULL`,
} }
vEnv, rEnv := env.BuildEnv()
defer vEnv.CleanUp()
for _, sql := range sqls { for _, sql := range sqls {
stmt, syntaxErr := sqlparser.Parse(sql) stmt, syntaxErr := sqlparser.Parse(sql)
if syntaxErr != nil { if syntaxErr != nil {
...@@ -146,9 +147,6 @@ func TestIndexAdvisorRuleGroupByConst(t *testing.T) { ...@@ -146,9 +147,6 @@ func TestIndexAdvisorRuleGroupByConst(t *testing.T) {
}, },
} }
vEnv, rEnv := env.BuildEnv()
defer vEnv.CleanUp()
for _, sql := range sqls[0] { for _, sql := range sqls[0] {
stmt, syntaxErr := sqlparser.Parse(sql) stmt, syntaxErr := sqlparser.Parse(sql)
if syntaxErr != nil { if syntaxErr != nil {
...@@ -211,9 +209,6 @@ func TestIndexAdvisorRuleOrderByConst(t *testing.T) { ...@@ -211,9 +209,6 @@ func TestIndexAdvisorRuleOrderByConst(t *testing.T) {
}, },
} }
vEnv, rEnv := env.BuildEnv()
defer vEnv.CleanUp()
for _, sql := range sqls[0] { for _, sql := range sqls[0] {
stmt, syntaxErr := sqlparser.Parse(sql) stmt, syntaxErr := sqlparser.Parse(sql)
if syntaxErr != nil { if syntaxErr != nil {
...@@ -275,9 +270,6 @@ func TestRuleUpdatePrimaryKey(t *testing.T) { ...@@ -275,9 +270,6 @@ func TestRuleUpdatePrimaryKey(t *testing.T) {
}, },
} }
vEnv, rEnv := env.BuildEnv()
defer vEnv.CleanUp()
for _, sql := range sqls[0] { for _, sql := range sqls[0] {
stmt, syntaxErr := sqlparser.Parse(sql) stmt, syntaxErr := sqlparser.Parse(sql)
if syntaxErr != nil { if syntaxErr != nil {
...@@ -328,8 +320,6 @@ func TestRuleUpdatePrimaryKey(t *testing.T) { ...@@ -328,8 +320,6 @@ func TestRuleUpdatePrimaryKey(t *testing.T) {
func TestIndexAdvise(t *testing.T) { func TestIndexAdvise(t *testing.T) {
common.Log.Debug("Entering function: %s", common.GetFunctionName()) common.Log.Debug("Entering function: %s", common.GetFunctionName())
vEnv, rEnv := env.BuildEnv()
defer vEnv.CleanUp()
for _, sql := range common.TestSQLs { for _, sql := range common.TestSQLs {
stmt, syntaxErr := sqlparser.Parse(sql) stmt, syntaxErr := sqlparser.Parse(sql)
...@@ -360,8 +350,6 @@ func TestIndexAdviseNoEnv(t *testing.T) { ...@@ -360,8 +350,6 @@ func TestIndexAdviseNoEnv(t *testing.T) {
common.Log.Debug("Entering function: %s", common.GetFunctionName()) common.Log.Debug("Entering function: %s", common.GetFunctionName())
orgOnlineDSNStatus := common.Config.OnlineDSN.Disable orgOnlineDSNStatus := common.Config.OnlineDSN.Disable
common.Config.OnlineDSN.Disable = true common.Config.OnlineDSN.Disable = true
vEnv, rEnv := env.BuildEnv()
defer vEnv.CleanUp()
for _, sql := range common.TestSQLs { for _, sql := range common.TestSQLs {
stmt, syntaxErr := sqlparser.Parse(sql) stmt, syntaxErr := sqlparser.Parse(sql)
...@@ -391,7 +379,6 @@ func TestIndexAdviseNoEnv(t *testing.T) { ...@@ -391,7 +379,6 @@ func TestIndexAdviseNoEnv(t *testing.T) {
func TestDuplicateKeyChecker(t *testing.T) { func TestDuplicateKeyChecker(t *testing.T) {
common.Log.Debug("Entering function: %s", common.GetFunctionName()) common.Log.Debug("Entering function: %s", common.GetFunctionName())
_, rEnv := env.BuildEnv()
rule := DuplicateKeyChecker(rEnv, "sakila") rule := DuplicateKeyChecker(rEnv, "sakila")
if len(rule) != 0 { if len(rule) != 0 {
t.Errorf("got rules: %s", pretty.Sprint(rule)) t.Errorf("got rules: %s", pretty.Sprint(rule))
...@@ -426,9 +413,6 @@ func TestIdxColsTypeCheck(t *testing.T) { ...@@ -426,9 +413,6 @@ func TestIdxColsTypeCheck(t *testing.T) {
`select city_id, city, country from city left outer join country using(country_id) WHERE city.city_id=59 and country.country='Algeria'`, `select city_id, city, country from city left outer join country using(country_id) WHERE city.city_id=59 and country.country='Algeria'`,
} }
vEnv, rEnv := env.BuildEnv()
defer vEnv.CleanUp()
for _, sql := range sqls { for _, sql := range sqls {
stmt, syntaxErr := sqlparser.Parse(sql) stmt, syntaxErr := sqlparser.Parse(sql)
if syntaxErr != nil { if syntaxErr != nil {
......
...@@ -17,14 +17,11 @@ ...@@ -17,14 +17,11 @@
package advisor package advisor
import ( import (
"flag"
"testing" "testing"
"github.com/XiaoMi/soar/common" "github.com/XiaoMi/soar/common"
) )
var update = flag.Bool("update", false, "update .golden files")
func TestListTestSQLs(t *testing.T) { func TestListTestSQLs(t *testing.T) {
err := common.GoldenDiff(func() { ListTestSQLs() }, t.Name(), update) err := common.GoldenDiff(func() { ListTestSQLs() }, t.Name(), update)
if nil != err { if nil != err {
......
...@@ -68,7 +68,7 @@ func main() { ...@@ -68,7 +68,7 @@ func main() {
// 当程序卡死的时候,或者由于某些原因程序没有退出,可以通过捕获信号量的形式让程序优雅退出并且清理测试环境 // 当程序卡死的时候,或者由于某些原因程序没有退出,可以通过捕获信号量的形式让程序优雅退出并且清理测试环境
common.HandleSignal(func() { common.HandleSignal(func() {
shutdown(vEnv) shutdown(vEnv, rEnv)
}) })
// 对指定的库表进行索引重复检查 // 对指定的库表进行索引重复检查
......
...@@ -67,52 +67,48 @@ func initConfig() { ...@@ -67,52 +67,48 @@ func initConfig() {
// if error found return non-zero, no error return zero // if error found return non-zero, no error return zero
func checkConfig() int { func checkConfig() int {
// TestDSN connection check // TestDSN connection check
testConn := &database.Connector{ connTest, err := database.NewConnector(common.Config.TestDSN)
Addr: common.Config.TestDSN.Addr, if err != nil {
User: common.Config.TestDSN.User, fmt.Println("test-dsn:", common.Config.TestDSN.Addr, err.Error())
Pass: common.Config.TestDSN.Password, return 1
Database: common.Config.TestDSN.Schema, }
Charset: common.Config.TestDSN.Charset, testVersion, err := connTest.Version()
}
testVersion, err := testConn.Version()
if err != nil && !common.Config.TestDSN.Disable { if err != nil && !common.Config.TestDSN.Disable {
fmt.Println("test-dsn:", testConn, err.Error()) fmt.Println("test-dsn:", connTest, err.Error())
return 1 return 1
} }
if common.Config.Verbose { if common.Config.Verbose {
if err == nil { if err == nil {
fmt.Println("test-dsn", testConn, "Version:", testVersion) fmt.Println("test-dsn", connTest, "Version:", testVersion)
} else { } else {
fmt.Println("test-dsn", common.Config.TestDSN) fmt.Println("test-dsn", common.Config.TestDSN)
} }
} }
if !testConn.HasAllPrivilege() { if !connTest.HasAllPrivilege() {
fmt.Printf("test-dsn: %s, need all privileges", common.FormatDSN(common.Config.TestDSN)) fmt.Printf("test-dsn: %s, need all privileges", common.FormatDSN(common.Config.TestDSN))
return 1 return 1
} }
// OnlineDSN connection check // OnlineDSN connection check
onlineConn := &database.Connector{ connOnline, err := database.NewConnector(common.Config.OnlineDSN)
Addr: common.Config.OnlineDSN.Addr, if err != nil {
User: common.Config.OnlineDSN.User, fmt.Println("test-dsn:", common.Config.OnlineDSN.Addr, err.Error())
Pass: common.Config.OnlineDSN.Password, return 1
Database: common.Config.OnlineDSN.Schema, }
Charset: common.Config.OnlineDSN.Charset, onlineVersion, err := connOnline.Version()
}
onlineVersion, err := onlineConn.Version()
if err != nil && !common.Config.OnlineDSN.Disable { if err != nil && !common.Config.OnlineDSN.Disable {
fmt.Println("online-dsn:", onlineConn, err.Error()) fmt.Println("online-dsn:", connOnline, err.Error())
return 1 return 1
} }
if common.Config.Verbose { if common.Config.Verbose {
if err == nil { if err == nil {
fmt.Println("online-dsn", onlineConn, "Version:", onlineVersion) fmt.Println("online-dsn", connOnline, "Version:", onlineVersion)
} else { } else {
fmt.Println("online-dsn", common.Config.OnlineDSN) fmt.Println("online-dsn", common.Config.OnlineDSN)
} }
} }
if !onlineConn.HasSelectPrivilege() { if !connOnline.HasSelectPrivilege() {
fmt.Printf("online-dsn: %s, need all privileges", common.FormatDSN(common.Config.OnlineDSN)) fmt.Printf("online-dsn: %s, need all privileges", common.FormatDSN(common.Config.OnlineDSN))
return 1 return 1
} }
...@@ -229,9 +225,11 @@ func initQuery(query string) string { ...@@ -229,9 +225,11 @@ func initQuery(query string) string {
return query return query
} }
func shutdown(vEnv *env.VirtualEnv) { func shutdown(vEnv *env.VirtualEnv, rEnv *database.Connector) {
if common.Config.DropTestTemporary { if common.Config.DropTestTemporary {
vEnv.CleanUp() vEnv.CleanUp()
} }
vEnv.Conn.Close()
rEnv.Conn.Close()
os.Exit(0) os.Exit(0)
} }
...@@ -48,8 +48,8 @@ var ( ...@@ -48,8 +48,8 @@ var (
// Configuration 配置文件定义结构体 // Configuration 配置文件定义结构体
type Configuration struct { type Configuration struct {
// +++++++++++++++测试环境+++++++++++++++++ // +++++++++++++++测试环境+++++++++++++++++
OnlineDSN *dsn `yaml:"online-dsn"` // 线上环境数据库配置 OnlineDSN *Dsn `yaml:"online-dsn"` // 线上环境数据库配置
TestDSN *dsn `yaml:"test-dsn"` // 测试环境数据库配置 TestDSN *Dsn `yaml:"test-dsn"` // 测试环境数据库配置
AllowOnlineAsTest bool `yaml:"allow-online-as-test"` // 允许 Online 环境也可以当作 Test 环境 AllowOnlineAsTest bool `yaml:"allow-online-as-test"` // 允许 Online 环境也可以当作 Test 环境
DropTestTemporary bool `yaml:"drop-test-temporary"` // 是否清理Test环境产生的临时库表 DropTestTemporary bool `yaml:"drop-test-temporary"` // 是否清理Test环境产生的临时库表
CleanupTestDatabase bool `yaml:"cleanup-test-database"` // 清理残余的测试数据库(程序异常退出或未开启drop-test-temporary) issue #48 CleanupTestDatabase bool `yaml:"cleanup-test-database"` // 清理残余的测试数据库(程序异常退出或未开启drop-test-temporary) issue #48
...@@ -132,14 +132,14 @@ type Configuration struct { ...@@ -132,14 +132,14 @@ type Configuration struct {
// Config 默认设置 // Config 默认设置
var Config = &Configuration{ var Config = &Configuration{
OnlineDSN: &dsn{ OnlineDSN: &Dsn{
Net: "tcp", Net: "tcp",
Schema: "information_schema", Schema: "information_schema",
Charset: "utf8mb4", Charset: "utf8mb4",
Disable: true, Disable: true,
Version: 99999, Version: 99999,
}, },
TestDSN: &dsn{ TestDSN: &Dsn{
Net: "tcp", Net: "tcp",
Schema: "information_schema", Schema: "information_schema",
Charset: "utf8mb4", Charset: "utf8mb4",
...@@ -224,7 +224,8 @@ var Config = &Configuration{ ...@@ -224,7 +224,8 @@ var Config = &Configuration{
MaxPrettySQLLength: 1024, MaxPrettySQLLength: 1024,
} }
type dsn struct { // Dsn Data source name
type Dsn struct {
Net string `yaml:"net"` Net string `yaml:"net"`
Addr string `yaml:"addr"` Addr string `yaml:"addr"`
Schema string `yaml:"schema"` Schema string `yaml:"schema"`
...@@ -243,7 +244,7 @@ type dsn struct { ...@@ -243,7 +244,7 @@ type dsn struct {
} }
// 解析命令行DSN输入 // 解析命令行DSN输入
func parseDSN(odbc string, d *dsn) *dsn { func parseDSN(odbc string, d *Dsn) *Dsn {
var addr, user, password, schema, charset string var addr, user, password, schema, charset string
if odbc == FormatDSN(d) { if odbc == FormatDSN(d) {
return d return d
...@@ -260,7 +261,7 @@ func parseDSN(odbc string, d *dsn) *dsn { ...@@ -260,7 +261,7 @@ func parseDSN(odbc string, d *dsn) *dsn {
// 设置为空表示禁用环境 // 设置为空表示禁用环境
odbc = strings.TrimSpace(odbc) odbc = strings.TrimSpace(odbc)
if odbc == "" { if odbc == "" {
return &dsn{Disable: true} return &Dsn{Disable: true}
} }
// username:password@ip:port/database // username:password@ip:port/database
...@@ -354,7 +355,7 @@ func parseDSN(odbc string, d *dsn) *dsn { ...@@ -354,7 +355,7 @@ func parseDSN(odbc string, d *dsn) *dsn {
charset = "utf8mb4" charset = "utf8mb4"
} }
dsn := &dsn{ dsn := &Dsn{
Addr: addr, Addr: addr,
User: user, User: user,
Password: password, Password: password,
...@@ -367,7 +368,7 @@ func parseDSN(odbc string, d *dsn) *dsn { ...@@ -367,7 +368,7 @@ func parseDSN(odbc string, d *dsn) *dsn {
} }
// FormatDSN 格式化打印DSN // FormatDSN 格式化打印DSN
func FormatDSN(env *dsn) string { func FormatDSN(env *Dsn) string {
if env == nil || env.Disable { if env == nil || env.Disable {
return "" return ""
} }
......
...@@ -19,7 +19,7 @@ package common ...@@ -19,7 +19,7 @@ package common
import "fmt" import "fmt"
func ExampleFormatDSN() { func ExampleFormatDSN() {
dsxExp := &dsn{ dsxExp := &Dsn{
Addr: "127.0.0.1:3306", Addr: "127.0.0.1:3306",
Schema: "mysql", Schema: "mysql",
User: "root", User: "root",
......
...@@ -38,7 +38,7 @@ func NewDB(db string) *DB { ...@@ -38,7 +38,7 @@ func NewDB(db string) *DB {
} }
} }
// TableName 含有表的属性 // Table 含有表的属性
type Table struct { type Table struct {
TableName string TableName string
TableAliases []string TableAliases []string
......
...@@ -556,14 +556,12 @@ func (db *Connector) explainAbleSQL(sql string) (string, error) { ...@@ -556,14 +556,12 @@ func (db *Connector) explainAbleSQL(sql string) (string, error) {
return "", nil return "", nil
} }
// 执行explain请求,返回mysql.Result执行结果 // explainQuery 生成可执行的 explain 查询请求
func (db *Connector) executeExplain(sql string, explainType int, formatType int) (QueryResult, error) { func (db *Connector) explainQuery(sql string, explainType int, formatType int) string {
var res QueryResult
var err error var err error
var explainQuery string
sql, err = db.explainAbleSQL(sql) sql, err = db.explainAbleSQL(sql)
if sql == "" { if sql == "" || err != nil {
return res, err return sql
} }
// 5.6以上支持 FORMAT=JSON // 5.6以上支持 FORMAT=JSON
...@@ -580,18 +578,17 @@ func (db *Connector) executeExplain(sql string, explainType int, formatType int) ...@@ -580,18 +578,17 @@ func (db *Connector) executeExplain(sql string, explainType int, formatType int)
case ExtendedExplainType: case ExtendedExplainType:
// 5.6以上extended关键字已经不推荐使用,8.0废弃了这个关键字 // 5.6以上extended关键字已经不推荐使用,8.0废弃了这个关键字
if common.Config.TestDSN.Version >= 50600 { if common.Config.TestDSN.Version >= 50600 {
explainQuery = fmt.Sprintf("explain %s", sql) sql = fmt.Sprintf("explain %s", sql)
} else { } else {
explainQuery = fmt.Sprintf("explain extended %s", sql) sql = fmt.Sprintf("explain extended %s", sql)
} }
case PartitionsExplainType: case PartitionsExplainType:
explainQuery = fmt.Sprintf("explain partitions %s", sql) sql = fmt.Sprintf("explain partitions %s", sql)
default: default:
explainQuery = fmt.Sprintf("explain %s %s", explainFormat, sql) sql = fmt.Sprintf("explain %s %s", explainFormat, sql)
} }
res, err = db.Query(explainQuery) return sql
return res, err
} }
// MySQLExplainWarnings WARNINGS信息中包含的优化器信息 // MySQLExplainWarnings WARNINGS信息中包含的优化器信息
...@@ -944,6 +941,7 @@ func ParseExplainResult(res QueryResult, formatType int) (exp *ExplainInfo, err ...@@ -944,6 +941,7 @@ func ParseExplainResult(res QueryResult, formatType int) (exp *ExplainInfo, err
err = res.Rows.Scan(&explainString) err = res.Rows.Scan(&explainString)
exp.ExplainJSON, err = parseJSONExplainText(explainString) exp.ExplainJSON, err = parseJSONExplainText(explainString)
} }
res.Rows.Close()
return exp, err return exp, err
} }
...@@ -992,23 +990,24 @@ func ParseExplainResult(res QueryResult, formatType int) (exp *ExplainInfo, err ...@@ -992,23 +990,24 @@ func ParseExplainResult(res QueryResult, formatType int) (exp *ExplainInfo, err
expRow.Scalability = ExplainScalability[expRow.AccessType] expRow.Scalability = ExplainScalability[expRow.AccessType]
explainRows = append(explainRows, expRow) explainRows = append(explainRows, expRow)
} }
res.Rows.Close()
exp.ExplainRows = explainRows exp.ExplainRows = explainRows
// check explain warning info // check explain warning info
if common.Config.ShowWarnings { if common.Config.ShowWarnings {
for res.Warning.Next() { for res.Warning.Next() {
var expWarning *ExplainWarning var expWarning *ExplainWarning
res.Warning.Scan( err = res.Warning.Scan(expWarning.Level, expWarning.Code, expWarning.Message)
expWarning.Level, if err != nil {
expWarning.Code, break
expWarning.Message, }
)
// 'EXTENDED' is deprecated and will be removed in a future release. // 'EXTENDED' is deprecated and will be removed in a future release.
if expWarning.Code != 1681 { if expWarning.Code != 1681 {
exp.Warnings = append(exp.Warnings, expWarning) exp.Warnings = append(exp.Warnings, expWarning)
} }
} }
res.Warning.Close()
} }
// 添加 last_query_cost // 添加 last_query_cost
...@@ -1022,6 +1021,7 @@ func (db *Connector) Explain(sql string, explainType int, formatType int) (exp * ...@@ -1022,6 +1021,7 @@ func (db *Connector) Explain(sql string, explainType int, formatType int) (exp *
if explainType != TraditionalExplainType { if explainType != TraditionalExplainType {
formatType = TraditionalFormatExplain formatType = TraditionalFormatExplain
} }
defer func() { defer func() {
if e := recover(); e != nil { if e := recover(); e != nil {
const size = 4096 const size = 4096
...@@ -1033,10 +1033,8 @@ func (db *Connector) Explain(sql string, explainType int, formatType int) (exp * ...@@ -1033,10 +1033,8 @@ func (db *Connector) Explain(sql string, explainType int, formatType int) (exp *
}() }()
// 执行EXPLAIN请求 // 执行EXPLAIN请求
res, err := db.executeExplain(sql, explainType, formatType) sql = db.explainQuery(sql, explainType, formatType)
if err != nil { res, err := db.Query(sql)
return exp, err
}
// 解析mysql结果,输出ExplainInfo // 解析mysql结果,输出ExplainInfo
exp, err = ParseExplainResult(res, formatType) exp, err = ParseExplainResult(res, formatType)
......
...@@ -39,6 +39,7 @@ type Connector struct { ...@@ -39,6 +39,7 @@ type Connector struct {
Database string Database string
Charset string Charset string
Net string Net string
Conn *sql.DB
} }
// QueryResult 数据库查询返回值 // QueryResult 数据库查询返回值
...@@ -49,20 +50,38 @@ type QueryResult struct { ...@@ -49,20 +50,38 @@ type QueryResult struct {
QueryCost float64 QueryCost float64
} }
// NewConnection 创建新连接 // NewConnector 创建新连接
func (db *Connector) NewConnection() (*sql.DB, error) { func NewConnector(dsn *common.Dsn) (*Connector, error) {
dsn := fmt.Sprintf("%s:%s@%s(%s)/%s?parseTime=true&charset=%s", db.User, db.Pass, db.Net, db.Addr, db.Database, db.Charset) conn, err := sql.Open("mysql", fmt.Sprintf("%s:%s@%s(%s)/%s?parseTime=true&charset=%s",
return sql.Open("mysql", dsn) dsn.User,
dsn.Password,
dsn.Net,
dsn.Addr,
dsn.Schema,
dsn.Charset,
))
if err != nil {
return nil, err
}
connector := &Connector{
Addr: dsn.Addr,
User: dsn.User,
Pass: dsn.Password,
Database: dsn.Schema,
Charset: dsn.Charset,
Conn: conn,
}
return connector, err
} }
// Query 执行SQL // Query 执行SQL
func (db *Connector) Query(sql string, params ...interface{}) (QueryResult, error) { func (db *Connector) Query(sql string, params ...interface{}) (QueryResult, error) {
var res QueryResult var res QueryResult
var err error
// 测试环境如果检查是关闭的,则SQL不会被执行 // 测试环境如果检查是关闭的,则SQL不会被执行
if common.Config.TestDSN.Disable { if common.Config.TestDSN.Disable {
return res, errors.New("dsn is disable") return res, errors.New("dsn is disable")
} }
// 数据库安全性检查:如果 Connector 的 IP 端口与 TEST 环境不一致,则启用SQL白名单 // 数据库安全性检查:如果 Connector 的 IP 端口与 TEST 环境不一致,则启用SQL白名单
// 不在白名单中的SQL不允许执行 // 不在白名单中的SQL不允许执行
// 执行环境与test环境不相同 // 执行环境与test环境不相同
...@@ -72,23 +91,25 @@ func (db *Connector) Query(sql string, params ...interface{}) (QueryResult, erro ...@@ -72,23 +91,25 @@ func (db *Connector) Query(sql string, params ...interface{}) (QueryResult, erro
} }
common.Log.Debug("Execute SQL with DSN(%s/%s) : %s", db.Addr, db.Database, fmt.Sprintf(sql, params...)) common.Log.Debug("Execute SQL with DSN(%s/%s) : %s", db.Addr, db.Database, fmt.Sprintf(sql, params...))
conn, err := db.NewConnection() _, err = db.Conn.Exec("USE " + db.Database)
defer conn.Close() common.LogIfError(err, "")
if err != nil { res.Rows, res.Error = db.Conn.Query(sql, params...)
return res, err
}
res.Rows, res.Error = conn.Query(sql, params...)
if common.Config.ShowWarnings { if common.Config.ShowWarnings {
res.Warning, err = conn.Query("SHOW WARNINGS") res.Warning, err = db.Conn.Query("SHOW WARNINGS")
common.LogIfError(err, "")
} }
// SHOW WARNINGS 并不会影响 last_query_cost // SHOW WARNINGS 并不会影响 last_query_cost
if common.Config.ShowLastQueryCost { if common.Config.ShowLastQueryCost {
cost, err := conn.Query("SHOW SESSION STATUS LIKE 'last_query_cost'") cost, err := db.Conn.Query("SHOW SESSION STATUS LIKE 'last_query_cost'")
if err == nil { if err == nil {
if cost.Next() { if cost.Next() {
err = cost.Scan(res.QueryCost) err = cost.Scan(res.QueryCost)
common.LogIfError(err, "")
}
if err := cost.Close(); err != nil {
common.Log.Error(err.Error())
} }
} }
} }
...@@ -115,6 +136,9 @@ func (db *Connector) Version() (int, error) { ...@@ -115,6 +136,9 @@ func (db *Connector) Version() (int, error) {
var versionSeg []string var versionSeg []string
for res.Rows.Next() { for res.Rows.Next() {
err = res.Rows.Scan(&versionStr) err = res.Rows.Scan(&versionStr)
if err != nil {
break
}
versionStr = strings.Split(versionStr, "-")[0] versionStr = strings.Split(versionStr, "-")[0]
versionSeg = strings.Split(versionStr, ".") versionSeg = strings.Split(versionStr, ".")
if len(versionSeg) == 3 { if len(versionSeg) == 3 {
...@@ -123,7 +147,9 @@ func (db *Connector) Version() (int, error) { ...@@ -123,7 +147,9 @@ func (db *Connector) Version() (int, error) {
} }
break break
} }
if err := res.Rows.Close(); err != nil {
common.Log.Error(err.Error())
}
return version, err return version, err
} }
...@@ -140,6 +166,9 @@ func (db *Connector) SingleIntValue(option string) (int, error) { ...@@ -140,6 +166,9 @@ func (db *Connector) SingleIntValue(option string) (int, error) {
if res.Rows.Next() { if res.Rows.Next() {
err = res.Rows.Scan(&intVal) err = res.Rows.Scan(&intVal)
} }
if err := res.Rows.Close(); err != nil {
common.Log.Error(err.Error())
}
return intVal, err return intVal, err
} }
...@@ -188,6 +217,7 @@ func (db *Connector) ColumnCardinality(tb, col string) float64 { ...@@ -188,6 +217,7 @@ func (db *Connector) ColumnCardinality(tb, col string) float64 {
return 0 return 0
} }
} }
res.Rows.Close()
// 当table status元数据不准确时 rowTotal 可能远小于count(*),导致散粒度大于1 // 当table status元数据不准确时 rowTotal 可能远小于count(*),导致散粒度大于1
if colNum > float64(rowTotal) { if colNum > float64(rowTotal) {
......
...@@ -28,31 +28,23 @@ import ( ...@@ -28,31 +28,23 @@ import (
) )
var connTest *Connector var connTest *Connector
var update = flag.Bool("update", false, "update .golden files") var update = flag.Bool("update", false, "update .golden files")
func init() { func init() {
common.BaseDir = common.DevPath common.BaseDir = common.DevPath
common.ParseConfig("") err := common.ParseConfig("")
connTest = &Connector{ common.LogIfError(err, "init ParseConfig")
Addr: common.Config.OnlineDSN.Addr, common.Log.Debug("mysql_test init")
User: common.Config.OnlineDSN.User, connTest, err = NewConnector(common.Config.TestDSN)
Pass: common.Config.OnlineDSN.Password, if err != nil {
Database: common.Config.OnlineDSN.Schema,
Charset: common.Config.OnlineDSN.Charset,
}
if _, err := connTest.Version(); err != nil {
common.Log.Critical("Test env Error: %v", err) common.Log.Critical("Test env Error: %v", err)
os.Exit(0) os.Exit(0)
} }
}
func TestNewConnection(t *testing.T) { if _, err := connTest.Version(); err != nil {
conn, err := connTest.NewConnection() common.Log.Critical("Test env Error: %v", err)
if err != nil { os.Exit(0)
t.Errorf("TestNewConnection, Error: %s", err.Error())
} }
defer conn.Close()
} }
func TestQuery(t *testing.T) { func TestQuery(t *testing.T) {
...@@ -70,6 +62,7 @@ func TestQuery(t *testing.T) { ...@@ -70,6 +62,7 @@ func TestQuery(t *testing.T) {
t.Error("should return 0") t.Error("should return 0")
} }
} }
res.Rows.Close()
// TODO: timeout test // TODO: timeout test
} }
...@@ -115,6 +108,7 @@ func TestWarningsAndQueryCost(t *testing.T) { ...@@ -115,6 +108,7 @@ func TestWarningsAndQueryCost(t *testing.T) {
} }
pretty.Println(str) pretty.Println(str)
} }
res.Warning.Close()
fmt.Println(res.QueryCost, err) fmt.Println(res.QueryCost, err)
} }
} }
......
...@@ -26,28 +26,31 @@ import ( ...@@ -26,28 +26,31 @@ import (
// CurrentUser get current user with current_user() function // CurrentUser get current user with current_user() function
func (db *Connector) CurrentUser() (string, string, error) { func (db *Connector) CurrentUser() (string, string, error) {
var user, host string
res, err := db.Query("select current_user()") res, err := db.Query("select current_user()")
if err != nil { if err != nil {
return "", "", err return user, host, err
} }
if res.Rows.Next() { if res.Rows.Next() {
var currentUser string var currentUser string
err = res.Rows.Scan(&currentUser) err = res.Rows.Scan(&currentUser)
if err != nil { if err != nil {
return "", "", err return user, host, err
} }
res.Rows.Close()
cols := strings.Split(currentUser, "@") cols := strings.Split(currentUser, "@")
if len(cols) == 2 { if len(cols) == 2 {
user := strings.Trim(cols[0], "'") user = strings.Trim(cols[0], "'")
host := strings.Trim(cols[1], "'") host = strings.Trim(cols[1], "'")
if strings.Contains(user, "'") || strings.Contains(host, "'") { if strings.Contains(user, "'") || strings.Contains(host, "'") {
return "", "", errors.New("user or host contains irregular character") return "", "", errors.New("user or host contains irregular character")
} }
return user, host, nil return user, host, nil
} }
return "", "", errors.New("user or host contains irregular character") return user, host, errors.New("user or host contains irregular character")
} }
return "", "", errors.New("no privilege info") return user, host, errors.New("no privilege info")
} }
// HasSelectPrivilege if user has select privilege // HasSelectPrivilege if user has select privilege
...@@ -70,6 +73,8 @@ func (db *Connector) HasSelectPrivilege() bool { ...@@ -70,6 +73,8 @@ func (db *Connector) HasSelectPrivilege() bool {
common.Log.Error("HasSelectPrivilege, Scan Error: %s", err.Error()) common.Log.Error("HasSelectPrivilege, Scan Error: %s", err.Error())
return false return false
} }
res.Rows.Close()
if selectPrivilege == "Y" { if selectPrivilege == "Y" {
return true return true
} }
...@@ -99,6 +104,7 @@ func (db *Connector) HasAllPrivilege() bool { ...@@ -99,6 +104,7 @@ func (db *Connector) HasAllPrivilege() bool {
common.Log.Error("HasAllPrivilege, DSN: %s, Scan error", db.Addr) common.Log.Error("HasAllPrivilege, DSN: %s, Scan error", db.Addr)
return false return false
} }
res.Rows.Close()
} }
// get all privilege status // get all privilege status
...@@ -115,6 +121,7 @@ func (db *Connector) HasAllPrivilege() bool { ...@@ -115,6 +121,7 @@ func (db *Connector) HasAllPrivilege() bool {
common.Log.Error("HasAllPrivilege, DSN: %s, Scan error", db.Addr) common.Log.Error("HasAllPrivilege, DSN: %s, Scan error", db.Addr)
return false return false
} }
res.Rows.Close()
if strings.Replace(priv, "Y", "", -1) == "" { if strings.Replace(priv, "Y", "", -1) == "" {
return true return true
} }
......
...@@ -62,15 +62,9 @@ func (db *Connector) Profiling(sql string, params ...interface{}) ([]ProfilingRo ...@@ -62,15 +62,9 @@ func (db *Connector) Profiling(sql string, params ...interface{}) ([]ProfilingRo
} }
common.Log.Debug("Execute SQL with DSN(%s/%s) : %s", db.Addr, db.Database, sql) common.Log.Debug("Execute SQL with DSN(%s/%s) : %s", db.Addr, db.Database, sql)
conn, err := db.NewConnection()
if err != nil {
return rows, err
}
defer conn.Close()
// Keep connection // Keep connection
// https://github.com/go-sql-driver/mysql/issues/208 // https://github.com/go-sql-driver/mysql/issues/208
trx, err := conn.Begin() trx, err := db.Conn.Begin()
if err != nil { if err != nil {
return rows, err return rows, err
} }
...@@ -91,14 +85,20 @@ func (db *Connector) Profiling(sql string, params ...interface{}) ([]ProfilingRo ...@@ -91,14 +85,20 @@ func (db *Connector) Profiling(sql string, params ...interface{}) ([]ProfilingRo
// 返回 Profiling 结果 // 返回 Profiling 结果
res, err := trx.Query("show profile") res, err := trx.Query("show profile")
if err != nil {
trx.Rollback()
return rows, err
}
var profileRow ProfilingRow
for res.Next() { for res.Next() {
var profileRow ProfilingRow err = res.Scan(&profileRow.Status, &profileRow.Duration)
err := res.Scan(&profileRow.Status, &profileRow.Duration)
if err != nil { if err != nil {
common.LogIfError(err, "") common.LogIfError(err, "")
break
} }
rows = append(rows, profileRow) rows = append(rows, profileRow)
} }
res.Close()
// 关闭 Profiling // 关闭 Profiling
_, err = trx.Query("set @@profiling=0") _, err = trx.Query("set @@profiling=0")
......
...@@ -19,6 +19,7 @@ package database ...@@ -19,6 +19,7 @@ package database
import ( import (
"database/sql" "database/sql"
"fmt" "fmt"
"github.com/XiaoMi/soar/common" "github.com/XiaoMi/soar/common"
) )
...@@ -43,7 +44,7 @@ import ( ...@@ -43,7 +44,7 @@ import (
*/ */
// SamplingData 将数据从Remote拉取到 db 中 // SamplingData 将数据从Remote拉取到 db 中
func (db *Connector) SamplingData(remote Connector, tables ...string) error { func (db *Connector) SamplingData(remote *Connector, tables ...string) error {
// 计算需要泵取的数据量 // 计算需要泵取的数据量
wantRowsCount := 300 * common.Config.SamplingStatisticTarget wantRowsCount := 300 * common.Config.SamplingStatisticTarget
...@@ -51,19 +52,6 @@ func (db *Connector) SamplingData(remote Connector, tables ...string) error { ...@@ -51,19 +52,6 @@ func (db *Connector) SamplingData(remote Connector, tables ...string) error {
// 该数值越大,在内存中缓存的data就越多,但相对的,插入时速度就越快 // 该数值越大,在内存中缓存的data就越多,但相对的,插入时速度就越快
maxValCount := 200 maxValCount := 200
// 获取数据库连接对象
conn, err := remote.NewConnection()
if err != nil {
return err
}
defer conn.Close()
localConn, err := db.NewConnection()
if err != nil {
return err
}
defer localConn.Close()
for _, table := range tables { for _, table := range tables {
// 表类型检查 // 表类型检查
if remote.IsView(table) { if remote.IsView(table) {
...@@ -89,7 +77,7 @@ func (db *Connector) SamplingData(remote Connector, tables ...string) error { ...@@ -89,7 +77,7 @@ func (db *Connector) SamplingData(remote Connector, tables ...string) error {
factor := float64(wantRowsCount) / float64(tableRows) factor := float64(wantRowsCount) / float64(tableRows)
common.Log.Debug("SamplingData, tableRows: %d, wantRowsCount: %d, factor: %f", tableRows, wantRowsCount, factor) common.Log.Debug("SamplingData, tableRows: %d, wantRowsCount: %d, factor: %f", tableRows, wantRowsCount, factor)
err = startSampling(conn, localConn, db.Database, table, factor, wantRowsCount, maxValCount) err = startSampling(remote.Conn, db.Conn, db.Database, table, factor, wantRowsCount, maxValCount)
if err != nil { if err != nil {
common.Log.Error("(db *Connector) SamplingData Error : %v", err) common.Log.Error("(db *Connector) SamplingData Error : %v", err)
} }
......
...@@ -27,27 +27,12 @@ func init() { ...@@ -27,27 +27,12 @@ func init() {
} }
func TestSamplingData(t *testing.T) { func TestSamplingData(t *testing.T) {
online := &Connector{ connOnline, err := NewConnector(common.Config.OnlineDSN)
Addr: common.Config.OnlineDSN.Addr, if err != nil {
User: common.Config.OnlineDSN.User, t.Error(err)
Pass: common.Config.OnlineDSN.Password,
Database: common.Config.OnlineDSN.Schema,
Charset: common.Config.OnlineDSN.Charset,
Net: common.Config.OnlineDSN.Net,
}
offline := &Connector{
Addr: common.Config.TestDSN.Addr,
User: common.Config.TestDSN.User,
Pass: common.Config.TestDSN.Password,
Database: common.Config.TestDSN.Schema,
Charset: common.Config.TestDSN.Charset,
Net: common.Config.TestDSN.Net,
} }
offline.Database = "test" err = connTest.SamplingData(connOnline, "film")
err := connTest.SamplingData(*online, "film")
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} }
......
...@@ -99,10 +99,11 @@ func (db *Connector) ShowTables() ([]string, error) { ...@@ -99,10 +99,11 @@ func (db *Connector) ShowTables() ([]string, error) {
var table string var table string
err = res.Rows.Scan(&table) err = res.Rows.Scan(&table)
if err != nil { if err != nil {
return []string{}, err break
} }
tables = append(tables, table) tables = append(tables, table)
} }
res.Rows.Close()
return tables, err return tables, err
} }
...@@ -156,6 +157,7 @@ func (db *Connector) ShowTableStatus(tableName string) (*TableStatInfo, error) { ...@@ -156,6 +157,7 @@ func (db *Connector) ShowTableStatus(tableName string) (*TableStatInfo, error) {
res.Rows.Scan(statusFields...) res.Rows.Scan(statusFields...)
tbStatus.Rows = append(tbStatus.Rows, ts) tbStatus.Rows = append(tbStatus.Rows, ts)
} }
res.Rows.Close()
return tbStatus, err return tbStatus, err
} }
...@@ -244,6 +246,7 @@ func (db *Connector) ShowIndex(tableName string) (*TableIndexInfo, error) { ...@@ -244,6 +246,7 @@ func (db *Connector) ShowIndex(tableName string) (*TableIndexInfo, error) {
res.Rows.Scan(indexFields...) res.Rows.Scan(indexFields...)
tbIndex.Rows = append(tbIndex.Rows, ti) tbIndex.Rows = append(tbIndex.Rows, ti)
} }
res.Rows.Close()
return tbIndex, err return tbIndex, err
} }
...@@ -374,6 +377,7 @@ func (db *Connector) ShowColumns(tableName string) (*TableDesc, error) { ...@@ -374,6 +377,7 @@ func (db *Connector) ShowColumns(tableName string) (*TableDesc, error) {
res.Rows.Scan(columnFields...) res.Rows.Scan(columnFields...)
tbDesc.DescValues = append(tbDesc.DescValues, tc) tbDesc.DescValues = append(tbDesc.DescValues, tc)
} }
res.Rows.Close()
return tbDesc, err return tbDesc, err
} }
...@@ -428,7 +432,7 @@ func (db *Connector) showCreate(createType, name string) (string, error) { ...@@ -428,7 +432,7 @@ func (db *Connector) showCreate(createType, name string) (string, error) {
for res.Rows.Next() { for res.Rows.Next() {
res.Rows.Scan(createFields...) res.Rows.Scan(createFields...)
} }
res.Rows.Close()
return create, err return create, err
} }
...@@ -508,13 +512,10 @@ func (db *Connector) FindColumn(name, dbName string, tables ...string) ([]*commo ...@@ -508,13 +512,10 @@ func (db *Connector) FindColumn(name, dbName string, tables ...string) ([]*commo
var col common.Column var col common.Column
for res.Rows.Next() { for res.Rows.Next() {
var character, collation []byte var character, collation []byte
res.Rows.Scan( err = res.Rows.Scan(&col.Table, &col.DB, &col.DataType, &character, &collation)
&col.Table, if err != nil {
&col.DB, break
&col.DataType, }
&character,
&collation,
)
col.Name = name col.Name = name
col.Character = string(character) col.Character = string(character)
col.Collation = string(collation) col.Collation = string(collation)
...@@ -537,8 +538,12 @@ func (db *Connector) FindColumn(name, dbName string, tables ...string) ([]*commo ...@@ -537,8 +538,12 @@ func (db *Connector) FindColumn(name, dbName string, tables ...string) ([]*commo
var tbCollation []byte var tbCollation []byte
if newRes.Rows.Next() { if newRes.Rows.Next() {
newRes.Rows.Scan(&tbCollation) err = newRes.Rows.Scan(&tbCollation)
if err != nil {
break
}
} }
newRes.Rows.Close()
if string(tbCollation) != "" { if string(tbCollation) != "" {
col.Character = strings.Split(string(tbCollation), "_")[0] col.Character = strings.Split(string(tbCollation), "_")[0]
col.Collation = string(tbCollation) col.Collation = string(tbCollation)
...@@ -546,6 +551,7 @@ func (db *Connector) FindColumn(name, dbName string, tables ...string) ([]*commo ...@@ -546,6 +551,7 @@ func (db *Connector) FindColumn(name, dbName string, tables ...string) ([]*commo
} }
columns = append(columns, &col) columns = append(columns, &col)
} }
res.Rows.Close()
return columns, err return columns, err
} }
...@@ -603,13 +609,12 @@ func (db *Connector) ShowReference(dbName string, tbName ...string) ([]Reference ...@@ -603,13 +609,12 @@ func (db *Connector) ShowReference(dbName string, tbName ...string) ([]Reference
// 获取值 // 获取值
for res.Rows.Next() { for res.Rows.Next() {
var rv ReferenceValue var rv ReferenceValue
res.Rows.Scan(&rv.ReferencedTableSchema, err = res.Rows.Scan(&rv.ReferencedTableSchema, &rv.ReferencedTableName, &rv.TableSchema, &rv.TableName, &rv.ConstraintName)
&rv.ReferencedTableName, if err != nil {
&rv.TableSchema, break
&rv.TableName, }
&rv.ConstraintName)
referenceValues = append(referenceValues, rv) referenceValues = append(referenceValues, rv)
} }
res.Rows.Close()
return referenceValues, err return referenceValues, err
} }
...@@ -68,6 +68,15 @@ func TestShowTables(t *testing.T) { ...@@ -68,6 +68,15 @@ func TestShowTables(t *testing.T) {
connTest.Database = orgDatabase connTest.Database = orgDatabase
} }
func TestShowCreateDatabase(t *testing.T) {
err := common.GoldenDiff(func() {
fmt.Println(connTest.ShowCreateDatabase("sakila"))
}, t.Name(), update)
if err != nil {
t.Error(err)
}
}
func TestShowCreateTable(t *testing.T) { func TestShowCreateTable(t *testing.T) {
orgDatabase := connTest.Database orgDatabase := connTest.Database
connTest.Database = "sakila" connTest.Database = "sakila"
...@@ -87,7 +96,6 @@ func TestShowCreateTable(t *testing.T) { ...@@ -87,7 +96,6 @@ func TestShowCreateTable(t *testing.T) {
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} }
connTest.Database = orgDatabase connTest.Database = orgDatabase
} }
......
CREATE DATABASE `sakila` /*!40100 DEFAULT CHARACTER SET utf8mb4 COLLATE utf8mb4_0900_ai_ci */ <nil>
...@@ -19,10 +19,11 @@ package database ...@@ -19,10 +19,11 @@ package database
import ( import (
"errors" "errors"
"fmt" "fmt"
"github.com/XiaoMi/soar/common"
"regexp" "regexp"
"strings" "strings"
"github.com/XiaoMi/soar/common"
"vitess.io/vitess/go/vt/sqlparser" "vitess.io/vitess/go/vt/sqlparser"
) )
...@@ -70,15 +71,9 @@ func (db *Connector) Trace(sql string, params ...interface{}) ([]TraceRow, error ...@@ -70,15 +71,9 @@ func (db *Connector) Trace(sql string, params ...interface{}) ([]TraceRow, error
} }
common.Log.Debug("Execute SQL with DSN(%s/%s) : %s", db.Addr, db.Database, sql) common.Log.Debug("Execute SQL with DSN(%s/%s) : %s", db.Addr, db.Database, sql)
conn, err := db.NewConnection()
if err != nil {
return rows, err
}
defer conn.Close()
// 开启Trace // 开启Trace
common.Log.Debug("SET SESSION OPTIMIZER_TRACE='enabled=on'") common.Log.Debug("SET SESSION OPTIMIZER_TRACE='enabled=on'")
trx, err := conn.Begin() trx, err := db.Conn.Begin()
if err != nil { if err != nil {
return rows, err return rows, err
} }
...@@ -97,14 +92,20 @@ func (db *Connector) Trace(sql string, params ...interface{}) ([]TraceRow, error ...@@ -97,14 +92,20 @@ func (db *Connector) Trace(sql string, params ...interface{}) ([]TraceRow, error
// 返回Trace结果 // 返回Trace结果
res, err := trx.Query("SELECT * FROM information_schema.OPTIMIZER_TRACE") res, err := trx.Query("SELECT * FROM information_schema.OPTIMIZER_TRACE")
if err != nil {
trx.Rollback()
return rows, err
}
for res.Next() { for res.Next() {
var traceRow TraceRow var traceRow TraceRow
err = res.Scan(&traceRow.Query, &traceRow.Trace, &traceRow.MissingBytesBeyondMaxMemSize, &traceRow.InsufficientPrivileges) err = res.Scan(&traceRow.Query, &traceRow.Trace, &traceRow.MissingBytesBeyondMaxMemSize, &traceRow.InsufficientPrivileges)
if err != nil { if err != nil {
common.LogIfError(err, "") common.LogIfError(err, "")
break
} }
rows = append(rows, traceRow) rows = append(rows, traceRow)
} }
res.Close()
// 关闭Trace // 关闭Trace
common.Log.Debug("SET SESSION OPTIMIZER_TRACE='enabled=off'") common.Log.Debug("SET SESSION OPTIMIZER_TRACE='enabled=off'")
......
...@@ -57,14 +57,10 @@ func NewVirtualEnv(vEnv *database.Connector) *VirtualEnv { ...@@ -57,14 +57,10 @@ func NewVirtualEnv(vEnv *database.Connector) *VirtualEnv {
// @output *VirtualEnv 测试环境 // @output *VirtualEnv 测试环境
// @output *database.Connector 线上环境连接句柄 // @output *database.Connector 线上环境连接句柄
func BuildEnv() (*VirtualEnv, *database.Connector) { func BuildEnv() (*VirtualEnv, *database.Connector) {
connTest, err := database.NewConnector(common.Config.TestDSN)
common.LogIfError(err, "")
// 生成测试环境 // 生成测试环境
vEnv := NewVirtualEnv(&database.Connector{ vEnv := NewVirtualEnv(connTest)
Addr: common.Config.TestDSN.Addr,
User: common.Config.TestDSN.User,
Pass: common.Config.TestDSN.Password,
Database: common.Config.TestDSN.Schema,
Charset: common.Config.TestDSN.Charset,
})
// 检查测试环境可用性,并记录数据库版本 // 检查测试环境可用性,并记录数据库版本
vEnvVersion, err := vEnv.Version() vEnvVersion, err := vEnv.Version()
...@@ -82,13 +78,8 @@ func BuildEnv() (*VirtualEnv, *database.Connector) { ...@@ -82,13 +78,8 @@ func BuildEnv() (*VirtualEnv, *database.Connector) {
vEnv.User, vEnv.Addr, vEnv.Database) vEnv.User, vEnv.Addr, vEnv.Database)
common.Config.OnlineDSN = common.Config.TestDSN common.Config.OnlineDSN = common.Config.TestDSN
} }
conn := &database.Connector{ connOnline, err := database.NewConnector(common.Config.OnlineDSN)
Addr: common.Config.OnlineDSN.Addr, common.LogIfError(err, "")
User: common.Config.OnlineDSN.User,
Pass: common.Config.OnlineDSN.Password,
Database: common.Config.OnlineDSN.Schema,
Charset: common.Config.OnlineDSN.Charset,
}
// 检查线上环境可用性版本 // 检查线上环境可用性版本
rEnvVersion, err := vEnv.Version() rEnvVersion, err := vEnv.Version()
...@@ -114,7 +105,7 @@ func BuildEnv() (*VirtualEnv, *database.Connector) { ...@@ -114,7 +105,7 @@ func BuildEnv() (*VirtualEnv, *database.Connector) {
common.Config.TestDSN.Disable = true common.Config.TestDSN.Disable = true
} }
return vEnv, conn return vEnv, connOnline
} }
// RealDB 从测试环境中获取通过hash后的DB // RealDB 从测试环境中获取通过hash后的DB
...@@ -163,7 +154,10 @@ func (ve *VirtualEnv) CleanupTestDatabase() { ...@@ -163,7 +154,10 @@ func (ve *VirtualEnv) CleanupTestDatabase() {
minHour := 1 minHour := 1
for dbs.Rows.Next() { for dbs.Rows.Next() {
var testDatabase string var testDatabase string
dbs.Rows.Scan(&testDatabase) err = dbs.Rows.Scan(&testDatabase)
if err != nil {
break
}
// test temporary database format `optimizer_YYMMDDHHmmss_randomString(16)` // test temporary database format `optimizer_YYMMDDHHmmss_randomString(16)`
if len(testDatabase) != 39 { if len(testDatabase) != 39 {
common.Log.Debug("CleanupTestDatabase by pass %s", testDatabase) common.Log.Debug("CleanupTestDatabase by pass %s", testDatabase)
...@@ -187,7 +181,8 @@ func (ve *VirtualEnv) CleanupTestDatabase() { ...@@ -187,7 +181,8 @@ func (ve *VirtualEnv) CleanupTestDatabase() {
} }
common.Log.Debug("CleanupTestDatabase by pass database %s, %.2f less than %d hours", testDatabase, subHour, minHour) common.Log.Debug("CleanupTestDatabase by pass database %s, %.2f less than %d hours", testDatabase, subHour, minHour)
} }
err = dbs.Rows.Close()
common.LogIfError(err, "")
common.Log.Debug("CleanupTestDatabase done") common.Log.Debug("CleanupTestDatabase done")
} }
...@@ -200,7 +195,7 @@ func (ve *VirtualEnv) BuildVirtualEnv(rEnv *database.Connector, SQLs ...string) ...@@ -200,7 +195,7 @@ func (ve *VirtualEnv) BuildVirtualEnv(rEnv *database.Connector, SQLs ...string)
// 置空错误信息 // 置空错误信息
ve.Error = nil ve.Error = nil
// 检测是否已经创建初始数据库,如果未创建则创建一个名称hash过的映射数据库 // 检测是否已经创建初始数据库,如果未创建则创建一个名称hash过的映射数据库
err = ve.createDatabase(*rEnv, rEnv.Database) err = ve.createDatabase(rEnv, rEnv.Database)
common.LogIfWarn(err, "") common.LogIfWarn(err, "")
// 测试环境检测 // 测试环境检测
...@@ -237,7 +232,7 @@ func (ve *VirtualEnv) BuildVirtualEnv(rEnv *database.Connector, SQLs ...string) ...@@ -237,7 +232,7 @@ func (ve *VirtualEnv) BuildVirtualEnv(rEnv *database.Connector, SQLs ...string)
rEnv.Database = stmt.DBName.String() rEnv.Database = stmt.DBName.String()
// use DB 后检查 DB是否已经创建,如果没有创建则创建DB // use DB 后检查 DB是否已经创建,如果没有创建则创建DB
err = ve.createDatabase(*rEnv, rEnv.Database) err = ve.createDatabase(rEnv, rEnv.Database)
common.LogIfWarn(err, "") common.LogIfWarn(err, "")
} }
return true return true
...@@ -271,7 +266,7 @@ func (ve *VirtualEnv) BuildVirtualEnv(rEnv *database.Connector, SQLs ...string) ...@@ -271,7 +266,7 @@ func (ve *VirtualEnv) BuildVirtualEnv(rEnv *database.Connector, SQLs ...string)
// 拉取表结构 // 拉取表结构
table := stmt.Table.Name.String() table := stmt.Table.Name.String()
if table != "" { if table != "" {
err = ve.createTable(*rEnv, rEnv.Database, table) err = ve.createTable(rEnv, rEnv.Database, table)
// 这里如果报错可能有两种可能: // 这里如果报错可能有两种可能:
// 1. SQL 是 Create 语句,线上环境并没有相关的库表结构 // 1. SQL 是 Create 语句,线上环境并没有相关的库表结构
// 2. 在测试环境中执行 SQL 报错 // 2. 在测试环境中执行 SQL 报错
...@@ -303,7 +298,7 @@ func (ve *VirtualEnv) BuildVirtualEnv(rEnv *database.Connector, SQLs ...string) ...@@ -303,7 +298,7 @@ func (ve *VirtualEnv) BuildVirtualEnv(rEnv *database.Connector, SQLs ...string)
if db == "" { if db == "" {
db = rEnv.Database db = rEnv.Database
} }
tmpEnv := *rEnv tmpEnv := rEnv
tmpEnv.Database = db tmpEnv.Database = db
// 创建数据库环境 // 创建数据库环境
...@@ -336,7 +331,7 @@ func (ve *VirtualEnv) BuildVirtualEnv(rEnv *database.Connector, SQLs ...string) ...@@ -336,7 +331,7 @@ func (ve *VirtualEnv) BuildVirtualEnv(rEnv *database.Connector, SQLs ...string)
return false return false
} }
viewDDL = viewDDL[startIdx+2:] viewDDL = viewDDL[startIdx+2:]
if !ve.BuildVirtualEnv(&tmpEnv, viewDDL) { if !ve.BuildVirtualEnv(tmpEnv, viewDDL) {
return false return false
} }
} }
...@@ -352,7 +347,7 @@ func (ve *VirtualEnv) BuildVirtualEnv(rEnv *database.Connector, SQLs ...string) ...@@ -352,7 +347,7 @@ func (ve *VirtualEnv) BuildVirtualEnv(rEnv *database.Connector, SQLs ...string)
return true return true
} }
func (ve VirtualEnv) createDatabase(rEnv database.Connector, dbName string) error { func (ve VirtualEnv) createDatabase(rEnv *database.Connector, dbName string) error {
// 生成映射关系 // 生成映射关系
if _, ok := ve.DBRef[dbName]; ok { if _, ok := ve.DBRef[dbName]; ok {
common.Log.Debug("createDatabase, Database `%s` created", dbName) common.Log.Debug("createDatabase, Database `%s` created", dbName)
...@@ -369,6 +364,9 @@ func (ve VirtualEnv) createDatabase(rEnv database.Connector, dbName string) erro ...@@ -369,6 +364,9 @@ func (ve VirtualEnv) createDatabase(rEnv database.Connector, dbName string) erro
} }
ddl = strings.Replace(ddl, dbName, dbHash, -1) ddl = strings.Replace(ddl, dbName, dbHash, -1)
if ddl == "" {
return fmt.Errorf("dbName: '%s' get create info error", dbName)
}
_, err = ve.Query(ddl) _, err = ve.Query(ddl)
if err != nil { if err != nil {
common.Log.Warning("createDatabase, Error : %v", err) common.Log.Warning("createDatabase, Error : %v", err)
...@@ -401,7 +399,7 @@ func (ve VirtualEnv) createDatabase(rEnv database.Connector, dbName string) erro ...@@ -401,7 +399,7 @@ func (ve VirtualEnv) createDatabase(rEnv database.Connector, dbName string) erro
soar 能够做出判断并进行 session 级别的修改,但是这一阶段可用性保证应该是由用户提供两个完全相同(或测试环境兼容线上环境) soar 能够做出判断并进行 session 级别的修改,但是这一阶段可用性保证应该是由用户提供两个完全相同(或测试环境兼容线上环境)
的数据库环境来实现的。 的数据库环境来实现的。
*/ */
func (ve VirtualEnv) createTable(rEnv database.Connector, dbName, tbName string) error { func (ve VirtualEnv) createTable(rEnv *database.Connector, dbName, tbName string) error {
if dbName == "" { if dbName == "" {
dbName = rEnv.Database dbName = rEnv.Database
...@@ -470,7 +468,7 @@ func (ve VirtualEnv) createTable(rEnv database.Connector, dbName, tbName string) ...@@ -470,7 +468,7 @@ func (ve VirtualEnv) createTable(rEnv database.Connector, dbName, tbName string)
return nil return nil
} }
// GenTableColumns 为Rewrite提供的结构体初始化 // GenTableColumns 为 Rewrite 提供的结构体初始化
func (ve *VirtualEnv) GenTableColumns(meta common.Meta) common.TableColumns { func (ve *VirtualEnv) GenTableColumns(meta common.Meta) common.TableColumns {
tableColumns := make(common.TableColumns) tableColumns := make(common.TableColumns)
for dbName, db := range meta { for dbName, db := range meta {
......
...@@ -18,10 +18,12 @@ package env ...@@ -18,10 +18,12 @@ package env
import ( import (
"flag" "flag"
"os"
"testing" "testing"
"github.com/XiaoMi/soar/common" "github.com/XiaoMi/soar/common"
"github.com/XiaoMi/soar/database" "github.com/XiaoMi/soar/database"
"github.com/go-sql-driver/mysql" "github.com/go-sql-driver/mysql"
"github.com/kr/pretty" "github.com/kr/pretty"
) )
...@@ -33,12 +35,16 @@ func init() { ...@@ -33,12 +35,16 @@ func init() {
common.BaseDir = common.DevPath common.BaseDir = common.DevPath
err := common.ParseConfig("") err := common.ParseConfig("")
common.LogIfError(err, "init ParseConfig") common.LogIfError(err, "init ParseConfig")
connTest = &database.Connector{ common.Log.Debug("env_test init")
Addr: common.Config.TestDSN.Addr, connTest, err = database.NewConnector(common.Config.TestDSN)
User: common.Config.TestDSN.User, if err != nil {
Pass: common.Config.TestDSN.Password, common.Log.Critical("Test env Error: %v", err)
Database: common.Config.TestDSN.Schema, os.Exit(0)
Charset: common.Config.TestDSN.Charset, }
if _, err := connTest.Version(); err != nil {
common.Log.Critical("Test env Error: %v", err)
os.Exit(0)
} }
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册