提交 029e7182 编写于 作者: M Mark Haines 提交者: GitHub

Add context.Context to the federation client (#225)

* Add context.Context to the federation client

* gb vendor update github.com/matrix-org/gomatrixserverlib
上级 08668345
......@@ -66,7 +66,7 @@ func DirectoryRoom(
}
}
} else {
resp, err = federation.LookupRoomAlias(domain, roomAlias)
resp, err = federation.LookupRoomAlias(req.Context(), domain, roomAlias)
if err != nil {
switch x := err.(type) {
case gomatrix.HTTPError:
......
......@@ -136,7 +136,7 @@ func (r joinRoomReq) joinRoomByAlias(roomAlias string) util.JSONResponse {
func (r joinRoomReq) joinRoomByRemoteAlias(
domain gomatrixserverlib.ServerName, roomAlias string,
) util.JSONResponse {
resp, err := r.federation.LookupRoomAlias(domain, roomAlias)
resp, err := r.federation.LookupRoomAlias(r.req.Context(), domain, roomAlias)
if err != nil {
switch x := err.(type) {
case gomatrix.HTTPError:
......@@ -226,7 +226,7 @@ func (r joinRoomReq) joinRoomUsingServers(
// server was invalid this returns an error.
// Otherwise this returns a JSONResponse.
func (r joinRoomReq) joinRoomUsingServer(roomID string, server gomatrixserverlib.ServerName) (*util.JSONResponse, error) {
respMakeJoin, err := r.federation.MakeJoin(server, roomID, r.userID)
respMakeJoin, err := r.federation.MakeJoin(r.req.Context(), server, roomID, r.userID)
if err != nil {
// TODO: Check if the user was not allowed to join the room.
return nil, err
......@@ -246,12 +246,12 @@ func (r joinRoomReq) joinRoomUsingServer(roomID string, server gomatrixserverlib
return &res, nil
}
respSendJoin, err := r.federation.SendJoin(server, event)
respSendJoin, err := r.federation.SendJoin(r.req.Context(), server, event)
if err != nil {
return nil, err
}
if err = respSendJoin.Check(r.keyRing, event); err != nil {
if err = respSendJoin.Check(r.req.Context(), r.keyRing, event); err != nil {
return nil, err
}
......
......@@ -15,7 +15,9 @@
package keydb
import (
"context"
"database/sql"
"github.com/matrix-org/gomatrixserverlib"
)
......@@ -41,6 +43,7 @@ func NewDatabase(dataSourceName string) (*Database, error) {
// FetchKeys implements gomatrixserverlib.KeyDatabase
func (d *Database) FetchKeys(
ctx context.Context,
requests map[gomatrixserverlib.PublicKeyRequest]gomatrixserverlib.Timestamp,
) (map[gomatrixserverlib.PublicKeyRequest]gomatrixserverlib.ServerKeys, error) {
return d.statements.bulkSelectServerKeys(requests)
......@@ -48,6 +51,7 @@ func (d *Database) FetchKeys(
// StoreKeys implements gomatrixserverlib.KeyDatabase
func (d *Database) StoreKeys(
ctx context.Context,
keyMap map[gomatrixserverlib.PublicKeyRequest]gomatrixserverlib.ServerKeys,
) error {
// TODO: Inserting all the keys within a single transaction may
......
......@@ -76,7 +76,7 @@ func Invite(
Message: event.Redact().JSON(),
AtTS: event.OriginServerTS(),
}}
verifyResults, err := keys.VerifyJSONs(verifyRequests)
verifyResults, err := keys.VerifyJSONs(httpReq.Context(), verifyRequests)
if err != nil {
return httputil.LogThenError(httpReq, err)
}
......
......@@ -15,6 +15,7 @@
package writers
import (
"context"
"encoding/json"
"fmt"
"net/http"
......@@ -41,6 +42,7 @@ func Send(
) util.JSONResponse {
t := txnReq{
context: httpReq.Context(),
query: query,
producer: producer,
keys: keys,
......@@ -70,6 +72,7 @@ func Send(
type txnReq struct {
gomatrixserverlib.Transaction
context context.Context
query api.RoomserverQueryAPI
producer *producers.RoomserverProducer
keys gomatrixserverlib.KeyRing
......@@ -78,7 +81,7 @@ type txnReq struct {
func (t *txnReq) processTransaction() (*gomatrixserverlib.RespSend, error) {
// Check the event signatures
if err := gomatrixserverlib.VerifyEventSignatures(t.PDUs, t.keys); err != nil {
if err := gomatrixserverlib.VerifyEventSignatures(t.context, t.PDUs, t.keys); err != nil {
return nil, err
}
......@@ -110,7 +113,9 @@ func (t *txnReq) processTransaction() (*gomatrixserverlib.RespSend, error) {
// our server so we should bail processing the transaction entirely.
return nil, err
}
results[e.EventID()] = gomatrixserverlib.PDUResult{err.Error()}
results[e.EventID()] = gomatrixserverlib.PDUResult{
Error: err.Error(),
}
} else {
results[e.EventID()] = gomatrixserverlib.PDUResult{}
}
......@@ -197,12 +202,12 @@ func (t *txnReq) processEventWithMissingState(e gomatrixserverlib.Event) error {
// need to fallback to /state.
// TODO: Attempt to fill in the gap using /get_missing_events
// TODO: Attempt to fetch the state using /state_ids and /events
state, err := t.federation.LookupState(t.Origin, e.RoomID(), e.EventID())
state, err := t.federation.LookupState(t.context, t.Origin, e.RoomID(), e.EventID())
if err != nil {
return err
}
// Check that the returned state is valid.
if err := state.Check(t.keys); err != nil {
if err := state.Check(t.context, t.keys); err != nil {
return err
}
// Check that the event is allowed by the state.
......
......@@ -15,6 +15,7 @@
package writers
import (
"context"
"encoding/json"
"errors"
"fmt"
......@@ -63,7 +64,9 @@ func CreateInvitesFrom3PIDInvites(
evs := []gomatrixserverlib.Event{}
for _, inv := range body.Invites {
event, err := createInviteFrom3PIDInvite(queryAPI, cfg, inv, federation)
event, err := createInviteFrom3PIDInvite(
req.Context(), queryAPI, cfg, inv, federation,
)
if err != nil {
return httputil.LogThenError(req, err)
}
......@@ -139,7 +142,7 @@ func ExchangeThirdPartyInvite(
// Ask the requesting server to sign the newly created event so we know it
// acknowledged it
signedEvent, err := federation.SendInvite(request.Origin(), *event)
signedEvent, err := federation.SendInvite(httpReq.Context(), request.Origin(), *event)
if err != nil {
return httputil.LogThenError(httpReq, err)
}
......@@ -160,8 +163,8 @@ func ExchangeThirdPartyInvite(
// Returns an error if there was a problem building the event or fetching the
// necessary data to do so.
func createInviteFrom3PIDInvite(
queryAPI api.RoomserverQueryAPI, cfg config.Dendrite, inv invite,
federation *gomatrixserverlib.FederationClient,
ctx context.Context, queryAPI api.RoomserverQueryAPI, cfg config.Dendrite,
inv invite, federation *gomatrixserverlib.FederationClient,
) (*gomatrixserverlib.Event, error) {
// Build the event
builder := &gomatrixserverlib.EventBuilder{
......@@ -185,7 +188,10 @@ func createInviteFrom3PIDInvite(
event, err := buildMembershipEvent(builder, queryAPI, cfg)
if err == errNotInRoom {
return nil, sendToRemoteServer(inv, federation, cfg, *builder)
return nil, sendToRemoteServer(ctx, inv, federation, cfg, *builder)
}
if err != nil {
return nil, err
}
return event, nil
......@@ -253,7 +259,8 @@ func buildMembershipEvent(
// Returns an error if it couldn't get the server names to reach or if all of
// them responded with an error.
func sendToRemoteServer(
inv invite, federation *gomatrixserverlib.FederationClient, cfg config.Dendrite,
ctx context.Context, inv invite,
federation *gomatrixserverlib.FederationClient, cfg config.Dendrite,
builder gomatrixserverlib.EventBuilder,
) (err error) {
remoteServers := make([]gomatrixserverlib.ServerName, 2)
......@@ -269,7 +276,7 @@ func sendToRemoteServer(
}
for _, server := range remoteServers {
err = federation.ExchangeThirdPartyInvite(server, builder)
err = federation.ExchangeThirdPartyInvite(ctx, server, builder)
if err == nil {
return
}
......
......@@ -15,6 +15,7 @@
package queue
import (
"context"
"fmt"
"sync"
"time"
......@@ -65,7 +66,7 @@ func (oq *destinationQueue) backgroundSend() {
// TODO: handle retries.
// TODO: blacklist uncooperative servers.
_, err := oq.client.SendTransaction(*t)
_, err := oq.client.SendTransaction(context.TODO(), *t)
if err != nil {
log.WithFields(log.Fields{
"destination": oq.destination,
......
......@@ -116,7 +116,7 @@
{
"importpath": "github.com/matrix-org/gomatrixserverlib",
"repository": "https://github.com/matrix-org/gomatrixserverlib",
"revision": "790f02e8f465552dab4317ffe7ca047ccb594cbf",
"revision": "ec5a0d21b03ed4d3bd955ecc9f7a69936f64391e",
"branch": "master"
},
{
......
......@@ -17,6 +17,7 @@ package gomatrixserverlib
import (
"bytes"
"context"
"crypto/tls"
"encoding/json"
"fmt"
......@@ -103,7 +104,9 @@ func (f *federationTripper) RoundTrip(r *http.Request) (*http.Response, error) {
// LookupUserInfo gets information about a user from a given matrix homeserver
// using a bearer access token.
func (fc *Client) LookupUserInfo(matrixServer ServerName, token string) (u UserInfo, err error) {
func (fc *Client) LookupUserInfo(
ctx context.Context, matrixServer ServerName, token string,
) (u UserInfo, err error) {
url := url.URL{
Scheme: "matrix",
Host: string(matrixServer),
......@@ -111,8 +114,13 @@ func (fc *Client) LookupUserInfo(matrixServer ServerName, token string) (u UserI
RawQuery: url.Values{"access_token": []string{token}}.Encode(),
}
req, err := http.NewRequest("GET", url.String(), nil)
if err != nil {
return
}
var response *http.Response
response, err = fc.client.Get(url.String())
response, err = fc.client.Do(req.WithContext(ctx))
if response != nil {
defer response.Body.Close() // nolint: errcheck
}
......@@ -153,7 +161,7 @@ func (fc *Client) LookupUserInfo(matrixServer ServerName, token string) (u UserI
// copy of the keys.
// Returns the keys or an error if there was a problem talking to the server.
func (fc *Client) LookupServerKeys( // nolint: gocyclo
matrixServer ServerName, keyRequests map[PublicKeyRequest]Timestamp,
ctx context.Context, matrixServer ServerName, keyRequests map[PublicKeyRequest]Timestamp,
) (map[PublicKeyRequest]ServerKeys, error) {
url := url.URL{
Scheme: "matrix",
......@@ -183,7 +191,13 @@ func (fc *Client) LookupServerKeys( // nolint: gocyclo
return nil, err
}
response, err := fc.client.Post(url.String(), "application/json", bytes.NewBuffer(requestBytes))
req, err := http.NewRequest("POST", url.String(), bytes.NewBuffer(requestBytes))
if err != nil {
return nil, err
}
req.Header.Add("Content-Type", "application/json")
response, err := fc.client.Do(req.WithContext(ctx))
if response != nil {
defer response.Body.Close() // nolint: errcheck
}
......
......@@ -17,6 +17,7 @@ package gomatrixserverlib
import (
"bytes"
"context"
"crypto/sha256"
"encoding/json"
"fmt"
......@@ -188,7 +189,7 @@ func verifyEventSignature(signingName string, keyID KeyID, publicKey ed25519.Pub
// VerifyEventSignatures checks that each event in a list of events has valid
// signatures from the server that sent it.
func VerifyEventSignatures(events []Event, keyRing KeyRing) error { // nolint: gocyclo
func VerifyEventSignatures(ctx context.Context, events []Event, keyRing KeyRing) error { // nolint: gocyclo
var toVerify []VerifyJSONRequest
for _, event := range events {
redactedJSON, err := redactEvent(event.eventJSON)
......@@ -222,7 +223,7 @@ func VerifyEventSignatures(events []Event, keyRing KeyRing) error { // nolint: g
}
}
results, err := keyRing.VerifyJSONs(toVerify)
results, err := keyRing.VerifyJSONs(ctx, toVerify)
if err != nil {
return err
}
......
package gomatrixserverlib
import (
"context"
"encoding/json"
"io/ioutil"
"net/http"
......@@ -31,7 +32,7 @@ func NewFederationClient(
}
}
func (ac *FederationClient) doRequest(r FederationRequest, resBody interface{}) error {
func (ac *FederationClient) doRequest(ctx context.Context, r FederationRequest, resBody interface{}) error {
if err := r.Sign(ac.serverName, ac.serverKeyID, ac.serverPrivateKey); err != nil {
return err
}
......@@ -41,7 +42,7 @@ func (ac *FederationClient) doRequest(r FederationRequest, resBody interface{})
return err
}
res, err := ac.client.Do(req)
res, err := ac.client.Do(req.WithContext(ctx))
if res != nil {
defer res.Body.Close() // nolint: errcheck
}
......@@ -87,13 +88,15 @@ func (ac *FederationClient) doRequest(r FederationRequest, resBody interface{})
var federationPathPrefix = "/_matrix/federation/v1"
// SendTransaction sends a transaction
func (ac *FederationClient) SendTransaction(t Transaction) (res RespSend, err error) {
func (ac *FederationClient) SendTransaction(
ctx context.Context, t Transaction,
) (res RespSend, err error) {
path := federationPathPrefix + "/send/" + string(t.TransactionID) + "/"
req := NewFederationRequest("PUT", t.Destination, path)
if err = req.SetContent(t); err != nil {
return
}
err = ac.doRequest(req, &res)
err = ac.doRequest(ctx, req, &res)
return
}
......@@ -106,12 +109,14 @@ func (ac *FederationClient) SendTransaction(t Transaction) (res RespSend, err er
// If this successfully returns an acceptable event we will sign it with our
// server's key and pass it to SendJoin.
// See https://matrix.org/docs/spec/server_server/unstable.html#joining-rooms
func (ac *FederationClient) MakeJoin(s ServerName, roomID, userID string) (res RespMakeJoin, err error) {
func (ac *FederationClient) MakeJoin(
ctx context.Context, s ServerName, roomID, userID string,
) (res RespMakeJoin, err error) {
path := federationPathPrefix + "/make_join/" +
url.PathEscape(roomID) + "/" +
url.PathEscape(userID)
req := NewFederationRequest("GET", s, path)
err = ac.doRequest(req, &res)
err = ac.doRequest(ctx, req, &res)
return
}
......@@ -119,7 +124,9 @@ func (ac *FederationClient) MakeJoin(s ServerName, roomID, userID string) (res R
// remote matrix server.
// This is used to join a room the local server isn't a member of.
// See https://matrix.org/docs/spec/server_server/unstable.html#joining-rooms
func (ac *FederationClient) SendJoin(s ServerName, event Event) (res RespSendJoin, err error) {
func (ac *FederationClient) SendJoin(
ctx context.Context, s ServerName, event Event,
) (res RespSendJoin, err error) {
path := federationPathPrefix + "/send_join/" +
url.PathEscape(event.RoomID()) + "/" +
url.PathEscape(event.EventID())
......@@ -127,13 +134,15 @@ func (ac *FederationClient) SendJoin(s ServerName, event Event) (res RespSendJoi
if err = req.SetContent(event); err != nil {
return
}
err = ac.doRequest(req, &res)
err = ac.doRequest(ctx, req, &res)
return
}
// SendInvite sends an invite m.room.member event to an invited server to be
// signed by it. This is used to invite a user that is not on the local server.
func (ac *FederationClient) SendInvite(s ServerName, event Event) (res RespInvite, err error) {
func (ac *FederationClient) SendInvite(
ctx context.Context, s ServerName, event Event,
) (res RespInvite, err error) {
path := federationPathPrefix + "/invite/" +
url.PathEscape(event.RoomID()) + "/" +
url.PathEscape(event.EventID())
......@@ -141,7 +150,7 @@ func (ac *FederationClient) SendInvite(s ServerName, event Event) (res RespInvit
if err = req.SetContent(event); err != nil {
return
}
err = ac.doRequest(req, &res)
err = ac.doRequest(ctx, req, &res)
return
}
......@@ -150,38 +159,44 @@ func (ac *FederationClient) SendInvite(s ServerName, event Event) (res RespInvit
// server.
// This is used to exchange a m.room.third_party_invite event for a m.room.member
// one in a room the local server isn't a member of.
func (ac *FederationClient) ExchangeThirdPartyInvite(s ServerName, builder EventBuilder) (err error) {
func (ac *FederationClient) ExchangeThirdPartyInvite(
ctx context.Context, s ServerName, builder EventBuilder,
) (err error) {
path := federationPathPrefix + "/exchange_third_party_invite/" +
url.PathEscape(builder.RoomID)
req := NewFederationRequest("PUT", s, path)
if err = req.SetContent(builder); err != nil {
return
}
err = ac.doRequest(req, nil)
err = ac.doRequest(ctx, req, nil)
return
}
// LookupState retrieves the room state for a room at an event from a
// remote matrix server as full matrix events.
func (ac *FederationClient) LookupState(s ServerName, roomID, eventID string) (res RespState, err error) {
func (ac *FederationClient) LookupState(
ctx context.Context, s ServerName, roomID, eventID string,
) (res RespState, err error) {
path := federationPathPrefix + "/state/" +
url.PathEscape(roomID) +
"/?event_id=" +
url.QueryEscape(eventID)
req := NewFederationRequest("GET", s, path)
err = ac.doRequest(req, &res)
err = ac.doRequest(ctx, req, &res)
return
}
// LookupStateIDs retrieves the room state for a room at an event from a
// remote matrix server as lists of matrix event IDs.
func (ac *FederationClient) LookupStateIDs(s ServerName, roomID, eventID string) (res RespStateIDs, err error) {
func (ac *FederationClient) LookupStateIDs(
ctx context.Context, s ServerName, roomID, eventID string,
) (res RespStateIDs, err error) {
path := federationPathPrefix + "/state_ids/" +
url.PathEscape(roomID) +
"/?event_id=" +
url.QueryEscape(eventID)
req := NewFederationRequest("GET", s, path)
err = ac.doRequest(req, &res)
err = ac.doRequest(ctx, req, &res)
return
}
......@@ -190,10 +205,12 @@ func (ac *FederationClient) LookupStateIDs(s ServerName, roomID, eventID string)
// being looked up on.
// If the room alias doesn't exist on the remote server then a 404 gomatrix.HTTPError
// is returned.
func (ac *FederationClient) LookupRoomAlias(s ServerName, roomAlias string) (res RespDirectory, err error) {
func (ac *FederationClient) LookupRoomAlias(
ctx context.Context, s ServerName, roomAlias string,
) (res RespDirectory, err error) {
path := federationPathPrefix + "/query/directory?room_alias=" +
url.QueryEscape(roomAlias)
req := NewFederationRequest("GET", s, path)
err = ac.doRequest(req, &res)
err = ac.doRequest(ctx, req, &res)
return
}
package gomatrixserverlib
import (
"context"
"encoding/json"
"fmt"
)
......@@ -107,7 +108,7 @@ func (r RespState) Events() ([]Event, error) {
}
// Check that a response to /state is valid.
func (r RespState) Check(keyRing KeyRing) error {
func (r RespState) Check(ctx context.Context, keyRing KeyRing) error {
var allEvents []Event
for _, event := range r.AuthEvents {
if event.StateKey() == nil {
......@@ -133,7 +134,7 @@ func (r RespState) Check(keyRing KeyRing) error {
}
// Check if the events pass signature checks.
if err := VerifyEventSignatures(allEvents, keyRing); err != nil {
if err := VerifyEventSignatures(ctx, allEvents, keyRing); err != nil {
return nil
}
......@@ -213,11 +214,11 @@ type respSendJoinFields struct {
// Check that a response to /send_join is valid.
// This checks that it would be valid as a response to /state
// This also checks that the join event is allowed by the state.
func (r RespSendJoin) Check(keyRing KeyRing, joinEvent Event) error {
func (r RespSendJoin) Check(ctx context.Context, keyRing KeyRing, joinEvent Event) error {
// First check that the state is valid.
// The response to /send_join has the same data as a response to /state
// and the checks for a response to /state also apply.
if err := RespState(r).Check(keyRing); err != nil {
if err := RespState(r).Check(ctx, keyRing); err != nil {
return err
}
......
......@@ -6,13 +6,13 @@ echo "Installing lint search engine..."
go get github.com/alecthomas/gometalinter/
gometalinter --config=linter.json --install --update
echo "Testing..."
go test
echo "Looking for lint..."
gometalinter --config=linter.json
echo "Double checking spelling..."
misspell -error src *.md
echo "Testing..."
go test
echo "Done!"
package gomatrixserverlib
import (
"context"
"fmt"
"strings"
"time"
......@@ -26,7 +27,7 @@ type KeyFetcher interface {
// The result may have fewer (server name, key ID) pairs than were in the request.
// The result may have more (server name, key ID) pairs than were in the request.
// Returns an error if there was a problem fetching the keys.
FetchKeys(requests map[PublicKeyRequest]Timestamp) (map[PublicKeyRequest]ServerKeys, error)
FetchKeys(ctx context.Context, requests map[PublicKeyRequest]Timestamp) (map[PublicKeyRequest]ServerKeys, error)
}
// A KeyDatabase is a store for caching public keys.
......@@ -39,7 +40,7 @@ type KeyDatabase interface {
// to a concurrent FetchKeys(). This is acceptable since the database is
// only used as a cache for the keys, so if a FetchKeys() races with a
// StoreKeys() and some of the keys are missing they will be just be refetched.
StoreKeys(map[PublicKeyRequest]ServerKeys) error
StoreKeys(ctx context.Context, results map[PublicKeyRequest]ServerKeys) error
}
// A KeyRing stores keys for matrix servers and provides methods for verifying JSON messages.
......@@ -73,7 +74,7 @@ type VerifyJSONResult struct {
// The caller should check the Result field for each entry to see if it was valid.
// Returns an error if there was a problem talking to the database or one of the other methods
// of fetching the public keys.
func (k *KeyRing) VerifyJSONs(requests []VerifyJSONRequest) ([]VerifyJSONResult, error) { // nolint: gocyclo
func (k *KeyRing) VerifyJSONs(ctx context.Context, requests []VerifyJSONRequest) ([]VerifyJSONResult, error) { // nolint: gocyclo
results := make([]VerifyJSONResult, len(requests))
keyIDs := make([][]KeyID, len(requests))
......@@ -109,7 +110,7 @@ func (k *KeyRing) VerifyJSONs(requests []VerifyJSONRequest) ([]VerifyJSONResult,
// This will happen if all the objects are missing supported signatures.
return results, nil
}
keysFromDatabase, err := k.KeyDatabase.FetchKeys(keyRequests)
keysFromDatabase, err := k.KeyDatabase.FetchKeys(ctx, keyRequests)
if err != nil {
return nil, err
}
......@@ -124,14 +125,14 @@ func (k *KeyRing) VerifyJSONs(requests []VerifyJSONRequest) ([]VerifyJSONResult,
}
// TODO: Coalesce in-flight requests for the same keys.
// Otherwise we risk spamming the servers we query the keys from.
keysFetched, err := k.KeyFetchers[i].FetchKeys(keyRequests)
keysFetched, err := k.KeyFetchers[i].FetchKeys(ctx, keyRequests)
if err != nil {
return nil, err
}
k.checkUsingKeys(requests, results, keyIDs, keysFetched)
// Add the keys to the database so that we won't need to fetch them again.
if err := k.KeyDatabase.StoreKeys(keysFetched); err != nil {
if err := k.KeyDatabase.StoreKeys(ctx, keysFetched); err != nil {
return nil, err
}
}
......@@ -143,7 +144,9 @@ func (k *KeyRing) isAlgorithmSupported(keyID KeyID) bool {
return strings.HasPrefix(string(keyID), "ed25519:")
}
func (k *KeyRing) publicKeyRequests(requests []VerifyJSONRequest, results []VerifyJSONResult, keyIDs [][]KeyID) map[PublicKeyRequest]Timestamp {
func (k *KeyRing) publicKeyRequests(
requests []VerifyJSONRequest, results []VerifyJSONResult, keyIDs [][]KeyID,
) map[PublicKeyRequest]Timestamp {
keyRequests := map[PublicKeyRequest]Timestamp{}
for i := range requests {
if results[i].Error == nil {
......@@ -218,8 +221,10 @@ type PerspectiveKeyFetcher struct {
}
// FetchKeys implements KeyFetcher
func (p *PerspectiveKeyFetcher) FetchKeys(requests map[PublicKeyRequest]Timestamp) (map[PublicKeyRequest]ServerKeys, error) {
results, err := p.Client.LookupServerKeys(p.PerspectiveServerName, requests)
func (p *PerspectiveKeyFetcher) FetchKeys(
ctx context.Context, requests map[PublicKeyRequest]Timestamp,
) (map[PublicKeyRequest]ServerKeys, error) {
results, err := p.Client.LookupServerKeys(ctx, p.PerspectiveServerName, requests)
if err != nil {
return nil, err
}
......@@ -269,7 +274,9 @@ type DirectKeyFetcher struct {
}
// FetchKeys implements KeyFetcher
func (d *DirectKeyFetcher) FetchKeys(requests map[PublicKeyRequest]Timestamp) (map[PublicKeyRequest]ServerKeys, error) {
func (d *DirectKeyFetcher) FetchKeys(
ctx context.Context, requests map[PublicKeyRequest]Timestamp,
) (map[PublicKeyRequest]ServerKeys, error) {
byServer := map[ServerName]map[PublicKeyRequest]Timestamp{}
for req, ts := range requests {
server := byServer[req.ServerName]
......@@ -283,7 +290,7 @@ func (d *DirectKeyFetcher) FetchKeys(requests map[PublicKeyRequest]Timestamp) (m
results := map[PublicKeyRequest]ServerKeys{}
for server, reqs := range byServer {
// TODO: make these requests in parallel
serverResults, err := d.fetchKeysForServer(server, reqs)
serverResults, err := d.fetchKeysForServer(ctx, server, reqs)
if err != nil {
// TODO: Should we actually be erroring here? or should we just drop those keys from the result map?
return nil, err
......@@ -296,9 +303,9 @@ func (d *DirectKeyFetcher) FetchKeys(requests map[PublicKeyRequest]Timestamp) (m
}
func (d *DirectKeyFetcher) fetchKeysForServer(
serverName ServerName, requests map[PublicKeyRequest]Timestamp,
ctx context.Context, serverName ServerName, requests map[PublicKeyRequest]Timestamp,
) (map[PublicKeyRequest]ServerKeys, error) {
results, err := d.Client.LookupServerKeys(serverName, requests)
results, err := d.Client.LookupServerKeys(ctx, serverName, requests)
if err != nil {
return nil, err
}
......
package gomatrixserverlib
import (
"context"
"encoding/json"
"testing"
)
......@@ -36,7 +37,9 @@ var testKeys = `{
type testKeyDatabase struct{}
func (db *testKeyDatabase) FetchKeys(requests map[PublicKeyRequest]Timestamp) (map[PublicKeyRequest]ServerKeys, error) {
func (db *testKeyDatabase) FetchKeys(
ctx context.Context, requests map[PublicKeyRequest]Timestamp,
) (map[PublicKeyRequest]ServerKeys, error) {
results := map[PublicKeyRequest]ServerKeys{}
var keys ServerKeys
if err := json.Unmarshal([]byte(testKeys), &keys); err != nil {
......@@ -54,14 +57,16 @@ func (db *testKeyDatabase) FetchKeys(requests map[PublicKeyRequest]Timestamp) (m
return results, nil
}
func (db *testKeyDatabase) StoreKeys(requests map[PublicKeyRequest]ServerKeys) error {
func (db *testKeyDatabase) StoreKeys(
ctx context.Context, requests map[PublicKeyRequest]ServerKeys,
) error {
return nil
}
func TestVerifyJSONsSuccess(t *testing.T) {
// Check that trying to verify the server key JSON works.
k := KeyRing{nil, &testKeyDatabase{}}
results, err := k.VerifyJSONs([]VerifyJSONRequest{{
results, err := k.VerifyJSONs(context.Background(), []VerifyJSONRequest{{
ServerName: "localhost:8800",
Message: []byte(testKeys),
AtTS: 1493142432964,
......@@ -77,7 +82,7 @@ func TestVerifyJSONsSuccess(t *testing.T) {
func TestVerifyJSONsUnknownServerFails(t *testing.T) {
// Check that trying to verify JSON for an unknown server fails.
k := KeyRing{nil, &testKeyDatabase{}}
results, err := k.VerifyJSONs([]VerifyJSONRequest{{
results, err := k.VerifyJSONs(context.Background(), []VerifyJSONRequest{{
ServerName: "unknown:8800",
Message: []byte(testKeys),
AtTS: 1493142432964,
......@@ -94,7 +99,7 @@ func TestVerifyJSONsDistantFutureFails(t *testing.T) {
// Check that trying to verify JSON from the distant future fails.
distantFuture := Timestamp(2000000000000)
k := KeyRing{nil, &testKeyDatabase{}}
results, err := k.VerifyJSONs([]VerifyJSONRequest{{
results, err := k.VerifyJSONs(context.Background(), []VerifyJSONRequest{{
ServerName: "unknown:8800",
Message: []byte(testKeys),
AtTS: distantFuture,
......@@ -110,7 +115,7 @@ func TestVerifyJSONsDistantFutureFails(t *testing.T) {
func TestVerifyJSONsFetcherError(t *testing.T) {
// Check that if the database errors then the attempt to verify JSON fails.
k := KeyRing{nil, &erroringKeyDatabase{}}
results, err := k.VerifyJSONs([]VerifyJSONRequest{{
results, err := k.VerifyJSONs(context.Background(), []VerifyJSONRequest{{
ServerName: "localhost:8800",
Message: []byte(testKeys),
AtTS: 1493142432964,
......@@ -129,10 +134,14 @@ func (e *erroringKeyDatabaseError) Error() string { return "An error with the ke
var testErrorFetch = erroringKeyDatabaseError(1)
var testErrorStore = erroringKeyDatabaseError(2)
func (e *erroringKeyDatabase) FetchKeys(requests map[PublicKeyRequest]Timestamp) (map[PublicKeyRequest]ServerKeys, error) {
func (e *erroringKeyDatabase) FetchKeys(
ctx context.Context, requests map[PublicKeyRequest]Timestamp,
) (map[PublicKeyRequest]ServerKeys, error) {
return nil, &testErrorFetch
}
func (e *erroringKeyDatabase) StoreKeys(keys map[PublicKeyRequest]ServerKeys) error {
func (e *erroringKeyDatabase) StoreKeys(
ctx context.Context, keys map[PublicKeyRequest]ServerKeys,
) error {
return &testErrorStore
}
{
"Deadline": "5m",
"Enable": [
"vet",
"vetshadow",
......
......@@ -215,7 +215,7 @@ func VerifyHTTPRequest(
return nil, util.MessageResponse(401, message)
}
results, err := keys.VerifyJSONs([]VerifyJSONRequest{{
results, err := keys.VerifyJSONs(req.Context(), []VerifyJSONRequest{{
ServerName: request.Origin(),
AtTS: AsTimestamp(now),
Message: toVerify,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册