提交 16df1e34 编写于 作者: martianzhang's avatar martianzhang

NewConnector

上级 78477138
......@@ -17,25 +17,30 @@
package advisor
import (
"flag"
"fmt"
"os"
"strings"
"testing"
"github.com/XiaoMi/soar/common"
"github.com/XiaoMi/soar/database"
"github.com/XiaoMi/soar/env"
"github.com/kr/pretty"
"vitess.io/vitess/go/vt/sqlparser"
)
var update = flag.Bool("update", false, "update .golden files")
var vEnv *env.VirtualEnv
var rEnv *database.Connector
func init() {
common.BaseDir = common.DevPath
err := common.ParseConfig("")
if err != nil {
fmt.Println(err.Error())
}
vEnv, rEnv := env.BuildEnv()
common.LogIfError(err, "init ParseConfig")
common.Log.Debug("advisor_test init")
vEnv, rEnv = env.BuildEnv()
if _, err = vEnv.Version(); err != nil {
fmt.Println(err.Error(), ", By pass all advisor test cases")
os.Exit(0)
......@@ -45,6 +50,7 @@ func init() {
fmt.Println(err.Error(), ", By pass all advisor test cases")
os.Exit(0)
}
defer vEnv.CleanUp()
}
// ARG.003
......@@ -52,8 +58,6 @@ func TestRuleImplicitConversion(t *testing.T) {
common.Log.Debug("Entering function: %s", common.GetFunctionName())
dsn := common.Config.OnlineDSN
common.Config.OnlineDSN = common.Config.TestDSN
vEnv, rEnv := env.BuildEnv()
defer vEnv.CleanUp()
initSQLs := []string{
`CREATE TABLE t1 (id int, title varchar(255) CHARSET utf8 COLLATE utf8_general_ci);`,
......@@ -104,9 +108,6 @@ func TestRuleImpossibleOuterJoin(t *testing.T) {
`select city_id, city, country from city left outer join country on city.country_id=country.country_id WHERE city.city_id IS NULL`,
}
vEnv, rEnv := env.BuildEnv()
defer vEnv.CleanUp()
for _, sql := range sqls {
stmt, syntaxErr := sqlparser.Parse(sql)
if syntaxErr != nil {
......@@ -146,9 +147,6 @@ func TestIndexAdvisorRuleGroupByConst(t *testing.T) {
},
}
vEnv, rEnv := env.BuildEnv()
defer vEnv.CleanUp()
for _, sql := range sqls[0] {
stmt, syntaxErr := sqlparser.Parse(sql)
if syntaxErr != nil {
......@@ -211,9 +209,6 @@ func TestIndexAdvisorRuleOrderByConst(t *testing.T) {
},
}
vEnv, rEnv := env.BuildEnv()
defer vEnv.CleanUp()
for _, sql := range sqls[0] {
stmt, syntaxErr := sqlparser.Parse(sql)
if syntaxErr != nil {
......@@ -275,9 +270,6 @@ func TestRuleUpdatePrimaryKey(t *testing.T) {
},
}
vEnv, rEnv := env.BuildEnv()
defer vEnv.CleanUp()
for _, sql := range sqls[0] {
stmt, syntaxErr := sqlparser.Parse(sql)
if syntaxErr != nil {
......@@ -328,8 +320,6 @@ func TestRuleUpdatePrimaryKey(t *testing.T) {
func TestIndexAdvise(t *testing.T) {
common.Log.Debug("Entering function: %s", common.GetFunctionName())
vEnv, rEnv := env.BuildEnv()
defer vEnv.CleanUp()
for _, sql := range common.TestSQLs {
stmt, syntaxErr := sqlparser.Parse(sql)
......@@ -360,8 +350,6 @@ func TestIndexAdviseNoEnv(t *testing.T) {
common.Log.Debug("Entering function: %s", common.GetFunctionName())
orgOnlineDSNStatus := common.Config.OnlineDSN.Disable
common.Config.OnlineDSN.Disable = true
vEnv, rEnv := env.BuildEnv()
defer vEnv.CleanUp()
for _, sql := range common.TestSQLs {
stmt, syntaxErr := sqlparser.Parse(sql)
......@@ -391,7 +379,6 @@ func TestIndexAdviseNoEnv(t *testing.T) {
func TestDuplicateKeyChecker(t *testing.T) {
common.Log.Debug("Entering function: %s", common.GetFunctionName())
_, rEnv := env.BuildEnv()
rule := DuplicateKeyChecker(rEnv, "sakila")
if len(rule) != 0 {
t.Errorf("got rules: %s", pretty.Sprint(rule))
......@@ -426,9 +413,6 @@ func TestIdxColsTypeCheck(t *testing.T) {
`select city_id, city, country from city left outer join country using(country_id) WHERE city.city_id=59 and country.country='Algeria'`,
}
vEnv, rEnv := env.BuildEnv()
defer vEnv.CleanUp()
for _, sql := range sqls {
stmt, syntaxErr := sqlparser.Parse(sql)
if syntaxErr != nil {
......
......@@ -17,14 +17,11 @@
package advisor
import (
"flag"
"testing"
"github.com/XiaoMi/soar/common"
)
var update = flag.Bool("update", false, "update .golden files")
func TestListTestSQLs(t *testing.T) {
err := common.GoldenDiff(func() { ListTestSQLs() }, t.Name(), update)
if nil != err {
......
......@@ -68,7 +68,7 @@ func main() {
// 当程序卡死的时候,或者由于某些原因程序没有退出,可以通过捕获信号量的形式让程序优雅退出并且清理测试环境
common.HandleSignal(func() {
shutdown(vEnv)
shutdown(vEnv, rEnv)
})
// 对指定的库表进行索引重复检查
......
......@@ -67,52 +67,48 @@ func initConfig() {
// if error found return non-zero, no error return zero
func checkConfig() int {
// TestDSN connection check
testConn := &database.Connector{
Addr: common.Config.TestDSN.Addr,
User: common.Config.TestDSN.User,
Pass: common.Config.TestDSN.Password,
Database: common.Config.TestDSN.Schema,
Charset: common.Config.TestDSN.Charset,
}
testVersion, err := testConn.Version()
connTest, err := database.NewConnector(common.Config.TestDSN)
if err != nil {
fmt.Println("test-dsn:", common.Config.TestDSN.Addr, err.Error())
return 1
}
testVersion, err := connTest.Version()
if err != nil && !common.Config.TestDSN.Disable {
fmt.Println("test-dsn:", testConn, err.Error())
fmt.Println("test-dsn:", connTest, err.Error())
return 1
}
if common.Config.Verbose {
if err == nil {
fmt.Println("test-dsn", testConn, "Version:", testVersion)
fmt.Println("test-dsn", connTest, "Version:", testVersion)
} else {
fmt.Println("test-dsn", common.Config.TestDSN)
}
}
if !testConn.HasAllPrivilege() {
if !connTest.HasAllPrivilege() {
fmt.Printf("test-dsn: %s, need all privileges", common.FormatDSN(common.Config.TestDSN))
return 1
}
// OnlineDSN connection check
onlineConn := &database.Connector{
Addr: common.Config.OnlineDSN.Addr,
User: common.Config.OnlineDSN.User,
Pass: common.Config.OnlineDSN.Password,
Database: common.Config.OnlineDSN.Schema,
Charset: common.Config.OnlineDSN.Charset,
}
onlineVersion, err := onlineConn.Version()
connOnline, err := database.NewConnector(common.Config.OnlineDSN)
if err != nil {
fmt.Println("test-dsn:", common.Config.OnlineDSN.Addr, err.Error())
return 1
}
onlineVersion, err := connOnline.Version()
if err != nil && !common.Config.OnlineDSN.Disable {
fmt.Println("online-dsn:", onlineConn, err.Error())
fmt.Println("online-dsn:", connOnline, err.Error())
return 1
}
if common.Config.Verbose {
if err == nil {
fmt.Println("online-dsn", onlineConn, "Version:", onlineVersion)
fmt.Println("online-dsn", connOnline, "Version:", onlineVersion)
} else {
fmt.Println("online-dsn", common.Config.OnlineDSN)
}
}
if !onlineConn.HasSelectPrivilege() {
if !connOnline.HasSelectPrivilege() {
fmt.Printf("online-dsn: %s, need all privileges", common.FormatDSN(common.Config.OnlineDSN))
return 1
}
......@@ -229,9 +225,11 @@ func initQuery(query string) string {
return query
}
func shutdown(vEnv *env.VirtualEnv) {
func shutdown(vEnv *env.VirtualEnv, rEnv *database.Connector) {
if common.Config.DropTestTemporary {
vEnv.CleanUp()
}
vEnv.Conn.Close()
rEnv.Conn.Close()
os.Exit(0)
}
......@@ -48,8 +48,8 @@ var (
// Configuration 配置文件定义结构体
type Configuration struct {
// +++++++++++++++测试环境+++++++++++++++++
OnlineDSN *dsn `yaml:"online-dsn"` // 线上环境数据库配置
TestDSN *dsn `yaml:"test-dsn"` // 测试环境数据库配置
OnlineDSN *Dsn `yaml:"online-dsn"` // 线上环境数据库配置
TestDSN *Dsn `yaml:"test-dsn"` // 测试环境数据库配置
AllowOnlineAsTest bool `yaml:"allow-online-as-test"` // 允许 Online 环境也可以当作 Test 环境
DropTestTemporary bool `yaml:"drop-test-temporary"` // 是否清理Test环境产生的临时库表
CleanupTestDatabase bool `yaml:"cleanup-test-database"` // 清理残余的测试数据库(程序异常退出或未开启drop-test-temporary) issue #48
......@@ -132,14 +132,14 @@ type Configuration struct {
// Config 默认设置
var Config = &Configuration{
OnlineDSN: &dsn{
OnlineDSN: &Dsn{
Net: "tcp",
Schema: "information_schema",
Charset: "utf8mb4",
Disable: true,
Version: 99999,
},
TestDSN: &dsn{
TestDSN: &Dsn{
Net: "tcp",
Schema: "information_schema",
Charset: "utf8mb4",
......@@ -224,7 +224,8 @@ var Config = &Configuration{
MaxPrettySQLLength: 1024,
}
type dsn struct {
// Dsn Data source name
type Dsn struct {
Net string `yaml:"net"`
Addr string `yaml:"addr"`
Schema string `yaml:"schema"`
......@@ -243,7 +244,7 @@ type dsn struct {
}
// 解析命令行DSN输入
func parseDSN(odbc string, d *dsn) *dsn {
func parseDSN(odbc string, d *Dsn) *Dsn {
var addr, user, password, schema, charset string
if odbc == FormatDSN(d) {
return d
......@@ -260,7 +261,7 @@ func parseDSN(odbc string, d *dsn) *dsn {
// 设置为空表示禁用环境
odbc = strings.TrimSpace(odbc)
if odbc == "" {
return &dsn{Disable: true}
return &Dsn{Disable: true}
}
// username:password@ip:port/database
......@@ -354,7 +355,7 @@ func parseDSN(odbc string, d *dsn) *dsn {
charset = "utf8mb4"
}
dsn := &dsn{
dsn := &Dsn{
Addr: addr,
User: user,
Password: password,
......@@ -367,7 +368,7 @@ func parseDSN(odbc string, d *dsn) *dsn {
}
// FormatDSN 格式化打印DSN
func FormatDSN(env *dsn) string {
func FormatDSN(env *Dsn) string {
if env == nil || env.Disable {
return ""
}
......
......@@ -19,7 +19,7 @@ package common
import "fmt"
func ExampleFormatDSN() {
dsxExp := &dsn{
dsxExp := &Dsn{
Addr: "127.0.0.1:3306",
Schema: "mysql",
User: "root",
......
......@@ -38,7 +38,7 @@ func NewDB(db string) *DB {
}
}
// TableName 含有表的属性
// Table 含有表的属性
type Table struct {
TableName string
TableAliases []string
......
......@@ -556,14 +556,12 @@ func (db *Connector) explainAbleSQL(sql string) (string, error) {
return "", nil
}
// 执行explain请求,返回mysql.Result执行结果
func (db *Connector) executeExplain(sql string, explainType int, formatType int) (QueryResult, error) {
var res QueryResult
// explainQuery 生成可执行的 explain 查询请求
func (db *Connector) explainQuery(sql string, explainType int, formatType int) string {
var err error
var explainQuery string
sql, err = db.explainAbleSQL(sql)
if sql == "" {
return res, err
if sql == "" || err != nil {
return sql
}
// 5.6以上支持 FORMAT=JSON
......@@ -580,18 +578,17 @@ func (db *Connector) executeExplain(sql string, explainType int, formatType int)
case ExtendedExplainType:
// 5.6以上extended关键字已经不推荐使用,8.0废弃了这个关键字
if common.Config.TestDSN.Version >= 50600 {
explainQuery = fmt.Sprintf("explain %s", sql)
sql = fmt.Sprintf("explain %s", sql)
} else {
explainQuery = fmt.Sprintf("explain extended %s", sql)
sql = fmt.Sprintf("explain extended %s", sql)
}
case PartitionsExplainType:
explainQuery = fmt.Sprintf("explain partitions %s", sql)
sql = fmt.Sprintf("explain partitions %s", sql)
default:
explainQuery = fmt.Sprintf("explain %s %s", explainFormat, sql)
sql = fmt.Sprintf("explain %s %s", explainFormat, sql)
}
res, err = db.Query(explainQuery)
return res, err
return sql
}
// MySQLExplainWarnings WARNINGS信息中包含的优化器信息
......@@ -944,6 +941,7 @@ func ParseExplainResult(res QueryResult, formatType int) (exp *ExplainInfo, err
err = res.Rows.Scan(&explainString)
exp.ExplainJSON, err = parseJSONExplainText(explainString)
}
res.Rows.Close()
return exp, err
}
......@@ -992,23 +990,24 @@ func ParseExplainResult(res QueryResult, formatType int) (exp *ExplainInfo, err
expRow.Scalability = ExplainScalability[expRow.AccessType]
explainRows = append(explainRows, expRow)
}
res.Rows.Close()
exp.ExplainRows = explainRows
// check explain warning info
if common.Config.ShowWarnings {
for res.Warning.Next() {
var expWarning *ExplainWarning
res.Warning.Scan(
expWarning.Level,
expWarning.Code,
expWarning.Message,
)
err = res.Warning.Scan(expWarning.Level, expWarning.Code, expWarning.Message)
if err != nil {
break
}
// 'EXTENDED' is deprecated and will be removed in a future release.
if expWarning.Code != 1681 {
exp.Warnings = append(exp.Warnings, expWarning)
}
}
res.Warning.Close()
}
// 添加 last_query_cost
......@@ -1022,6 +1021,7 @@ func (db *Connector) Explain(sql string, explainType int, formatType int) (exp *
if explainType != TraditionalExplainType {
formatType = TraditionalFormatExplain
}
defer func() {
if e := recover(); e != nil {
const size = 4096
......@@ -1033,10 +1033,8 @@ func (db *Connector) Explain(sql string, explainType int, formatType int) (exp *
}()
// 执行EXPLAIN请求
res, err := db.executeExplain(sql, explainType, formatType)
if err != nil {
return exp, err
}
sql = db.explainQuery(sql, explainType, formatType)
res, err := db.Query(sql)
// 解析mysql结果,输出ExplainInfo
exp, err = ParseExplainResult(res, formatType)
......
......@@ -39,6 +39,7 @@ type Connector struct {
Database string
Charset string
Net string
Conn *sql.DB
}
// QueryResult 数据库查询返回值
......@@ -49,20 +50,38 @@ type QueryResult struct {
QueryCost float64
}
// NewConnection 创建新连接
func (db *Connector) NewConnection() (*sql.DB, error) {
dsn := fmt.Sprintf("%s:%s@%s(%s)/%s?parseTime=true&charset=%s", db.User, db.Pass, db.Net, db.Addr, db.Database, db.Charset)
return sql.Open("mysql", dsn)
// NewConnector 创建新连接
func NewConnector(dsn *common.Dsn) (*Connector, error) {
conn, err := sql.Open("mysql", fmt.Sprintf("%s:%s@%s(%s)/%s?parseTime=true&charset=%s",
dsn.User,
dsn.Password,
dsn.Net,
dsn.Addr,
dsn.Schema,
dsn.Charset,
))
if err != nil {
return nil, err
}
connector := &Connector{
Addr: dsn.Addr,
User: dsn.User,
Pass: dsn.Password,
Database: dsn.Schema,
Charset: dsn.Charset,
Conn: conn,
}
return connector, err
}
// Query 执行SQL
func (db *Connector) Query(sql string, params ...interface{}) (QueryResult, error) {
var res QueryResult
var err error
// 测试环境如果检查是关闭的,则SQL不会被执行
if common.Config.TestDSN.Disable {
return res, errors.New("dsn is disable")
}
// 数据库安全性检查:如果 Connector 的 IP 端口与 TEST 环境不一致,则启用SQL白名单
// 不在白名单中的SQL不允许执行
// 执行环境与test环境不相同
......@@ -72,23 +91,25 @@ func (db *Connector) Query(sql string, params ...interface{}) (QueryResult, erro
}
common.Log.Debug("Execute SQL with DSN(%s/%s) : %s", db.Addr, db.Database, fmt.Sprintf(sql, params...))
conn, err := db.NewConnection()
defer conn.Close()
if err != nil {
return res, err
}
res.Rows, res.Error = conn.Query(sql, params...)
_, err = db.Conn.Exec("USE " + db.Database)
common.LogIfError(err, "")
res.Rows, res.Error = db.Conn.Query(sql, params...)
if common.Config.ShowWarnings {
res.Warning, err = conn.Query("SHOW WARNINGS")
res.Warning, err = db.Conn.Query("SHOW WARNINGS")
common.LogIfError(err, "")
}
// SHOW WARNINGS 并不会影响 last_query_cost
if common.Config.ShowLastQueryCost {
cost, err := conn.Query("SHOW SESSION STATUS LIKE 'last_query_cost'")
cost, err := db.Conn.Query("SHOW SESSION STATUS LIKE 'last_query_cost'")
if err == nil {
if cost.Next() {
err = cost.Scan(res.QueryCost)
common.LogIfError(err, "")
}
if err := cost.Close(); err != nil {
common.Log.Error(err.Error())
}
}
}
......@@ -115,6 +136,9 @@ func (db *Connector) Version() (int, error) {
var versionSeg []string
for res.Rows.Next() {
err = res.Rows.Scan(&versionStr)
if err != nil {
break
}
versionStr = strings.Split(versionStr, "-")[0]
versionSeg = strings.Split(versionStr, ".")
if len(versionSeg) == 3 {
......@@ -123,7 +147,9 @@ func (db *Connector) Version() (int, error) {
}
break
}
if err := res.Rows.Close(); err != nil {
common.Log.Error(err.Error())
}
return version, err
}
......@@ -140,6 +166,9 @@ func (db *Connector) SingleIntValue(option string) (int, error) {
if res.Rows.Next() {
err = res.Rows.Scan(&intVal)
}
if err := res.Rows.Close(); err != nil {
common.Log.Error(err.Error())
}
return intVal, err
}
......@@ -188,6 +217,7 @@ func (db *Connector) ColumnCardinality(tb, col string) float64 {
return 0
}
}
res.Rows.Close()
// 当table status元数据不准确时 rowTotal 可能远小于count(*),导致散粒度大于1
if colNum > float64(rowTotal) {
......
......@@ -28,31 +28,23 @@ import (
)
var connTest *Connector
var update = flag.Bool("update", false, "update .golden files")
func init() {
common.BaseDir = common.DevPath
common.ParseConfig("")
connTest = &Connector{
Addr: common.Config.OnlineDSN.Addr,
User: common.Config.OnlineDSN.User,
Pass: common.Config.OnlineDSN.Password,
Database: common.Config.OnlineDSN.Schema,
Charset: common.Config.OnlineDSN.Charset,
}
if _, err := connTest.Version(); err != nil {
err := common.ParseConfig("")
common.LogIfError(err, "init ParseConfig")
common.Log.Debug("mysql_test init")
connTest, err = NewConnector(common.Config.TestDSN)
if err != nil {
common.Log.Critical("Test env Error: %v", err)
os.Exit(0)
}
}
func TestNewConnection(t *testing.T) {
conn, err := connTest.NewConnection()
if err != nil {
t.Errorf("TestNewConnection, Error: %s", err.Error())
if _, err := connTest.Version(); err != nil {
common.Log.Critical("Test env Error: %v", err)
os.Exit(0)
}
defer conn.Close()
}
func TestQuery(t *testing.T) {
......@@ -70,6 +62,7 @@ func TestQuery(t *testing.T) {
t.Error("should return 0")
}
}
res.Rows.Close()
// TODO: timeout test
}
......@@ -115,6 +108,7 @@ func TestWarningsAndQueryCost(t *testing.T) {
}
pretty.Println(str)
}
res.Warning.Close()
fmt.Println(res.QueryCost, err)
}
}
......
......@@ -26,28 +26,31 @@ import (
// CurrentUser get current user with current_user() function
func (db *Connector) CurrentUser() (string, string, error) {
var user, host string
res, err := db.Query("select current_user()")
if err != nil {
return "", "", err
return user, host, err
}
if res.Rows.Next() {
var currentUser string
err = res.Rows.Scan(&currentUser)
if err != nil {
return "", "", err
return user, host, err
}
res.Rows.Close()
cols := strings.Split(currentUser, "@")
if len(cols) == 2 {
user := strings.Trim(cols[0], "'")
host := strings.Trim(cols[1], "'")
user = strings.Trim(cols[0], "'")
host = strings.Trim(cols[1], "'")
if strings.Contains(user, "'") || strings.Contains(host, "'") {
return "", "", errors.New("user or host contains irregular character")
}
return user, host, nil
}
return "", "", errors.New("user or host contains irregular character")
return user, host, errors.New("user or host contains irregular character")
}
return "", "", errors.New("no privilege info")
return user, host, errors.New("no privilege info")
}
// HasSelectPrivilege if user has select privilege
......@@ -70,6 +73,8 @@ func (db *Connector) HasSelectPrivilege() bool {
common.Log.Error("HasSelectPrivilege, Scan Error: %s", err.Error())
return false
}
res.Rows.Close()
if selectPrivilege == "Y" {
return true
}
......@@ -99,6 +104,7 @@ func (db *Connector) HasAllPrivilege() bool {
common.Log.Error("HasAllPrivilege, DSN: %s, Scan error", db.Addr)
return false
}
res.Rows.Close()
}
// get all privilege status
......@@ -115,6 +121,7 @@ func (db *Connector) HasAllPrivilege() bool {
common.Log.Error("HasAllPrivilege, DSN: %s, Scan error", db.Addr)
return false
}
res.Rows.Close()
if strings.Replace(priv, "Y", "", -1) == "" {
return true
}
......
......@@ -62,15 +62,9 @@ func (db *Connector) Profiling(sql string, params ...interface{}) ([]ProfilingRo
}
common.Log.Debug("Execute SQL with DSN(%s/%s) : %s", db.Addr, db.Database, sql)
conn, err := db.NewConnection()
if err != nil {
return rows, err
}
defer conn.Close()
// Keep connection
// https://github.com/go-sql-driver/mysql/issues/208
trx, err := conn.Begin()
trx, err := db.Conn.Begin()
if err != nil {
return rows, err
}
......@@ -91,14 +85,20 @@ func (db *Connector) Profiling(sql string, params ...interface{}) ([]ProfilingRo
// 返回 Profiling 结果
res, err := trx.Query("show profile")
if err != nil {
trx.Rollback()
return rows, err
}
var profileRow ProfilingRow
for res.Next() {
var profileRow ProfilingRow
err := res.Scan(&profileRow.Status, &profileRow.Duration)
err = res.Scan(&profileRow.Status, &profileRow.Duration)
if err != nil {
common.LogIfError(err, "")
break
}
rows = append(rows, profileRow)
}
res.Close()
// 关闭 Profiling
_, err = trx.Query("set @@profiling=0")
......
......@@ -19,6 +19,7 @@ package database
import (
"database/sql"
"fmt"
"github.com/XiaoMi/soar/common"
)
......@@ -43,7 +44,7 @@ import (
*/
// SamplingData 将数据从Remote拉取到 db 中
func (db *Connector) SamplingData(remote Connector, tables ...string) error {
func (db *Connector) SamplingData(remote *Connector, tables ...string) error {
// 计算需要泵取的数据量
wantRowsCount := 300 * common.Config.SamplingStatisticTarget
......@@ -51,19 +52,6 @@ func (db *Connector) SamplingData(remote Connector, tables ...string) error {
// 该数值越大,在内存中缓存的data就越多,但相对的,插入时速度就越快
maxValCount := 200
// 获取数据库连接对象
conn, err := remote.NewConnection()
if err != nil {
return err
}
defer conn.Close()
localConn, err := db.NewConnection()
if err != nil {
return err
}
defer localConn.Close()
for _, table := range tables {
// 表类型检查
if remote.IsView(table) {
......@@ -89,7 +77,7 @@ func (db *Connector) SamplingData(remote Connector, tables ...string) error {
factor := float64(wantRowsCount) / float64(tableRows)
common.Log.Debug("SamplingData, tableRows: %d, wantRowsCount: %d, factor: %f", tableRows, wantRowsCount, factor)
err = startSampling(conn, localConn, db.Database, table, factor, wantRowsCount, maxValCount)
err = startSampling(remote.Conn, db.Conn, db.Database, table, factor, wantRowsCount, maxValCount)
if err != nil {
common.Log.Error("(db *Connector) SamplingData Error : %v", err)
}
......
......@@ -27,27 +27,12 @@ func init() {
}
func TestSamplingData(t *testing.T) {
online := &Connector{
Addr: common.Config.OnlineDSN.Addr,
User: common.Config.OnlineDSN.User,
Pass: common.Config.OnlineDSN.Password,
Database: common.Config.OnlineDSN.Schema,
Charset: common.Config.OnlineDSN.Charset,
Net: common.Config.OnlineDSN.Net,
}
offline := &Connector{
Addr: common.Config.TestDSN.Addr,
User: common.Config.TestDSN.User,
Pass: common.Config.TestDSN.Password,
Database: common.Config.TestDSN.Schema,
Charset: common.Config.TestDSN.Charset,
Net: common.Config.TestDSN.Net,
connOnline, err := NewConnector(common.Config.OnlineDSN)
if err != nil {
t.Error(err)
}
offline.Database = "test"
err := connTest.SamplingData(*online, "film")
err = connTest.SamplingData(connOnline, "film")
if err != nil {
t.Error(err)
}
......
......@@ -99,10 +99,11 @@ func (db *Connector) ShowTables() ([]string, error) {
var table string
err = res.Rows.Scan(&table)
if err != nil {
return []string{}, err
break
}
tables = append(tables, table)
}
res.Rows.Close()
return tables, err
}
......@@ -156,6 +157,7 @@ func (db *Connector) ShowTableStatus(tableName string) (*TableStatInfo, error) {
res.Rows.Scan(statusFields...)
tbStatus.Rows = append(tbStatus.Rows, ts)
}
res.Rows.Close()
return tbStatus, err
}
......@@ -244,6 +246,7 @@ func (db *Connector) ShowIndex(tableName string) (*TableIndexInfo, error) {
res.Rows.Scan(indexFields...)
tbIndex.Rows = append(tbIndex.Rows, ti)
}
res.Rows.Close()
return tbIndex, err
}
......@@ -374,6 +377,7 @@ func (db *Connector) ShowColumns(tableName string) (*TableDesc, error) {
res.Rows.Scan(columnFields...)
tbDesc.DescValues = append(tbDesc.DescValues, tc)
}
res.Rows.Close()
return tbDesc, err
}
......@@ -428,7 +432,7 @@ func (db *Connector) showCreate(createType, name string) (string, error) {
for res.Rows.Next() {
res.Rows.Scan(createFields...)
}
res.Rows.Close()
return create, err
}
......@@ -508,13 +512,10 @@ func (db *Connector) FindColumn(name, dbName string, tables ...string) ([]*commo
var col common.Column
for res.Rows.Next() {
var character, collation []byte
res.Rows.Scan(
&col.Table,
&col.DB,
&col.DataType,
&character,
&collation,
)
err = res.Rows.Scan(&col.Table, &col.DB, &col.DataType, &character, &collation)
if err != nil {
break
}
col.Name = name
col.Character = string(character)
col.Collation = string(collation)
......@@ -537,8 +538,12 @@ func (db *Connector) FindColumn(name, dbName string, tables ...string) ([]*commo
var tbCollation []byte
if newRes.Rows.Next() {
newRes.Rows.Scan(&tbCollation)
err = newRes.Rows.Scan(&tbCollation)
if err != nil {
break
}
}
newRes.Rows.Close()
if string(tbCollation) != "" {
col.Character = strings.Split(string(tbCollation), "_")[0]
col.Collation = string(tbCollation)
......@@ -546,6 +551,7 @@ func (db *Connector) FindColumn(name, dbName string, tables ...string) ([]*commo
}
columns = append(columns, &col)
}
res.Rows.Close()
return columns, err
}
......@@ -603,13 +609,12 @@ func (db *Connector) ShowReference(dbName string, tbName ...string) ([]Reference
// 获取值
for res.Rows.Next() {
var rv ReferenceValue
res.Rows.Scan(&rv.ReferencedTableSchema,
&rv.ReferencedTableName,
&rv.TableSchema,
&rv.TableName,
&rv.ConstraintName)
err = res.Rows.Scan(&rv.ReferencedTableSchema, &rv.ReferencedTableName, &rv.TableSchema, &rv.TableName, &rv.ConstraintName)
if err != nil {
break
}
referenceValues = append(referenceValues, rv)
}
res.Rows.Close()
return referenceValues, err
}
......@@ -68,6 +68,15 @@ func TestShowTables(t *testing.T) {
connTest.Database = orgDatabase
}
func TestShowCreateDatabase(t *testing.T) {
err := common.GoldenDiff(func() {
fmt.Println(connTest.ShowCreateDatabase("sakila"))
}, t.Name(), update)
if err != nil {
t.Error(err)
}
}
func TestShowCreateTable(t *testing.T) {
orgDatabase := connTest.Database
connTest.Database = "sakila"
......@@ -87,7 +96,6 @@ func TestShowCreateTable(t *testing.T) {
if err != nil {
t.Error(err)
}
connTest.Database = orgDatabase
}
......
CREATE DATABASE `sakila` /*!40100 DEFAULT CHARACTER SET utf8mb4 COLLATE utf8mb4_0900_ai_ci */ <nil>
......@@ -19,10 +19,11 @@ package database
import (
"errors"
"fmt"
"github.com/XiaoMi/soar/common"
"regexp"
"strings"
"github.com/XiaoMi/soar/common"
"vitess.io/vitess/go/vt/sqlparser"
)
......@@ -70,15 +71,9 @@ func (db *Connector) Trace(sql string, params ...interface{}) ([]TraceRow, error
}
common.Log.Debug("Execute SQL with DSN(%s/%s) : %s", db.Addr, db.Database, sql)
conn, err := db.NewConnection()
if err != nil {
return rows, err
}
defer conn.Close()
// 开启Trace
common.Log.Debug("SET SESSION OPTIMIZER_TRACE='enabled=on'")
trx, err := conn.Begin()
trx, err := db.Conn.Begin()
if err != nil {
return rows, err
}
......@@ -97,14 +92,20 @@ func (db *Connector) Trace(sql string, params ...interface{}) ([]TraceRow, error
// 返回Trace结果
res, err := trx.Query("SELECT * FROM information_schema.OPTIMIZER_TRACE")
if err != nil {
trx.Rollback()
return rows, err
}
for res.Next() {
var traceRow TraceRow
err = res.Scan(&traceRow.Query, &traceRow.Trace, &traceRow.MissingBytesBeyondMaxMemSize, &traceRow.InsufficientPrivileges)
if err != nil {
common.LogIfError(err, "")
break
}
rows = append(rows, traceRow)
}
res.Close()
// 关闭Trace
common.Log.Debug("SET SESSION OPTIMIZER_TRACE='enabled=off'")
......
......@@ -57,14 +57,10 @@ func NewVirtualEnv(vEnv *database.Connector) *VirtualEnv {
// @output *VirtualEnv 测试环境
// @output *database.Connector 线上环境连接句柄
func BuildEnv() (*VirtualEnv, *database.Connector) {
connTest, err := database.NewConnector(common.Config.TestDSN)
common.LogIfError(err, "")
// 生成测试环境
vEnv := NewVirtualEnv(&database.Connector{
Addr: common.Config.TestDSN.Addr,
User: common.Config.TestDSN.User,
Pass: common.Config.TestDSN.Password,
Database: common.Config.TestDSN.Schema,
Charset: common.Config.TestDSN.Charset,
})
vEnv := NewVirtualEnv(connTest)
// 检查测试环境可用性,并记录数据库版本
vEnvVersion, err := vEnv.Version()
......@@ -82,13 +78,8 @@ func BuildEnv() (*VirtualEnv, *database.Connector) {
vEnv.User, vEnv.Addr, vEnv.Database)
common.Config.OnlineDSN = common.Config.TestDSN
}
conn := &database.Connector{
Addr: common.Config.OnlineDSN.Addr,
User: common.Config.OnlineDSN.User,
Pass: common.Config.OnlineDSN.Password,
Database: common.Config.OnlineDSN.Schema,
Charset: common.Config.OnlineDSN.Charset,
}
connOnline, err := database.NewConnector(common.Config.OnlineDSN)
common.LogIfError(err, "")
// 检查线上环境可用性版本
rEnvVersion, err := vEnv.Version()
......@@ -114,7 +105,7 @@ func BuildEnv() (*VirtualEnv, *database.Connector) {
common.Config.TestDSN.Disable = true
}
return vEnv, conn
return vEnv, connOnline
}
// RealDB 从测试环境中获取通过hash后的DB
......@@ -163,7 +154,10 @@ func (ve *VirtualEnv) CleanupTestDatabase() {
minHour := 1
for dbs.Rows.Next() {
var testDatabase string
dbs.Rows.Scan(&testDatabase)
err = dbs.Rows.Scan(&testDatabase)
if err != nil {
break
}
// test temporary database format `optimizer_YYMMDDHHmmss_randomString(16)`
if len(testDatabase) != 39 {
common.Log.Debug("CleanupTestDatabase by pass %s", testDatabase)
......@@ -187,7 +181,8 @@ func (ve *VirtualEnv) CleanupTestDatabase() {
}
common.Log.Debug("CleanupTestDatabase by pass database %s, %.2f less than %d hours", testDatabase, subHour, minHour)
}
err = dbs.Rows.Close()
common.LogIfError(err, "")
common.Log.Debug("CleanupTestDatabase done")
}
......@@ -200,7 +195,7 @@ func (ve *VirtualEnv) BuildVirtualEnv(rEnv *database.Connector, SQLs ...string)
// 置空错误信息
ve.Error = nil
// 检测是否已经创建初始数据库,如果未创建则创建一个名称hash过的映射数据库
err = ve.createDatabase(*rEnv, rEnv.Database)
err = ve.createDatabase(rEnv, rEnv.Database)
common.LogIfWarn(err, "")
// 测试环境检测
......@@ -237,7 +232,7 @@ func (ve *VirtualEnv) BuildVirtualEnv(rEnv *database.Connector, SQLs ...string)
rEnv.Database = stmt.DBName.String()
// use DB 后检查 DB是否已经创建,如果没有创建则创建DB
err = ve.createDatabase(*rEnv, rEnv.Database)
err = ve.createDatabase(rEnv, rEnv.Database)
common.LogIfWarn(err, "")
}
return true
......@@ -271,7 +266,7 @@ func (ve *VirtualEnv) BuildVirtualEnv(rEnv *database.Connector, SQLs ...string)
// 拉取表结构
table := stmt.Table.Name.String()
if table != "" {
err = ve.createTable(*rEnv, rEnv.Database, table)
err = ve.createTable(rEnv, rEnv.Database, table)
// 这里如果报错可能有两种可能:
// 1. SQL 是 Create 语句,线上环境并没有相关的库表结构
// 2. 在测试环境中执行 SQL 报错
......@@ -303,7 +298,7 @@ func (ve *VirtualEnv) BuildVirtualEnv(rEnv *database.Connector, SQLs ...string)
if db == "" {
db = rEnv.Database
}
tmpEnv := *rEnv
tmpEnv := rEnv
tmpEnv.Database = db
// 创建数据库环境
......@@ -336,7 +331,7 @@ func (ve *VirtualEnv) BuildVirtualEnv(rEnv *database.Connector, SQLs ...string)
return false
}
viewDDL = viewDDL[startIdx+2:]
if !ve.BuildVirtualEnv(&tmpEnv, viewDDL) {
if !ve.BuildVirtualEnv(tmpEnv, viewDDL) {
return false
}
}
......@@ -352,7 +347,7 @@ func (ve *VirtualEnv) BuildVirtualEnv(rEnv *database.Connector, SQLs ...string)
return true
}
func (ve VirtualEnv) createDatabase(rEnv database.Connector, dbName string) error {
func (ve VirtualEnv) createDatabase(rEnv *database.Connector, dbName string) error {
// 生成映射关系
if _, ok := ve.DBRef[dbName]; ok {
common.Log.Debug("createDatabase, Database `%s` created", dbName)
......@@ -369,6 +364,9 @@ func (ve VirtualEnv) createDatabase(rEnv database.Connector, dbName string) erro
}
ddl = strings.Replace(ddl, dbName, dbHash, -1)
if ddl == "" {
return fmt.Errorf("dbName: '%s' get create info error", dbName)
}
_, err = ve.Query(ddl)
if err != nil {
common.Log.Warning("createDatabase, Error : %v", err)
......@@ -401,7 +399,7 @@ func (ve VirtualEnv) createDatabase(rEnv database.Connector, dbName string) erro
soar 能够做出判断并进行 session 级别的修改,但是这一阶段可用性保证应该是由用户提供两个完全相同(或测试环境兼容线上环境)
的数据库环境来实现的。
*/
func (ve VirtualEnv) createTable(rEnv database.Connector, dbName, tbName string) error {
func (ve VirtualEnv) createTable(rEnv *database.Connector, dbName, tbName string) error {
if dbName == "" {
dbName = rEnv.Database
......@@ -470,7 +468,7 @@ func (ve VirtualEnv) createTable(rEnv database.Connector, dbName, tbName string)
return nil
}
// GenTableColumns 为Rewrite提供的结构体初始化
// GenTableColumns 为 Rewrite 提供的结构体初始化
func (ve *VirtualEnv) GenTableColumns(meta common.Meta) common.TableColumns {
tableColumns := make(common.TableColumns)
for dbName, db := range meta {
......
......@@ -18,10 +18,12 @@ package env
import (
"flag"
"os"
"testing"
"github.com/XiaoMi/soar/common"
"github.com/XiaoMi/soar/database"
"github.com/go-sql-driver/mysql"
"github.com/kr/pretty"
)
......@@ -33,12 +35,16 @@ func init() {
common.BaseDir = common.DevPath
err := common.ParseConfig("")
common.LogIfError(err, "init ParseConfig")
connTest = &database.Connector{
Addr: common.Config.TestDSN.Addr,
User: common.Config.TestDSN.User,
Pass: common.Config.TestDSN.Password,
Database: common.Config.TestDSN.Schema,
Charset: common.Config.TestDSN.Charset,
common.Log.Debug("env_test init")
connTest, err = database.NewConnector(common.Config.TestDSN)
if err != nil {
common.Log.Critical("Test env Error: %v", err)
os.Exit(0)
}
if _, err := connTest.Version(); err != nil {
common.Log.Critical("Test env Error: %v", err)
os.Exit(0)
}
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册