提交 5ada8872 编写于 作者: M Mark Haines 提交者: GitHub

Add context to the federationsender database (#231)

上级 dc5dd4c5
...@@ -123,8 +123,12 @@ func (s *OutputRoomEvent) processMessage(ore api.OutputNewRoomEvent) error { ...@@ -123,8 +123,12 @@ func (s *OutputRoomEvent) processMessage(ore api.OutputNewRoomEvent) error {
// TODO: handle EventIDMismatchError and recover the current state by talking // TODO: handle EventIDMismatchError and recover the current state by talking
// to the roomserver // to the roomserver
oldJoinedHosts, err := s.db.UpdateRoom( oldJoinedHosts, err := s.db.UpdateRoom(
ore.Event.RoomID(), ore.LastSentEventID, ore.Event.EventID(), context.TODO(),
addsJoinedHosts, ore.RemovesStateEventIDs, ore.Event.RoomID(),
ore.LastSentEventID,
ore.Event.EventID(),
addsJoinedHosts,
ore.RemovesStateEventIDs,
) )
if err != nil { if err != nil {
return err return err
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
package storage package storage
import ( import (
"context"
"database/sql" "database/sql"
"github.com/lib/pq" "github.com/lib/pq"
...@@ -78,20 +79,29 @@ func (s *joinedHostsStatements) prepare(db *sql.DB) (err error) { ...@@ -78,20 +79,29 @@ func (s *joinedHostsStatements) prepare(db *sql.DB) (err error) {
} }
func (s *joinedHostsStatements) insertJoinedHosts( func (s *joinedHostsStatements) insertJoinedHosts(
txn *sql.Tx, roomID, eventID string, serverName gomatrixserverlib.ServerName, ctx context.Context,
txn *sql.Tx,
roomID, eventID string,
serverName gomatrixserverlib.ServerName,
) error { ) error {
_, err := common.TxStmt(txn, s.insertJoinedHostsStmt).Exec(roomID, eventID, serverName) stmt := common.TxStmt(txn, s.insertJoinedHostsStmt)
_, err := stmt.ExecContext(ctx, roomID, eventID, serverName)
return err return err
} }
func (s *joinedHostsStatements) deleteJoinedHosts(txn *sql.Tx, eventIDs []string) error { func (s *joinedHostsStatements) deleteJoinedHosts(
_, err := common.TxStmt(txn, s.deleteJoinedHostsStmt).Exec(pq.StringArray(eventIDs)) ctx context.Context, txn *sql.Tx, eventIDs []string,
) error {
stmt := common.TxStmt(txn, s.deleteJoinedHostsStmt)
_, err := stmt.ExecContext(ctx, pq.StringArray(eventIDs))
return err return err
} }
func (s *joinedHostsStatements) selectJoinedHosts(txn *sql.Tx, roomID string, func (s *joinedHostsStatements) selectJoinedHosts(
ctx context.Context, txn *sql.Tx, roomID string,
) ([]types.JoinedHost, error) { ) ([]types.JoinedHost, error) {
rows, err := common.TxStmt(txn, s.selectJoinedHostsStmt).Query(roomID) stmt := common.TxStmt(txn, s.selectJoinedHostsStmt)
rows, err := stmt.QueryContext(ctx, roomID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
package storage package storage
import ( import (
"context"
"database/sql" "database/sql"
"github.com/matrix-org/dendrite/common" "github.com/matrix-org/dendrite/common"
...@@ -66,17 +67,22 @@ func (s *roomStatements) prepare(db *sql.DB) (err error) { ...@@ -66,17 +67,22 @@ func (s *roomStatements) prepare(db *sql.DB) (err error) {
// insertRoom inserts the room if it didn't already exist. // insertRoom inserts the room if it didn't already exist.
// If the room didn't exist then last_event_id is set to the empty string. // If the room didn't exist then last_event_id is set to the empty string.
func (s *roomStatements) insertRoom(txn *sql.Tx, roomID string) error { func (s *roomStatements) insertRoom(
_, err := common.TxStmt(txn, s.insertRoomStmt).Exec(roomID) ctx context.Context, txn *sql.Tx, roomID string,
) error {
_, err := common.TxStmt(txn, s.insertRoomStmt).ExecContext(ctx, roomID)
return err return err
} }
// selectRoomForUpdate locks the row for the room and returns the last_event_id. // selectRoomForUpdate locks the row for the room and returns the last_event_id.
// The row must already exist in the table. Callers can ensure that the row // The row must already exist in the table. Callers can ensure that the row
// exists by calling insertRoom first. // exists by calling insertRoom first.
func (s *roomStatements) selectRoomForUpdate(txn *sql.Tx, roomID string) (string, error) { func (s *roomStatements) selectRoomForUpdate(
ctx context.Context, txn *sql.Tx, roomID string,
) (string, error) {
var lastEventID string var lastEventID string
err := common.TxStmt(txn, s.selectRoomForUpdateStmt).QueryRow(roomID).Scan(&lastEventID) stmt := common.TxStmt(txn, s.selectRoomForUpdateStmt)
err := stmt.QueryRowContext(ctx, roomID).Scan(&lastEventID)
if err != nil { if err != nil {
return "", err return "", err
} }
...@@ -85,7 +91,10 @@ func (s *roomStatements) selectRoomForUpdate(txn *sql.Tx, roomID string) (string ...@@ -85,7 +91,10 @@ func (s *roomStatements) selectRoomForUpdate(txn *sql.Tx, roomID string) (string
// updateRoom updates the last_event_id for the room. selectRoomForUpdate should // updateRoom updates the last_event_id for the room. selectRoomForUpdate should
// have already been called earlier within the transaction. // have already been called earlier within the transaction.
func (s *roomStatements) updateRoom(txn *sql.Tx, roomID, lastEventID string) error { func (s *roomStatements) updateRoom(
_, err := common.TxStmt(txn, s.updateRoomStmt).Exec(roomID, lastEventID) ctx context.Context, txn *sql.Tx, roomID, lastEventID string,
) error {
stmt := common.TxStmt(txn, s.updateRoomStmt)
_, err := stmt.ExecContext(ctx, roomID, lastEventID)
return err return err
} }
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
package storage package storage
import ( import (
"context"
"database/sql" "database/sql"
"github.com/matrix-org/dendrite/common" "github.com/matrix-org/dendrite/common"
...@@ -73,35 +74,38 @@ func (d *Database) SetPartitionOffset(topic string, partition int32, offset int6 ...@@ -73,35 +74,38 @@ func (d *Database) SetPartitionOffset(topic string, partition int32, offset int6
// UpdateRoom updates the joined hosts for a room and returns what the joined // UpdateRoom updates the joined hosts for a room and returns what the joined
// hosts were before the update. // hosts were before the update.
func (d *Database) UpdateRoom( func (d *Database) UpdateRoom(
ctx context.Context,
roomID, oldEventID, newEventID string, roomID, oldEventID, newEventID string,
addHosts []types.JoinedHost, addHosts []types.JoinedHost,
removeHosts []string, removeHosts []string,
) (joinedHosts []types.JoinedHost, err error) { ) (joinedHosts []types.JoinedHost, err error) {
err = common.WithTransaction(d.db, func(txn *sql.Tx) error { err = common.WithTransaction(d.db, func(txn *sql.Tx) error {
if err = d.insertRoom(txn, roomID); err != nil { if err = d.insertRoom(ctx, txn, roomID); err != nil {
return err return err
} }
lastSentEventID, err := d.selectRoomForUpdate(txn, roomID) lastSentEventID, err := d.selectRoomForUpdate(ctx, txn, roomID)
if err != nil { if err != nil {
return err return err
} }
if lastSentEventID != oldEventID { if lastSentEventID != oldEventID {
return types.EventIDMismatchError{lastSentEventID, oldEventID} return types.EventIDMismatchError{
DatabaseID: lastSentEventID, RoomServerID: oldEventID,
}
} }
joinedHosts, err = d.selectJoinedHosts(txn, roomID) joinedHosts, err = d.selectJoinedHosts(ctx, txn, roomID)
if err != nil { if err != nil {
return err return err
} }
for _, add := range addHosts { for _, add := range addHosts {
err = d.insertJoinedHosts(txn, roomID, add.MemberEventID, add.ServerName) err = d.insertJoinedHosts(ctx, txn, roomID, add.MemberEventID, add.ServerName)
if err != nil { if err != nil {
return err return err
} }
} }
if err = d.deleteJoinedHosts(txn, removeHosts); err != nil { if err = d.deleteJoinedHosts(ctx, txn, removeHosts); err != nil {
return err return err
} }
return d.updateRoom(txn, roomID, newEventID) return d.updateRoom(ctx, txn, roomID, newEventID)
}) })
return return
} }
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册