diff --git a/advisor/explainer_test.go b/advisor/explainer_test.go index 0fc8a86a4ff0af53ece1b61bca74afb78998123f..87dc0f45a1c631959bb5eed6fbfa54537ff4fdc8 100644 --- a/advisor/explainer_test.go +++ b/advisor/explainer_test.go @@ -23,6 +23,7 @@ import ( ) func TestDigestExplainText(t *testing.T) { + common.Log.Debug("Enter function: %s", common.GetFunctionName()) var text = `+----+-------------+---------+-------+---------------------------------------------------------+-------------------+---------+---------------------------+------+-------------+ | id | select_type | table | type | possible_keys | key | key_len | ref | rows | Extra | +----+-------------+---------+-------+---------------------------------------------------------+-------------------+---------+---------------------------+------+-------------+ @@ -34,4 +35,5 @@ func TestDigestExplainText(t *testing.T) { if nil != err { t.Fatal(err) } + common.Log.Debug("Exiting function: %s", common.GetFunctionName()) } diff --git a/advisor/heuristic_test.go b/advisor/heuristic_test.go index 8018e72e736c6414120158fd9f359a0efea35828..17132cafb189d2b5e13685743f231e32e12e4c0f 100644 --- a/advisor/heuristic_test.go +++ b/advisor/heuristic_test.go @@ -35,7 +35,7 @@ func TestRuleImplicitAlias(t *testing.T) { "select col from tbl tb where id < 1000", }, { - "do 1", + "select 1", }, } for _, sql := range sqls[0] { diff --git a/advisor/index_test.go b/advisor/index_test.go index 0a355bb636fe28e835294a67623032f9e0f1402d..b3f8ebdb4e7c0288bd1153ec749f1c4221c1db5c 100644 --- a/advisor/index_test.go +++ b/advisor/index_test.go @@ -95,8 +95,8 @@ func TestRuleImplicitConversion(t *testing.T) { } } } - common.Log.Debug("Exiting function: %s", common.GetFunctionName()) common.Config.OnlineDSN = dsn + common.Log.Debug("Exiting function: %s", common.GetFunctionName()) } // JOI.003 & JOI.004 @@ -383,9 +383,11 @@ func TestDuplicateKeyChecker(t *testing.T) { if len(rule) != 0 { t.Errorf("got rules: %s", pretty.Sprint(rule)) } + common.Log.Debug("Exiting function: %s", common.GetFunctionName()) } func TestMergeAdvices(t *testing.T) { + common.Log.Debug("Entering function: %s", common.GetFunctionName()) dst := []IndexInfo{ { Name: "test", @@ -405,6 +407,7 @@ func TestMergeAdvices(t *testing.T) { if len(advise) != 1 { t.Error(pretty.Sprint(advise)) } + common.Log.Debug("Exiting function: %s", common.GetFunctionName()) } func TestIdxColsTypeCheck(t *testing.T) { @@ -450,13 +453,16 @@ func TestIdxColsTypeCheck(t *testing.T) { } } } + common.Log.Debug("Exiting function: %s", common.GetFunctionName()) } func TestGetRandomIndexSuffix(t *testing.T) { + common.Log.Debug("Enter function: %s", common.GetFunctionName()) for i := 0; i < 5; i++ { r := getRandomIndexSuffix() if !(strings.HasPrefix(r, "_") && len(r) == 5) { t.Errorf("getRandomIndexSuffix should return a string with prefix `_` and 5 length, but got:%s", r) } } + common.Log.Debug("Exiting function: %s", common.GetFunctionName()) } diff --git a/advisor/rules.go b/advisor/rules.go index 404e84518a6a4e4ad3aa30f8ca884f66ac88c3df..2abff0156e657e434927a2bc95815b552e583f97 100644 --- a/advisor/rules.go +++ b/advisor/rules.go @@ -58,7 +58,7 @@ func NewQuery4Audit(sql string, options ...string) (*Query4Audit, error) { // vitess 语法解析不上报,以 tidb parser 为主 q.Stmt, vErr = sqlparser.Parse(sql) if vErr != nil { - common.Log.Warn("NewQuery4Audit vitess parse Error: %s", vErr.Error()) + common.Log.Warn("NewQuery4Audit vitess parse Error: %s, Query: %s", vErr.Error(), sql) } // TODO: charset, collation diff --git a/advisor/rules_test.go b/advisor/rules_test.go index b97e912ce9a8065a43a6501f1e660c257e50695f..843351a30d3f2370527041e1973337ef4d573a14 100644 --- a/advisor/rules_test.go +++ b/advisor/rules_test.go @@ -23,29 +23,37 @@ import ( ) func TestListTestSQLs(t *testing.T) { + common.Log.Debug("Entering function: %s", common.GetFunctionName()) err := common.GoldenDiff(func() { ListTestSQLs() }, t.Name(), update) if nil != err { t.Fatal(err) } + common.Log.Debug("Exiting function: %s", common.GetFunctionName()) } func TestListHeuristicRules(t *testing.T) { + common.Log.Debug("Entering function: %s", common.GetFunctionName()) err := common.GoldenDiff(func() { ListHeuristicRules(HeuristicRules) }, t.Name(), update) if nil != err { t.Fatal(err) } + common.Log.Debug("Exiting function: %s", common.GetFunctionName()) } func TestInBlackList(t *testing.T) { + common.Log.Debug("Entering function: %s", common.GetFunctionName()) common.BlackList = []string{"select"} if !InBlackList("select 1") { t.Error("should be true") } + common.Log.Debug("Exiting function: %s", common.GetFunctionName()) } func TestIsIgnoreRule(t *testing.T) { + common.Log.Debug("Entering function: %s", common.GetFunctionName()) common.Config.IgnoreRules = []string{"test"} if !IsIgnoreRule("test") { t.Error("should be true") } + common.Log.Debug("Exiting function: %s", common.GetFunctionName()) } diff --git a/ast/meta_test.go b/ast/meta_test.go index f072978ec88b4cc376b8f8e486bae1e98a15a964..f4817ae37b3222234808a18ed83c6ef082878011 100644 --- a/ast/meta_test.go +++ b/ast/meta_test.go @@ -27,6 +27,7 @@ import ( ) func TestGetTableFromExprs(t *testing.T) { + common.Log.Debug("Entering function: %s", common.GetFunctionName()) tbExprs := sqlparser.TableExprs{ &sqlparser.AliasedTableExpr{ Expr: sqlparser.TableName{ @@ -40,9 +41,11 @@ func TestGetTableFromExprs(t *testing.T) { if tb, ok := meta["db"]; !ok { t.Errorf("no table qualifier, meta: %s", pretty.Sprint(tb)) } + common.Log.Debug("Exiting function: %s", common.GetFunctionName()) } func TestGetParseTableWithStmt(t *testing.T) { + common.Log.Debug("Entering function: %s", common.GetFunctionName()) for _, sql := range common.TestSQLs { fmt.Println(sql) stmt, err := sqlparser.Parse(sql) @@ -53,9 +56,11 @@ func TestGetParseTableWithStmt(t *testing.T) { pretty.Println(meta) fmt.Println() } + common.Log.Debug("Exiting function: %s", common.GetFunctionName()) } func TestFindCondition(t *testing.T) { + common.Log.Debug("Entering function: %s", common.GetFunctionName()) for _, sql := range common.TestSQLs { fmt.Println(sql) stmt, err := sqlparser.Parse(sql) @@ -71,9 +76,11 @@ func TestFindCondition(t *testing.T) { pretty.Println(inEq) fmt.Println() } + common.Log.Debug("Exiting function: %s", common.GetFunctionName()) } func TestFindGroupBy(t *testing.T) { + common.Log.Debug("Entering function: %s", common.GetFunctionName()) sqlList := []string{ "select a from t group by c", } @@ -88,9 +95,11 @@ func TestFindGroupBy(t *testing.T) { pretty.Println(res) fmt.Println() } + common.Log.Debug("Exiting function: %s", common.GetFunctionName()) } func TestFindOrderBy(t *testing.T) { + common.Log.Debug("Entering function: %s", common.GetFunctionName()) sqlList := []string{ "select a from t group by c order by d, c desc", "select a from t group by c order by d desc", @@ -106,9 +115,11 @@ func TestFindOrderBy(t *testing.T) { pretty.Println(res) fmt.Println() } + common.Log.Debug("Exiting function: %s", common.GetFunctionName()) } func TestFindSubquery(t *testing.T) { + common.Log.Debug("Entering function: %s", common.GetFunctionName()) sqlList := []string{ "SELECT * FROM t1 WHERE column1 = (SELECT column1 FROM (SELECT column1 FROM t2) a);", "select column1 from t2", @@ -127,10 +138,11 @@ func TestFindSubquery(t *testing.T) { fmt.Println(len(subquery)) pretty.Println(subquery) } - + common.Log.Debug("Exiting function: %s", common.GetFunctionName()) } func TestFindJoinTable(t *testing.T) { + common.Log.Debug("Entering function: %s", common.GetFunctionName()) sqlList := []string{ "SELECT * FROM t1 LEFT JOIN (t2 CROSS JOIN t3 CROSS JOIN t4) ON (t2.a = t1.a AND t3.b = t1.b AND t4.c = t1.c)", "select ID,name from (select address from customer_list where SID=1 order by phone limit 50,10) a join customer_list l on (a.address=l.address) join city c on (c.city=l.city) order by phone desc;", @@ -151,9 +163,11 @@ func TestFindJoinTable(t *testing.T) { joinMeta := FindJoinTable(stmt, nil) pretty.Println(joinMeta) } + common.Log.Debug("Exiting function: %s", common.GetFunctionName()) } func TestFindJoinCols(t *testing.T) { + common.Log.Debug("Entering function: %s", common.GetFunctionName()) sqlList := []string{ "SELECT * FROM t1 LEFT JOIN (t2 CROSS JOIN t3 CROSS JOIN t4) ON (t2.a = t1.a AND t3.b = t1.b AND t4.c = t1.c)", "select t from a LEFT JOIN b USING (c1, c2, c3)", @@ -175,9 +189,11 @@ func TestFindJoinCols(t *testing.T) { columns := FindJoinCols(stmt) pretty.Println(columns) } + common.Log.Debug("Exiting function: %s", common.GetFunctionName()) } func TestFindJoinColBeWhereEQ(t *testing.T) { + common.Log.Debug("Entering function: %s", common.GetFunctionName()) sqlList := []string{ "select ID,name from (select address from customer_list where SID=1 order by phone limit 50,10) a join customer_list l on (a.address=l.address) join city c on (c.city=l.city) order by phone desc;", "SELECT * FROM t1 LEFT JOIN (t2, t3, t4) ON (t2.a = t1.a AND t3.b = t1.b AND t4.c = t1.c)", @@ -197,9 +213,11 @@ func TestFindJoinColBeWhereEQ(t *testing.T) { columns := FindEQColsInJoinCond(stmt) pretty.Println(columns) } + common.Log.Debug("Exiting function: %s", common.GetFunctionName()) } func TestFindJoinColBeWhereINEQ(t *testing.T) { + common.Log.Debug("Entering function: %s", common.GetFunctionName()) sqlList := []string{ "select ID,name from (select address from customer_list where SID=1 order by phone limit 50,10) a join customer_list l on (a.address=l.address) join city c on (c.city=l.city) order by phone desc;", "SELECT * FROM t1 LEFT JOIN (t2, t3, t4) ON (t2.a = t1.a AND t3.b = t1.b AND t4.c = t1.c)", @@ -219,9 +237,11 @@ func TestFindJoinColBeWhereINEQ(t *testing.T) { columns := FindINEQColsInJoinCond(stmt) pretty.Println(columns) } + common.Log.Debug("Exiting function: %s", common.GetFunctionName()) } func TestFindAllCondition(t *testing.T) { + common.Log.Debug("Entering function: %s", common.GetFunctionName()) sqlList := []string{ "SELECT * FROM t1 LEFT JOIN (t2 CROSS JOIN t3 CROSS JOIN t4) ON (t2.a = t1.a AND t3.b = t1.b AND t4.c = t1.c)", "select t from a LEFT JOIN b USING (c1, c2, c3)", @@ -247,9 +267,11 @@ func TestFindAllCondition(t *testing.T) { columns := FindAllCondition(stmt) pretty.Println(columns) } + common.Log.Debug("Exiting function: %s", common.GetFunctionName()) } func TestFindColumn(t *testing.T) { + common.Log.Debug("Entering function: %s", common.GetFunctionName()) sqlList := []string{ "select col, col2, sum(col1) from tb group by col", "select col from tb group by col,sum(col1)", @@ -266,9 +288,11 @@ func TestFindColumn(t *testing.T) { columns := FindColumn(stmt) pretty.Println(columns) } + common.Log.Debug("Exiting function: %s", common.GetFunctionName()) } func TestFindAllCols(t *testing.T) { + common.Log.Debug("Entering function: %s", common.GetFunctionName()) sqlList := []string{ "select * from tb where a = '1' order by c", "select * from tb where a = '1' group by c", @@ -296,9 +320,11 @@ func TestFindAllCols(t *testing.T) { t.Error(fmt.Errorf("want 'c' got %v", columns)) } } + common.Log.Debug("Exiting function: %s", common.GetFunctionName()) } func TestGetSubqueryDepth(t *testing.T) { + common.Log.Debug("Entering function: %s", common.GetFunctionName()) sqlList := []string{ "SELECT * FROM t1 LEFT JOIN (t2 CROSS JOIN t3 CROSS JOIN t4) ON (t2.a = t1.a AND t3.b = t1.b AND t4.c = t1.c)", "select t from a LEFT JOIN b USING (c1, c2, c3)", @@ -323,9 +349,11 @@ func TestGetSubqueryDepth(t *testing.T) { dep := GetSubqueryDepth(stmt) fmt.Println(dep) } + common.Log.Debug("Exiting function: %s", common.GetFunctionName()) } func TestAppendTable(t *testing.T) { + common.Log.Debug("Entering function: %s", common.GetFunctionName()) sqlList := []string{ "select ID,name from (select address from customer_list where SID=1 order by phone limit 50,10) a join customer_list l on (a.address=l.address) join city c on (c.city=l.city) order by phone desc;", } @@ -367,4 +395,5 @@ func TestAppendTable(t *testing.T) { if meta[""].Table["customer_list"].TableAliases[0] != "l" || meta[""].Table["city"].TableAliases[0] != "c" { t.Error("alias filed\n", pretty.Sprint(meta)) } + common.Log.Debug("Exiting function: %s", common.GetFunctionName()) } diff --git a/ast/pretty_test.go b/ast/pretty_test.go index 2046cad0ab005a8ce951ea1e553fdde9623a44e5..e67d95550331694398b5521468ec027dea92e64e 100644 --- a/ast/pretty_test.go +++ b/ast/pretty_test.go @@ -128,6 +128,7 @@ var TestSqlsPretty = []string{ } func TestPretty(t *testing.T) { + common.Log.Debug("Entering function: %s", common.GetFunctionName()) err := common.GoldenDiff(func() { for _, sql := range append(TestSqlsPretty, common.TestSQLs...) { fmt.Println(sql) @@ -137,9 +138,11 @@ func TestPretty(t *testing.T) { if nil != err { t.Fatal(err) } + common.Log.Debug("Exiting function: %s", common.GetFunctionName()) } func TestIsKeyword(t *testing.T) { + common.Log.Debug("Entering function: %s", common.GetFunctionName()) tks := map[string]bool{ "AGAINST": true, "AUTO_INCREMENT": true, @@ -155,9 +158,11 @@ func TestIsKeyword(t *testing.T) { t.Error("isKeyword:", tk) } } + common.Log.Debug("Exiting function: %s", common.GetFunctionName()) } func TestRemoveComments(t *testing.T) { + common.Log.Debug("Entering function: %s", common.GetFunctionName()) for _, sql := range TestSqlsPretty { stmt, _ := sqlparser.Parse(sql) newSQL := sqlparser.String(stmt) @@ -165,9 +170,11 @@ func TestRemoveComments(t *testing.T) { fmt.Print(newSQL) } } + common.Log.Debug("Exiting function: %s", common.GetFunctionName()) } func TestMysqlEscapeString(t *testing.T) { + common.Log.Debug("Entering function: %s", common.GetFunctionName()) var strs = []map[string]string{ { "input": "abc", @@ -198,4 +205,5 @@ abc`, } } } + common.Log.Debug("Exiting function: %s", common.GetFunctionName()) } diff --git a/ast/rewrite_test.go b/ast/rewrite_test.go index b0758e43bd7ed647804497112a18af589212c38a..361d051ac32ce5842f34f00e2e2bf1f269241f74 100644 --- a/ast/rewrite_test.go +++ b/ast/rewrite_test.go @@ -25,6 +25,7 @@ import ( ) func TestRewrite(t *testing.T) { + common.Log.Debug("Entering function: %s", common.GetFunctionName()) orgTestDSNStatus := common.Config.TestDSN.Disable common.Config.TestDSN.Disable = false testSQL := []map[string]string{ @@ -99,9 +100,11 @@ func TestRewrite(t *testing.T) { } } common.Config.TestDSN.Disable = orgTestDSNStatus + common.Log.Debug("Exiting function: %s", common.GetFunctionName()) } func TestRewriteStar2Columns(t *testing.T) { + common.Log.Debug("Entering function: %s", common.GetFunctionName()) orgTestDSNStatus := common.Config.TestDSN.Disable common.Config.TestDSN.Disable = false testSQL := []map[string]string{ @@ -131,9 +134,11 @@ func TestRewriteStar2Columns(t *testing.T) { } } common.Config.TestDSN.Disable = orgTestDSNStatus + common.Log.Debug("Exiting function: %s", common.GetFunctionName()) } func TestRewriteInsertColumns(t *testing.T) { + common.Log.Debug("Entering function: %s", common.GetFunctionName()) testSQL := []map[string]string{ { "input": `insert into film values(1,2,3,4,5)`, @@ -173,9 +178,11 @@ func TestRewriteInsertColumns(t *testing.T) { t.Errorf("want: %s\ngot: %s", sql["output"], rw.NewSQL) } } + common.Log.Debug("Exiting function: %s", common.GetFunctionName()) } func TestRewriteHaving(t *testing.T) { + common.Log.Debug("Entering function: %s", common.GetFunctionName()) testSQL := []map[string]string{ { "input": `SELECT state, COUNT(*) FROM Drivers GROUP BY state HAVING state IN ('GA', 'TX') ORDER BY state`, @@ -196,9 +203,11 @@ func TestRewriteHaving(t *testing.T) { t.Errorf("want: %s\ngot: %s", sql["output"], rw.NewSQL) } } + common.Log.Debug("Exiting function: %s", common.GetFunctionName()) } func TestRewriteAddOrderByNull(t *testing.T) { + common.Log.Debug("Entering function: %s", common.GetFunctionName()) testSQL := []map[string]string{ { "input": "SELECT sum(col1) FROM tbl GROUP BY col", @@ -211,9 +220,11 @@ func TestRewriteAddOrderByNull(t *testing.T) { t.Errorf("want: %s\ngot: %s", sql["output"], rw.NewSQL) } } + common.Log.Debug("Exiting function: %s", common.GetFunctionName()) } func TestRewriteRemoveDMLOrderBy(t *testing.T) { + common.Log.Debug("Entering function: %s", common.GetFunctionName()) testSQL := []map[string]string{ { "input": "DELETE FROM tbl WHERE col1=1 ORDER BY col", @@ -230,9 +241,11 @@ func TestRewriteRemoveDMLOrderBy(t *testing.T) { t.Errorf("want: %s\ngot: %s", sql["output"], rw.NewSQL) } } + common.Log.Debug("Exiting function: %s", common.GetFunctionName()) } func TestRewriteGroupByConst(t *testing.T) { + common.Log.Debug("Entering function: %s", common.GetFunctionName()) testSQL := []map[string]string{ { "input": "select 1;", @@ -259,9 +272,11 @@ func TestRewriteGroupByConst(t *testing.T) { t.Errorf("want: %s\ngot: %s", sql["output"], rw.NewSQL) } } + common.Log.Debug("Exiting function: %s", common.GetFunctionName()) } func TestRewriteStandard(t *testing.T) { + common.Log.Debug("Entering function: %s", common.GetFunctionName()) testSQL := []map[string]string{ { "input": "SELECT sum(col1) FROM tbl GROUP BY 1;", @@ -274,9 +289,11 @@ func TestRewriteStandard(t *testing.T) { t.Errorf("want: %s\ngot: %s", sql["output"], rw.NewSQL) } } + common.Log.Debug("Exiting function: %s", common.GetFunctionName()) } func TestRewriteCountStar(t *testing.T) { + common.Log.Debug("Entering function: %s", common.GetFunctionName()) testSQL := []map[string]string{ { "input": "SELECT count(col) FROM tbl GROUP BY 1;", @@ -293,9 +310,11 @@ func TestRewriteCountStar(t *testing.T) { t.Errorf("want: %s\ngot: %s", sql["output"], rw.NewSQL) } } + common.Log.Debug("Exiting function: %s", common.GetFunctionName()) } func TestRewriteInnoDB(t *testing.T) { + common.Log.Debug("Entering function: %s", common.GetFunctionName()) testSQL := []map[string]string{ { "input": "CREATE TABLE t1(id bigint(20) NOT NULL AUTO_INCREMENT);", @@ -312,9 +331,11 @@ func TestRewriteInnoDB(t *testing.T) { t.Errorf("want: %s\ngot: %s", sql["output"], rw.NewSQL) } } + common.Log.Debug("Exiting function: %s", common.GetFunctionName()) } func TestRewriteAutoIncrement(t *testing.T) { + common.Log.Debug("Entering function: %s", common.GetFunctionName()) testSQL := []map[string]string{ { "input": "CREATE TABLE t1(id bigint(20) NOT NULL AUTO_INCREMENT) ENGINE=InnoDB AUTO_INCREMENT=123802;", @@ -331,9 +352,11 @@ func TestRewriteAutoIncrement(t *testing.T) { t.Errorf("want: %s\ngot: %s", sql["output"], rw.NewSQL) } } + common.Log.Debug("Exiting function: %s", common.GetFunctionName()) } func TestRewriteIntWidth(t *testing.T) { + common.Log.Debug("Entering function: %s", common.GetFunctionName()) testSQL := []map[string]string{ { "input": "CREATE TABLE t1(id bigint(10) NOT NULL AUTO_INCREMENT) ENGINE=InnoDB AUTO_INCREMENT=123802;", @@ -358,9 +381,11 @@ func TestRewriteIntWidth(t *testing.T) { t.Errorf("want: %s\ngot: %s", sql["output"], rw.NewSQL) } } + common.Log.Debug("Exiting function: %s", common.GetFunctionName()) } func TestRewriteAlwaysTrue(t *testing.T) { + common.Log.Debug("Entering function: %s", common.GetFunctionName()) testSQL := []map[string]string{ { "input": "SELECT count(col) FROM tbl where 1=1;", @@ -427,10 +452,12 @@ func TestRewriteAlwaysTrue(t *testing.T) { t.Errorf("want: %s\ngot: %s", sql["output"], rw.NewSQL) } } + common.Log.Debug("Exiting function: %s", common.GetFunctionName()) } // TODO: func TestRewriteSubQuery2Join(t *testing.T) { + common.Log.Debug("Entering function: %s", common.GetFunctionName()) orgTestDSNStatus := common.Config.TestDSN.Disable common.Config.TestDSN.Disable = true testSQL := []map[string]string{ @@ -458,9 +485,11 @@ func TestRewriteSubQuery2Join(t *testing.T) { } } common.Config.TestDSN.Disable = orgTestDSNStatus + common.Log.Debug("Exiting function: %s", common.GetFunctionName()) } func TestRewriteDML2Select(t *testing.T) { + common.Log.Debug("Entering function: %s", common.GetFunctionName()) testSQL := []map[string]string{ { "input": "DELETE city, country FROM city INNER JOIN country using (country_id) WHERE city.city_id = 1;", @@ -513,9 +542,11 @@ func TestRewriteDML2Select(t *testing.T) { t.Errorf("want: %s\ngot: %s", sql["output"], rw.NewSQL) } } + common.Log.Debug("Exiting function: %s", common.GetFunctionName()) } func TestRewriteDistinctStar(t *testing.T) { + common.Log.Debug("Entering function: %s", common.GetFunctionName()) testSQL := []map[string]string{ { "input": `SELECT DISTINCT * FROM film;`, @@ -549,9 +580,11 @@ func TestRewriteDistinctStar(t *testing.T) { t.Errorf("want: %s\ngot: %s", sql["output"], rw.NewSQL) } } + common.Log.Debug("Exiting function: %s", common.GetFunctionName()) } func TestMergeAlterTables(t *testing.T) { + common.Log.Debug("Entering function: %s", common.GetFunctionName()) sqls := []string{ // ADD|DROP INDEX // TODO: PRIMARY KEY, [UNIQUE|FULLTEXT|SPATIAL] INDEX @@ -602,9 +635,11 @@ func TestMergeAlterTables(t *testing.T) { if err != nil { t.Error(err) } + common.Log.Debug("Exiting function: %s", common.GetFunctionName()) } func TestRewriteUnionAll(t *testing.T) { + common.Log.Debug("Entering function: %s", common.GetFunctionName()) testSQL := []map[string]string{ { "input": `select country_id from city union select country_id from country;`, @@ -617,8 +652,10 @@ func TestRewriteUnionAll(t *testing.T) { t.Errorf("want: %s\ngot: %s", sql["output"], rw.NewSQL) } } + common.Log.Debug("Exiting function: %s", common.GetFunctionName()) } func TestRewriteTruncate(t *testing.T) { + common.Log.Debug("Entering function: %s", common.GetFunctionName()) testSQL := []map[string]string{ { "input": `delete from tbl;`, @@ -631,9 +668,11 @@ func TestRewriteTruncate(t *testing.T) { t.Errorf("want: %s\ngot: %s", sql["output"], rw.NewSQL) } } + common.Log.Debug("Exiting function: %s", common.GetFunctionName()) } func TestRewriteOr2In(t *testing.T) { + common.Log.Debug("Entering function: %s", common.GetFunctionName()) testSQL := []map[string]string{ { "input": `select country_id from city where country_id = 1 or country_id = 2 or country_id = 3;`, @@ -672,9 +711,11 @@ func TestRewriteOr2In(t *testing.T) { t.Errorf("want: %s\ngot: %s", sql["output"], rw.NewSQL) } } + common.Log.Debug("Exiting function: %s", common.GetFunctionName()) } func TestRmParenthesis(t *testing.T) { + common.Log.Debug("Entering function: %s", common.GetFunctionName()) testSQL := []map[string]string{ { "input": `select country_id from city where (country_id = 1);`, @@ -699,13 +740,16 @@ func TestRmParenthesis(t *testing.T) { t.Errorf("want: %s\ngot: %s", sql["output"], rw.NewSQL) } } + common.Log.Debug("Exiting function: %s", common.GetFunctionName()) } func TestListRewriteRules(t *testing.T) { + common.Log.Debug("Entering function: %s", common.GetFunctionName()) err := common.GoldenDiff(func() { ListRewriteRules(RewriteRules) }, t.Name(), update) if err != nil { t.Error(err) } + common.Log.Debug("Exiting function: %s", common.GetFunctionName()) } diff --git a/ast/token_test.go b/ast/token_test.go index 97d62542362106de893f45674a2dc788d332e164..939cd44857d3529e473b4ba7519f8ad9eaa0d00d 100644 --- a/ast/token_test.go +++ b/ast/token_test.go @@ -26,6 +26,7 @@ import ( ) func TestTokenize(t *testing.T) { + common.Log.Debug("Entering function: %s", common.GetFunctionName()) err := common.GoldenDiff(func() { for _, sql := range common.TestSQLs { fmt.Println(sql) @@ -35,9 +36,11 @@ func TestTokenize(t *testing.T) { if nil != err { t.Fatal(err) } + common.Log.Debug("Exiting function: %s", common.GetFunctionName()) } func TestTokenizer(t *testing.T) { + common.Log.Debug("Entering function: %s", common.GetFunctionName()) sqls := []string{ "select c1,c2,c3 from t1,t2 join t3 on t1.c1=t2.c1 and t1.c3=t3.c1 where id>1000", "select sourcetable, if(f.lastcontent = ?, f.lastupdate, f.lastcontent) as lastactivity, f.totalcount as activity, type.class as type, (f.nodeoptions & ?) as nounsubscribe from node as f inner join contenttype as type on type.contenttypeid = f.contenttypeid inner join subscribed as sd on sd.did = f.nodeid and sd.userid = ? union all select f.name as title, f.userid as keyval, ? as sourcetable, ifnull(f.lastpost, f.joindate) as lastactivity, f.posts as activity, ? as type, ? as nounsubscribe from user as f inner join userlist as ul on ul.relationid = f.userid and ul.userid = ? where ul.type = ? and ul.aq = ? order by title limit ?", @@ -57,9 +60,11 @@ func TestTokenizer(t *testing.T) { if nil != err { t.Fatal(err) } + common.Log.Debug("Exiting function: %s", common.GetFunctionName()) } func TestGetQuotedString(t *testing.T) { + common.Log.Debug("Entering function: %s", common.GetFunctionName()) var str = []string{ `"hello world"`, "`hello world`", @@ -82,9 +87,11 @@ func TestGetQuotedString(t *testing.T) { if nil != err { t.Fatal(err) } + common.Log.Debug("Exiting function: %s", common.GetFunctionName()) } func TestCompress(t *testing.T) { + common.Log.Debug("Entering function: %s", common.GetFunctionName()) err := common.GoldenDiff(func() { for _, sql := range common.TestSQLs { fmt.Println(sql) @@ -94,10 +101,11 @@ func TestCompress(t *testing.T) { if nil != err { t.Fatal(err) } - + common.Log.Debug("Exiting function: %s", common.GetFunctionName()) } func TestFormat(t *testing.T) { + common.Log.Debug("Entering function: %s", common.GetFunctionName()) err := common.GoldenDiff(func() { for _, sql := range common.TestSQLs { fmt.Println(sql) @@ -107,9 +115,11 @@ func TestFormat(t *testing.T) { if nil != err { t.Fatal(err) } + common.Log.Debug("Exiting function: %s", common.GetFunctionName()) } func TestSplitStatement(t *testing.T) { + common.Log.Debug("Entering function: %s", common.GetFunctionName()) bufs := [][]byte{ []byte("select * from test;hello"), []byte("select 'asd;fas', col from test;hello"), @@ -181,9 +191,11 @@ select col from tb; if nil != err { t.Fatal(err) } + common.Log.Debug("Exiting function: %s", common.GetFunctionName()) } func TestLeftNewLines(t *testing.T) { + common.Log.Debug("Entering function: %s", common.GetFunctionName()) bufs := [][]byte{ []byte(` select * from test;hello`), @@ -200,9 +212,11 @@ func TestLeftNewLines(t *testing.T) { if nil != err { t.Fatal(err) } + common.Log.Debug("Exiting function: %s", common.GetFunctionName()) } func TestNewLines(t *testing.T) { + common.Log.Debug("Entering function: %s", common.GetFunctionName()) bufs := [][]byte{ []byte(` select * from test;hello`), @@ -219,4 +233,5 @@ func TestNewLines(t *testing.T) { if nil != err { t.Fatal(err) } + common.Log.Debug("Exiting function: %s", common.GetFunctionName()) } diff --git a/common/config.go b/common/config.go index b74f9728668df5a7d013f655cb8949addcaa53ba..e796fbfd59c6c6074781ed4b7bc4331cd296a3cb 100644 --- a/common/config.go +++ b/common/config.go @@ -56,6 +56,7 @@ type Configuration struct { OnlySyntaxCheck bool `yaml:"only-syntax-check"` // 只做语法检查不输出优化建议 SamplingStatisticTarget int `yaml:"sampling-statistic-target"` // 数据采样因子,对应 PostgreSQL 的 default_statistics_target Sampling bool `yaml:"sampling"` // 数据采样开关 + SamplingCondition string `yaml:"sampling-condition"` // 指定采样条件,如:WHERE xxx LIMIT xxx; Profiling bool `yaml:"profiling"` // 在开启数据采样的情况下,在测试环境执行进行profile Trace bool `yaml:"trace"` // 在开启数据采样的情况下,在测试环境执行进行Trace Explain bool `yaml:"explain"` // Explain开关 @@ -506,6 +507,7 @@ func readCmdFlags() error { explain := flag.Bool("explain", Config.Explain, "Explain, 是否开启Explain执行计划分析") sampling := flag.Bool("sampling", Config.Sampling, "Sampling, 数据采样开关") samplingStatisticTarget := flag.Int("sampling-statistic-target", Config.SamplingStatisticTarget, "SamplingStatisticTarget, 数据采样因子,对应 PostgreSQL 的 default_statistics_target") + samplingCondition := flag.String("sampling-condition", Config.SamplingCondition, "SamplingCondition, 数据采样条件,如: WHERE xxx LIMIT xxx") delimiter := flag.String("delimiter", Config.Delimiter, "Delimiter, SQL分隔符") // +++++++++++++++日志相关+++++++++++++++++ logLevel := flag.Int("log-level", Config.LogLevel, "LogLevel, 日志级别, [0:Emergency, 1:Alert, 2:Critical, 3:Error, 4:Warning, 5:Notice, 6:Informational, 7:Debug]") @@ -585,6 +587,7 @@ func readCmdFlags() error { Config.Explain = *explain Config.Sampling = *sampling Config.SamplingStatisticTarget = *samplingStatisticTarget + Config.SamplingCondition = *samplingCondition Config.LogLevel = *logLevel if strings.HasPrefix(*logOutput, "/") { diff --git a/common/config_test.go b/common/config_test.go index 48f5a7acdcda34fc26ad604e69d310c1c9244b34..41858401b430aec12d98fc7a1a0a95ca76555d85 100644 --- a/common/config_test.go +++ b/common/config_test.go @@ -26,6 +26,10 @@ import ( var update = flag.Bool("update", false, "update .golden files") +func init() { + BaseDir = DevPath +} + func TestParseConfig(t *testing.T) { err := ParseConfig("") if err != nil { @@ -37,7 +41,7 @@ func TestReadConfigFile(t *testing.T) { if Config == nil { Config = new(Configuration) } - Config.readConfigFile("../soar.yaml") + Config.readConfigFile(DevPath + "/soar.yaml") } func TestParseDSN(t *testing.T) { diff --git a/common/logger_test.go b/common/logger_test.go index 66e2706a7f011f19edac80715fd18283617a6a2a..0add12b16d3cd55c15dbe8fcc9e8976b42b09a5c 100644 --- a/common/logger_test.go +++ b/common/logger_test.go @@ -21,15 +21,11 @@ import ( "testing" ) -func init() { - BaseDir = DevPath -} - func TestLogger(t *testing.T) { - Log.Info("info") - Log.Debug("debug") - Log.Warning("warning") - Log.Error("error") + Log.Info("TestLogger_Info") + Log.Debug("TestLogger_Debug") + Log.Warning("TestLogger_Warning") + Log.Error("Warning_Error") } func TestCaller(t *testing.T) { @@ -47,7 +43,7 @@ func TestGetFunctionName(t *testing.T) { } func TestIfError(t *testing.T) { - err := errors.New("test") + err := errors.New("TestIfError") LogIfError(err, "") LogIfError(err, "func %s", "func_test") } diff --git a/database/explain_test.go b/database/explain_test.go index 9366a18ac3afa1cb864100d9b308207aca021e28..de4900d54e40b82861cf23073b027d06129d663d 100644 --- a/database/explain_test.go +++ b/database/explain_test.go @@ -2332,6 +2332,7 @@ possible_keys: idx_fk_country_id,idx_country_id_city,idx_all,idx_other } func TestExplain(t *testing.T) { + common.Log.Debug("Enter function: %s", common.GetFunctionName()) // TraditionalFormatExplain for idx, sql := range sqls { exp, err := connTest.Explain(sql, TraditionalExplainType, TraditionalFormatExplain) @@ -2350,9 +2351,11 @@ func TestExplain(t *testing.T) { pretty.Println("No.:", idx, "\nOld: ", sql, "\nNew: ", exp.SQL) pretty.Println(exp) } + common.Log.Debug("Exiting function: %s", common.GetFunctionName()) } func TestParseExplainText(t *testing.T) { + common.Log.Debug("Enter function: %s", common.GetFunctionName()) for _, content := range exp { pretty.Println(RemoveSQLComments(content)) pretty.Println(ParseExplainText(content)) @@ -2364,26 +2367,32 @@ func TestParseExplainText(t *testing.T) { pretty.Println(explainInfo) fmt.Println(err) */ + common.Log.Debug("Exiting function: %s", common.GetFunctionName()) } func TestFindTablesInJson(t *testing.T) { + common.Log.Debug("Enter function: %s", common.GetFunctionName()) idx := 9 for _, j := range exp[idx : idx+1] { pretty.Println(j) findTablesInJSON(j, 0) } pretty.Println(len(explainJSONTables), explainJSONTables) + common.Log.Debug("Exiting function: %s", common.GetFunctionName()) } func TestFormatJsonIntoTraditional(t *testing.T) { + common.Log.Debug("Enter function: %s", common.GetFunctionName()) idx := 11 for _, j := range exp[idx : idx+1] { pretty.Println(j) pretty.Println(FormatJSONIntoTraditional(j)) } + common.Log.Debug("Exiting function: %s", common.GetFunctionName()) } func TestPrintMarkdownExplainTable(t *testing.T) { + common.Log.Debug("Enter function: %s", common.GetFunctionName()) expInfo, err := connTest.Explain("select 1", TraditionalExplainType, TraditionalFormatExplain) if err != nil { t.Error(err) @@ -2395,9 +2404,11 @@ func TestPrintMarkdownExplainTable(t *testing.T) { if err != nil { t.Error(err) } + common.Log.Debug("Exiting function: %s", common.GetFunctionName()) } func TestExplainInfoTranslator(t *testing.T) { + common.Log.Debug("Enter function: %s", common.GetFunctionName()) expInfo, err := connTest.Explain("select 1", TraditionalExplainType, TraditionalFormatExplain) if err != nil { t.Error(err) @@ -2408,9 +2419,11 @@ func TestExplainInfoTranslator(t *testing.T) { if err != nil { t.Error(err) } + common.Log.Debug("Exiting function: %s", common.GetFunctionName()) } func TestMySQLExplainWarnings(t *testing.T) { + common.Log.Debug("Enter function: %s", common.GetFunctionName()) expInfo, err := connTest.Explain("select 1", TraditionalExplainType, TraditionalFormatExplain) if err != nil { t.Error(err) @@ -2421,9 +2434,11 @@ func TestMySQLExplainWarnings(t *testing.T) { if err != nil { t.Error(err) } + common.Log.Debug("Exiting function: %s", common.GetFunctionName()) } func TestMySQLExplainQueryCost(t *testing.T) { + common.Log.Debug("Enter function: %s", common.GetFunctionName()) expInfo, err := connTest.Explain("select 1", TraditionalExplainType, TraditionalFormatExplain) if err != nil { t.Error(err) @@ -2434,19 +2449,24 @@ func TestMySQLExplainQueryCost(t *testing.T) { if err != nil { t.Error(err) } + common.Log.Debug("Exiting function: %s", common.GetFunctionName()) } func TestSupportExplainWrite(t *testing.T) { + common.Log.Debug("Enter function: %s", common.GetFunctionName()) _, err := connTest.supportExplainWrite() if err != nil { t.Error(err) } + common.Log.Debug("Exiting function: %s", common.GetFunctionName()) } func TestExplainAbleSQL(t *testing.T) { + common.Log.Debug("Enter function: %s", common.GetFunctionName()) for _, sql := range sqls { if _, err := connTest.explainAbleSQL(sql); err != nil { t.Errorf("SQL: %s, not explain able", sql) } } + common.Log.Debug("Exiting function: %s", common.GetFunctionName()) } diff --git a/database/mysql.go b/database/mysql.go index ce7cbd0a0045e3eff1b0d07391e474031f4165be..f3a48c06ce3a1af8f7a33d5a4b0ae16d03a1a7af 100644 --- a/database/mysql.go +++ b/database/mysql.go @@ -104,8 +104,9 @@ func (db *Connector) Query(sql string, params ...interface{}) (QueryResult, erro if common.Config.ShowLastQueryCost { cost, err := db.Conn.Query("SHOW SESSION STATUS LIKE 'last_query_cost'") if err == nil { + var varName string if cost.Next() { - err = cost.Scan(res.QueryCost) + err = cost.Scan(&varName, &res.QueryCost) common.LogIfError(err, "") } if err := cost.Close(); err != nil { diff --git a/database/mysql_test.go b/database/mysql_test.go index 9f4827e77dbbbfdbb2011f4ddca4ff237914aaf9..526cec054952d345305ded8f5bc28e0f9bc2b376 100644 --- a/database/mysql_test.go +++ b/database/mysql_test.go @@ -48,6 +48,7 @@ func init() { } func TestQuery(t *testing.T) { + common.Log.Debug("Entering function: %s", common.GetFunctionName()) res, err := connTest.Query("select 0") if err != nil { t.Error(err.Error()) @@ -64,9 +65,11 @@ func TestQuery(t *testing.T) { } res.Rows.Close() // TODO: timeout test + common.Log.Debug("Exiting function: %s", common.GetFunctionName()) } func TestColumnCardinality(t *testing.T) { + common.Log.Debug("Entering function: %s", common.GetFunctionName()) orgDatabase := connTest.Database connTest.Database = "sakila" a := connTest.ColumnCardinality("actor", "first_name") @@ -74,9 +77,11 @@ func TestColumnCardinality(t *testing.T) { t.Error("sakila.actor.first_name cardinality should in [0, 1], now it's", a) } connTest.Database = orgDatabase + common.Log.Debug("Exiting function: %s", common.GetFunctionName()) } func TestDangerousSQL(t *testing.T) { + common.Log.Debug("Entering function: %s", common.GetFunctionName()) testCase := map[string]bool{ "select * from tb;delete from tb;": true, "show database;": false, @@ -91,9 +96,11 @@ func TestDangerousSQL(t *testing.T) { t.Errorf("SQL:%s got:%v want:%v", sql, got, want) } } + common.Log.Debug("Exiting function: %s", common.GetFunctionName()) } func TestWarningsAndQueryCost(t *testing.T) { + common.Log.Debug("Entering function: %s", common.GetFunctionName()) common.Config.ShowWarnings = true common.Config.ShowLastQueryCost = true res, err := connTest.Query("explain select * from sakila.film") @@ -111,17 +118,21 @@ func TestWarningsAndQueryCost(t *testing.T) { res.Warning.Close() fmt.Println(res.QueryCost, err) } + common.Log.Debug("Exiting function: %s", common.GetFunctionName()) } func TestVersion(t *testing.T) { + common.Log.Debug("Entering function: %s", common.GetFunctionName()) version, err := connTest.Version() if err != nil { t.Error(err.Error()) } fmt.Println(version) + common.Log.Debug("Exiting function: %s", common.GetFunctionName()) } func TestRemoveSQLComments(t *testing.T) { + common.Log.Debug("Entering function: %s", common.GetFunctionName()) SQLs := []string{ `-- comment`, `--`, @@ -140,9 +151,11 @@ comment*/`, if err != nil { t.Error(err) } + common.Log.Debug("Exiting function: %s", common.GetFunctionName()) } func TestSingleIntValue(t *testing.T) { + common.Log.Debug("Entering function: %s", common.GetFunctionName()) val, err := connTest.SingleIntValue("read_only") if err != nil { t.Error(err) @@ -150,13 +163,16 @@ func TestSingleIntValue(t *testing.T) { if val < 0 { t.Error("SingleIntValue, return should large than zero") } + common.Log.Debug("Exiting function: %s", common.GetFunctionName()) } func TestIsView(t *testing.T) { + common.Log.Debug("Entering function: %s", common.GetFunctionName()) originalDatabase := connTest.Database connTest.Database = "sakila" if !connTest.IsView("actor_info") { t.Error("actor_info should be a VIEW") } connTest.Database = originalDatabase + common.Log.Debug("Exiting function: %s", common.GetFunctionName()) } diff --git a/database/privilege_test.go b/database/privilege_test.go index 3690fd3dc5c99ac6fc51051f535e63fd4c8489db..05ac77f61fb02a128d05d198c1ddc9fe4967047b 100644 --- a/database/privilege_test.go +++ b/database/privilege_test.go @@ -16,9 +16,14 @@ package database -import "testing" +import ( + "testing" + + "github.com/XiaoMi/soar/common" +) func TestCurrentUser(t *testing.T) { + common.Log.Debug("Entering function: %s", common.GetFunctionName()) user, host, err := connTest.CurrentUser() if err != nil { t.Error(err.Error()) @@ -26,16 +31,21 @@ func TestCurrentUser(t *testing.T) { if user != "root" || host != "%" { t.Errorf("Want user: root, host: %%. Get user: %s, host: %s", user, host) } + common.Log.Debug("Exiting function: %s", common.GetFunctionName()) } func TestHasSelectPrivilege(t *testing.T) { + common.Log.Debug("Entering function: %s", common.GetFunctionName()) if !connTest.HasSelectPrivilege() { t.Errorf("DSN: %s, User: %s, should has select privilege", connTest.Addr, connTest.User) } + common.Log.Debug("Exiting function: %s", common.GetFunctionName()) } func TestHasAllPrivilege(t *testing.T) { + common.Log.Debug("Entering function: %s", common.GetFunctionName()) if !connTest.HasAllPrivilege() { t.Errorf("DSN: %s, User: %s, should has all privilege", connTest.Addr, connTest.User) } + common.Log.Debug("Exiting function: %s", common.GetFunctionName()) } diff --git a/database/profiling_test.go b/database/profiling_test.go index 78d28fef46876fb1aca1cb2155cf74bab0886bae..d1e71ee26dbf0ee07132eed398334ac828815bd0 100644 --- a/database/profiling_test.go +++ b/database/profiling_test.go @@ -19,21 +19,26 @@ package database import ( "testing" + "github.com/XiaoMi/soar/common" "github.com/kr/pretty" ) func TestProfiling(t *testing.T) { + common.Log.Debug("Entering function: %s", common.GetFunctionName()) rows, err := connTest.Profiling("select 1") if err != nil { t.Error(err) } pretty.Println(rows) + common.Log.Debug("Exiting function: %s", common.GetFunctionName()) } func TestFormatProfiling(t *testing.T) { + common.Log.Debug("Entering function: %s", common.GetFunctionName()) res, err := connTest.Profiling("select 1") if err != nil { t.Error(err) } pretty.Println(FormatProfiling(res)) + common.Log.Debug("Exiting function: %s", common.GetFunctionName()) } diff --git a/database/sampling.go b/database/sampling.go index fab82402c78742e241f2d011198a74f0cdd605b8..5ccbf55e92d43df2e9dc6f03036b54328c9b0bc2 100644 --- a/database/sampling.go +++ b/database/sampling.go @@ -17,11 +17,15 @@ package database import ( - "database/sql" "fmt" + "time" - "github.com/XiaoMi/soar/common" "strings" + + "database/sql" + + "github.com/XiaoMi/soar/common" + "github.com/ziutek/mymysql/mysql" ) /*-------------------- @@ -44,99 +48,125 @@ import ( *-------------------- */ -// SamplingData 将数据从Remote拉取到 db 中 -func (db *Connector) SamplingData(remote *Connector, tables ...string) error { +// SamplingData 将数据从 onlineConn 拉取到 db 中 +func (db *Connector) SamplingData(onlineConn *Connector, database string, tables ...string) error { + var err error + if database == db.Database { + return fmt.Errorf("SamplingData the same database, From: %s/%s, To: %s/%s", onlineConn.Addr, database, db.Addr, db.Database) + } + // 计算需要泵取的数据量 wantRowsCount := 300 * common.Config.SamplingStatisticTarget - // 设置数据采样单条 SQL 中 value 的数量 - // 该数值越大,在内存中缓存的data就越多,但相对的,插入时速度就越快 - maxValCount := 200 - for _, table := range tables { // 表类型检查 - if remote.IsView(table) { - return nil - } - - tableStatus, err := remote.ShowTableStatus(table) - if err != nil { - return err - } - - if len(tableStatus.Rows) == 0 { - common.Log.Info("SamplingData, Table %s with no data, stop sampling", table) + if onlineConn.IsView(table) { return nil } - tableRows := tableStatus.Rows[0].Rows - if tableRows == 0 { - common.Log.Info("SamplingData, Table %s with no data, stop sampling", table) - return nil + // generate where condition + var where string + if common.Config.SamplingCondition == "" { + tableStatus, err := onlineConn.ShowTableStatus(table) + if err != nil { + return err + } + + if len(tableStatus.Rows) == 0 { + common.Log.Info("SamplingData, Table %s with no data, stop sampling", table) + return nil + } + + tableRows := tableStatus.Rows[0].Rows + if tableRows == 0 { + common.Log.Info("SamplingData, Table %s with no data, stop sampling", table) + return nil + } + + factor := float64(wantRowsCount) / float64(tableRows) + common.Log.Debug("SamplingData, tableRows: %d, wantRowsCount: %d, factor: %f", tableRows, wantRowsCount, factor) + where = fmt.Sprintf("WHERE RAND() <= %f LIMIT %d", factor, wantRowsCount) + if factor >= 1 { + where = "" + } + } else { + where = common.Config.SamplingCondition } - factor := float64(wantRowsCount) / float64(tableRows) - common.Log.Debug("SamplingData, tableRows: %d, wantRowsCount: %d, factor: %f", tableRows, wantRowsCount, factor) - - err = startSampling(remote.Conn, db.Conn, db.Database, table, factor, wantRowsCount, maxValCount) - if err != nil { - common.Log.Error("(db *Connector) SamplingData Error : %v", err) - } + err = db.startSampling(onlineConn.Conn, database, table, where) } - return nil + return err } // startSampling sampling data from OnlineDSN to TestDSN -// 因为涉及到的数据量问题,所以泵取与插入时同时进行的 -// TODO: 加 ref link -func startSampling(conn, localConn *sql.DB, database, table string, factor float64, wants, maxValCount int) error { - // generate where condition - where := fmt.Sprintf("WHERE RAND() <= %f", factor) - if factor >= 1 { - where = "" - } - - res, err := conn.Query(fmt.Sprintf("SELECT * FROM `%s`.`%s` %s LIMIT %d;", database, table, where, wants)) +func (db *Connector) startSampling(onlineConn *sql.DB, database, table string, where string) error { + samplingQuery := fmt.Sprintf("SELECT * FROM `%s`.`%s` %s", database, table, where) + common.Log.Debug("startSampling with Query: %s", samplingQuery) + res, err := onlineConn.Query(samplingQuery) if err != nil { return err } - // column info + // columns list columns, err := res.Columns() if err != nil { return err } - row := make(map[string][]byte, len(columns)) + row := make([][]byte, len(columns)) tableFields := make([]interface{}, 0) - for _, col := range columns { - if _, ok := row[col]; ok { - tableFields = append(tableFields, row[col]) - } + for i := range columns { + tableFields = append(tableFields, &row[i]) + } + columnTypes, err := res.ColumnTypes() + if err != nil { + return err } // sampling data - var valuesStr string - var values []string + var valuesCount int + var valuesStr []string + maxValuesCount := 200 // one time insert values count, TODO: config able columnsStr := "`" + strings.Join(columns, "`,`") + "`" for res.Next() { + var values []string res.Scan(tableFields...) - for _, val := range row { - values = append(values, fmt.Sprintf(`unhex("%s")`, fmt.Sprintf("%x", val))) + for i, val := range row { + if val == nil { + values = append(values, "NULL") + } else { + switch columnTypes[i].DatabaseTypeName() { + case "TIMESTAMP", "DATETIME": + t, err := time.Parse(time.RFC3339, string(val)) + common.LogIfWarn(err, "") + values = append(values, fmt.Sprintf(`"%s"`, mysql.TimeString(t))) + default: + values = append(values, fmt.Sprintf(`unhex("%s")`, fmt.Sprintf("%x", val))) + } + } + } + valuesStr = append(valuesStr, "("+strings.Join(values, `,`)+")") + valuesCount++ + if maxValuesCount <= valuesCount { + err = db.doSampling(table, columnsStr, strings.Join(valuesStr, `,`)) + if err != nil { + break + } + values = make([]string, 0) + valuesStr = make([]string, 0) + valuesCount = 0 } - valuesStr = fmt.Sprintf(`(%s)`, strings.Join(values, `,`)) - doSampling(localConn, database, table, columnsStr, valuesStr) } res.Close() - return nil + return err } -// 将泵取的数据转换成Insert语句并在数据库中执行 -func doSampling(conn *sql.DB, dbName, table, colDef, values string) { - query := fmt.Sprintf("INSERT INTO `%s`.`%s` (%s) VALUES %s;", dbName, table, - colDef, values) - - _, err := conn.Exec(query) - if err != nil { - common.Log.Error("doSampling Error from %s.%s: %v", dbName, table, err) +// 将泵取的数据转换成 insert 语句并在 testConn 数据库中执行 +func (db *Connector) doSampling(table, colDef, values string) error { + // db.Database is hashed database name + query := fmt.Sprintf("INSERT INTO `%s`.`%s` (%s) VALUES %s;", db.Database, table, colDef, values) + res, err := db.Query(query) + if res.Rows != nil { + res.Rows.Close() } + return err } diff --git a/database/sampling_test.go b/database/sampling_test.go deleted file mode 100644 index f8525dbcc5bc58ee67544cc23046c443024bec44..0000000000000000000000000000000000000000 --- a/database/sampling_test.go +++ /dev/null @@ -1,39 +0,0 @@ -/* - * Copyright 2018 Xiaomi, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package database - -import ( - "testing" - - "github.com/XiaoMi/soar/common" -) - -func init() { - common.BaseDir = common.DevPath -} - -func TestSamplingData(t *testing.T) { - connOnline, err := NewConnector(common.Config.OnlineDSN) - if err != nil { - t.Error(err) - } - - err = connTest.SamplingData(connOnline, "film") - if err != nil { - t.Error(err) - } -} diff --git a/database/show.go b/database/show.go index 75ca837c461b8845d49c0b64cbae910374e11fee..214781cd9547fa12e4d84f60c39062996814690d 100644 --- a/database/show.go +++ b/database/show.go @@ -459,27 +459,27 @@ func (db *Connector) ShowCreateTable(tableName string) (string, error) { ddl, err := db.showCreate("table", tableName) // 去除外键关联条件 - var noConstraint []string - relationReg, _ := regexp.Compile("CONSTRAINT") - for _, line := range strings.Split(ddl, "\n") { - - if relationReg.Match([]byte(line)) { - continue - } - - // 去除外键语句会使DDL中多一个','导致语法错误,要把多余的逗号去除 - if strings.Index(line, ")") == 0 { - lineWrongSyntax := noConstraint[len(noConstraint)-1] - // 如果')'前一句的末尾是',' 删除 ',' 保证语法正确性 - if strings.Index(lineWrongSyntax, ",") == len(lineWrongSyntax)-1 { - noConstraint[len(noConstraint)-1] = lineWrongSyntax[:len(lineWrongSyntax)-1] + lines := strings.Split(ddl, "\n") + // CREATE VIEW ONLY 1 LINE + if len(lines) > 2 { + var noConstraint []string + relationReg, _ := regexp.Compile("CONSTRAINT") + for _, line := range lines[1 : len(lines)-1] { + if relationReg.Match([]byte(line)) { + continue } + line = strings.TrimSuffix(line, ",") + noConstraint = append(noConstraint, line) } - noConstraint = append(noConstraint, line) + // 去除外键语句会使DDL中多一个','导致语法错误,要把多余的逗号去除 + ddl = fmt.Sprint( + lines[0], "\n", + strings.Join(noConstraint, ",\n"), "\n", + lines[len(lines)-1], + ) } - - return strings.Join(noConstraint, "\n"), err + return ddl, err } // FindColumn find column diff --git a/database/show_test.go b/database/show_test.go index ceea0abffc4883dd4d4a679603c7561c86ebead7..165b2996986415a97c9afc2ce576331d90566b9c 100644 --- a/database/show_test.go +++ b/database/show_test.go @@ -26,6 +26,7 @@ import ( ) func TestShowTableStatus(t *testing.T) { + common.Log.Debug("Entering function: %s", common.GetFunctionName()) orgDatabase := connTest.Database connTest.Database = "sakila" ts, err := connTest.ShowTableStatus("film") @@ -47,9 +48,11 @@ func TestShowTableStatus(t *testing.T) { } pretty.Println(ts) connTest.Database = orgDatabase + common.Log.Debug("Exiting function: %s", common.GetFunctionName()) } func TestShowTables(t *testing.T) { + common.Log.Debug("Entering function: %s", common.GetFunctionName()) orgDatabase := connTest.Database connTest.Database = "sakila" ts, err := connTest.ShowTables() @@ -66,23 +69,29 @@ func TestShowTables(t *testing.T) { t.Error(err) } connTest.Database = orgDatabase + common.Log.Debug("Exiting function: %s", common.GetFunctionName()) } func TestShowCreateDatabase(t *testing.T) { + common.Log.Debug("Entering function: %s", common.GetFunctionName()) err := common.GoldenDiff(func() { fmt.Println(connTest.ShowCreateDatabase("sakila")) }, t.Name(), update) if err != nil { t.Error(err) } + common.Log.Debug("Exiting function: %s", common.GetFunctionName()) } func TestShowCreateTable(t *testing.T) { + common.Log.Debug("Entering function: %s", common.GetFunctionName()) orgDatabase := connTest.Database connTest.Database = "sakila" tables := []string{ "film", + "category", "customer_list", + "inventory", } err := common.GoldenDiff(func() { for _, table := range tables { @@ -97,9 +106,11 @@ func TestShowCreateTable(t *testing.T) { t.Error(err) } connTest.Database = orgDatabase + common.Log.Debug("Exiting function: %s", common.GetFunctionName()) } func TestShowIndex(t *testing.T) { + common.Log.Debug("Entering function: %s", common.GetFunctionName()) orgDatabase := connTest.Database connTest.Database = "sakila" ti, err := connTest.ShowIndex("film") @@ -114,11 +125,12 @@ func TestShowIndex(t *testing.T) { if err != nil { t.Error(err) } - connTest.Database = orgDatabase + common.Log.Debug("Exiting function: %s", common.GetFunctionName()) } func TestShowColumns(t *testing.T) { + common.Log.Debug("Entering function: %s", common.GetFunctionName()) orgDatabase := connTest.Database connTest.Database = "sakila" ti, err := connTest.ShowColumns("actor_info") @@ -134,9 +146,11 @@ func TestShowColumns(t *testing.T) { } connTest.Database = orgDatabase + common.Log.Debug("Exiting function: %s", common.GetFunctionName()) } func TestFindColumn(t *testing.T) { + common.Log.Debug("Entering function: %s", common.GetFunctionName()) ti, err := connTest.FindColumn("film_id", "sakila", "film") if err != nil { t.Error("FindColumn Error: ", err) @@ -147,15 +161,19 @@ func TestFindColumn(t *testing.T) { if err != nil { t.Error(err) } + common.Log.Debug("Exiting function: %s", common.GetFunctionName()) } func TestIsFKey(t *testing.T) { + common.Log.Debug("Entering function: %s", common.GetFunctionName()) if !connTest.IsForeignKey("sakila", "film", "language_id") { t.Error("want True. got false") } + common.Log.Debug("Exiting function: %s", common.GetFunctionName()) } func TestShowReference(t *testing.T) { + common.Log.Debug("Entering function: %s", common.GetFunctionName()) rv, err := connTest.ShowReference("sakila", "film") if err != nil { t.Error("ShowReference Error: ", err) @@ -167,4 +185,5 @@ func TestShowReference(t *testing.T) { if err != nil { t.Error(err) } + common.Log.Debug("Exiting function: %s", common.GetFunctionName()) } diff --git a/database/testdata/TestShowCreateTable.golden b/database/testdata/TestShowCreateTable.golden index 74208c702b85a1014f34c2a7561c273ade684f04..d91cc1c055b3287b3b111ba8390f8f5a08216264 100644 --- a/database/testdata/TestShowCreateTable.golden +++ b/database/testdata/TestShowCreateTable.golden @@ -17,4 +17,19 @@ CREATE TABLE `film` ( KEY `idx_fk_language_id` (`language_id`), KEY `idx_fk_original_language_id` (`original_language_id`) ) ENGINE=InnoDB AUTO_INCREMENT=1001 DEFAULT CHARSET=utf8 +CREATE TABLE `category` ( + `category_id` tinyint(3) unsigned NOT NULL AUTO_INCREMENT, + `name` varchar(25) NOT NULL, + `last_update` timestamp NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, + PRIMARY KEY (`category_id`) +) ENGINE=InnoDB AUTO_INCREMENT=17 DEFAULT CHARSET=utf8 CREATE ALGORITHM=UNDEFINED DEFINER=`root`@`localhost` SQL SECURITY DEFINER VIEW `customer_list` AS select `cu`.`customer_id` AS `ID`,concat(`cu`.`first_name`,_utf8mb3' ',`cu`.`last_name`) AS `name`,`a`.`address` AS `address`,`a`.`postal_code` AS `zip code`,`a`.`phone` AS `phone`,`city`.`city` AS `city`,`country`.`country` AS `country`,if(`cu`.`active`,_utf8mb3'active',_utf8mb3'') AS `notes`,`cu`.`store_id` AS `SID` from (((`customer` `cu` join `address` `a` on((`cu`.`address_id` = `a`.`address_id`))) join `city` on((`a`.`city_id` = `city`.`city_id`))) join `country` on((`city`.`country_id` = `country`.`country_id`))) +CREATE TABLE `inventory` ( + `inventory_id` mediumint(8) unsigned NOT NULL AUTO_INCREMENT, + `film_id` smallint(5) unsigned NOT NULL, + `store_id` tinyint(3) unsigned NOT NULL, + `last_update` timestamp NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, + PRIMARY KEY (`inventory_id`), + KEY `idx_fk_film_id` (`film_id`), + KEY `idx_store_id_film_id` (`store_id`,`film_id`) +) ENGINE=InnoDB AUTO_INCREMENT=4582 DEFAULT CHARSET=utf8 diff --git a/database/trace_test.go b/database/trace_test.go index 535d3f58cf948fbb54c1a8407a63bb3317225753..887ee4c6956bb1afbbb12765589ba061ca0d6271 100644 --- a/database/trace_test.go +++ b/database/trace_test.go @@ -25,6 +25,7 @@ import ( ) func TestTrace(t *testing.T) { + common.Log.Debug("Entering function: %s", common.GetFunctionName()) res, err := connTest.Trace("select 1") if err != nil { t.Error(err) @@ -36,9 +37,11 @@ func TestTrace(t *testing.T) { if err != nil { t.Error(err) } + common.Log.Debug("Exiting function: %s", common.GetFunctionName()) } func TestFormatTrace(t *testing.T) { + common.Log.Debug("Entering function: %s", common.GetFunctionName()) res, err := connTest.Trace("select 1") if err != nil { t.Error(err) @@ -50,4 +53,5 @@ func TestFormatTrace(t *testing.T) { if err != nil { t.Error(err) } + common.Log.Debug("Exiting function: %s", common.GetFunctionName()) } diff --git a/env/env.go b/env/env.go index 336152bd84cbd7339d49cb36ff7a878e83e86c6f..293f6217ca922d065854af67eb6fdaaa37eff14a 100644 --- a/env/env.go +++ b/env/env.go @@ -82,7 +82,7 @@ func BuildEnv() (*VirtualEnv, *database.Connector) { common.LogIfError(err, "") // 检查线上环境可用性版本 - rEnvVersion, err := vEnv.Version() + rEnvVersion, err := connOnline.Version() common.Config.OnlineDSN.Version = rEnvVersion if err != nil { common.Log.Warn("BuildEnv OnlineDSN: %s:********@%s/%s not available , Error: %s", @@ -245,20 +245,20 @@ func (ve *VirtualEnv) BuildVirtualEnv(rEnv *database.Connector, SQLs ...string) // 为了支持并发,需要将DB进行映射,但db.table这种形式无法保证DB的映射是正确的 // TODO:暂不支持 create db.tableName (id int) 形式的建表语句 if stmt.Table.Qualifier.String() != "" { - common.Log.Error("BuildVirtualEnv DDL Not support '.'") + common.Log.Error("BuildVirtualEnv DDL Not support db.tb format") return false } for _, tb := range stmt.FromTables { if tb.Qualifier.String() != "" { - common.Log.Error("BuildVirtualEnv DDL Not support '.'") + common.Log.Error("BuildVirtualEnv DDL Not support db.tb format") return false } } for _, tb := range stmt.ToTables { if tb.Qualifier.String() != "" { - common.Log.Error("BuildVirtualEnv DDL Not support '.'") + common.Log.Error("BuildVirtualEnv DDL Not support db.tb format") return false } } @@ -338,7 +338,7 @@ func (ve *VirtualEnv) BuildVirtualEnv(rEnv *database.Connector, SQLs ...string) err = ve.createTable(tmpEnv, db, tb.TableName) if err != nil { - common.Log.Error("BuildVirtualEnv Error : %v", err) + common.Log.Error("BuildVirtualEnv %s.%s Error : %v", db, tb.TableName, err) return false } } @@ -453,7 +453,7 @@ func (ve VirtualEnv) createTable(rEnv *database.Connector, dbName, tbName string res, err := ve.Query(ddl) if err != nil { // 有可能是用户新建表,因此线上环境查不到 - common.Log.Error("createTable, %s Error : %v", tbName, err) + common.Log.Error("createTable: %s Error : %v", tbName, err) return err } res.Rows.Close() @@ -461,13 +461,9 @@ func (ve VirtualEnv) createTable(rEnv *database.Connector, dbName, tbName string // 泵取数据 if common.Config.Sampling { common.Log.Debug("createTable, Start Sampling data from %s.%s to %s.%s ...", dbName, tbName, ve.DBRef[dbName], tbName) - err := ve.SamplingData(rEnv, tbName) - if err != nil { - common.Log.Error(" (ve VirtualEnv) createTable SamplingData Error: %v", err) - return err - } + err = ve.SamplingData(rEnv, dbName, tbName) } - return nil + return err } // GenTableColumns 为 Rewrite 提供的结构体初始化 diff --git a/env/env_test.go b/env/env_test.go index 6f48a6310b5223c721fdcc07679879e8bc89c76c..fa3805a252ef4714ad5ffc4121fa4b4e4c8b9d67 100644 --- a/env/env_test.go +++ b/env/env_test.go @@ -49,6 +49,7 @@ func init() { } func TestNewVirtualEnv(t *testing.T) { + common.Log.Debug("Entering function: %s", common.GetFunctionName()) testSQL := []string{ "create table t(id int,c1 varchar(20),PRIMARY KEY (id));", "alter table t add index `idx_c1`(c1);", @@ -117,9 +118,11 @@ func TestNewVirtualEnv(t *testing.T) { } } }, t.Name(), update) + common.Log.Debug("Exiting function: %s", common.GetFunctionName()) } func TestCleanupTestDatabase(t *testing.T) { + common.Log.Debug("Enter function: %s", common.GetFunctionName()) vEnv, _ := BuildEnv() if common.Config.TestDSN.Disable { common.Log.Warn("common.Config.TestDSN.Disable=true, by pass TestCleanupTestDatabase") @@ -146,9 +149,11 @@ func TestCleanupTestDatabase(t *testing.T) { if err != nil { t.Error("optimizer_060102150405 not exist, should not be dropped") } + common.Log.Debug("Exiting function: %s", common.GetFunctionName()) } func TestGenTableColumns(t *testing.T) { + common.Log.Debug("Enter function: %s", common.GetFunctionName()) vEnv, rEnv := BuildEnv() defer vEnv.CleanUp() @@ -214,4 +219,66 @@ func TestGenTableColumns(t *testing.T) { } } } + common.Log.Debug("Exiting function: %s", common.GetFunctionName()) +} + +func TestCreateTable(t *testing.T) { + common.Log.Debug("Enter function: %s", common.GetFunctionName()) + orgSamplingCondition := common.Config.SamplingCondition + common.Config.SamplingCondition = "LIMIT 1" + + vEnv, rEnv := BuildEnv() + defer vEnv.CleanUp() + // TODO: support VIEW, + tables := []string{ + "actor", + // "actor_info", // VIEW + "address", + "category", + "city", + "country", + "customer", + "customer_list", + "film", + "film_actor", + "film_category", + "film_list", + "film_text", + "inventory", + "language", + "nicer_but_slower_film_list", + "payment", + "rental", + // "sales_by_film_category", // VIEW + // "sales_by_store", // VIEW + "staff", + "staff_list", + "store", + } + for _, table := range tables { + err := vEnv.createTable(rEnv, "sakila", table) + if err != nil { + t.Error(err) + } + } + common.Config.SamplingCondition = orgSamplingCondition + common.Log.Debug("Exiting function: %s", common.GetFunctionName()) +} + +func TestCreateDatabase(t *testing.T) { + common.Log.Debug("Enter function: %s", common.GetFunctionName()) + vEnv, rEnv := BuildEnv() + defer vEnv.CleanUp() + err := vEnv.createDatabase(rEnv, "sakila") + if err != nil { + t.Error(err) + } + if vEnv.DBHash("sakila") == "sakila" { + t.Errorf("database: sakila rehashed failed!") + } + + if vEnv.DBHash("not_exist_db") != "not_exist_db" { + t.Errorf("database: not_exist_db rehashed!") + } + common.Log.Debug("Exiting function: %s", common.GetFunctionName()) }