提交 a28d294a 编写于 作者: J jianzhiyao

upgrade acquire method return *sql.Stmt

上级 a5e8344a
...@@ -125,7 +125,6 @@ func (d *DB) getStmtDecorator(query string) (*stmtDecorator, error) { ...@@ -125,7 +125,6 @@ func (d *DB) getStmtDecorator(query string) (*stmtDecorator, error) {
c, ok := d.stmtDecorators.Get(query) c, ok := d.stmtDecorators.Get(query)
d.RUnlock() d.RUnlock()
if ok { if ok {
c.(*stmtDecorator).acquire()
return c.(*stmtDecorator), nil return c.(*stmtDecorator), nil
} }
...@@ -133,7 +132,6 @@ func (d *DB) getStmtDecorator(query string) (*stmtDecorator, error) { ...@@ -133,7 +132,6 @@ func (d *DB) getStmtDecorator(query string) (*stmtDecorator, error) {
c, ok = d.stmtDecorators.Get(query) c, ok = d.stmtDecorators.Get(query)
if ok { if ok {
d.Unlock() d.Unlock()
c.(*stmtDecorator).acquire()
return c.(*stmtDecorator), nil return c.(*stmtDecorator), nil
} }
...@@ -146,7 +144,6 @@ func (d *DB) getStmtDecorator(query string) (*stmtDecorator, error) { ...@@ -146,7 +144,6 @@ func (d *DB) getStmtDecorator(query string) (*stmtDecorator, error) {
d.stmtDecorators.Add(query, sd) d.stmtDecorators.Add(query, sd)
d.Unlock() d.Unlock()
sd.acquire()
return sd, nil return sd, nil
} }
...@@ -163,7 +160,7 @@ func (d *DB) Exec(query string, args ...interface{}) (sql.Result, error) { ...@@ -163,7 +160,7 @@ func (d *DB) Exec(query string, args ...interface{}) (sql.Result, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
stmt := sd.getStmt() stmt := sd.acquire()
defer sd.release() defer sd.release()
return stmt.Exec(args...) return stmt.Exec(args...)
} }
...@@ -173,7 +170,7 @@ func (d *DB) ExecContext(ctx context.Context, query string, args ...interface{}) ...@@ -173,7 +170,7 @@ func (d *DB) ExecContext(ctx context.Context, query string, args ...interface{})
if err != nil { if err != nil {
return nil, err return nil, err
} }
stmt := sd.getStmt() stmt := sd.acquire()
defer sd.release() defer sd.release()
return stmt.ExecContext(ctx, args...) return stmt.ExecContext(ctx, args...)
} }
...@@ -183,7 +180,7 @@ func (d *DB) Query(query string, args ...interface{}) (*sql.Rows, error) { ...@@ -183,7 +180,7 @@ func (d *DB) Query(query string, args ...interface{}) (*sql.Rows, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
stmt := sd.getStmt() stmt := sd.acquire()
defer sd.release() defer sd.release()
return stmt.Query(args...) return stmt.Query(args...)
} }
...@@ -193,7 +190,7 @@ func (d *DB) QueryContext(ctx context.Context, query string, args ...interface{} ...@@ -193,7 +190,7 @@ func (d *DB) QueryContext(ctx context.Context, query string, args ...interface{}
if err != nil { if err != nil {
return nil, err return nil, err
} }
stmt := sd.getStmt() stmt := sd.acquire()
defer sd.release() defer sd.release()
return stmt.QueryContext(ctx, args...) return stmt.QueryContext(ctx, args...)
} }
...@@ -203,7 +200,7 @@ func (d *DB) QueryRow(query string, args ...interface{}) *sql.Row { ...@@ -203,7 +200,7 @@ func (d *DB) QueryRow(query string, args ...interface{}) *sql.Row {
if err != nil { if err != nil {
panic(err) panic(err)
} }
stmt := sd.getStmt() stmt := sd.acquire()
defer sd.release() defer sd.release()
return stmt.QueryRow(args...) return stmt.QueryRow(args...)
...@@ -214,7 +211,7 @@ func (d *DB) QueryRowContext(ctx context.Context, query string, args ...interfac ...@@ -214,7 +211,7 @@ func (d *DB) QueryRowContext(ctx context.Context, query string, args ...interfac
if err != nil { if err != nil {
panic(err) panic(err)
} }
stmt := sd.getStmt() stmt := sd.acquire()
defer sd.release() defer sd.release()
return stmt.QueryRowContext(ctx, args) return stmt.QueryRowContext(ctx, args)
} }
...@@ -423,15 +420,14 @@ func GetDB(aliasNames ...string) (*sql.DB, error) { ...@@ -423,15 +420,14 @@ func GetDB(aliasNames ...string) (*sql.DB, error) {
type stmtDecorator struct { type stmtDecorator struct {
wg sync.WaitGroup wg sync.WaitGroup
lastUse int64
stmt *sql.Stmt stmt *sql.Stmt
} }
func (s *stmtDecorator) getStmt() *sql.Stmt { func (s *stmtDecorator) acquire() *sql.Stmt{
return s.stmt
}
func (s *stmtDecorator) acquire() {
s.wg.Add(1) s.wg.Add(1)
s.lastUse = time.Now().Unix()
return s.stmt
} }
func (s *stmtDecorator) release() { func (s *stmtDecorator) release() {
...@@ -447,6 +443,7 @@ func (s *stmtDecorator) destroy() { ...@@ -447,6 +443,7 @@ func (s *stmtDecorator) destroy() {
func newStmtDecorator(sqlStmt *sql.Stmt) *stmtDecorator { func newStmtDecorator(sqlStmt *sql.Stmt) *stmtDecorator {
return &stmtDecorator{ return &stmtDecorator{
stmt: sqlStmt, stmt: sqlStmt,
lastUse: time.Now().Unix(),
} }
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册