diff --git a/Makefile b/Makefile index 5e6ce8fd650706eaba9cfaec3524b7d3db250ff3..2a82688b5f9f8eb601cc85f142e3f877807cafa5 100644 --- a/Makefile +++ b/Makefile @@ -61,7 +61,7 @@ fmt: go_version_check .PHONY: test test: @echo "\033[92mRun all test cases ...\033[0m" - go test ./... + go test -race ./... @echo "test Success!" # Code Coverage diff --git a/database/mysql_test.go b/database/mysql_test.go index c8e6a1155be006d559c917e0a9f87ed6dba876c6..1688ba12cefe9794b164130d478a745578de0c11 100644 --- a/database/mysql_test.go +++ b/database/mysql_test.go @@ -48,13 +48,13 @@ func init() { } func TestNewConnection(t *testing.T) { - _, err := connTest.NewConnection() + conn, err := connTest.NewConnection() if err != nil { t.Errorf("TestNewConnection, Error: %s", err.Error()) } + defer conn.Close() } -// TODO: go test -race不通过待解决 func TestQuery(t *testing.T) { res, err := connTest.Query("select 0") if err != nil { diff --git a/database/show.go b/database/show.go index a49dc5e978e9f3051b4a72ad695d1f0681b3c805..404274de502349897f0e10465136d8125bfcb851 100644 --- a/database/show.go +++ b/database/show.go @@ -372,19 +372,38 @@ func (td TableDesc) Columns() []string { // showCreate show create func (db *Connector) showCreate(createType, name string) (string, error) { // 执行 show create table|database - // createType = [table|database] + // createType = [table|database|view] res, err := db.Query(fmt.Sprintf("show create %s `%s`", createType, name)) if err != nil { return "", err } - // 获取 CREATE TABLE 语句 - var tableName, createTable string + // columns info + var colByPass []byte + var create string + createFields := make([]interface{}, 0) + fields := map[string]interface{}{ + "Create View": &create, + "Create Table": &create, + "Create Database": &create, + "Table": &colByPass, + "Database": &colByPass, + "View": &colByPass, + "character_set_client": &colByPass, + "collation_connection": &colByPass, + } + cols, err := res.Rows.Columns() + common.LogIfError(err, "") + for _, col := range cols { + createFields = append(createFields, fields[col]) + } + + // 获取 CREATE TABLE|VIEW|DATABASE 语句 for res.Rows.Next() { - res.Rows.Scan(&tableName, &createTable) + res.Rows.Scan(createFields...) } - return createTable, err + return create, err } // ShowCreateDatabase show create database diff --git a/database/show_test.go b/database/show_test.go index 97e8d04a134f4172a72b35625ffbc23751961587..49b2441674fb3c8babab042ffc9bf365446ced82 100644 --- a/database/show_test.go +++ b/database/show_test.go @@ -23,7 +23,6 @@ import ( "github.com/XiaoMi/soar/common" "github.com/kr/pretty" - "vitess.io/vitess/go/vt/sqlparser" ) func TestShowTableStatus(t *testing.T) { @@ -72,18 +71,18 @@ func TestShowTables(t *testing.T) { func TestShowCreateTable(t *testing.T) { orgDatabase := connTest.Database connTest.Database = "sakila" - ts, err := connTest.ShowCreateTable("film") - if err != nil { - t.Error("ShowCreateTable Error: ", err) - } - - err = common.GoldenDiff(func() { - fmt.Println(ts) - stmt, err := sqlparser.Parse(ts) - if err != nil { - t.Error(err.Error()) + tables := []string{ + "film", + "customer_list", + } + err := common.GoldenDiff(func() { + for _, table := range tables { + ts, err := connTest.ShowCreateTable(table) + if err != nil { + t.Error("ShowCreateTable Error: ", err) + } + fmt.Println(ts) } - pretty.Println(stmt, err) }, t.Name(), update) if err != nil { t.Error(err) diff --git a/database/testdata/TestShowCreateTable.golden b/database/testdata/TestShowCreateTable.golden index 377ef8d32fc8da8685f4b7e4268ba613371b489a..74208c702b85a1014f34c2a7561c273ade684f04 100644 --- a/database/testdata/TestShowCreateTable.golden +++ b/database/testdata/TestShowCreateTable.golden @@ -17,18 +17,4 @@ CREATE TABLE `film` ( KEY `idx_fk_language_id` (`language_id`), KEY `idx_fk_original_language_id` (`original_language_id`) ) ENGINE=InnoDB AUTO_INCREMENT=1001 DEFAULT CHARSET=utf8 -&sqlparser.DDL{ - Action: "create", - FromTables: nil, - ToTables: nil, - Table: sqlparser.TableName{ - Name: sqlparser.TableIdent{v:"film"}, - Qualifier: sqlparser.TableIdent{}, - }, - IfExists: false, - TableSpec: (*sqlparser.TableSpec)(nil), - OptLike: (*sqlparser.OptLike)(nil), - PartitionSpec: (*sqlparser.PartitionSpec)(nil), - VindexSpec: (*sqlparser.VindexSpec)(nil), - VindexCols: nil, -} nil +CREATE ALGORITHM=UNDEFINED DEFINER=`root`@`localhost` SQL SECURITY DEFINER VIEW `customer_list` AS select `cu`.`customer_id` AS `ID`,concat(`cu`.`first_name`,_utf8mb3' ',`cu`.`last_name`) AS `name`,`a`.`address` AS `address`,`a`.`postal_code` AS `zip code`,`a`.`phone` AS `phone`,`city`.`city` AS `city`,`country`.`country` AS `country`,if(`cu`.`active`,_utf8mb3'active',_utf8mb3'') AS `notes`,`cu`.`store_id` AS `SID` from (((`customer` `cu` join `address` `a` on((`cu`.`address_id` = `a`.`address_id`))) join `city` on((`a`.`city_id` = `city`.`city_id`))) join `country` on((`city`.`country_id` = `country`.`country_id`))) diff --git a/env/env.go b/env/env.go index 7ae1c5120c84f1015a14ea3a4d6b86972de297ec..bc7491bd042aae3edcb4802fcca3501539996abf 100644 --- a/env/env.go +++ b/env/env.go @@ -331,6 +331,10 @@ func (ve *VirtualEnv) BuildVirtualEnv(rEnv *database.Connector, SQLs ...string) } startIdx := strings.Index(viewDDL, "AS") + if startIdx < 0 || viewDDL == "" { + common.Log.Error("BuildVirtualEnv '%s' got '%s', Index: %d", tb.TableName, viewDDL, startIdx) + return false + } viewDDL = viewDDL[startIdx+2:] if !ve.BuildVirtualEnv(&tmpEnv, viewDDL) { return false