提交 b9e152bf 编写于 作者: martianzhang's avatar martianzhang

sampling data type trading

上级 86da258b
...@@ -23,6 +23,7 @@ import ( ...@@ -23,6 +23,7 @@ import (
) )
func TestDigestExplainText(t *testing.T) { func TestDigestExplainText(t *testing.T) {
common.Log.Debug("Enter function: %s", common.GetFunctionName())
var text = `+----+-------------+---------+-------+---------------------------------------------------------+-------------------+---------+---------------------------+------+-------------+ var text = `+----+-------------+---------+-------+---------------------------------------------------------+-------------------+---------+---------------------------+------+-------------+
| id | select_type | table | type | possible_keys | key | key_len | ref | rows | Extra | | id | select_type | table | type | possible_keys | key | key_len | ref | rows | Extra |
+----+-------------+---------+-------+---------------------------------------------------------+-------------------+---------+---------------------------+------+-------------+ +----+-------------+---------+-------+---------------------------------------------------------+-------------------+---------+---------------------------+------+-------------+
...@@ -34,4 +35,5 @@ func TestDigestExplainText(t *testing.T) { ...@@ -34,4 +35,5 @@ func TestDigestExplainText(t *testing.T) {
if nil != err { if nil != err {
t.Fatal(err) t.Fatal(err)
} }
common.Log.Debug("Exiting function: %s", common.GetFunctionName())
} }
...@@ -35,7 +35,7 @@ func TestRuleImplicitAlias(t *testing.T) { ...@@ -35,7 +35,7 @@ func TestRuleImplicitAlias(t *testing.T) {
"select col from tbl tb where id < 1000", "select col from tbl tb where id < 1000",
}, },
{ {
"do 1", "select 1",
}, },
} }
for _, sql := range sqls[0] { for _, sql := range sqls[0] {
......
...@@ -95,8 +95,8 @@ func TestRuleImplicitConversion(t *testing.T) { ...@@ -95,8 +95,8 @@ func TestRuleImplicitConversion(t *testing.T) {
} }
} }
} }
common.Log.Debug("Exiting function: %s", common.GetFunctionName())
common.Config.OnlineDSN = dsn common.Config.OnlineDSN = dsn
common.Log.Debug("Exiting function: %s", common.GetFunctionName())
} }
// JOI.003 & JOI.004 // JOI.003 & JOI.004
...@@ -383,9 +383,11 @@ func TestDuplicateKeyChecker(t *testing.T) { ...@@ -383,9 +383,11 @@ func TestDuplicateKeyChecker(t *testing.T) {
if len(rule) != 0 { if len(rule) != 0 {
t.Errorf("got rules: %s", pretty.Sprint(rule)) t.Errorf("got rules: %s", pretty.Sprint(rule))
} }
common.Log.Debug("Exiting function: %s", common.GetFunctionName())
} }
func TestMergeAdvices(t *testing.T) { func TestMergeAdvices(t *testing.T) {
common.Log.Debug("Entering function: %s", common.GetFunctionName())
dst := []IndexInfo{ dst := []IndexInfo{
{ {
Name: "test", Name: "test",
...@@ -405,6 +407,7 @@ func TestMergeAdvices(t *testing.T) { ...@@ -405,6 +407,7 @@ func TestMergeAdvices(t *testing.T) {
if len(advise) != 1 { if len(advise) != 1 {
t.Error(pretty.Sprint(advise)) t.Error(pretty.Sprint(advise))
} }
common.Log.Debug("Exiting function: %s", common.GetFunctionName())
} }
func TestIdxColsTypeCheck(t *testing.T) { func TestIdxColsTypeCheck(t *testing.T) {
...@@ -450,13 +453,16 @@ func TestIdxColsTypeCheck(t *testing.T) { ...@@ -450,13 +453,16 @@ func TestIdxColsTypeCheck(t *testing.T) {
} }
} }
} }
common.Log.Debug("Exiting function: %s", common.GetFunctionName())
} }
func TestGetRandomIndexSuffix(t *testing.T) { func TestGetRandomIndexSuffix(t *testing.T) {
common.Log.Debug("Enter function: %s", common.GetFunctionName())
for i := 0; i < 5; i++ { for i := 0; i < 5; i++ {
r := getRandomIndexSuffix() r := getRandomIndexSuffix()
if !(strings.HasPrefix(r, "_") && len(r) == 5) { if !(strings.HasPrefix(r, "_") && len(r) == 5) {
t.Errorf("getRandomIndexSuffix should return a string with prefix `_` and 5 length, but got:%s", r) t.Errorf("getRandomIndexSuffix should return a string with prefix `_` and 5 length, but got:%s", r)
} }
} }
common.Log.Debug("Exiting function: %s", common.GetFunctionName())
} }
...@@ -58,7 +58,7 @@ func NewQuery4Audit(sql string, options ...string) (*Query4Audit, error) { ...@@ -58,7 +58,7 @@ func NewQuery4Audit(sql string, options ...string) (*Query4Audit, error) {
// vitess 语法解析不上报,以 tidb parser 为主 // vitess 语法解析不上报,以 tidb parser 为主
q.Stmt, vErr = sqlparser.Parse(sql) q.Stmt, vErr = sqlparser.Parse(sql)
if vErr != nil { if vErr != nil {
common.Log.Warn("NewQuery4Audit vitess parse Error: %s", vErr.Error()) common.Log.Warn("NewQuery4Audit vitess parse Error: %s, Query: %s", vErr.Error(), sql)
} }
// TODO: charset, collation // TODO: charset, collation
......
...@@ -23,29 +23,37 @@ import ( ...@@ -23,29 +23,37 @@ import (
) )
func TestListTestSQLs(t *testing.T) { func TestListTestSQLs(t *testing.T) {
common.Log.Debug("Entering function: %s", common.GetFunctionName())
err := common.GoldenDiff(func() { ListTestSQLs() }, t.Name(), update) err := common.GoldenDiff(func() { ListTestSQLs() }, t.Name(), update)
if nil != err { if nil != err {
t.Fatal(err) t.Fatal(err)
} }
common.Log.Debug("Exiting function: %s", common.GetFunctionName())
} }
func TestListHeuristicRules(t *testing.T) { func TestListHeuristicRules(t *testing.T) {
common.Log.Debug("Entering function: %s", common.GetFunctionName())
err := common.GoldenDiff(func() { ListHeuristicRules(HeuristicRules) }, t.Name(), update) err := common.GoldenDiff(func() { ListHeuristicRules(HeuristicRules) }, t.Name(), update)
if nil != err { if nil != err {
t.Fatal(err) t.Fatal(err)
} }
common.Log.Debug("Exiting function: %s", common.GetFunctionName())
} }
func TestInBlackList(t *testing.T) { func TestInBlackList(t *testing.T) {
common.Log.Debug("Entering function: %s", common.GetFunctionName())
common.BlackList = []string{"select"} common.BlackList = []string{"select"}
if !InBlackList("select 1") { if !InBlackList("select 1") {
t.Error("should be true") t.Error("should be true")
} }
common.Log.Debug("Exiting function: %s", common.GetFunctionName())
} }
func TestIsIgnoreRule(t *testing.T) { func TestIsIgnoreRule(t *testing.T) {
common.Log.Debug("Entering function: %s", common.GetFunctionName())
common.Config.IgnoreRules = []string{"test"} common.Config.IgnoreRules = []string{"test"}
if !IsIgnoreRule("test") { if !IsIgnoreRule("test") {
t.Error("should be true") t.Error("should be true")
} }
common.Log.Debug("Exiting function: %s", common.GetFunctionName())
} }
...@@ -27,6 +27,7 @@ import ( ...@@ -27,6 +27,7 @@ import (
) )
func TestGetTableFromExprs(t *testing.T) { func TestGetTableFromExprs(t *testing.T) {
common.Log.Debug("Entering function: %s", common.GetFunctionName())
tbExprs := sqlparser.TableExprs{ tbExprs := sqlparser.TableExprs{
&sqlparser.AliasedTableExpr{ &sqlparser.AliasedTableExpr{
Expr: sqlparser.TableName{ Expr: sqlparser.TableName{
...@@ -40,9 +41,11 @@ func TestGetTableFromExprs(t *testing.T) { ...@@ -40,9 +41,11 @@ func TestGetTableFromExprs(t *testing.T) {
if tb, ok := meta["db"]; !ok { if tb, ok := meta["db"]; !ok {
t.Errorf("no table qualifier, meta: %s", pretty.Sprint(tb)) t.Errorf("no table qualifier, meta: %s", pretty.Sprint(tb))
} }
common.Log.Debug("Exiting function: %s", common.GetFunctionName())
} }
func TestGetParseTableWithStmt(t *testing.T) { func TestGetParseTableWithStmt(t *testing.T) {
common.Log.Debug("Entering function: %s", common.GetFunctionName())
for _, sql := range common.TestSQLs { for _, sql := range common.TestSQLs {
fmt.Println(sql) fmt.Println(sql)
stmt, err := sqlparser.Parse(sql) stmt, err := sqlparser.Parse(sql)
...@@ -53,9 +56,11 @@ func TestGetParseTableWithStmt(t *testing.T) { ...@@ -53,9 +56,11 @@ func TestGetParseTableWithStmt(t *testing.T) {
pretty.Println(meta) pretty.Println(meta)
fmt.Println() fmt.Println()
} }
common.Log.Debug("Exiting function: %s", common.GetFunctionName())
} }
func TestFindCondition(t *testing.T) { func TestFindCondition(t *testing.T) {
common.Log.Debug("Entering function: %s", common.GetFunctionName())
for _, sql := range common.TestSQLs { for _, sql := range common.TestSQLs {
fmt.Println(sql) fmt.Println(sql)
stmt, err := sqlparser.Parse(sql) stmt, err := sqlparser.Parse(sql)
...@@ -71,9 +76,11 @@ func TestFindCondition(t *testing.T) { ...@@ -71,9 +76,11 @@ func TestFindCondition(t *testing.T) {
pretty.Println(inEq) pretty.Println(inEq)
fmt.Println() fmt.Println()
} }
common.Log.Debug("Exiting function: %s", common.GetFunctionName())
} }
func TestFindGroupBy(t *testing.T) { func TestFindGroupBy(t *testing.T) {
common.Log.Debug("Entering function: %s", common.GetFunctionName())
sqlList := []string{ sqlList := []string{
"select a from t group by c", "select a from t group by c",
} }
...@@ -88,9 +95,11 @@ func TestFindGroupBy(t *testing.T) { ...@@ -88,9 +95,11 @@ func TestFindGroupBy(t *testing.T) {
pretty.Println(res) pretty.Println(res)
fmt.Println() fmt.Println()
} }
common.Log.Debug("Exiting function: %s", common.GetFunctionName())
} }
func TestFindOrderBy(t *testing.T) { func TestFindOrderBy(t *testing.T) {
common.Log.Debug("Entering function: %s", common.GetFunctionName())
sqlList := []string{ sqlList := []string{
"select a from t group by c order by d, c desc", "select a from t group by c order by d, c desc",
"select a from t group by c order by d desc", "select a from t group by c order by d desc",
...@@ -106,9 +115,11 @@ func TestFindOrderBy(t *testing.T) { ...@@ -106,9 +115,11 @@ func TestFindOrderBy(t *testing.T) {
pretty.Println(res) pretty.Println(res)
fmt.Println() fmt.Println()
} }
common.Log.Debug("Exiting function: %s", common.GetFunctionName())
} }
func TestFindSubquery(t *testing.T) { func TestFindSubquery(t *testing.T) {
common.Log.Debug("Entering function: %s", common.GetFunctionName())
sqlList := []string{ sqlList := []string{
"SELECT * FROM t1 WHERE column1 = (SELECT column1 FROM (SELECT column1 FROM t2) a);", "SELECT * FROM t1 WHERE column1 = (SELECT column1 FROM (SELECT column1 FROM t2) a);",
"select column1 from t2", "select column1 from t2",
...@@ -127,10 +138,11 @@ func TestFindSubquery(t *testing.T) { ...@@ -127,10 +138,11 @@ func TestFindSubquery(t *testing.T) {
fmt.Println(len(subquery)) fmt.Println(len(subquery))
pretty.Println(subquery) pretty.Println(subquery)
} }
common.Log.Debug("Exiting function: %s", common.GetFunctionName())
} }
func TestFindJoinTable(t *testing.T) { func TestFindJoinTable(t *testing.T) {
common.Log.Debug("Entering function: %s", common.GetFunctionName())
sqlList := []string{ sqlList := []string{
"SELECT * FROM t1 LEFT JOIN (t2 CROSS JOIN t3 CROSS JOIN t4) ON (t2.a = t1.a AND t3.b = t1.b AND t4.c = t1.c)", "SELECT * FROM t1 LEFT JOIN (t2 CROSS JOIN t3 CROSS JOIN t4) ON (t2.a = t1.a AND t3.b = t1.b AND t4.c = t1.c)",
"select ID,name from (select address from customer_list where SID=1 order by phone limit 50,10) a join customer_list l on (a.address=l.address) join city c on (c.city=l.city) order by phone desc;", "select ID,name from (select address from customer_list where SID=1 order by phone limit 50,10) a join customer_list l on (a.address=l.address) join city c on (c.city=l.city) order by phone desc;",
...@@ -151,9 +163,11 @@ func TestFindJoinTable(t *testing.T) { ...@@ -151,9 +163,11 @@ func TestFindJoinTable(t *testing.T) {
joinMeta := FindJoinTable(stmt, nil) joinMeta := FindJoinTable(stmt, nil)
pretty.Println(joinMeta) pretty.Println(joinMeta)
} }
common.Log.Debug("Exiting function: %s", common.GetFunctionName())
} }
func TestFindJoinCols(t *testing.T) { func TestFindJoinCols(t *testing.T) {
common.Log.Debug("Entering function: %s", common.GetFunctionName())
sqlList := []string{ sqlList := []string{
"SELECT * FROM t1 LEFT JOIN (t2 CROSS JOIN t3 CROSS JOIN t4) ON (t2.a = t1.a AND t3.b = t1.b AND t4.c = t1.c)", "SELECT * FROM t1 LEFT JOIN (t2 CROSS JOIN t3 CROSS JOIN t4) ON (t2.a = t1.a AND t3.b = t1.b AND t4.c = t1.c)",
"select t from a LEFT JOIN b USING (c1, c2, c3)", "select t from a LEFT JOIN b USING (c1, c2, c3)",
...@@ -175,9 +189,11 @@ func TestFindJoinCols(t *testing.T) { ...@@ -175,9 +189,11 @@ func TestFindJoinCols(t *testing.T) {
columns := FindJoinCols(stmt) columns := FindJoinCols(stmt)
pretty.Println(columns) pretty.Println(columns)
} }
common.Log.Debug("Exiting function: %s", common.GetFunctionName())
} }
func TestFindJoinColBeWhereEQ(t *testing.T) { func TestFindJoinColBeWhereEQ(t *testing.T) {
common.Log.Debug("Entering function: %s", common.GetFunctionName())
sqlList := []string{ sqlList := []string{
"select ID,name from (select address from customer_list where SID=1 order by phone limit 50,10) a join customer_list l on (a.address=l.address) join city c on (c.city=l.city) order by phone desc;", "select ID,name from (select address from customer_list where SID=1 order by phone limit 50,10) a join customer_list l on (a.address=l.address) join city c on (c.city=l.city) order by phone desc;",
"SELECT * FROM t1 LEFT JOIN (t2, t3, t4) ON (t2.a = t1.a AND t3.b = t1.b AND t4.c = t1.c)", "SELECT * FROM t1 LEFT JOIN (t2, t3, t4) ON (t2.a = t1.a AND t3.b = t1.b AND t4.c = t1.c)",
...@@ -197,9 +213,11 @@ func TestFindJoinColBeWhereEQ(t *testing.T) { ...@@ -197,9 +213,11 @@ func TestFindJoinColBeWhereEQ(t *testing.T) {
columns := FindEQColsInJoinCond(stmt) columns := FindEQColsInJoinCond(stmt)
pretty.Println(columns) pretty.Println(columns)
} }
common.Log.Debug("Exiting function: %s", common.GetFunctionName())
} }
func TestFindJoinColBeWhereINEQ(t *testing.T) { func TestFindJoinColBeWhereINEQ(t *testing.T) {
common.Log.Debug("Entering function: %s", common.GetFunctionName())
sqlList := []string{ sqlList := []string{
"select ID,name from (select address from customer_list where SID=1 order by phone limit 50,10) a join customer_list l on (a.address=l.address) join city c on (c.city=l.city) order by phone desc;", "select ID,name from (select address from customer_list where SID=1 order by phone limit 50,10) a join customer_list l on (a.address=l.address) join city c on (c.city=l.city) order by phone desc;",
"SELECT * FROM t1 LEFT JOIN (t2, t3, t4) ON (t2.a = t1.a AND t3.b = t1.b AND t4.c = t1.c)", "SELECT * FROM t1 LEFT JOIN (t2, t3, t4) ON (t2.a = t1.a AND t3.b = t1.b AND t4.c = t1.c)",
...@@ -219,9 +237,11 @@ func TestFindJoinColBeWhereINEQ(t *testing.T) { ...@@ -219,9 +237,11 @@ func TestFindJoinColBeWhereINEQ(t *testing.T) {
columns := FindINEQColsInJoinCond(stmt) columns := FindINEQColsInJoinCond(stmt)
pretty.Println(columns) pretty.Println(columns)
} }
common.Log.Debug("Exiting function: %s", common.GetFunctionName())
} }
func TestFindAllCondition(t *testing.T) { func TestFindAllCondition(t *testing.T) {
common.Log.Debug("Entering function: %s", common.GetFunctionName())
sqlList := []string{ sqlList := []string{
"SELECT * FROM t1 LEFT JOIN (t2 CROSS JOIN t3 CROSS JOIN t4) ON (t2.a = t1.a AND t3.b = t1.b AND t4.c = t1.c)", "SELECT * FROM t1 LEFT JOIN (t2 CROSS JOIN t3 CROSS JOIN t4) ON (t2.a = t1.a AND t3.b = t1.b AND t4.c = t1.c)",
"select t from a LEFT JOIN b USING (c1, c2, c3)", "select t from a LEFT JOIN b USING (c1, c2, c3)",
...@@ -247,9 +267,11 @@ func TestFindAllCondition(t *testing.T) { ...@@ -247,9 +267,11 @@ func TestFindAllCondition(t *testing.T) {
columns := FindAllCondition(stmt) columns := FindAllCondition(stmt)
pretty.Println(columns) pretty.Println(columns)
} }
common.Log.Debug("Exiting function: %s", common.GetFunctionName())
} }
func TestFindColumn(t *testing.T) { func TestFindColumn(t *testing.T) {
common.Log.Debug("Entering function: %s", common.GetFunctionName())
sqlList := []string{ sqlList := []string{
"select col, col2, sum(col1) from tb group by col", "select col, col2, sum(col1) from tb group by col",
"select col from tb group by col,sum(col1)", "select col from tb group by col,sum(col1)",
...@@ -266,9 +288,11 @@ func TestFindColumn(t *testing.T) { ...@@ -266,9 +288,11 @@ func TestFindColumn(t *testing.T) {
columns := FindColumn(stmt) columns := FindColumn(stmt)
pretty.Println(columns) pretty.Println(columns)
} }
common.Log.Debug("Exiting function: %s", common.GetFunctionName())
} }
func TestFindAllCols(t *testing.T) { func TestFindAllCols(t *testing.T) {
common.Log.Debug("Entering function: %s", common.GetFunctionName())
sqlList := []string{ sqlList := []string{
"select * from tb where a = '1' order by c", "select * from tb where a = '1' order by c",
"select * from tb where a = '1' group by c", "select * from tb where a = '1' group by c",
...@@ -296,9 +320,11 @@ func TestFindAllCols(t *testing.T) { ...@@ -296,9 +320,11 @@ func TestFindAllCols(t *testing.T) {
t.Error(fmt.Errorf("want 'c' got %v", columns)) t.Error(fmt.Errorf("want 'c' got %v", columns))
} }
} }
common.Log.Debug("Exiting function: %s", common.GetFunctionName())
} }
func TestGetSubqueryDepth(t *testing.T) { func TestGetSubqueryDepth(t *testing.T) {
common.Log.Debug("Entering function: %s", common.GetFunctionName())
sqlList := []string{ sqlList := []string{
"SELECT * FROM t1 LEFT JOIN (t2 CROSS JOIN t3 CROSS JOIN t4) ON (t2.a = t1.a AND t3.b = t1.b AND t4.c = t1.c)", "SELECT * FROM t1 LEFT JOIN (t2 CROSS JOIN t3 CROSS JOIN t4) ON (t2.a = t1.a AND t3.b = t1.b AND t4.c = t1.c)",
"select t from a LEFT JOIN b USING (c1, c2, c3)", "select t from a LEFT JOIN b USING (c1, c2, c3)",
...@@ -323,9 +349,11 @@ func TestGetSubqueryDepth(t *testing.T) { ...@@ -323,9 +349,11 @@ func TestGetSubqueryDepth(t *testing.T) {
dep := GetSubqueryDepth(stmt) dep := GetSubqueryDepth(stmt)
fmt.Println(dep) fmt.Println(dep)
} }
common.Log.Debug("Exiting function: %s", common.GetFunctionName())
} }
func TestAppendTable(t *testing.T) { func TestAppendTable(t *testing.T) {
common.Log.Debug("Entering function: %s", common.GetFunctionName())
sqlList := []string{ sqlList := []string{
"select ID,name from (select address from customer_list where SID=1 order by phone limit 50,10) a join customer_list l on (a.address=l.address) join city c on (c.city=l.city) order by phone desc;", "select ID,name from (select address from customer_list where SID=1 order by phone limit 50,10) a join customer_list l on (a.address=l.address) join city c on (c.city=l.city) order by phone desc;",
} }
...@@ -367,4 +395,5 @@ func TestAppendTable(t *testing.T) { ...@@ -367,4 +395,5 @@ func TestAppendTable(t *testing.T) {
if meta[""].Table["customer_list"].TableAliases[0] != "l" || meta[""].Table["city"].TableAliases[0] != "c" { if meta[""].Table["customer_list"].TableAliases[0] != "l" || meta[""].Table["city"].TableAliases[0] != "c" {
t.Error("alias filed\n", pretty.Sprint(meta)) t.Error("alias filed\n", pretty.Sprint(meta))
} }
common.Log.Debug("Exiting function: %s", common.GetFunctionName())
} }
...@@ -128,6 +128,7 @@ var TestSqlsPretty = []string{ ...@@ -128,6 +128,7 @@ var TestSqlsPretty = []string{
} }
func TestPretty(t *testing.T) { func TestPretty(t *testing.T) {
common.Log.Debug("Entering function: %s", common.GetFunctionName())
err := common.GoldenDiff(func() { err := common.GoldenDiff(func() {
for _, sql := range append(TestSqlsPretty, common.TestSQLs...) { for _, sql := range append(TestSqlsPretty, common.TestSQLs...) {
fmt.Println(sql) fmt.Println(sql)
...@@ -137,9 +138,11 @@ func TestPretty(t *testing.T) { ...@@ -137,9 +138,11 @@ func TestPretty(t *testing.T) {
if nil != err { if nil != err {
t.Fatal(err) t.Fatal(err)
} }
common.Log.Debug("Exiting function: %s", common.GetFunctionName())
} }
func TestIsKeyword(t *testing.T) { func TestIsKeyword(t *testing.T) {
common.Log.Debug("Entering function: %s", common.GetFunctionName())
tks := map[string]bool{ tks := map[string]bool{
"AGAINST": true, "AGAINST": true,
"AUTO_INCREMENT": true, "AUTO_INCREMENT": true,
...@@ -155,9 +158,11 @@ func TestIsKeyword(t *testing.T) { ...@@ -155,9 +158,11 @@ func TestIsKeyword(t *testing.T) {
t.Error("isKeyword:", tk) t.Error("isKeyword:", tk)
} }
} }
common.Log.Debug("Exiting function: %s", common.GetFunctionName())
} }
func TestRemoveComments(t *testing.T) { func TestRemoveComments(t *testing.T) {
common.Log.Debug("Entering function: %s", common.GetFunctionName())
for _, sql := range TestSqlsPretty { for _, sql := range TestSqlsPretty {
stmt, _ := sqlparser.Parse(sql) stmt, _ := sqlparser.Parse(sql)
newSQL := sqlparser.String(stmt) newSQL := sqlparser.String(stmt)
...@@ -165,9 +170,11 @@ func TestRemoveComments(t *testing.T) { ...@@ -165,9 +170,11 @@ func TestRemoveComments(t *testing.T) {
fmt.Print(newSQL) fmt.Print(newSQL)
} }
} }
common.Log.Debug("Exiting function: %s", common.GetFunctionName())
} }
func TestMysqlEscapeString(t *testing.T) { func TestMysqlEscapeString(t *testing.T) {
common.Log.Debug("Entering function: %s", common.GetFunctionName())
var strs = []map[string]string{ var strs = []map[string]string{
{ {
"input": "abc", "input": "abc",
...@@ -198,4 +205,5 @@ abc`, ...@@ -198,4 +205,5 @@ abc`,
} }
} }
} }
common.Log.Debug("Exiting function: %s", common.GetFunctionName())
} }
...@@ -25,6 +25,7 @@ import ( ...@@ -25,6 +25,7 @@ import (
) )
func TestRewrite(t *testing.T) { func TestRewrite(t *testing.T) {
common.Log.Debug("Entering function: %s", common.GetFunctionName())
orgTestDSNStatus := common.Config.TestDSN.Disable orgTestDSNStatus := common.Config.TestDSN.Disable
common.Config.TestDSN.Disable = false common.Config.TestDSN.Disable = false
testSQL := []map[string]string{ testSQL := []map[string]string{
...@@ -99,9 +100,11 @@ func TestRewrite(t *testing.T) { ...@@ -99,9 +100,11 @@ func TestRewrite(t *testing.T) {
} }
} }
common.Config.TestDSN.Disable = orgTestDSNStatus common.Config.TestDSN.Disable = orgTestDSNStatus
common.Log.Debug("Exiting function: %s", common.GetFunctionName())
} }
func TestRewriteStar2Columns(t *testing.T) { func TestRewriteStar2Columns(t *testing.T) {
common.Log.Debug("Entering function: %s", common.GetFunctionName())
orgTestDSNStatus := common.Config.TestDSN.Disable orgTestDSNStatus := common.Config.TestDSN.Disable
common.Config.TestDSN.Disable = false common.Config.TestDSN.Disable = false
testSQL := []map[string]string{ testSQL := []map[string]string{
...@@ -131,9 +134,11 @@ func TestRewriteStar2Columns(t *testing.T) { ...@@ -131,9 +134,11 @@ func TestRewriteStar2Columns(t *testing.T) {
} }
} }
common.Config.TestDSN.Disable = orgTestDSNStatus common.Config.TestDSN.Disable = orgTestDSNStatus
common.Log.Debug("Exiting function: %s", common.GetFunctionName())
} }
func TestRewriteInsertColumns(t *testing.T) { func TestRewriteInsertColumns(t *testing.T) {
common.Log.Debug("Entering function: %s", common.GetFunctionName())
testSQL := []map[string]string{ testSQL := []map[string]string{
{ {
"input": `insert into film values(1,2,3,4,5)`, "input": `insert into film values(1,2,3,4,5)`,
...@@ -173,9 +178,11 @@ func TestRewriteInsertColumns(t *testing.T) { ...@@ -173,9 +178,11 @@ func TestRewriteInsertColumns(t *testing.T) {
t.Errorf("want: %s\ngot: %s", sql["output"], rw.NewSQL) t.Errorf("want: %s\ngot: %s", sql["output"], rw.NewSQL)
} }
} }
common.Log.Debug("Exiting function: %s", common.GetFunctionName())
} }
func TestRewriteHaving(t *testing.T) { func TestRewriteHaving(t *testing.T) {
common.Log.Debug("Entering function: %s", common.GetFunctionName())
testSQL := []map[string]string{ testSQL := []map[string]string{
{ {
"input": `SELECT state, COUNT(*) FROM Drivers GROUP BY state HAVING state IN ('GA', 'TX') ORDER BY state`, "input": `SELECT state, COUNT(*) FROM Drivers GROUP BY state HAVING state IN ('GA', 'TX') ORDER BY state`,
...@@ -196,9 +203,11 @@ func TestRewriteHaving(t *testing.T) { ...@@ -196,9 +203,11 @@ func TestRewriteHaving(t *testing.T) {
t.Errorf("want: %s\ngot: %s", sql["output"], rw.NewSQL) t.Errorf("want: %s\ngot: %s", sql["output"], rw.NewSQL)
} }
} }
common.Log.Debug("Exiting function: %s", common.GetFunctionName())
} }
func TestRewriteAddOrderByNull(t *testing.T) { func TestRewriteAddOrderByNull(t *testing.T) {
common.Log.Debug("Entering function: %s", common.GetFunctionName())
testSQL := []map[string]string{ testSQL := []map[string]string{
{ {
"input": "SELECT sum(col1) FROM tbl GROUP BY col", "input": "SELECT sum(col1) FROM tbl GROUP BY col",
...@@ -211,9 +220,11 @@ func TestRewriteAddOrderByNull(t *testing.T) { ...@@ -211,9 +220,11 @@ func TestRewriteAddOrderByNull(t *testing.T) {
t.Errorf("want: %s\ngot: %s", sql["output"], rw.NewSQL) t.Errorf("want: %s\ngot: %s", sql["output"], rw.NewSQL)
} }
} }
common.Log.Debug("Exiting function: %s", common.GetFunctionName())
} }
func TestRewriteRemoveDMLOrderBy(t *testing.T) { func TestRewriteRemoveDMLOrderBy(t *testing.T) {
common.Log.Debug("Entering function: %s", common.GetFunctionName())
testSQL := []map[string]string{ testSQL := []map[string]string{
{ {
"input": "DELETE FROM tbl WHERE col1=1 ORDER BY col", "input": "DELETE FROM tbl WHERE col1=1 ORDER BY col",
...@@ -230,9 +241,11 @@ func TestRewriteRemoveDMLOrderBy(t *testing.T) { ...@@ -230,9 +241,11 @@ func TestRewriteRemoveDMLOrderBy(t *testing.T) {
t.Errorf("want: %s\ngot: %s", sql["output"], rw.NewSQL) t.Errorf("want: %s\ngot: %s", sql["output"], rw.NewSQL)
} }
} }
common.Log.Debug("Exiting function: %s", common.GetFunctionName())
} }
func TestRewriteGroupByConst(t *testing.T) { func TestRewriteGroupByConst(t *testing.T) {
common.Log.Debug("Entering function: %s", common.GetFunctionName())
testSQL := []map[string]string{ testSQL := []map[string]string{
{ {
"input": "select 1;", "input": "select 1;",
...@@ -259,9 +272,11 @@ func TestRewriteGroupByConst(t *testing.T) { ...@@ -259,9 +272,11 @@ func TestRewriteGroupByConst(t *testing.T) {
t.Errorf("want: %s\ngot: %s", sql["output"], rw.NewSQL) t.Errorf("want: %s\ngot: %s", sql["output"], rw.NewSQL)
} }
} }
common.Log.Debug("Exiting function: %s", common.GetFunctionName())
} }
func TestRewriteStandard(t *testing.T) { func TestRewriteStandard(t *testing.T) {
common.Log.Debug("Entering function: %s", common.GetFunctionName())
testSQL := []map[string]string{ testSQL := []map[string]string{
{ {
"input": "SELECT sum(col1) FROM tbl GROUP BY 1;", "input": "SELECT sum(col1) FROM tbl GROUP BY 1;",
...@@ -274,9 +289,11 @@ func TestRewriteStandard(t *testing.T) { ...@@ -274,9 +289,11 @@ func TestRewriteStandard(t *testing.T) {
t.Errorf("want: %s\ngot: %s", sql["output"], rw.NewSQL) t.Errorf("want: %s\ngot: %s", sql["output"], rw.NewSQL)
} }
} }
common.Log.Debug("Exiting function: %s", common.GetFunctionName())
} }
func TestRewriteCountStar(t *testing.T) { func TestRewriteCountStar(t *testing.T) {
common.Log.Debug("Entering function: %s", common.GetFunctionName())
testSQL := []map[string]string{ testSQL := []map[string]string{
{ {
"input": "SELECT count(col) FROM tbl GROUP BY 1;", "input": "SELECT count(col) FROM tbl GROUP BY 1;",
...@@ -293,9 +310,11 @@ func TestRewriteCountStar(t *testing.T) { ...@@ -293,9 +310,11 @@ func TestRewriteCountStar(t *testing.T) {
t.Errorf("want: %s\ngot: %s", sql["output"], rw.NewSQL) t.Errorf("want: %s\ngot: %s", sql["output"], rw.NewSQL)
} }
} }
common.Log.Debug("Exiting function: %s", common.GetFunctionName())
} }
func TestRewriteInnoDB(t *testing.T) { func TestRewriteInnoDB(t *testing.T) {
common.Log.Debug("Entering function: %s", common.GetFunctionName())
testSQL := []map[string]string{ testSQL := []map[string]string{
{ {
"input": "CREATE TABLE t1(id bigint(20) NOT NULL AUTO_INCREMENT);", "input": "CREATE TABLE t1(id bigint(20) NOT NULL AUTO_INCREMENT);",
...@@ -312,9 +331,11 @@ func TestRewriteInnoDB(t *testing.T) { ...@@ -312,9 +331,11 @@ func TestRewriteInnoDB(t *testing.T) {
t.Errorf("want: %s\ngot: %s", sql["output"], rw.NewSQL) t.Errorf("want: %s\ngot: %s", sql["output"], rw.NewSQL)
} }
} }
common.Log.Debug("Exiting function: %s", common.GetFunctionName())
} }
func TestRewriteAutoIncrement(t *testing.T) { func TestRewriteAutoIncrement(t *testing.T) {
common.Log.Debug("Entering function: %s", common.GetFunctionName())
testSQL := []map[string]string{ testSQL := []map[string]string{
{ {
"input": "CREATE TABLE t1(id bigint(20) NOT NULL AUTO_INCREMENT) ENGINE=InnoDB AUTO_INCREMENT=123802;", "input": "CREATE TABLE t1(id bigint(20) NOT NULL AUTO_INCREMENT) ENGINE=InnoDB AUTO_INCREMENT=123802;",
...@@ -331,9 +352,11 @@ func TestRewriteAutoIncrement(t *testing.T) { ...@@ -331,9 +352,11 @@ func TestRewriteAutoIncrement(t *testing.T) {
t.Errorf("want: %s\ngot: %s", sql["output"], rw.NewSQL) t.Errorf("want: %s\ngot: %s", sql["output"], rw.NewSQL)
} }
} }
common.Log.Debug("Exiting function: %s", common.GetFunctionName())
} }
func TestRewriteIntWidth(t *testing.T) { func TestRewriteIntWidth(t *testing.T) {
common.Log.Debug("Entering function: %s", common.GetFunctionName())
testSQL := []map[string]string{ testSQL := []map[string]string{
{ {
"input": "CREATE TABLE t1(id bigint(10) NOT NULL AUTO_INCREMENT) ENGINE=InnoDB AUTO_INCREMENT=123802;", "input": "CREATE TABLE t1(id bigint(10) NOT NULL AUTO_INCREMENT) ENGINE=InnoDB AUTO_INCREMENT=123802;",
...@@ -358,9 +381,11 @@ func TestRewriteIntWidth(t *testing.T) { ...@@ -358,9 +381,11 @@ func TestRewriteIntWidth(t *testing.T) {
t.Errorf("want: %s\ngot: %s", sql["output"], rw.NewSQL) t.Errorf("want: %s\ngot: %s", sql["output"], rw.NewSQL)
} }
} }
common.Log.Debug("Exiting function: %s", common.GetFunctionName())
} }
func TestRewriteAlwaysTrue(t *testing.T) { func TestRewriteAlwaysTrue(t *testing.T) {
common.Log.Debug("Entering function: %s", common.GetFunctionName())
testSQL := []map[string]string{ testSQL := []map[string]string{
{ {
"input": "SELECT count(col) FROM tbl where 1=1;", "input": "SELECT count(col) FROM tbl where 1=1;",
...@@ -427,10 +452,12 @@ func TestRewriteAlwaysTrue(t *testing.T) { ...@@ -427,10 +452,12 @@ func TestRewriteAlwaysTrue(t *testing.T) {
t.Errorf("want: %s\ngot: %s", sql["output"], rw.NewSQL) t.Errorf("want: %s\ngot: %s", sql["output"], rw.NewSQL)
} }
} }
common.Log.Debug("Exiting function: %s", common.GetFunctionName())
} }
// TODO: // TODO:
func TestRewriteSubQuery2Join(t *testing.T) { func TestRewriteSubQuery2Join(t *testing.T) {
common.Log.Debug("Entering function: %s", common.GetFunctionName())
orgTestDSNStatus := common.Config.TestDSN.Disable orgTestDSNStatus := common.Config.TestDSN.Disable
common.Config.TestDSN.Disable = true common.Config.TestDSN.Disable = true
testSQL := []map[string]string{ testSQL := []map[string]string{
...@@ -458,9 +485,11 @@ func TestRewriteSubQuery2Join(t *testing.T) { ...@@ -458,9 +485,11 @@ func TestRewriteSubQuery2Join(t *testing.T) {
} }
} }
common.Config.TestDSN.Disable = orgTestDSNStatus common.Config.TestDSN.Disable = orgTestDSNStatus
common.Log.Debug("Exiting function: %s", common.GetFunctionName())
} }
func TestRewriteDML2Select(t *testing.T) { func TestRewriteDML2Select(t *testing.T) {
common.Log.Debug("Entering function: %s", common.GetFunctionName())
testSQL := []map[string]string{ testSQL := []map[string]string{
{ {
"input": "DELETE city, country FROM city INNER JOIN country using (country_id) WHERE city.city_id = 1;", "input": "DELETE city, country FROM city INNER JOIN country using (country_id) WHERE city.city_id = 1;",
...@@ -513,9 +542,11 @@ func TestRewriteDML2Select(t *testing.T) { ...@@ -513,9 +542,11 @@ func TestRewriteDML2Select(t *testing.T) {
t.Errorf("want: %s\ngot: %s", sql["output"], rw.NewSQL) t.Errorf("want: %s\ngot: %s", sql["output"], rw.NewSQL)
} }
} }
common.Log.Debug("Exiting function: %s", common.GetFunctionName())
} }
func TestRewriteDistinctStar(t *testing.T) { func TestRewriteDistinctStar(t *testing.T) {
common.Log.Debug("Entering function: %s", common.GetFunctionName())
testSQL := []map[string]string{ testSQL := []map[string]string{
{ {
"input": `SELECT DISTINCT * FROM film;`, "input": `SELECT DISTINCT * FROM film;`,
...@@ -549,9 +580,11 @@ func TestRewriteDistinctStar(t *testing.T) { ...@@ -549,9 +580,11 @@ func TestRewriteDistinctStar(t *testing.T) {
t.Errorf("want: %s\ngot: %s", sql["output"], rw.NewSQL) t.Errorf("want: %s\ngot: %s", sql["output"], rw.NewSQL)
} }
} }
common.Log.Debug("Exiting function: %s", common.GetFunctionName())
} }
func TestMergeAlterTables(t *testing.T) { func TestMergeAlterTables(t *testing.T) {
common.Log.Debug("Entering function: %s", common.GetFunctionName())
sqls := []string{ sqls := []string{
// ADD|DROP INDEX // ADD|DROP INDEX
// TODO: PRIMARY KEY, [UNIQUE|FULLTEXT|SPATIAL] INDEX // TODO: PRIMARY KEY, [UNIQUE|FULLTEXT|SPATIAL] INDEX
...@@ -602,9 +635,11 @@ func TestMergeAlterTables(t *testing.T) { ...@@ -602,9 +635,11 @@ func TestMergeAlterTables(t *testing.T) {
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} }
common.Log.Debug("Exiting function: %s", common.GetFunctionName())
} }
func TestRewriteUnionAll(t *testing.T) { func TestRewriteUnionAll(t *testing.T) {
common.Log.Debug("Entering function: %s", common.GetFunctionName())
testSQL := []map[string]string{ testSQL := []map[string]string{
{ {
"input": `select country_id from city union select country_id from country;`, "input": `select country_id from city union select country_id from country;`,
...@@ -617,8 +652,10 @@ func TestRewriteUnionAll(t *testing.T) { ...@@ -617,8 +652,10 @@ func TestRewriteUnionAll(t *testing.T) {
t.Errorf("want: %s\ngot: %s", sql["output"], rw.NewSQL) t.Errorf("want: %s\ngot: %s", sql["output"], rw.NewSQL)
} }
} }
common.Log.Debug("Exiting function: %s", common.GetFunctionName())
} }
func TestRewriteTruncate(t *testing.T) { func TestRewriteTruncate(t *testing.T) {
common.Log.Debug("Entering function: %s", common.GetFunctionName())
testSQL := []map[string]string{ testSQL := []map[string]string{
{ {
"input": `delete from tbl;`, "input": `delete from tbl;`,
...@@ -631,9 +668,11 @@ func TestRewriteTruncate(t *testing.T) { ...@@ -631,9 +668,11 @@ func TestRewriteTruncate(t *testing.T) {
t.Errorf("want: %s\ngot: %s", sql["output"], rw.NewSQL) t.Errorf("want: %s\ngot: %s", sql["output"], rw.NewSQL)
} }
} }
common.Log.Debug("Exiting function: %s", common.GetFunctionName())
} }
func TestRewriteOr2In(t *testing.T) { func TestRewriteOr2In(t *testing.T) {
common.Log.Debug("Entering function: %s", common.GetFunctionName())
testSQL := []map[string]string{ testSQL := []map[string]string{
{ {
"input": `select country_id from city where country_id = 1 or country_id = 2 or country_id = 3;`, "input": `select country_id from city where country_id = 1 or country_id = 2 or country_id = 3;`,
...@@ -672,9 +711,11 @@ func TestRewriteOr2In(t *testing.T) { ...@@ -672,9 +711,11 @@ func TestRewriteOr2In(t *testing.T) {
t.Errorf("want: %s\ngot: %s", sql["output"], rw.NewSQL) t.Errorf("want: %s\ngot: %s", sql["output"], rw.NewSQL)
} }
} }
common.Log.Debug("Exiting function: %s", common.GetFunctionName())
} }
func TestRmParenthesis(t *testing.T) { func TestRmParenthesis(t *testing.T) {
common.Log.Debug("Entering function: %s", common.GetFunctionName())
testSQL := []map[string]string{ testSQL := []map[string]string{
{ {
"input": `select country_id from city where (country_id = 1);`, "input": `select country_id from city where (country_id = 1);`,
...@@ -699,13 +740,16 @@ func TestRmParenthesis(t *testing.T) { ...@@ -699,13 +740,16 @@ func TestRmParenthesis(t *testing.T) {
t.Errorf("want: %s\ngot: %s", sql["output"], rw.NewSQL) t.Errorf("want: %s\ngot: %s", sql["output"], rw.NewSQL)
} }
} }
common.Log.Debug("Exiting function: %s", common.GetFunctionName())
} }
func TestListRewriteRules(t *testing.T) { func TestListRewriteRules(t *testing.T) {
common.Log.Debug("Entering function: %s", common.GetFunctionName())
err := common.GoldenDiff(func() { err := common.GoldenDiff(func() {
ListRewriteRules(RewriteRules) ListRewriteRules(RewriteRules)
}, t.Name(), update) }, t.Name(), update)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} }
common.Log.Debug("Exiting function: %s", common.GetFunctionName())
} }
...@@ -26,6 +26,7 @@ import ( ...@@ -26,6 +26,7 @@ import (
) )
func TestTokenize(t *testing.T) { func TestTokenize(t *testing.T) {
common.Log.Debug("Entering function: %s", common.GetFunctionName())
err := common.GoldenDiff(func() { err := common.GoldenDiff(func() {
for _, sql := range common.TestSQLs { for _, sql := range common.TestSQLs {
fmt.Println(sql) fmt.Println(sql)
...@@ -35,9 +36,11 @@ func TestTokenize(t *testing.T) { ...@@ -35,9 +36,11 @@ func TestTokenize(t *testing.T) {
if nil != err { if nil != err {
t.Fatal(err) t.Fatal(err)
} }
common.Log.Debug("Exiting function: %s", common.GetFunctionName())
} }
func TestTokenizer(t *testing.T) { func TestTokenizer(t *testing.T) {
common.Log.Debug("Entering function: %s", common.GetFunctionName())
sqls := []string{ sqls := []string{
"select c1,c2,c3 from t1,t2 join t3 on t1.c1=t2.c1 and t1.c3=t3.c1 where id>1000", "select c1,c2,c3 from t1,t2 join t3 on t1.c1=t2.c1 and t1.c3=t3.c1 where id>1000",
"select sourcetable, if(f.lastcontent = ?, f.lastupdate, f.lastcontent) as lastactivity, f.totalcount as activity, type.class as type, (f.nodeoptions & ?) as nounsubscribe from node as f inner join contenttype as type on type.contenttypeid = f.contenttypeid inner join subscribed as sd on sd.did = f.nodeid and sd.userid = ? union all select f.name as title, f.userid as keyval, ? as sourcetable, ifnull(f.lastpost, f.joindate) as lastactivity, f.posts as activity, ? as type, ? as nounsubscribe from user as f inner join userlist as ul on ul.relationid = f.userid and ul.userid = ? where ul.type = ? and ul.aq = ? order by title limit ?", "select sourcetable, if(f.lastcontent = ?, f.lastupdate, f.lastcontent) as lastactivity, f.totalcount as activity, type.class as type, (f.nodeoptions & ?) as nounsubscribe from node as f inner join contenttype as type on type.contenttypeid = f.contenttypeid inner join subscribed as sd on sd.did = f.nodeid and sd.userid = ? union all select f.name as title, f.userid as keyval, ? as sourcetable, ifnull(f.lastpost, f.joindate) as lastactivity, f.posts as activity, ? as type, ? as nounsubscribe from user as f inner join userlist as ul on ul.relationid = f.userid and ul.userid = ? where ul.type = ? and ul.aq = ? order by title limit ?",
...@@ -57,9 +60,11 @@ func TestTokenizer(t *testing.T) { ...@@ -57,9 +60,11 @@ func TestTokenizer(t *testing.T) {
if nil != err { if nil != err {
t.Fatal(err) t.Fatal(err)
} }
common.Log.Debug("Exiting function: %s", common.GetFunctionName())
} }
func TestGetQuotedString(t *testing.T) { func TestGetQuotedString(t *testing.T) {
common.Log.Debug("Entering function: %s", common.GetFunctionName())
var str = []string{ var str = []string{
`"hello world"`, `"hello world"`,
"`hello world`", "`hello world`",
...@@ -82,9 +87,11 @@ func TestGetQuotedString(t *testing.T) { ...@@ -82,9 +87,11 @@ func TestGetQuotedString(t *testing.T) {
if nil != err { if nil != err {
t.Fatal(err) t.Fatal(err)
} }
common.Log.Debug("Exiting function: %s", common.GetFunctionName())
} }
func TestCompress(t *testing.T) { func TestCompress(t *testing.T) {
common.Log.Debug("Entering function: %s", common.GetFunctionName())
err := common.GoldenDiff(func() { err := common.GoldenDiff(func() {
for _, sql := range common.TestSQLs { for _, sql := range common.TestSQLs {
fmt.Println(sql) fmt.Println(sql)
...@@ -94,10 +101,11 @@ func TestCompress(t *testing.T) { ...@@ -94,10 +101,11 @@ func TestCompress(t *testing.T) {
if nil != err { if nil != err {
t.Fatal(err) t.Fatal(err)
} }
common.Log.Debug("Exiting function: %s", common.GetFunctionName())
} }
func TestFormat(t *testing.T) { func TestFormat(t *testing.T) {
common.Log.Debug("Entering function: %s", common.GetFunctionName())
err := common.GoldenDiff(func() { err := common.GoldenDiff(func() {
for _, sql := range common.TestSQLs { for _, sql := range common.TestSQLs {
fmt.Println(sql) fmt.Println(sql)
...@@ -107,9 +115,11 @@ func TestFormat(t *testing.T) { ...@@ -107,9 +115,11 @@ func TestFormat(t *testing.T) {
if nil != err { if nil != err {
t.Fatal(err) t.Fatal(err)
} }
common.Log.Debug("Exiting function: %s", common.GetFunctionName())
} }
func TestSplitStatement(t *testing.T) { func TestSplitStatement(t *testing.T) {
common.Log.Debug("Entering function: %s", common.GetFunctionName())
bufs := [][]byte{ bufs := [][]byte{
[]byte("select * from test;hello"), []byte("select * from test;hello"),
[]byte("select 'asd;fas', col from test;hello"), []byte("select 'asd;fas', col from test;hello"),
...@@ -181,9 +191,11 @@ select col from tb; ...@@ -181,9 +191,11 @@ select col from tb;
if nil != err { if nil != err {
t.Fatal(err) t.Fatal(err)
} }
common.Log.Debug("Exiting function: %s", common.GetFunctionName())
} }
func TestLeftNewLines(t *testing.T) { func TestLeftNewLines(t *testing.T) {
common.Log.Debug("Entering function: %s", common.GetFunctionName())
bufs := [][]byte{ bufs := [][]byte{
[]byte(` []byte(`
select * from test;hello`), select * from test;hello`),
...@@ -200,9 +212,11 @@ func TestLeftNewLines(t *testing.T) { ...@@ -200,9 +212,11 @@ func TestLeftNewLines(t *testing.T) {
if nil != err { if nil != err {
t.Fatal(err) t.Fatal(err)
} }
common.Log.Debug("Exiting function: %s", common.GetFunctionName())
} }
func TestNewLines(t *testing.T) { func TestNewLines(t *testing.T) {
common.Log.Debug("Entering function: %s", common.GetFunctionName())
bufs := [][]byte{ bufs := [][]byte{
[]byte(` []byte(`
select * from test;hello`), select * from test;hello`),
...@@ -219,4 +233,5 @@ func TestNewLines(t *testing.T) { ...@@ -219,4 +233,5 @@ func TestNewLines(t *testing.T) {
if nil != err { if nil != err {
t.Fatal(err) t.Fatal(err)
} }
common.Log.Debug("Exiting function: %s", common.GetFunctionName())
} }
...@@ -56,6 +56,7 @@ type Configuration struct { ...@@ -56,6 +56,7 @@ type Configuration struct {
OnlySyntaxCheck bool `yaml:"only-syntax-check"` // 只做语法检查不输出优化建议 OnlySyntaxCheck bool `yaml:"only-syntax-check"` // 只做语法检查不输出优化建议
SamplingStatisticTarget int `yaml:"sampling-statistic-target"` // 数据采样因子,对应 PostgreSQL 的 default_statistics_target SamplingStatisticTarget int `yaml:"sampling-statistic-target"` // 数据采样因子,对应 PostgreSQL 的 default_statistics_target
Sampling bool `yaml:"sampling"` // 数据采样开关 Sampling bool `yaml:"sampling"` // 数据采样开关
SamplingCondition string `yaml:"sampling-condition"` // 指定采样条件,如:WHERE xxx LIMIT xxx;
Profiling bool `yaml:"profiling"` // 在开启数据采样的情况下,在测试环境执行进行profile Profiling bool `yaml:"profiling"` // 在开启数据采样的情况下,在测试环境执行进行profile
Trace bool `yaml:"trace"` // 在开启数据采样的情况下,在测试环境执行进行Trace Trace bool `yaml:"trace"` // 在开启数据采样的情况下,在测试环境执行进行Trace
Explain bool `yaml:"explain"` // Explain开关 Explain bool `yaml:"explain"` // Explain开关
...@@ -506,6 +507,7 @@ func readCmdFlags() error { ...@@ -506,6 +507,7 @@ func readCmdFlags() error {
explain := flag.Bool("explain", Config.Explain, "Explain, 是否开启Explain执行计划分析") explain := flag.Bool("explain", Config.Explain, "Explain, 是否开启Explain执行计划分析")
sampling := flag.Bool("sampling", Config.Sampling, "Sampling, 数据采样开关") sampling := flag.Bool("sampling", Config.Sampling, "Sampling, 数据采样开关")
samplingStatisticTarget := flag.Int("sampling-statistic-target", Config.SamplingStatisticTarget, "SamplingStatisticTarget, 数据采样因子,对应 PostgreSQL 的 default_statistics_target") samplingStatisticTarget := flag.Int("sampling-statistic-target", Config.SamplingStatisticTarget, "SamplingStatisticTarget, 数据采样因子,对应 PostgreSQL 的 default_statistics_target")
samplingCondition := flag.String("sampling-condition", Config.SamplingCondition, "SamplingCondition, 数据采样条件,如: WHERE xxx LIMIT xxx")
delimiter := flag.String("delimiter", Config.Delimiter, "Delimiter, SQL分隔符") delimiter := flag.String("delimiter", Config.Delimiter, "Delimiter, SQL分隔符")
// +++++++++++++++日志相关+++++++++++++++++ // +++++++++++++++日志相关+++++++++++++++++
logLevel := flag.Int("log-level", Config.LogLevel, "LogLevel, 日志级别, [0:Emergency, 1:Alert, 2:Critical, 3:Error, 4:Warning, 5:Notice, 6:Informational, 7:Debug]") logLevel := flag.Int("log-level", Config.LogLevel, "LogLevel, 日志级别, [0:Emergency, 1:Alert, 2:Critical, 3:Error, 4:Warning, 5:Notice, 6:Informational, 7:Debug]")
...@@ -585,6 +587,7 @@ func readCmdFlags() error { ...@@ -585,6 +587,7 @@ func readCmdFlags() error {
Config.Explain = *explain Config.Explain = *explain
Config.Sampling = *sampling Config.Sampling = *sampling
Config.SamplingStatisticTarget = *samplingStatisticTarget Config.SamplingStatisticTarget = *samplingStatisticTarget
Config.SamplingCondition = *samplingCondition
Config.LogLevel = *logLevel Config.LogLevel = *logLevel
if strings.HasPrefix(*logOutput, "/") { if strings.HasPrefix(*logOutput, "/") {
......
...@@ -26,6 +26,10 @@ import ( ...@@ -26,6 +26,10 @@ import (
var update = flag.Bool("update", false, "update .golden files") var update = flag.Bool("update", false, "update .golden files")
func init() {
BaseDir = DevPath
}
func TestParseConfig(t *testing.T) { func TestParseConfig(t *testing.T) {
err := ParseConfig("") err := ParseConfig("")
if err != nil { if err != nil {
...@@ -37,7 +41,7 @@ func TestReadConfigFile(t *testing.T) { ...@@ -37,7 +41,7 @@ func TestReadConfigFile(t *testing.T) {
if Config == nil { if Config == nil {
Config = new(Configuration) Config = new(Configuration)
} }
Config.readConfigFile("../soar.yaml") Config.readConfigFile(DevPath + "/soar.yaml")
} }
func TestParseDSN(t *testing.T) { func TestParseDSN(t *testing.T) {
......
...@@ -21,15 +21,11 @@ import ( ...@@ -21,15 +21,11 @@ import (
"testing" "testing"
) )
func init() {
BaseDir = DevPath
}
func TestLogger(t *testing.T) { func TestLogger(t *testing.T) {
Log.Info("info") Log.Info("TestLogger_Info")
Log.Debug("debug") Log.Debug("TestLogger_Debug")
Log.Warning("warning") Log.Warning("TestLogger_Warning")
Log.Error("error") Log.Error("Warning_Error")
} }
func TestCaller(t *testing.T) { func TestCaller(t *testing.T) {
...@@ -47,7 +43,7 @@ func TestGetFunctionName(t *testing.T) { ...@@ -47,7 +43,7 @@ func TestGetFunctionName(t *testing.T) {
} }
func TestIfError(t *testing.T) { func TestIfError(t *testing.T) {
err := errors.New("test") err := errors.New("TestIfError")
LogIfError(err, "") LogIfError(err, "")
LogIfError(err, "func %s", "func_test") LogIfError(err, "func %s", "func_test")
} }
......
...@@ -2332,6 +2332,7 @@ possible_keys: idx_fk_country_id,idx_country_id_city,idx_all,idx_other ...@@ -2332,6 +2332,7 @@ possible_keys: idx_fk_country_id,idx_country_id_city,idx_all,idx_other
} }
func TestExplain(t *testing.T) { func TestExplain(t *testing.T) {
common.Log.Debug("Enter function: %s", common.GetFunctionName())
// TraditionalFormatExplain // TraditionalFormatExplain
for idx, sql := range sqls { for idx, sql := range sqls {
exp, err := connTest.Explain(sql, TraditionalExplainType, TraditionalFormatExplain) exp, err := connTest.Explain(sql, TraditionalExplainType, TraditionalFormatExplain)
...@@ -2350,9 +2351,11 @@ func TestExplain(t *testing.T) { ...@@ -2350,9 +2351,11 @@ func TestExplain(t *testing.T) {
pretty.Println("No.:", idx, "\nOld: ", sql, "\nNew: ", exp.SQL) pretty.Println("No.:", idx, "\nOld: ", sql, "\nNew: ", exp.SQL)
pretty.Println(exp) pretty.Println(exp)
} }
common.Log.Debug("Exiting function: %s", common.GetFunctionName())
} }
func TestParseExplainText(t *testing.T) { func TestParseExplainText(t *testing.T) {
common.Log.Debug("Enter function: %s", common.GetFunctionName())
for _, content := range exp { for _, content := range exp {
pretty.Println(RemoveSQLComments(content)) pretty.Println(RemoveSQLComments(content))
pretty.Println(ParseExplainText(content)) pretty.Println(ParseExplainText(content))
...@@ -2364,26 +2367,32 @@ func TestParseExplainText(t *testing.T) { ...@@ -2364,26 +2367,32 @@ func TestParseExplainText(t *testing.T) {
pretty.Println(explainInfo) pretty.Println(explainInfo)
fmt.Println(err) fmt.Println(err)
*/ */
common.Log.Debug("Exiting function: %s", common.GetFunctionName())
} }
func TestFindTablesInJson(t *testing.T) { func TestFindTablesInJson(t *testing.T) {
common.Log.Debug("Enter function: %s", common.GetFunctionName())
idx := 9 idx := 9
for _, j := range exp[idx : idx+1] { for _, j := range exp[idx : idx+1] {
pretty.Println(j) pretty.Println(j)
findTablesInJSON(j, 0) findTablesInJSON(j, 0)
} }
pretty.Println(len(explainJSONTables), explainJSONTables) pretty.Println(len(explainJSONTables), explainJSONTables)
common.Log.Debug("Exiting function: %s", common.GetFunctionName())
} }
func TestFormatJsonIntoTraditional(t *testing.T) { func TestFormatJsonIntoTraditional(t *testing.T) {
common.Log.Debug("Enter function: %s", common.GetFunctionName())
idx := 11 idx := 11
for _, j := range exp[idx : idx+1] { for _, j := range exp[idx : idx+1] {
pretty.Println(j) pretty.Println(j)
pretty.Println(FormatJSONIntoTraditional(j)) pretty.Println(FormatJSONIntoTraditional(j))
} }
common.Log.Debug("Exiting function: %s", common.GetFunctionName())
} }
func TestPrintMarkdownExplainTable(t *testing.T) { func TestPrintMarkdownExplainTable(t *testing.T) {
common.Log.Debug("Enter function: %s", common.GetFunctionName())
expInfo, err := connTest.Explain("select 1", TraditionalExplainType, TraditionalFormatExplain) expInfo, err := connTest.Explain("select 1", TraditionalExplainType, TraditionalFormatExplain)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
...@@ -2395,9 +2404,11 @@ func TestPrintMarkdownExplainTable(t *testing.T) { ...@@ -2395,9 +2404,11 @@ func TestPrintMarkdownExplainTable(t *testing.T) {
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} }
common.Log.Debug("Exiting function: %s", common.GetFunctionName())
} }
func TestExplainInfoTranslator(t *testing.T) { func TestExplainInfoTranslator(t *testing.T) {
common.Log.Debug("Enter function: %s", common.GetFunctionName())
expInfo, err := connTest.Explain("select 1", TraditionalExplainType, TraditionalFormatExplain) expInfo, err := connTest.Explain("select 1", TraditionalExplainType, TraditionalFormatExplain)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
...@@ -2408,9 +2419,11 @@ func TestExplainInfoTranslator(t *testing.T) { ...@@ -2408,9 +2419,11 @@ func TestExplainInfoTranslator(t *testing.T) {
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} }
common.Log.Debug("Exiting function: %s", common.GetFunctionName())
} }
func TestMySQLExplainWarnings(t *testing.T) { func TestMySQLExplainWarnings(t *testing.T) {
common.Log.Debug("Enter function: %s", common.GetFunctionName())
expInfo, err := connTest.Explain("select 1", TraditionalExplainType, TraditionalFormatExplain) expInfo, err := connTest.Explain("select 1", TraditionalExplainType, TraditionalFormatExplain)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
...@@ -2421,9 +2434,11 @@ func TestMySQLExplainWarnings(t *testing.T) { ...@@ -2421,9 +2434,11 @@ func TestMySQLExplainWarnings(t *testing.T) {
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} }
common.Log.Debug("Exiting function: %s", common.GetFunctionName())
} }
func TestMySQLExplainQueryCost(t *testing.T) { func TestMySQLExplainQueryCost(t *testing.T) {
common.Log.Debug("Enter function: %s", common.GetFunctionName())
expInfo, err := connTest.Explain("select 1", TraditionalExplainType, TraditionalFormatExplain) expInfo, err := connTest.Explain("select 1", TraditionalExplainType, TraditionalFormatExplain)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
...@@ -2434,19 +2449,24 @@ func TestMySQLExplainQueryCost(t *testing.T) { ...@@ -2434,19 +2449,24 @@ func TestMySQLExplainQueryCost(t *testing.T) {
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} }
common.Log.Debug("Exiting function: %s", common.GetFunctionName())
} }
func TestSupportExplainWrite(t *testing.T) { func TestSupportExplainWrite(t *testing.T) {
common.Log.Debug("Enter function: %s", common.GetFunctionName())
_, err := connTest.supportExplainWrite() _, err := connTest.supportExplainWrite()
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} }
common.Log.Debug("Exiting function: %s", common.GetFunctionName())
} }
func TestExplainAbleSQL(t *testing.T) { func TestExplainAbleSQL(t *testing.T) {
common.Log.Debug("Enter function: %s", common.GetFunctionName())
for _, sql := range sqls { for _, sql := range sqls {
if _, err := connTest.explainAbleSQL(sql); err != nil { if _, err := connTest.explainAbleSQL(sql); err != nil {
t.Errorf("SQL: %s, not explain able", sql) t.Errorf("SQL: %s, not explain able", sql)
} }
} }
common.Log.Debug("Exiting function: %s", common.GetFunctionName())
} }
...@@ -104,8 +104,9 @@ func (db *Connector) Query(sql string, params ...interface{}) (QueryResult, erro ...@@ -104,8 +104,9 @@ func (db *Connector) Query(sql string, params ...interface{}) (QueryResult, erro
if common.Config.ShowLastQueryCost { if common.Config.ShowLastQueryCost {
cost, err := db.Conn.Query("SHOW SESSION STATUS LIKE 'last_query_cost'") cost, err := db.Conn.Query("SHOW SESSION STATUS LIKE 'last_query_cost'")
if err == nil { if err == nil {
var varName string
if cost.Next() { if cost.Next() {
err = cost.Scan(res.QueryCost) err = cost.Scan(&varName, &res.QueryCost)
common.LogIfError(err, "") common.LogIfError(err, "")
} }
if err := cost.Close(); err != nil { if err := cost.Close(); err != nil {
......
...@@ -48,6 +48,7 @@ func init() { ...@@ -48,6 +48,7 @@ func init() {
} }
func TestQuery(t *testing.T) { func TestQuery(t *testing.T) {
common.Log.Debug("Entering function: %s", common.GetFunctionName())
res, err := connTest.Query("select 0") res, err := connTest.Query("select 0")
if err != nil { if err != nil {
t.Error(err.Error()) t.Error(err.Error())
...@@ -64,9 +65,11 @@ func TestQuery(t *testing.T) { ...@@ -64,9 +65,11 @@ func TestQuery(t *testing.T) {
} }
res.Rows.Close() res.Rows.Close()
// TODO: timeout test // TODO: timeout test
common.Log.Debug("Exiting function: %s", common.GetFunctionName())
} }
func TestColumnCardinality(t *testing.T) { func TestColumnCardinality(t *testing.T) {
common.Log.Debug("Entering function: %s", common.GetFunctionName())
orgDatabase := connTest.Database orgDatabase := connTest.Database
connTest.Database = "sakila" connTest.Database = "sakila"
a := connTest.ColumnCardinality("actor", "first_name") a := connTest.ColumnCardinality("actor", "first_name")
...@@ -74,9 +77,11 @@ func TestColumnCardinality(t *testing.T) { ...@@ -74,9 +77,11 @@ func TestColumnCardinality(t *testing.T) {
t.Error("sakila.actor.first_name cardinality should in [0, 1], now it's", a) t.Error("sakila.actor.first_name cardinality should in [0, 1], now it's", a)
} }
connTest.Database = orgDatabase connTest.Database = orgDatabase
common.Log.Debug("Exiting function: %s", common.GetFunctionName())
} }
func TestDangerousSQL(t *testing.T) { func TestDangerousSQL(t *testing.T) {
common.Log.Debug("Entering function: %s", common.GetFunctionName())
testCase := map[string]bool{ testCase := map[string]bool{
"select * from tb;delete from tb;": true, "select * from tb;delete from tb;": true,
"show database;": false, "show database;": false,
...@@ -91,9 +96,11 @@ func TestDangerousSQL(t *testing.T) { ...@@ -91,9 +96,11 @@ func TestDangerousSQL(t *testing.T) {
t.Errorf("SQL:%s got:%v want:%v", sql, got, want) t.Errorf("SQL:%s got:%v want:%v", sql, got, want)
} }
} }
common.Log.Debug("Exiting function: %s", common.GetFunctionName())
} }
func TestWarningsAndQueryCost(t *testing.T) { func TestWarningsAndQueryCost(t *testing.T) {
common.Log.Debug("Entering function: %s", common.GetFunctionName())
common.Config.ShowWarnings = true common.Config.ShowWarnings = true
common.Config.ShowLastQueryCost = true common.Config.ShowLastQueryCost = true
res, err := connTest.Query("explain select * from sakila.film") res, err := connTest.Query("explain select * from sakila.film")
...@@ -111,17 +118,21 @@ func TestWarningsAndQueryCost(t *testing.T) { ...@@ -111,17 +118,21 @@ func TestWarningsAndQueryCost(t *testing.T) {
res.Warning.Close() res.Warning.Close()
fmt.Println(res.QueryCost, err) fmt.Println(res.QueryCost, err)
} }
common.Log.Debug("Exiting function: %s", common.GetFunctionName())
} }
func TestVersion(t *testing.T) { func TestVersion(t *testing.T) {
common.Log.Debug("Entering function: %s", common.GetFunctionName())
version, err := connTest.Version() version, err := connTest.Version()
if err != nil { if err != nil {
t.Error(err.Error()) t.Error(err.Error())
} }
fmt.Println(version) fmt.Println(version)
common.Log.Debug("Exiting function: %s", common.GetFunctionName())
} }
func TestRemoveSQLComments(t *testing.T) { func TestRemoveSQLComments(t *testing.T) {
common.Log.Debug("Entering function: %s", common.GetFunctionName())
SQLs := []string{ SQLs := []string{
`-- comment`, `-- comment`,
`--`, `--`,
...@@ -140,9 +151,11 @@ comment*/`, ...@@ -140,9 +151,11 @@ comment*/`,
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} }
common.Log.Debug("Exiting function: %s", common.GetFunctionName())
} }
func TestSingleIntValue(t *testing.T) { func TestSingleIntValue(t *testing.T) {
common.Log.Debug("Entering function: %s", common.GetFunctionName())
val, err := connTest.SingleIntValue("read_only") val, err := connTest.SingleIntValue("read_only")
if err != nil { if err != nil {
t.Error(err) t.Error(err)
...@@ -150,13 +163,16 @@ func TestSingleIntValue(t *testing.T) { ...@@ -150,13 +163,16 @@ func TestSingleIntValue(t *testing.T) {
if val < 0 { if val < 0 {
t.Error("SingleIntValue, return should large than zero") t.Error("SingleIntValue, return should large than zero")
} }
common.Log.Debug("Exiting function: %s", common.GetFunctionName())
} }
func TestIsView(t *testing.T) { func TestIsView(t *testing.T) {
common.Log.Debug("Entering function: %s", common.GetFunctionName())
originalDatabase := connTest.Database originalDatabase := connTest.Database
connTest.Database = "sakila" connTest.Database = "sakila"
if !connTest.IsView("actor_info") { if !connTest.IsView("actor_info") {
t.Error("actor_info should be a VIEW") t.Error("actor_info should be a VIEW")
} }
connTest.Database = originalDatabase connTest.Database = originalDatabase
common.Log.Debug("Exiting function: %s", common.GetFunctionName())
} }
...@@ -16,9 +16,14 @@ ...@@ -16,9 +16,14 @@
package database package database
import "testing" import (
"testing"
"github.com/XiaoMi/soar/common"
)
func TestCurrentUser(t *testing.T) { func TestCurrentUser(t *testing.T) {
common.Log.Debug("Entering function: %s", common.GetFunctionName())
user, host, err := connTest.CurrentUser() user, host, err := connTest.CurrentUser()
if err != nil { if err != nil {
t.Error(err.Error()) t.Error(err.Error())
...@@ -26,16 +31,21 @@ func TestCurrentUser(t *testing.T) { ...@@ -26,16 +31,21 @@ func TestCurrentUser(t *testing.T) {
if user != "root" || host != "%" { if user != "root" || host != "%" {
t.Errorf("Want user: root, host: %%. Get user: %s, host: %s", user, host) t.Errorf("Want user: root, host: %%. Get user: %s, host: %s", user, host)
} }
common.Log.Debug("Exiting function: %s", common.GetFunctionName())
} }
func TestHasSelectPrivilege(t *testing.T) { func TestHasSelectPrivilege(t *testing.T) {
common.Log.Debug("Entering function: %s", common.GetFunctionName())
if !connTest.HasSelectPrivilege() { if !connTest.HasSelectPrivilege() {
t.Errorf("DSN: %s, User: %s, should has select privilege", connTest.Addr, connTest.User) t.Errorf("DSN: %s, User: %s, should has select privilege", connTest.Addr, connTest.User)
} }
common.Log.Debug("Exiting function: %s", common.GetFunctionName())
} }
func TestHasAllPrivilege(t *testing.T) { func TestHasAllPrivilege(t *testing.T) {
common.Log.Debug("Entering function: %s", common.GetFunctionName())
if !connTest.HasAllPrivilege() { if !connTest.HasAllPrivilege() {
t.Errorf("DSN: %s, User: %s, should has all privilege", connTest.Addr, connTest.User) t.Errorf("DSN: %s, User: %s, should has all privilege", connTest.Addr, connTest.User)
} }
common.Log.Debug("Exiting function: %s", common.GetFunctionName())
} }
...@@ -19,21 +19,26 @@ package database ...@@ -19,21 +19,26 @@ package database
import ( import (
"testing" "testing"
"github.com/XiaoMi/soar/common"
"github.com/kr/pretty" "github.com/kr/pretty"
) )
func TestProfiling(t *testing.T) { func TestProfiling(t *testing.T) {
common.Log.Debug("Entering function: %s", common.GetFunctionName())
rows, err := connTest.Profiling("select 1") rows, err := connTest.Profiling("select 1")
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} }
pretty.Println(rows) pretty.Println(rows)
common.Log.Debug("Exiting function: %s", common.GetFunctionName())
} }
func TestFormatProfiling(t *testing.T) { func TestFormatProfiling(t *testing.T) {
common.Log.Debug("Entering function: %s", common.GetFunctionName())
res, err := connTest.Profiling("select 1") res, err := connTest.Profiling("select 1")
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} }
pretty.Println(FormatProfiling(res)) pretty.Println(FormatProfiling(res))
common.Log.Debug("Exiting function: %s", common.GetFunctionName())
} }
...@@ -17,11 +17,15 @@ ...@@ -17,11 +17,15 @@
package database package database
import ( import (
"database/sql"
"fmt" "fmt"
"time"
"github.com/XiaoMi/soar/common"
"strings" "strings"
"database/sql"
"github.com/XiaoMi/soar/common"
"github.com/ziutek/mymysql/mysql"
) )
/*-------------------- /*--------------------
...@@ -44,99 +48,125 @@ import ( ...@@ -44,99 +48,125 @@ import (
*-------------------- *--------------------
*/ */
// SamplingData 将数据从Remote拉取到 db 中 // SamplingData 将数据从 onlineConn 拉取到 db 中
func (db *Connector) SamplingData(remote *Connector, tables ...string) error { func (db *Connector) SamplingData(onlineConn *Connector, database string, tables ...string) error {
var err error
if database == db.Database {
return fmt.Errorf("SamplingData the same database, From: %s/%s, To: %s/%s", onlineConn.Addr, database, db.Addr, db.Database)
}
// 计算需要泵取的数据量 // 计算需要泵取的数据量
wantRowsCount := 300 * common.Config.SamplingStatisticTarget wantRowsCount := 300 * common.Config.SamplingStatisticTarget
// 设置数据采样单条 SQL 中 value 的数量
// 该数值越大,在内存中缓存的data就越多,但相对的,插入时速度就越快
maxValCount := 200
for _, table := range tables { for _, table := range tables {
// 表类型检查 // 表类型检查
if remote.IsView(table) { if onlineConn.IsView(table) {
return nil
}
tableStatus, err := remote.ShowTableStatus(table)
if err != nil {
return err
}
if len(tableStatus.Rows) == 0 {
common.Log.Info("SamplingData, Table %s with no data, stop sampling", table)
return nil return nil
} }
tableRows := tableStatus.Rows[0].Rows // generate where condition
if tableRows == 0 { var where string
common.Log.Info("SamplingData, Table %s with no data, stop sampling", table) if common.Config.SamplingCondition == "" {
return nil tableStatus, err := onlineConn.ShowTableStatus(table)
if err != nil {
return err
}
if len(tableStatus.Rows) == 0 {
common.Log.Info("SamplingData, Table %s with no data, stop sampling", table)
return nil
}
tableRows := tableStatus.Rows[0].Rows
if tableRows == 0 {
common.Log.Info("SamplingData, Table %s with no data, stop sampling", table)
return nil
}
factor := float64(wantRowsCount) / float64(tableRows)
common.Log.Debug("SamplingData, tableRows: %d, wantRowsCount: %d, factor: %f", tableRows, wantRowsCount, factor)
where = fmt.Sprintf("WHERE RAND() <= %f LIMIT %d", factor, wantRowsCount)
if factor >= 1 {
where = ""
}
} else {
where = common.Config.SamplingCondition
} }
factor := float64(wantRowsCount) / float64(tableRows) err = db.startSampling(onlineConn.Conn, database, table, where)
common.Log.Debug("SamplingData, tableRows: %d, wantRowsCount: %d, factor: %f", tableRows, wantRowsCount, factor)
err = startSampling(remote.Conn, db.Conn, db.Database, table, factor, wantRowsCount, maxValCount)
if err != nil {
common.Log.Error("(db *Connector) SamplingData Error : %v", err)
}
} }
return nil return err
} }
// startSampling sampling data from OnlineDSN to TestDSN // startSampling sampling data from OnlineDSN to TestDSN
// 因为涉及到的数据量问题,所以泵取与插入时同时进行的 func (db *Connector) startSampling(onlineConn *sql.DB, database, table string, where string) error {
// TODO: 加 ref link samplingQuery := fmt.Sprintf("SELECT * FROM `%s`.`%s` %s", database, table, where)
func startSampling(conn, localConn *sql.DB, database, table string, factor float64, wants, maxValCount int) error { common.Log.Debug("startSampling with Query: %s", samplingQuery)
// generate where condition res, err := onlineConn.Query(samplingQuery)
where := fmt.Sprintf("WHERE RAND() <= %f", factor)
if factor >= 1 {
where = ""
}
res, err := conn.Query(fmt.Sprintf("SELECT * FROM `%s`.`%s` %s LIMIT %d;", database, table, where, wants))
if err != nil { if err != nil {
return err return err
} }
// column info // columns list
columns, err := res.Columns() columns, err := res.Columns()
if err != nil { if err != nil {
return err return err
} }
row := make(map[string][]byte, len(columns)) row := make([][]byte, len(columns))
tableFields := make([]interface{}, 0) tableFields := make([]interface{}, 0)
for _, col := range columns { for i := range columns {
if _, ok := row[col]; ok { tableFields = append(tableFields, &row[i])
tableFields = append(tableFields, row[col]) }
} columnTypes, err := res.ColumnTypes()
if err != nil {
return err
} }
// sampling data // sampling data
var valuesStr string var valuesCount int
var values []string var valuesStr []string
maxValuesCount := 200 // one time insert values count, TODO: config able
columnsStr := "`" + strings.Join(columns, "`,`") + "`" columnsStr := "`" + strings.Join(columns, "`,`") + "`"
for res.Next() { for res.Next() {
var values []string
res.Scan(tableFields...) res.Scan(tableFields...)
for _, val := range row { for i, val := range row {
values = append(values, fmt.Sprintf(`unhex("%s")`, fmt.Sprintf("%x", val))) if val == nil {
values = append(values, "NULL")
} else {
switch columnTypes[i].DatabaseTypeName() {
case "TIMESTAMP", "DATETIME":
t, err := time.Parse(time.RFC3339, string(val))
common.LogIfWarn(err, "")
values = append(values, fmt.Sprintf(`"%s"`, mysql.TimeString(t)))
default:
values = append(values, fmt.Sprintf(`unhex("%s")`, fmt.Sprintf("%x", val)))
}
}
}
valuesStr = append(valuesStr, "("+strings.Join(values, `,`)+")")
valuesCount++
if maxValuesCount <= valuesCount {
err = db.doSampling(table, columnsStr, strings.Join(valuesStr, `,`))
if err != nil {
break
}
values = make([]string, 0)
valuesStr = make([]string, 0)
valuesCount = 0
} }
valuesStr = fmt.Sprintf(`(%s)`, strings.Join(values, `,`))
doSampling(localConn, database, table, columnsStr, valuesStr)
} }
res.Close() res.Close()
return nil return err
} }
// 将泵取的数据转换成Insert语句并在数据库中执行 // 将泵取的数据转换成 insert 语句并在 testConn 数据库中执行
func doSampling(conn *sql.DB, dbName, table, colDef, values string) { func (db *Connector) doSampling(table, colDef, values string) error {
query := fmt.Sprintf("INSERT INTO `%s`.`%s` (%s) VALUES %s;", dbName, table, // db.Database is hashed database name
colDef, values) query := fmt.Sprintf("INSERT INTO `%s`.`%s` (%s) VALUES %s;", db.Database, table, colDef, values)
res, err := db.Query(query)
_, err := conn.Exec(query) if res.Rows != nil {
if err != nil { res.Rows.Close()
common.Log.Error("doSampling Error from %s.%s: %v", dbName, table, err)
} }
return err
} }
/*
* Copyright 2018 Xiaomi, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package database
import (
"testing"
"github.com/XiaoMi/soar/common"
)
func init() {
common.BaseDir = common.DevPath
}
func TestSamplingData(t *testing.T) {
connOnline, err := NewConnector(common.Config.OnlineDSN)
if err != nil {
t.Error(err)
}
err = connTest.SamplingData(connOnline, "film")
if err != nil {
t.Error(err)
}
}
...@@ -459,27 +459,27 @@ func (db *Connector) ShowCreateTable(tableName string) (string, error) { ...@@ -459,27 +459,27 @@ func (db *Connector) ShowCreateTable(tableName string) (string, error) {
ddl, err := db.showCreate("table", tableName) ddl, err := db.showCreate("table", tableName)
// 去除外键关联条件 // 去除外键关联条件
var noConstraint []string lines := strings.Split(ddl, "\n")
relationReg, _ := regexp.Compile("CONSTRAINT") // CREATE VIEW ONLY 1 LINE
for _, line := range strings.Split(ddl, "\n") { if len(lines) > 2 {
var noConstraint []string
if relationReg.Match([]byte(line)) { relationReg, _ := regexp.Compile("CONSTRAINT")
continue for _, line := range lines[1 : len(lines)-1] {
} if relationReg.Match([]byte(line)) {
continue
// 去除外键语句会使DDL中多一个','导致语法错误,要把多余的逗号去除
if strings.Index(line, ")") == 0 {
lineWrongSyntax := noConstraint[len(noConstraint)-1]
// 如果')'前一句的末尾是',' 删除 ',' 保证语法正确性
if strings.Index(lineWrongSyntax, ",") == len(lineWrongSyntax)-1 {
noConstraint[len(noConstraint)-1] = lineWrongSyntax[:len(lineWrongSyntax)-1]
} }
line = strings.TrimSuffix(line, ",")
noConstraint = append(noConstraint, line)
} }
noConstraint = append(noConstraint, line) // 去除外键语句会使DDL中多一个','导致语法错误,要把多余的逗号去除
ddl = fmt.Sprint(
lines[0], "\n",
strings.Join(noConstraint, ",\n"), "\n",
lines[len(lines)-1],
)
} }
return ddl, err
return strings.Join(noConstraint, "\n"), err
} }
// FindColumn find column // FindColumn find column
......
...@@ -26,6 +26,7 @@ import ( ...@@ -26,6 +26,7 @@ import (
) )
func TestShowTableStatus(t *testing.T) { func TestShowTableStatus(t *testing.T) {
common.Log.Debug("Entering function: %s", common.GetFunctionName())
orgDatabase := connTest.Database orgDatabase := connTest.Database
connTest.Database = "sakila" connTest.Database = "sakila"
ts, err := connTest.ShowTableStatus("film") ts, err := connTest.ShowTableStatus("film")
...@@ -47,9 +48,11 @@ func TestShowTableStatus(t *testing.T) { ...@@ -47,9 +48,11 @@ func TestShowTableStatus(t *testing.T) {
} }
pretty.Println(ts) pretty.Println(ts)
connTest.Database = orgDatabase connTest.Database = orgDatabase
common.Log.Debug("Exiting function: %s", common.GetFunctionName())
} }
func TestShowTables(t *testing.T) { func TestShowTables(t *testing.T) {
common.Log.Debug("Entering function: %s", common.GetFunctionName())
orgDatabase := connTest.Database orgDatabase := connTest.Database
connTest.Database = "sakila" connTest.Database = "sakila"
ts, err := connTest.ShowTables() ts, err := connTest.ShowTables()
...@@ -66,23 +69,29 @@ func TestShowTables(t *testing.T) { ...@@ -66,23 +69,29 @@ func TestShowTables(t *testing.T) {
t.Error(err) t.Error(err)
} }
connTest.Database = orgDatabase connTest.Database = orgDatabase
common.Log.Debug("Exiting function: %s", common.GetFunctionName())
} }
func TestShowCreateDatabase(t *testing.T) { func TestShowCreateDatabase(t *testing.T) {
common.Log.Debug("Entering function: %s", common.GetFunctionName())
err := common.GoldenDiff(func() { err := common.GoldenDiff(func() {
fmt.Println(connTest.ShowCreateDatabase("sakila")) fmt.Println(connTest.ShowCreateDatabase("sakila"))
}, t.Name(), update) }, t.Name(), update)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} }
common.Log.Debug("Exiting function: %s", common.GetFunctionName())
} }
func TestShowCreateTable(t *testing.T) { func TestShowCreateTable(t *testing.T) {
common.Log.Debug("Entering function: %s", common.GetFunctionName())
orgDatabase := connTest.Database orgDatabase := connTest.Database
connTest.Database = "sakila" connTest.Database = "sakila"
tables := []string{ tables := []string{
"film", "film",
"category",
"customer_list", "customer_list",
"inventory",
} }
err := common.GoldenDiff(func() { err := common.GoldenDiff(func() {
for _, table := range tables { for _, table := range tables {
...@@ -97,9 +106,11 @@ func TestShowCreateTable(t *testing.T) { ...@@ -97,9 +106,11 @@ func TestShowCreateTable(t *testing.T) {
t.Error(err) t.Error(err)
} }
connTest.Database = orgDatabase connTest.Database = orgDatabase
common.Log.Debug("Exiting function: %s", common.GetFunctionName())
} }
func TestShowIndex(t *testing.T) { func TestShowIndex(t *testing.T) {
common.Log.Debug("Entering function: %s", common.GetFunctionName())
orgDatabase := connTest.Database orgDatabase := connTest.Database
connTest.Database = "sakila" connTest.Database = "sakila"
ti, err := connTest.ShowIndex("film") ti, err := connTest.ShowIndex("film")
...@@ -114,11 +125,12 @@ func TestShowIndex(t *testing.T) { ...@@ -114,11 +125,12 @@ func TestShowIndex(t *testing.T) {
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} }
connTest.Database = orgDatabase connTest.Database = orgDatabase
common.Log.Debug("Exiting function: %s", common.GetFunctionName())
} }
func TestShowColumns(t *testing.T) { func TestShowColumns(t *testing.T) {
common.Log.Debug("Entering function: %s", common.GetFunctionName())
orgDatabase := connTest.Database orgDatabase := connTest.Database
connTest.Database = "sakila" connTest.Database = "sakila"
ti, err := connTest.ShowColumns("actor_info") ti, err := connTest.ShowColumns("actor_info")
...@@ -134,9 +146,11 @@ func TestShowColumns(t *testing.T) { ...@@ -134,9 +146,11 @@ func TestShowColumns(t *testing.T) {
} }
connTest.Database = orgDatabase connTest.Database = orgDatabase
common.Log.Debug("Exiting function: %s", common.GetFunctionName())
} }
func TestFindColumn(t *testing.T) { func TestFindColumn(t *testing.T) {
common.Log.Debug("Entering function: %s", common.GetFunctionName())
ti, err := connTest.FindColumn("film_id", "sakila", "film") ti, err := connTest.FindColumn("film_id", "sakila", "film")
if err != nil { if err != nil {
t.Error("FindColumn Error: ", err) t.Error("FindColumn Error: ", err)
...@@ -147,15 +161,19 @@ func TestFindColumn(t *testing.T) { ...@@ -147,15 +161,19 @@ func TestFindColumn(t *testing.T) {
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} }
common.Log.Debug("Exiting function: %s", common.GetFunctionName())
} }
func TestIsFKey(t *testing.T) { func TestIsFKey(t *testing.T) {
common.Log.Debug("Entering function: %s", common.GetFunctionName())
if !connTest.IsForeignKey("sakila", "film", "language_id") { if !connTest.IsForeignKey("sakila", "film", "language_id") {
t.Error("want True. got false") t.Error("want True. got false")
} }
common.Log.Debug("Exiting function: %s", common.GetFunctionName())
} }
func TestShowReference(t *testing.T) { func TestShowReference(t *testing.T) {
common.Log.Debug("Entering function: %s", common.GetFunctionName())
rv, err := connTest.ShowReference("sakila", "film") rv, err := connTest.ShowReference("sakila", "film")
if err != nil { if err != nil {
t.Error("ShowReference Error: ", err) t.Error("ShowReference Error: ", err)
...@@ -167,4 +185,5 @@ func TestShowReference(t *testing.T) { ...@@ -167,4 +185,5 @@ func TestShowReference(t *testing.T) {
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} }
common.Log.Debug("Exiting function: %s", common.GetFunctionName())
} }
...@@ -17,4 +17,19 @@ CREATE TABLE `film` ( ...@@ -17,4 +17,19 @@ CREATE TABLE `film` (
KEY `idx_fk_language_id` (`language_id`), KEY `idx_fk_language_id` (`language_id`),
KEY `idx_fk_original_language_id` (`original_language_id`) KEY `idx_fk_original_language_id` (`original_language_id`)
) ENGINE=InnoDB AUTO_INCREMENT=1001 DEFAULT CHARSET=utf8 ) ENGINE=InnoDB AUTO_INCREMENT=1001 DEFAULT CHARSET=utf8
CREATE TABLE `category` (
`category_id` tinyint(3) unsigned NOT NULL AUTO_INCREMENT,
`name` varchar(25) NOT NULL,
`last_update` timestamp NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
PRIMARY KEY (`category_id`)
) ENGINE=InnoDB AUTO_INCREMENT=17 DEFAULT CHARSET=utf8
CREATE ALGORITHM=UNDEFINED DEFINER=`root`@`localhost` SQL SECURITY DEFINER VIEW `customer_list` AS select `cu`.`customer_id` AS `ID`,concat(`cu`.`first_name`,_utf8mb3' ',`cu`.`last_name`) AS `name`,`a`.`address` AS `address`,`a`.`postal_code` AS `zip code`,`a`.`phone` AS `phone`,`city`.`city` AS `city`,`country`.`country` AS `country`,if(`cu`.`active`,_utf8mb3'active',_utf8mb3'') AS `notes`,`cu`.`store_id` AS `SID` from (((`customer` `cu` join `address` `a` on((`cu`.`address_id` = `a`.`address_id`))) join `city` on((`a`.`city_id` = `city`.`city_id`))) join `country` on((`city`.`country_id` = `country`.`country_id`))) CREATE ALGORITHM=UNDEFINED DEFINER=`root`@`localhost` SQL SECURITY DEFINER VIEW `customer_list` AS select `cu`.`customer_id` AS `ID`,concat(`cu`.`first_name`,_utf8mb3' ',`cu`.`last_name`) AS `name`,`a`.`address` AS `address`,`a`.`postal_code` AS `zip code`,`a`.`phone` AS `phone`,`city`.`city` AS `city`,`country`.`country` AS `country`,if(`cu`.`active`,_utf8mb3'active',_utf8mb3'') AS `notes`,`cu`.`store_id` AS `SID` from (((`customer` `cu` join `address` `a` on((`cu`.`address_id` = `a`.`address_id`))) join `city` on((`a`.`city_id` = `city`.`city_id`))) join `country` on((`city`.`country_id` = `country`.`country_id`)))
CREATE TABLE `inventory` (
`inventory_id` mediumint(8) unsigned NOT NULL AUTO_INCREMENT,
`film_id` smallint(5) unsigned NOT NULL,
`store_id` tinyint(3) unsigned NOT NULL,
`last_update` timestamp NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
PRIMARY KEY (`inventory_id`),
KEY `idx_fk_film_id` (`film_id`),
KEY `idx_store_id_film_id` (`store_id`,`film_id`)
) ENGINE=InnoDB AUTO_INCREMENT=4582 DEFAULT CHARSET=utf8
...@@ -25,6 +25,7 @@ import ( ...@@ -25,6 +25,7 @@ import (
) )
func TestTrace(t *testing.T) { func TestTrace(t *testing.T) {
common.Log.Debug("Entering function: %s", common.GetFunctionName())
res, err := connTest.Trace("select 1") res, err := connTest.Trace("select 1")
if err != nil { if err != nil {
t.Error(err) t.Error(err)
...@@ -36,9 +37,11 @@ func TestTrace(t *testing.T) { ...@@ -36,9 +37,11 @@ func TestTrace(t *testing.T) {
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} }
common.Log.Debug("Exiting function: %s", common.GetFunctionName())
} }
func TestFormatTrace(t *testing.T) { func TestFormatTrace(t *testing.T) {
common.Log.Debug("Entering function: %s", common.GetFunctionName())
res, err := connTest.Trace("select 1") res, err := connTest.Trace("select 1")
if err != nil { if err != nil {
t.Error(err) t.Error(err)
...@@ -50,4 +53,5 @@ func TestFormatTrace(t *testing.T) { ...@@ -50,4 +53,5 @@ func TestFormatTrace(t *testing.T) {
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} }
common.Log.Debug("Exiting function: %s", common.GetFunctionName())
} }
...@@ -82,7 +82,7 @@ func BuildEnv() (*VirtualEnv, *database.Connector) { ...@@ -82,7 +82,7 @@ func BuildEnv() (*VirtualEnv, *database.Connector) {
common.LogIfError(err, "") common.LogIfError(err, "")
// 检查线上环境可用性版本 // 检查线上环境可用性版本
rEnvVersion, err := vEnv.Version() rEnvVersion, err := connOnline.Version()
common.Config.OnlineDSN.Version = rEnvVersion common.Config.OnlineDSN.Version = rEnvVersion
if err != nil { if err != nil {
common.Log.Warn("BuildEnv OnlineDSN: %s:********@%s/%s not available , Error: %s", common.Log.Warn("BuildEnv OnlineDSN: %s:********@%s/%s not available , Error: %s",
...@@ -245,20 +245,20 @@ func (ve *VirtualEnv) BuildVirtualEnv(rEnv *database.Connector, SQLs ...string) ...@@ -245,20 +245,20 @@ func (ve *VirtualEnv) BuildVirtualEnv(rEnv *database.Connector, SQLs ...string)
// 为了支持并发,需要将DB进行映射,但db.table这种形式无法保证DB的映射是正确的 // 为了支持并发,需要将DB进行映射,但db.table这种形式无法保证DB的映射是正确的
// TODO:暂不支持 create db.tableName (id int) 形式的建表语句 // TODO:暂不支持 create db.tableName (id int) 形式的建表语句
if stmt.Table.Qualifier.String() != "" { if stmt.Table.Qualifier.String() != "" {
common.Log.Error("BuildVirtualEnv DDL Not support '.'") common.Log.Error("BuildVirtualEnv DDL Not support db.tb format")
return false return false
} }
for _, tb := range stmt.FromTables { for _, tb := range stmt.FromTables {
if tb.Qualifier.String() != "" { if tb.Qualifier.String() != "" {
common.Log.Error("BuildVirtualEnv DDL Not support '.'") common.Log.Error("BuildVirtualEnv DDL Not support db.tb format")
return false return false
} }
} }
for _, tb := range stmt.ToTables { for _, tb := range stmt.ToTables {
if tb.Qualifier.String() != "" { if tb.Qualifier.String() != "" {
common.Log.Error("BuildVirtualEnv DDL Not support '.'") common.Log.Error("BuildVirtualEnv DDL Not support db.tb format")
return false return false
} }
} }
...@@ -338,7 +338,7 @@ func (ve *VirtualEnv) BuildVirtualEnv(rEnv *database.Connector, SQLs ...string) ...@@ -338,7 +338,7 @@ func (ve *VirtualEnv) BuildVirtualEnv(rEnv *database.Connector, SQLs ...string)
err = ve.createTable(tmpEnv, db, tb.TableName) err = ve.createTable(tmpEnv, db, tb.TableName)
if err != nil { if err != nil {
common.Log.Error("BuildVirtualEnv Error : %v", err) common.Log.Error("BuildVirtualEnv %s.%s Error : %v", db, tb.TableName, err)
return false return false
} }
} }
...@@ -453,7 +453,7 @@ func (ve VirtualEnv) createTable(rEnv *database.Connector, dbName, tbName string ...@@ -453,7 +453,7 @@ func (ve VirtualEnv) createTable(rEnv *database.Connector, dbName, tbName string
res, err := ve.Query(ddl) res, err := ve.Query(ddl)
if err != nil { if err != nil {
// 有可能是用户新建表,因此线上环境查不到 // 有可能是用户新建表,因此线上环境查不到
common.Log.Error("createTable, %s Error : %v", tbName, err) common.Log.Error("createTable: %s Error : %v", tbName, err)
return err return err
} }
res.Rows.Close() res.Rows.Close()
...@@ -461,13 +461,9 @@ func (ve VirtualEnv) createTable(rEnv *database.Connector, dbName, tbName string ...@@ -461,13 +461,9 @@ func (ve VirtualEnv) createTable(rEnv *database.Connector, dbName, tbName string
// 泵取数据 // 泵取数据
if common.Config.Sampling { if common.Config.Sampling {
common.Log.Debug("createTable, Start Sampling data from %s.%s to %s.%s ...", dbName, tbName, ve.DBRef[dbName], tbName) common.Log.Debug("createTable, Start Sampling data from %s.%s to %s.%s ...", dbName, tbName, ve.DBRef[dbName], tbName)
err := ve.SamplingData(rEnv, tbName) err = ve.SamplingData(rEnv, dbName, tbName)
if err != nil {
common.Log.Error(" (ve VirtualEnv) createTable SamplingData Error: %v", err)
return err
}
} }
return nil return err
} }
// GenTableColumns 为 Rewrite 提供的结构体初始化 // GenTableColumns 为 Rewrite 提供的结构体初始化
......
...@@ -49,6 +49,7 @@ func init() { ...@@ -49,6 +49,7 @@ func init() {
} }
func TestNewVirtualEnv(t *testing.T) { func TestNewVirtualEnv(t *testing.T) {
common.Log.Debug("Entering function: %s", common.GetFunctionName())
testSQL := []string{ testSQL := []string{
"create table t(id int,c1 varchar(20),PRIMARY KEY (id));", "create table t(id int,c1 varchar(20),PRIMARY KEY (id));",
"alter table t add index `idx_c1`(c1);", "alter table t add index `idx_c1`(c1);",
...@@ -117,9 +118,11 @@ func TestNewVirtualEnv(t *testing.T) { ...@@ -117,9 +118,11 @@ func TestNewVirtualEnv(t *testing.T) {
} }
} }
}, t.Name(), update) }, t.Name(), update)
common.Log.Debug("Exiting function: %s", common.GetFunctionName())
} }
func TestCleanupTestDatabase(t *testing.T) { func TestCleanupTestDatabase(t *testing.T) {
common.Log.Debug("Enter function: %s", common.GetFunctionName())
vEnv, _ := BuildEnv() vEnv, _ := BuildEnv()
if common.Config.TestDSN.Disable { if common.Config.TestDSN.Disable {
common.Log.Warn("common.Config.TestDSN.Disable=true, by pass TestCleanupTestDatabase") common.Log.Warn("common.Config.TestDSN.Disable=true, by pass TestCleanupTestDatabase")
...@@ -146,9 +149,11 @@ func TestCleanupTestDatabase(t *testing.T) { ...@@ -146,9 +149,11 @@ func TestCleanupTestDatabase(t *testing.T) {
if err != nil { if err != nil {
t.Error("optimizer_060102150405 not exist, should not be dropped") t.Error("optimizer_060102150405 not exist, should not be dropped")
} }
common.Log.Debug("Exiting function: %s", common.GetFunctionName())
} }
func TestGenTableColumns(t *testing.T) { func TestGenTableColumns(t *testing.T) {
common.Log.Debug("Enter function: %s", common.GetFunctionName())
vEnv, rEnv := BuildEnv() vEnv, rEnv := BuildEnv()
defer vEnv.CleanUp() defer vEnv.CleanUp()
...@@ -214,4 +219,66 @@ func TestGenTableColumns(t *testing.T) { ...@@ -214,4 +219,66 @@ func TestGenTableColumns(t *testing.T) {
} }
} }
} }
common.Log.Debug("Exiting function: %s", common.GetFunctionName())
}
func TestCreateTable(t *testing.T) {
common.Log.Debug("Enter function: %s", common.GetFunctionName())
orgSamplingCondition := common.Config.SamplingCondition
common.Config.SamplingCondition = "LIMIT 1"
vEnv, rEnv := BuildEnv()
defer vEnv.CleanUp()
// TODO: support VIEW,
tables := []string{
"actor",
// "actor_info", // VIEW
"address",
"category",
"city",
"country",
"customer",
"customer_list",
"film",
"film_actor",
"film_category",
"film_list",
"film_text",
"inventory",
"language",
"nicer_but_slower_film_list",
"payment",
"rental",
// "sales_by_film_category", // VIEW
// "sales_by_store", // VIEW
"staff",
"staff_list",
"store",
}
for _, table := range tables {
err := vEnv.createTable(rEnv, "sakila", table)
if err != nil {
t.Error(err)
}
}
common.Config.SamplingCondition = orgSamplingCondition
common.Log.Debug("Exiting function: %s", common.GetFunctionName())
}
func TestCreateDatabase(t *testing.T) {
common.Log.Debug("Enter function: %s", common.GetFunctionName())
vEnv, rEnv := BuildEnv()
defer vEnv.CleanUp()
err := vEnv.createDatabase(rEnv, "sakila")
if err != nil {
t.Error(err)
}
if vEnv.DBHash("sakila") == "sakila" {
t.Errorf("database: sakila rehashed failed!")
}
if vEnv.DBHash("not_exist_db") != "not_exist_db" {
t.Errorf("database: not_exist_db rehashed!")
}
common.Log.Debug("Exiting function: %s", common.GetFunctionName())
} }
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册