提交 fef290c4 编写于 作者: M Mark Haines 提交者: GitHub

Add context to the server key database (#248)

上级 7596c19f
......@@ -49,7 +49,7 @@ func (d *Database) FetchKeys(
ctx context.Context,
requests map[gomatrixserverlib.PublicKeyRequest]gomatrixserverlib.Timestamp,
) (map[gomatrixserverlib.PublicKeyRequest]gomatrixserverlib.ServerKeys, error) {
return d.statements.bulkSelectServerKeys(requests)
return d.statements.bulkSelectServerKeys(ctx, requests)
}
// StoreKeys implements gomatrixserverlib.KeyDatabase
......@@ -62,7 +62,7 @@ func (d *Database) StoreKeys(
// high for a single insert statement.
var lastErr error
for request, keys := range keyMap {
if err := d.statements.upsertServerKeys(request, keys); err != nil {
if err := d.statements.upsertServerKeys(ctx, request, keys); err != nil {
// Rather than returning immediately on error we try to insert the
// remaining keys.
// Since we are inserting the keys outside of a transaction it is
......
......@@ -15,6 +15,7 @@
package keydb
import (
"context"
"database/sql"
"encoding/json"
......@@ -73,13 +74,15 @@ func (s *serverKeyStatements) prepare(db *sql.DB) (err error) {
}
func (s *serverKeyStatements) bulkSelectServerKeys(
ctx context.Context,
requests map[gomatrixserverlib.PublicKeyRequest]gomatrixserverlib.Timestamp,
) (map[gomatrixserverlib.PublicKeyRequest]gomatrixserverlib.ServerKeys, error) {
var nameAndKeyIDs []string
for request := range requests {
nameAndKeyIDs = append(nameAndKeyIDs, nameAndKeyID(request))
}
rows, err := s.bulkSelectServerKeysStmt.Query(pq.StringArray(nameAndKeyIDs))
stmt := s.bulkSelectServerKeysStmt
rows, err := stmt.QueryContext(ctx, pq.StringArray(nameAndKeyIDs))
if err != nil {
return nil, err
}
......@@ -106,15 +109,21 @@ func (s *serverKeyStatements) bulkSelectServerKeys(
}
func (s *serverKeyStatements) upsertServerKeys(
request gomatrixserverlib.PublicKeyRequest, keys gomatrixserverlib.ServerKeys,
ctx context.Context,
request gomatrixserverlib.PublicKeyRequest,
keys gomatrixserverlib.ServerKeys,
) error {
keyJSON, err := json.Marshal(keys)
if err != nil {
return err
}
_, err = s.upsertServerKeysStmt.Exec(
string(request.ServerName), string(request.KeyID), nameAndKeyID(request),
int64(keys.ValidUntilTS), keyJSON,
_, err = s.upsertServerKeysStmt.ExecContext(
ctx,
string(request.ServerName),
string(request.KeyID),
nameAndKeyID(request),
int64(keys.ValidUntilTS),
keyJSON,
)
return err
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册