diff --git a/common/config.go b/common/config.go index b74f9728668df5a7d013f655cb8949addcaa53ba..e796fbfd59c6c6074781ed4b7bc4331cd296a3cb 100644 --- a/common/config.go +++ b/common/config.go @@ -56,6 +56,7 @@ type Configuration struct { OnlySyntaxCheck bool `yaml:"only-syntax-check"` // 只做语法检查不输出优化建议 SamplingStatisticTarget int `yaml:"sampling-statistic-target"` // 数据采样因子,对应 PostgreSQL 的 default_statistics_target Sampling bool `yaml:"sampling"` // 数据采样开关 + SamplingCondition string `yaml:"sampling-condition"` // 指定采样条件,如:WHERE xxx LIMIT xxx; Profiling bool `yaml:"profiling"` // 在开启数据采样的情况下,在测试环境执行进行profile Trace bool `yaml:"trace"` // 在开启数据采样的情况下,在测试环境执行进行Trace Explain bool `yaml:"explain"` // Explain开关 @@ -506,6 +507,7 @@ func readCmdFlags() error { explain := flag.Bool("explain", Config.Explain, "Explain, 是否开启Explain执行计划分析") sampling := flag.Bool("sampling", Config.Sampling, "Sampling, 数据采样开关") samplingStatisticTarget := flag.Int("sampling-statistic-target", Config.SamplingStatisticTarget, "SamplingStatisticTarget, 数据采样因子,对应 PostgreSQL 的 default_statistics_target") + samplingCondition := flag.String("sampling-condition", Config.SamplingCondition, "SamplingCondition, 数据采样条件,如: WHERE xxx LIMIT xxx") delimiter := flag.String("delimiter", Config.Delimiter, "Delimiter, SQL分隔符") // +++++++++++++++日志相关+++++++++++++++++ logLevel := flag.Int("log-level", Config.LogLevel, "LogLevel, 日志级别, [0:Emergency, 1:Alert, 2:Critical, 3:Error, 4:Warning, 5:Notice, 6:Informational, 7:Debug]") @@ -585,6 +587,7 @@ func readCmdFlags() error { Config.Explain = *explain Config.Sampling = *sampling Config.SamplingStatisticTarget = *samplingStatisticTarget + Config.SamplingCondition = *samplingCondition Config.LogLevel = *logLevel if strings.HasPrefix(*logOutput, "/") { diff --git a/database/sampling.go b/database/sampling.go index fab82402c78742e241f2d011198a74f0cdd605b8..478b8177e024dec86e99ab6fa07d72c3ebdebbcf 100644 --- a/database/sampling.go +++ b/database/sampling.go @@ -17,11 +17,14 @@ package database import ( - "database/sql" "fmt" + "time" - "github.com/XiaoMi/soar/common" "strings" + + "database/sql" + "github.com/XiaoMi/soar/common" + "github.com/ziutek/mymysql/mysql" ) /*-------------------- @@ -44,99 +47,109 @@ import ( *-------------------- */ -// SamplingData 将数据从Remote拉取到 db 中 -func (db *Connector) SamplingData(remote *Connector, tables ...string) error { +// SamplingData 将数据从 onlineConn 拉取到 db 中 +func (db *Connector) SamplingData(onlineConn *Connector, database string, tables ...string) error { + var err error + if database == db.Database { + return fmt.Errorf("SamplingData the same database, From: %s/%s, To: %s/%s", onlineConn.Addr, database, db.Addr, db.Database) + } + // 计算需要泵取的数据量 wantRowsCount := 300 * common.Config.SamplingStatisticTarget - // 设置数据采样单条 SQL 中 value 的数量 - // 该数值越大,在内存中缓存的data就越多,但相对的,插入时速度就越快 - maxValCount := 200 - for _, table := range tables { // 表类型检查 - if remote.IsView(table) { + if onlineConn.IsView(table) { return nil } - tableStatus, err := remote.ShowTableStatus(table) - if err != nil { - return err - } - - if len(tableStatus.Rows) == 0 { - common.Log.Info("SamplingData, Table %s with no data, stop sampling", table) - return nil + // generate where condition + var where string + if common.Config.SamplingCondition == "" { + tableStatus, err := onlineConn.ShowTableStatus(table) + if err != nil { + return err + } + + if len(tableStatus.Rows) == 0 { + common.Log.Info("SamplingData, Table %s with no data, stop sampling", table) + return nil + } + + tableRows := tableStatus.Rows[0].Rows + if tableRows == 0 { + common.Log.Info("SamplingData, Table %s with no data, stop sampling", table) + return nil + } + + factor := float64(wantRowsCount) / float64(tableRows) + common.Log.Debug("SamplingData, tableRows: %d, wantRowsCount: %d, factor: %f", tableRows, wantRowsCount, factor) + where = fmt.Sprintf("WHERE RAND() <= %f LIMIT %d", factor, wantRowsCount) + if factor >= 1 { + where = "" + } + } else { + where = common.Config.SamplingCondition } - tableRows := tableStatus.Rows[0].Rows - if tableRows == 0 { - common.Log.Info("SamplingData, Table %s with no data, stop sampling", table) - return nil - } - - factor := float64(wantRowsCount) / float64(tableRows) - common.Log.Debug("SamplingData, tableRows: %d, wantRowsCount: %d, factor: %f", tableRows, wantRowsCount, factor) - - err = startSampling(remote.Conn, db.Conn, db.Database, table, factor, wantRowsCount, maxValCount) - if err != nil { - common.Log.Error("(db *Connector) SamplingData Error : %v", err) - } + err = db.startSampling(onlineConn.Conn, database, table, where) } - return nil + return err } // startSampling sampling data from OnlineDSN to TestDSN -// 因为涉及到的数据量问题,所以泵取与插入时同时进行的 -// TODO: 加 ref link -func startSampling(conn, localConn *sql.DB, database, table string, factor float64, wants, maxValCount int) error { - // generate where condition - where := fmt.Sprintf("WHERE RAND() <= %f", factor) - if factor >= 1 { - where = "" - } - - res, err := conn.Query(fmt.Sprintf("SELECT * FROM `%s`.`%s` %s LIMIT %d;", database, table, where, wants)) +func (db *Connector) startSampling(onlineConn *sql.DB, database, table string, where string) error { + samplingQuery := fmt.Sprintf("SELECT * FROM `%s`.`%s` %s", database, table, where) + common.Log.Debug("startSampling with Query: %s", samplingQuery) + res, err := onlineConn.Query(samplingQuery) if err != nil { return err } - // column info + // columns list columns, err := res.Columns() if err != nil { return err } - row := make(map[string][]byte, len(columns)) + row := make([][]byte, len(columns)) tableFields := make([]interface{}, 0) - for _, col := range columns { - if _, ok := row[col]; ok { - tableFields = append(tableFields, row[col]) - } + for i := range columns { + tableFields = append(tableFields, &row[i]) + } + columnTypes, err := res.ColumnTypes() + if err != nil { + return err } // sampling data - var valuesStr string var values []string columnsStr := "`" + strings.Join(columns, "`,`") + "`" for res.Next() { res.Scan(tableFields...) - for _, val := range row { - values = append(values, fmt.Sprintf(`unhex("%s")`, fmt.Sprintf("%x", val))) + for i, val := range row { + if val == nil { + values = append(values, "NULL") + } else { + switch columnTypes[i].DatabaseTypeName() { + case "TIMESTAMP", "DATETIME": + t, err := time.Parse(time.RFC3339, string(val)) + common.LogIfWarn(err, "") + values = append(values, fmt.Sprintf(`"%s"`, mysql.TimeString(t))) + default: + values = append(values, fmt.Sprintf(`unhex("%s")`, fmt.Sprintf("%x", val))) + } + } } - valuesStr = fmt.Sprintf(`(%s)`, strings.Join(values, `,`)) - doSampling(localConn, database, table, columnsStr, valuesStr) + err = db.doSampling(table, columnsStr, strings.Join(values, `,`)) } res.Close() - return nil + return err } -// 将泵取的数据转换成Insert语句并在数据库中执行 -func doSampling(conn *sql.DB, dbName, table, colDef, values string) { - query := fmt.Sprintf("INSERT INTO `%s`.`%s` (%s) VALUES %s;", dbName, table, - colDef, values) - - _, err := conn.Exec(query) - if err != nil { - common.Log.Error("doSampling Error from %s.%s: %v", dbName, table, err) - } +// 将泵取的数据转换成 insert 语句并在 testConn 数据库中执行 +func (db *Connector) doSampling(table, colDef, values string) error { + // db.Database is hashed database name + query := fmt.Sprintf("INSERT INTO `%s`.`%s` (%s) VALUES (%s);", db.Database, table, colDef, values) + _, err := db.Query(query) + return err } diff --git a/database/sampling_test.go b/database/sampling_test.go deleted file mode 100644 index f8525dbcc5bc58ee67544cc23046c443024bec44..0000000000000000000000000000000000000000 --- a/database/sampling_test.go +++ /dev/null @@ -1,39 +0,0 @@ -/* - * Copyright 2018 Xiaomi, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package database - -import ( - "testing" - - "github.com/XiaoMi/soar/common" -) - -func init() { - common.BaseDir = common.DevPath -} - -func TestSamplingData(t *testing.T) { - connOnline, err := NewConnector(common.Config.OnlineDSN) - if err != nil { - t.Error(err) - } - - err = connTest.SamplingData(connOnline, "film") - if err != nil { - t.Error(err) - } -} diff --git a/database/show.go b/database/show.go index 75ca837c461b8845d49c0b64cbae910374e11fee..214781cd9547fa12e4d84f60c39062996814690d 100644 --- a/database/show.go +++ b/database/show.go @@ -459,27 +459,27 @@ func (db *Connector) ShowCreateTable(tableName string) (string, error) { ddl, err := db.showCreate("table", tableName) // 去除外键关联条件 - var noConstraint []string - relationReg, _ := regexp.Compile("CONSTRAINT") - for _, line := range strings.Split(ddl, "\n") { - - if relationReg.Match([]byte(line)) { - continue - } - - // 去除外键语句会使DDL中多一个','导致语法错误,要把多余的逗号去除 - if strings.Index(line, ")") == 0 { - lineWrongSyntax := noConstraint[len(noConstraint)-1] - // 如果')'前一句的末尾是',' 删除 ',' 保证语法正确性 - if strings.Index(lineWrongSyntax, ",") == len(lineWrongSyntax)-1 { - noConstraint[len(noConstraint)-1] = lineWrongSyntax[:len(lineWrongSyntax)-1] + lines := strings.Split(ddl, "\n") + // CREATE VIEW ONLY 1 LINE + if len(lines) > 2 { + var noConstraint []string + relationReg, _ := regexp.Compile("CONSTRAINT") + for _, line := range lines[1 : len(lines)-1] { + if relationReg.Match([]byte(line)) { + continue } + line = strings.TrimSuffix(line, ",") + noConstraint = append(noConstraint, line) } - noConstraint = append(noConstraint, line) + // 去除外键语句会使DDL中多一个','导致语法错误,要把多余的逗号去除 + ddl = fmt.Sprint( + lines[0], "\n", + strings.Join(noConstraint, ",\n"), "\n", + lines[len(lines)-1], + ) } - - return strings.Join(noConstraint, "\n"), err + return ddl, err } // FindColumn find column diff --git a/database/show_test.go b/database/show_test.go index ceea0abffc4883dd4d4a679603c7561c86ebead7..a8ecbee9f319357ce7259cde8eb1fa19ba513a9e 100644 --- a/database/show_test.go +++ b/database/show_test.go @@ -82,7 +82,9 @@ func TestShowCreateTable(t *testing.T) { connTest.Database = "sakila" tables := []string{ "film", + "category", "customer_list", + "inventory", } err := common.GoldenDiff(func() { for _, table := range tables { diff --git a/database/testdata/TestShowCreateTable.golden b/database/testdata/TestShowCreateTable.golden index 74208c702b85a1014f34c2a7561c273ade684f04..d91cc1c055b3287b3b111ba8390f8f5a08216264 100644 --- a/database/testdata/TestShowCreateTable.golden +++ b/database/testdata/TestShowCreateTable.golden @@ -17,4 +17,19 @@ CREATE TABLE `film` ( KEY `idx_fk_language_id` (`language_id`), KEY `idx_fk_original_language_id` (`original_language_id`) ) ENGINE=InnoDB AUTO_INCREMENT=1001 DEFAULT CHARSET=utf8 +CREATE TABLE `category` ( + `category_id` tinyint(3) unsigned NOT NULL AUTO_INCREMENT, + `name` varchar(25) NOT NULL, + `last_update` timestamp NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, + PRIMARY KEY (`category_id`) +) ENGINE=InnoDB AUTO_INCREMENT=17 DEFAULT CHARSET=utf8 CREATE ALGORITHM=UNDEFINED DEFINER=`root`@`localhost` SQL SECURITY DEFINER VIEW `customer_list` AS select `cu`.`customer_id` AS `ID`,concat(`cu`.`first_name`,_utf8mb3' ',`cu`.`last_name`) AS `name`,`a`.`address` AS `address`,`a`.`postal_code` AS `zip code`,`a`.`phone` AS `phone`,`city`.`city` AS `city`,`country`.`country` AS `country`,if(`cu`.`active`,_utf8mb3'active',_utf8mb3'') AS `notes`,`cu`.`store_id` AS `SID` from (((`customer` `cu` join `address` `a` on((`cu`.`address_id` = `a`.`address_id`))) join `city` on((`a`.`city_id` = `city`.`city_id`))) join `country` on((`city`.`country_id` = `country`.`country_id`))) +CREATE TABLE `inventory` ( + `inventory_id` mediumint(8) unsigned NOT NULL AUTO_INCREMENT, + `film_id` smallint(5) unsigned NOT NULL, + `store_id` tinyint(3) unsigned NOT NULL, + `last_update` timestamp NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, + PRIMARY KEY (`inventory_id`), + KEY `idx_fk_film_id` (`film_id`), + KEY `idx_store_id_film_id` (`store_id`,`film_id`) +) ENGINE=InnoDB AUTO_INCREMENT=4582 DEFAULT CHARSET=utf8 diff --git a/env/env.go b/env/env.go index 336152bd84cbd7339d49cb36ff7a878e83e86c6f..c407e37ed40e40d491978e4792548b8f097bafeb 100644 --- a/env/env.go +++ b/env/env.go @@ -453,7 +453,7 @@ func (ve VirtualEnv) createTable(rEnv *database.Connector, dbName, tbName string res, err := ve.Query(ddl) if err != nil { // 有可能是用户新建表,因此线上环境查不到 - common.Log.Error("createTable, %s Error : %v", tbName, err) + common.Log.Error("createTable: %s Error : %v", tbName, err) return err } res.Rows.Close() @@ -461,13 +461,9 @@ func (ve VirtualEnv) createTable(rEnv *database.Connector, dbName, tbName string // 泵取数据 if common.Config.Sampling { common.Log.Debug("createTable, Start Sampling data from %s.%s to %s.%s ...", dbName, tbName, ve.DBRef[dbName], tbName) - err := ve.SamplingData(rEnv, tbName) - if err != nil { - common.Log.Error(" (ve VirtualEnv) createTable SamplingData Error: %v", err) - return err - } + err = ve.SamplingData(rEnv, dbName, tbName) } - return nil + return err } // GenTableColumns 为 Rewrite 提供的结构体初始化 diff --git a/env/env_test.go b/env/env_test.go index 6f48a6310b5223c721fdcc07679879e8bc89c76c..0091ae2472db21bfec08015d278469a8fb45d1ae 100644 --- a/env/env_test.go +++ b/env/env_test.go @@ -215,3 +215,60 @@ func TestGenTableColumns(t *testing.T) { } } } + +func TestCreateTable(t *testing.T) { + orgSamplingCondition := common.Config.SamplingCondition + common.Config.SamplingCondition = "LIMIT 1" + + vEnv, rEnv := BuildEnv() + defer vEnv.CleanUp() + // TODO: support VIEW, + tables := []string{ + "actor", + // "actor_info", // VIEW + "address", + "category", + "city", + "country", + "customer", + "customer_list", + "film", + "film_actor", + "film_category", + "film_list", + "film_text", + "inventory", + "language", + "nicer_but_slower_film_list", + "payment", + "rental", + // "sales_by_film_category", // VIEW + // "sales_by_store", // VIEW + "staff", + "staff_list", + "store", + } + for _, table := range tables { + err := vEnv.createTable(rEnv, "sakila", table) + if err != nil { + t.Error(err) + } + } + common.Config.SamplingCondition = orgSamplingCondition +} + +func TestCreateDatabase(t *testing.T) { + vEnv, rEnv := BuildEnv() + defer vEnv.CleanUp() + err := vEnv.createDatabase(rEnv, "sakila") + if err != nil { + t.Error(err) + } + if vEnv.DBHash("sakila") == "sakila" { + t.Errorf("database: sakila rehashed failed!") + } + + if vEnv.DBHash("not_exist_db") != "not_exist_db" { + t.Errorf("database: not_exist_db rehashed!") + } +}