devices_table.go 6.9 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
// Copyright 2017 Vector Creations Ltd
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package devices

import (
18
	"context"
19 20 21 22
	"database/sql"
	"fmt"
	"time"

23 24
	"github.com/matrix-org/dendrite/common"

25 26 27 28 29 30
	"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
	"github.com/matrix-org/gomatrixserverlib"
)

const devicesSchema = `
-- Stores data about devices.
31
CREATE TABLE IF NOT EXISTS device_devices (
32 33 34 35 36 37 38 39 40 41 42
    -- The access token granted to this device. This has to be the primary key
    -- so we can distinguish which device is making a given request.
    access_token TEXT NOT NULL PRIMARY KEY,
    -- The device identifier. This only needs to uniquely identify a device for a given user, not globally.
    -- access_tokens will be clobbered based on the device ID for a user.
    device_id TEXT NOT NULL,
    -- The Matrix user ID localpart for this device. This is preferable to storing the full user_id
    -- as it is smaller, makes it clearer that we only manage devices for our own users, and may make
    -- migration to different domain names easier.
    localpart TEXT NOT NULL,
    -- When this devices was first recognised on the network, as a unix timestamp (ms resolution).
P
Paul Tötterman 已提交
43 44 45
    created_ts BIGINT NOT NULL,
    -- The display name, human friendlier than device_id and updatable
    display_name TEXT
46 47 48 49
    -- TODO: device keys, device display names, last used ts and IP address?, token restrictions (if 3rd-party OAuth app)
);

-- Device IDs must be unique for a given user.
50
CREATE UNIQUE INDEX IF NOT EXISTS device_localpart_id_idx ON device_devices(localpart, device_id);
51 52 53
`

const insertDeviceSQL = "" +
P
Paul Tötterman 已提交
54
	"INSERT INTO device_devices(device_id, localpart, access_token, created_ts, display_name) VALUES ($1, $2, $3, $4, $5)"
55 56

const selectDeviceByTokenSQL = "" +
P
Paul Tötterman 已提交
57
	"SELECT device_id, localpart, display_name FROM device_devices WHERE access_token = $1"
58

59
const selectDeviceByIDSQL = "" +
P
Paul Tötterman 已提交
60
	"SELECT display_name FROM device_devices WHERE localpart = $1 and device_id = $2"
61 62

const selectDevicesByLocalpartSQL = "" +
P
Paul Tötterman 已提交
63 64 65 66
	"SELECT device_id, display_name FROM device_devices WHERE localpart = $1"

const updateDeviceNameSQL = "" +
	"UPDATE device_devices SET display_name = $1 WHERE localpart = $2 AND device_id = $3"
67

68
const deleteDeviceSQL = "" +
69
	"DELETE FROM device_devices WHERE device_id = $1 AND localpart = $2"
70

R
Remi Reuvekamp 已提交
71 72 73
const deleteDevicesByLocalpartSQL = "" +
	"DELETE FROM device_devices WHERE localpart = $1"

74
type devicesStatements struct {
R
Remi Reuvekamp 已提交
75 76
	insertDeviceStmt             *sql.Stmt
	selectDeviceByTokenStmt      *sql.Stmt
77 78
	selectDeviceByIDStmt         *sql.Stmt
	selectDevicesByLocalpartStmt *sql.Stmt
P
Paul Tötterman 已提交
79
	updateDeviceNameStmt         *sql.Stmt
R
Remi Reuvekamp 已提交
80 81
	deleteDeviceStmt             *sql.Stmt
	deleteDevicesByLocalpartStmt *sql.Stmt
82
	serverName                   gomatrixserverlib.ServerName
83 84 85 86 87 88 89 90 91 92 93 94 95
}

func (s *devicesStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerName) (err error) {
	_, err = db.Exec(devicesSchema)
	if err != nil {
		return
	}
	if s.insertDeviceStmt, err = db.Prepare(insertDeviceSQL); err != nil {
		return
	}
	if s.selectDeviceByTokenStmt, err = db.Prepare(selectDeviceByTokenSQL); err != nil {
		return
	}
96 97 98 99 100 101
	if s.selectDeviceByIDStmt, err = db.Prepare(selectDeviceByIDSQL); err != nil {
		return
	}
	if s.selectDevicesByLocalpartStmt, err = db.Prepare(selectDevicesByLocalpartSQL); err != nil {
		return
	}
P
Paul Tötterman 已提交
102 103 104
	if s.updateDeviceNameStmt, err = db.Prepare(updateDeviceNameSQL); err != nil {
		return
	}
105 106 107
	if s.deleteDeviceStmt, err = db.Prepare(deleteDeviceSQL); err != nil {
		return
	}
R
Remi Reuvekamp 已提交
108 109 110
	if s.deleteDevicesByLocalpartStmt, err = db.Prepare(deleteDevicesByLocalpartSQL); err != nil {
		return
	}
111 112 113 114 115 116 117
	s.serverName = server
	return
}

