提交 238646ee 编写于 作者: M Mark Haines 提交者: GitHub

Add contexts to device database (#233)

* Add contexts to device database

* Remove spurious whitespace
上级 e28ee276
...@@ -7,6 +7,16 @@ export GOGC=400 ...@@ -7,6 +7,16 @@ export GOGC=400
export GOPATH="$(pwd):$(pwd)/vendor" export GOPATH="$(pwd):$(pwd)/vendor"
export PATH="$PATH:$(pwd)/vendor/bin:$(pwd)/bin" export PATH="$PATH:$(pwd)/vendor/bin:$(pwd)/bin"
echo "Checking that it builds"
gb build
# Check that all the packages can build.
# When `go build` is given multiple packages it won't output anything, and just
# checks that everything builds. This seems to do a better job of handling
# missing imports than `gb build` does.
echo "Double checking it builds..."
go build github.com/matrix-org/dendrite/cmd/...
echo "Installing lint search engine..." echo "Installing lint search engine..."
go install github.com/alecthomas/gometalinter/ go install github.com/alecthomas/gometalinter/
gometalinter --config=linter.json ./... --install gometalinter --config=linter.json ./... --install
...@@ -20,11 +30,5 @@ misspell -error src *.md ...@@ -20,11 +30,5 @@ misspell -error src *.md
echo "Testing..." echo "Testing..."
gb test gb test
# Check that all the packages can build.
# When `go build` is given multiple packages it won't output anything, and just
# checks that everything builds. This seems to do a better job of handling
# missing imports than `gb build` does.
echo "Double checking it builds..."
go build github.com/matrix-org/dendrite/cmd/...
echo "Done!" echo "Done!"
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
package auth package auth
import ( import (
"context"
"crypto/rand" "crypto/rand"
"database/sql" "database/sql"
"encoding/base64" "encoding/base64"
...@@ -42,7 +43,7 @@ var tokenByteLength = 32 ...@@ -42,7 +43,7 @@ var tokenByteLength = 32
// DeviceDatabase represents a device database. // DeviceDatabase represents a device database.
type DeviceDatabase interface { type DeviceDatabase interface {
// Look up the device matching the given access token. // Look up the device matching the given access token.
GetDeviceByAccessToken(token string) (*authtypes.Device, error) GetDeviceByAccessToken(ctx context.Context, token string) (*authtypes.Device, error)
} }
// VerifyAccessToken verifies that an access token was supplied in the given HTTP request // VerifyAccessToken verifies that an access token was supplied in the given HTTP request
...@@ -57,7 +58,7 @@ func VerifyAccessToken(req *http.Request, deviceDB DeviceDatabase) (device *auth ...@@ -57,7 +58,7 @@ func VerifyAccessToken(req *http.Request, deviceDB DeviceDatabase) (device *auth
} }
return return
} }
device, err = deviceDB.GetDeviceByAccessToken(token) device, err = deviceDB.GetDeviceByAccessToken(req.Context(), token)
if err != nil { if err != nil {
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
resErr = &util.JSONResponse{ resErr = &util.JSONResponse{
......
...@@ -15,10 +15,13 @@ ...@@ -15,10 +15,13 @@
package devices package devices
import ( import (
"context"
"database/sql" "database/sql"
"fmt" "fmt"
"time" "time"
"github.com/matrix-org/dendrite/common"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
) )
...@@ -84,27 +87,36 @@ func (s *devicesStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerN ...@@ -84,27 +87,36 @@ func (s *devicesStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerN
// insertDevice creates a new device. Returns an error if any device with the same access token already exists. // insertDevice creates a new device. Returns an error if any device with the same access token already exists.
// Returns an error if the user already has a device with the given device ID. // Returns an error if the user already has a device with the given device ID.
// Returns the device on success. // Returns the device on success.
func (s *devicesStatements) insertDevice(txn *sql.Tx, id, localpart, accessToken string) (dev *authtypes.Device, err error) { func (s *devicesStatements) insertDevice(
ctx context.Context, txn *sql.Tx, id, localpart, accessToken string,
) (*authtypes.Device, error) {
createdTimeMS := time.Now().UnixNano() / 1000000 createdTimeMS := time.Now().UnixNano() / 1000000
if _, err = txn.Stmt(s.insertDeviceStmt).Exec(id, localpart, accessToken, createdTimeMS); err == nil { stmt := common.TxStmt(txn, s.insertDeviceStmt)
dev = &authtypes.Device{ if _, err := stmt.ExecContext(ctx, id, localpart, accessToken, createdTimeMS); err != nil {
return nil, err
}
return &authtypes.Device{
ID: id, ID: id,
UserID: makeUserID(localpart, s.serverName), UserID: makeUserID(localpart, s.serverName),
AccessToken: accessToken, AccessToken: accessToken,
} }, nil
}
return
} }
func (s *devicesStatements) deleteDevice(txn *sql.Tx, id, localpart string) error { func (s *devicesStatements) deleteDevice(
_, err := txn.Stmt(s.deleteDeviceStmt).Exec(id, localpart) ctx context.Context, txn *sql.Tx, id, localpart string,
) error {
stmt := common.TxStmt(txn, s.deleteDeviceStmt)
_, err := stmt.ExecContext(ctx, id, localpart)
return err return err
} }
func (s *devicesStatements) selectDeviceByToken(accessToken string) (*authtypes.Device, error) { func (s *devicesStatements) selectDeviceByToken(
ctx context.Context, accessToken string,
) (*authtypes.Device, error) {
var dev authtypes.Device var dev authtypes.Device
var localpart string var localpart string
err := s.selectDeviceByTokenStmt.QueryRow(accessToken).Scan(&dev.ID, &localpart) stmt := s.selectDeviceByTokenStmt
err := stmt.QueryRowContext(ctx, accessToken).Scan(&dev.ID, &localpart)
if err == nil { if err == nil {
dev.UserID = makeUserID(localpart, s.serverName) dev.UserID = makeUserID(localpart, s.serverName)
dev.AccessToken = accessToken dev.AccessToken = accessToken
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
package devices package devices
import ( import (
"context"
"database/sql" "database/sql"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/auth/authtypes"
...@@ -44,8 +45,10 @@ func NewDatabase(dataSourceName string, serverName gomatrixserverlib.ServerName) ...@@ -44,8 +45,10 @@ func NewDatabase(dataSourceName string, serverName gomatrixserverlib.ServerName)
// GetDeviceByAccessToken returns the device matching the given access token. // GetDeviceByAccessToken returns the device matching the given access token.
// Returns sql.ErrNoRows if no matching device was found. // Returns sql.ErrNoRows if no matching device was found.
func (d *Database) GetDeviceByAccessToken(token string) (*authtypes.Device, error) { func (d *Database) GetDeviceByAccessToken(
return d.devices.selectDeviceByToken(token) ctx context.Context, token string,
) (*authtypes.Device, error) {
return d.devices.selectDeviceByToken(ctx, token)
} }
// CreateDevice makes a new device associated with the given user ID localpart. // CreateDevice makes a new device associated with the given user ID localpart.
...@@ -53,15 +56,17 @@ func (d *Database) GetDeviceByAccessToken(token string) (*authtypes.Device, erro ...@@ -53,15 +56,17 @@ func (d *Database) GetDeviceByAccessToken(token string) (*authtypes.Device, erro
// and replaced with the given accessToken. If the given accessToken is already in use for another device, // and replaced with the given accessToken. If the given accessToken is already in use for another device,
// an error will be returned. // an error will be returned.
// Returns the device on success. // Returns the device on success.
func (d *Database) CreateDevice(localpart, deviceID, accessToken string) (dev *authtypes.Device, returnErr error) { func (d *Database) CreateDevice(
ctx context.Context, localpart, deviceID, accessToken string,
) (dev *authtypes.Device, returnErr error) {
returnErr = common.WithTransaction(d.db, func(txn *sql.Tx) error { returnErr = common.WithTransaction(d.db, func(txn *sql.Tx) error {
var err error var err error
// Revoke existing token for this device // Revoke existing token for this device
if err = d.devices.deleteDevice(txn, deviceID, localpart); err != nil { if err = d.devices.deleteDevice(ctx, txn, deviceID, localpart); err != nil {
return err return err
} }
dev, err = d.devices.insertDevice(txn, deviceID, localpart, accessToken) dev, err = d.devices.insertDevice(ctx, txn, deviceID, localpart, accessToken)
if err != nil { if err != nil {
return err return err
} }
...@@ -74,9 +79,11 @@ func (d *Database) CreateDevice(localpart, deviceID, accessToken string) (dev *a ...@@ -74,9 +79,11 @@ func (d *Database) CreateDevice(localpart, deviceID, accessToken string) (dev *a
// matching with the given device ID and user ID localpart // matching with the given device ID and user ID localpart
// If the device doesn't exist, it will not return an error // If the device doesn't exist, it will not return an error
// If something went wrong during the deletion, it will return the SQL error // If something went wrong during the deletion, it will return the SQL error
func (d *Database) RemoveDevice(deviceID string, localpart string) error { func (d *Database) RemoveDevice(
ctx context.Context, deviceID, localpart string,
) error {
return common.WithTransaction(d.db, func(txn *sql.Tx) error { return common.WithTransaction(d.db, func(txn *sql.Tx) error {
if err := d.devices.deleteDevice(txn, deviceID, localpart); err != sql.ErrNoRows { if err := d.devices.deleteDevice(ctx, txn, deviceID, localpart); err != sql.ErrNoRows {
return err return err
} }
return nil return nil
......
...@@ -98,7 +98,9 @@ func Login( ...@@ -98,7 +98,9 @@ func Login(
} }
// TODO: Use the device ID in the request // TODO: Use the device ID in the request
dev, err := deviceDB.CreateDevice(acc.Localpart, auth.UnknownDeviceID, token) dev, err := deviceDB.CreateDevice(
req.Context(), acc.Localpart, auth.UnknownDeviceID, token,
)
if err != nil { if err != nil {
return util.JSONResponse{ return util.JSONResponse{
Code: 500, Code: 500,
......
...@@ -41,7 +41,7 @@ func Logout( ...@@ -41,7 +41,7 @@ func Logout(
return httputil.LogThenError(req, err) return httputil.LogThenError(req, err)
} }
if err := deviceDB.RemoveDevice(device.ID, localpart); err != nil { if err := deviceDB.RemoveDevice(req.Context(), device.ID, localpart); err != nil {
return httputil.LogThenError(req, err) return httputil.LogThenError(req, err)
} }
......
...@@ -135,9 +135,7 @@ func Register(req *http.Request, accountDB *accounts.Database, deviceDB *devices ...@@ -135,9 +135,7 @@ func Register(req *http.Request, accountDB *accounts.Database, deviceDB *devices
switch r.Auth.Type { switch r.Auth.Type {
case authtypes.LoginTypeDummy: case authtypes.LoginTypeDummy:
// there is nothing to do // there is nothing to do
return completeRegistration( return completeRegistration(req.Context(), accountDB, deviceDB, r.Username, r.Password)
req.Context(), accountDB, deviceDB, r.Username, r.Password,
)
default: default:
return util.JSONResponse{ return util.JSONResponse{
Code: 501, Code: 501,
...@@ -182,7 +180,7 @@ func completeRegistration( ...@@ -182,7 +180,7 @@ func completeRegistration(
} }
// // TODO: Use the device ID in the request. // // TODO: Use the device ID in the request.
dev, err := deviceDB.CreateDevice(username, auth.UnknownDeviceID, token) dev, err := deviceDB.CreateDevice(ctx, username, auth.UnknownDeviceID, token)
if err != nil { if err != nil {
return util.JSONResponse{ return util.JSONResponse{
Code: 500, Code: 500,
......
...@@ -86,7 +86,9 @@ func main() { ...@@ -86,7 +86,9 @@ func main() {
accessToken = &t accessToken = &t
} }
device, err := deviceDB.CreateDevice(*username, "create-account-script", *accessToken) device, err := deviceDB.CreateDevice(
context.Background(), *username, "create-account-script", *accessToken,
)
if err != nil { if err != nil {
fmt.Println(err.Error()) fmt.Println(err.Error())
os.Exit(1) os.Exit(1)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册