server_key_table.go 4.0 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
// Copyright 2017 Vector Creations Ltd
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package keydb

import (
	"database/sql"
	"encoding/json"
20

21 22 23 24 25 26
	"github.com/lib/pq"
	"github.com/matrix-org/gomatrixserverlib"
)

const serverKeysSchema = `
-- A cache of server keys downloaded from remote servers.
27
CREATE TABLE IF NOT EXISTS keydb_server_keys (
28 29 30 31 32 33 34 35 36 37 38
	-- The name of the matrix server the key is for.
	server_name TEXT NOT NULL,
	-- The ID of the server key.
	server_key_id TEXT NOT NULL,
	-- Combined server name and key ID separated by the ASCII unit separator
	-- to make it easier to run bulk queries.
	server_name_and_key_id TEXT NOT NULL,
	-- When the keys are valid until as a millisecond timestamp.
	valid_until_ts BIGINT NOT NULL,
	-- The raw JSON for the server key.
	server_key_json TEXT NOT NULL,
39
	CONSTRAINT keydb_server_keys_unique UNIQUE (server_name, server_key_id)
40 41
);

42
CREATE INDEX IF NOT EXISTS keydb_server_name_and_key_id ON keydb_server_keys (server_name_and_key_id);
43 44 45
`

const bulkSelectServerKeysSQL = "" +
46
	"SELECT server_name, server_key_id, server_key_json FROM keydb_server_keys" +
47 48 49
	" WHERE server_name_and_key_id = ANY($1)"

const upsertServerKeysSQL = "" +
50
	"INSERT INTO keydb_server_keys (server_name, server_key_id," +
51 52
	" server_name_and_key_id, valid_until_ts, server_key_json)" +
	" VALUES ($1, $2, $3, $4, $5)" +
53
	" ON CONFLICT ON CONSTRAINT keydb_server_keys_unique" +
54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85
	" DO UPDATE SET valid_until_ts = $4, server_key_json = $5"

type serverKeyStatements struct {
	bulkSelectServerKeysStmt *sql.Stmt
	upsertServerKeysStmt     *sql.Stmt
}

func (s *serverKeyStatements) prepare(db *sql.DB) (err error) {
	_, err = db.Exec(serverKeysSchema)
	if err != nil {
		return
	}
	if s.bulkSelectServerKeysStmt, err = db.Prepare(bulkSelectServerKeysSQL); err != nil {
		return
	}
	if s.upsertServerKeysStmt, err = db.Prepare(upsertServerKeysSQL); err != nil {
		return
	}
	return
}

func (s *serverKeyStatements) bulkSelectServerKeys(
	requests map[gomatrixserverlib.PublicKeyRequest]gomatrixserverlib.Timestamp,
) (map[gomatrixserverlib.PublicKeyRequest]gomatrixserverlib.ServerKeys, error) {
	var nameAndKeyIDs []string
	for request := range requests {
		nameAndKeyIDs = append(nameAndKeyIDs, nameAndKeyID(request))
	}
	rows, err := s.bulkSelectServerKeysStmt.Query(pq.StringArray(nameAndKeyIDs))
	if err != nil {
		return nil, err
	}
E
Erik Johnston 已提交
86
	defer rows.Close() // nolint: errcheck
87 88 89 90 91 92 93 94 95 96 97 98 99
	results := map[gomatrixserverlib.PublicKeyRequest]gomatrixserverlib.ServerKeys{}
	for rows.Next() {
		var serverName string
		var keyID string
		var keyJSON []byte
		if err := rows.Scan(&serverName, &keyID, &keyJSON); err != nil {
			return nil, err
		}
		var serverKeys gomatrixserverlib.ServerKeys
		if err := json.Unmarshal(keyJSON, &serverKeys); err != nil {
			return nil, err
		}
		r := gomatrixserverlib.PublicKeyRequest{
E
Erik Johnston 已提交
100 101
			ServerName: gomatrixserverlib.ServerName(serverName),
			KeyID:      gomatrixserverlib.KeyID(keyID),
102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127
		}
		results[r] = serverKeys
	}
	return results, nil
}

func (s *serverKeyStatements) upsertServerKeys(
	request gomatrixserverlib.PublicKeyRequest, keys gomatrixserverlib.ServerKeys,
) error {
	keyJSON, err := json.Marshal(keys)
	if err != nil {
		return err
	}
	_, err = s.upsertServerKeysStmt.Exec(
		string(request.ServerName), string(request.KeyID), nameAndKeyID(request),
		int64(keys.ValidUntilTS), keyJSON,
	)
	if err != nil {
		return err
	}
	return nil
}

func nameAndKeyID(request gomatrixserverlib.PublicKeyRequest) string {
	return string(request.ServerName) + "\x1F" + string(request.KeyID)
}