diff --git a/ast/rewrite.go b/ast/rewrite.go index bc459b8c9be590b37c03153876d2472735ecd003..c7a07cc1409e62efb1650739cd55fee3ec3e1e91 100644 --- a/ast/rewrite.go +++ b/ast/rewrite.go @@ -672,6 +672,18 @@ func (rw *Rewrite) RewriteStar2Columns() *Rewrite { return rw } + // 单张表 select * 不补全表名,避免SQL过长,多张表的 select tb1.*, tb2.* 需要补全表名 + var multiTable bool + if len(rw.Columns) > 1 { + multiTable = true + } else { + for db := range rw.Columns { + if len(rw.Columns[db]) > 1 { + multiTable = true + } + } + } + err := sqlparser.Walk(func(node sqlparser.SQLNode) (kontinue bool, err error) { switch n := node.(type) { case *sqlparser.Select: @@ -692,12 +704,16 @@ func (rw *Rewrite) RewriteStar2Columns() *Rewrite { for _, tables := range rw.Columns { for _, cols := range tables { for _, col := range cols { + var table string + if multiTable { + table = col.Table + } newExpr := &sqlparser.AliasedExpr{ Expr: &sqlparser.ColName{ Metadata: nil, Name: sqlparser.NewColIdent(col.Name), Qualifier: sqlparser.TableName{ - Name: sqlparser.NewTableIdent(col.Table), + Name: sqlparser.NewTableIdent(table), // 因为不建议跨DB的查询,所以这里的db前缀将不进行补齐 Qualifier: sqlparser.TableIdent{}, }, diff --git a/ast/rewrite_test.go b/ast/rewrite_test.go index 361d051ac32ce5842f34f00e2e2bf1f269241f74..817e2eb43699b4bfdaee95238105a78824bea6aa 100644 --- a/ast/rewrite_test.go +++ b/ast/rewrite_test.go @@ -110,11 +110,11 @@ func TestRewriteStar2Columns(t *testing.T) { testSQL := []map[string]string{ { "input": `SELECT * FROM film`, - "output": `select film.film_id, film.title from film`, + "output": `select film_id, title from film`, }, { - "input": `SELECT film.*, actor.actor_id FROM film,actor`, - "output": `select film.film_id, film.title, actor.actor_id from film, actor`, + "input": `SELECT film.* FROM film`, + "output": `select film_id, title from film`, }, } @@ -133,6 +133,36 @@ func TestRewriteStar2Columns(t *testing.T) { t.Errorf("want: %s\ngot: %s", sql["output"], rw.NewSQL) } } + + testSQL2 := []map[string]string{ + { + "input": `SELECT film.* FROM film, actor`, + "output": `select film.film_id, film.title from film, actor`, + }, + { + "input": `SELECT film.*, actor.actor_id FROM film, actor`, + "output": `select film.film_id, film.title, actor.actor_id from film, actor`, + }, + } + + for _, sql := range testSQL2 { + rw := NewRewrite(sql["input"]) + rw.Columns = map[string]map[string][]*common.Column{ + "sakila": { + "film": { + {Name: "film_id", Table: "film"}, + {Name: "title", Table: "film"}, + }, + "actor": { + {Name: "actor_id", Table: "actor"}, + }, + }, + } + rw.RewriteStar2Columns() + if rw.NewSQL != sql["output"] { + t.Errorf("want: %s\ngot: %s", sql["output"], rw.NewSQL) + } + } common.Config.TestDSN.Disable = orgTestDSNStatus common.Log.Debug("Exiting function: %s", common.GetFunctionName()) }