// insertDevice creates a new device. Returns an error if any device with the same access token already exists.
// Returns an error if the user already has a device with the given device ID.
// Returns the device on success.
118 119
func (s *devicesStatements) insertDevice(
	ctx context.Context, txn *sql.Tx, id, localpart, accessToken string,
P
Paul Tötterman 已提交
120
	displayName *string,
121
) (*authtypes.Device, error) {
122
	createdTimeMS := time.Now().UnixNano() / 1000000
123
	stmt := common.TxStmt(txn, s.insertDeviceStmt)
P
Paul Tötterman 已提交
124
	if _, err := stmt.ExecContext(ctx, id, localpart, accessToken, createdTimeMS, displayName); err != nil {
125
		return nil, err
126
	}
127 128 129 130 131
	return &authtypes.Device{
		ID:          id,
		UserID:      makeUserID(localpart, s.serverName),
		AccessToken: accessToken,
	}, nil
132 133
}

134 135 136 137 138
func (s *devicesStatements) deleteDevice(
	ctx context.Context, txn *sql.Tx, id, localpart string,
) error {
	stmt := common.TxStmt(txn, s.deleteDeviceStmt)
	_, err := stmt.ExecContext(ctx, id, localpart)
139 140 141
	return err
}

R
Remi Reuvekamp 已提交
142 143 144 145 146 147 148 149
func (s *devicesStatements) deleteDevicesByLocalpart(
	ctx context.Context, txn *sql.Tx, localpart string,
) error {
	stmt := common.TxStmt(txn, s.deleteDevicesByLocalpartStmt)
	_, err := stmt.ExecContext(ctx, localpart)
	return err
}

P
Paul Tötterman 已提交
150 151 152 153 154 155 156 157
func (s *devicesStatements) updateDeviceName(
	ctx context.Context, txn *sql.Tx, localpart, deviceID string, displayName *string,
) error {
	stmt := common.TxStmt(txn, s.updateDeviceNameStmt)
	_, err := stmt.ExecContext(ctx, displayName, localpart, deviceID)
	return err
}

158 159 160
func (s *devicesStatements) selectDeviceByToken(
	ctx context.Context, accessToken string,
) (*authtypes.Device, error) {
161 162
	var dev authtypes.Device
	var localpart string
163 164
	stmt := s.selectDeviceByTokenStmt
	err := stmt.QueryRowContext(ctx, accessToken).Scan(&dev.ID, &localpart)
165 166 167 168 169 170 171
	if err == nil {
		dev.UserID = makeUserID(localpart, s.serverName)
		dev.AccessToken = accessToken
	}
	return &dev, err
}

172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209
func (s *devicesStatements) selectDeviceByID(
	ctx context.Context, localpart, deviceID string,
) (*authtypes.Device, error) {
	var dev authtypes.Device
	var created int64
	stmt := s.selectDeviceByIDStmt
	err := stmt.QueryRowContext(ctx, localpart, deviceID).Scan(&created)
	if err == nil {
		dev.ID = deviceID
		dev.UserID = makeUserID(localpart, s.serverName)
	}
	return &dev, err
}

func (s *devicesStatements) selectDevicesByLocalpart(
	ctx context.Context, localpart string,
) ([]authtypes.Device, error) {
	devices := []authtypes.Device{}

	rows, err := s.selectDevicesByLocalpartStmt.QueryContext(ctx, localpart)

	if err != nil {
		return devices, err
	}

	for rows.Next() {
		var dev authtypes.Device
		err = rows.Scan(&dev.ID)
		if err != nil {
			return devices, err
		}
		dev.UserID = makeUserID(localpart, s.serverName)
		devices = append(devices, dev)
	}

	return devices, nil
}

210 211 212
func makeUserID(localpart string, server gomatrixserverlib.ServerName) string {
	return fmt.Sprintf("@%s:%s", localpart, string(server))
}