提交 197d609b 编写于 作者: L lash 提交者: Anton Evangelatov

swarm/pss: Message handler refactor (#18169)

上级 ca228569
......@@ -81,14 +81,15 @@ func NewKadParams() *KadParams {
// Kademlia is a table of live peers and a db of known peers (node records)
type Kademlia struct {
lock sync.RWMutex
*KadParams // Kademlia configuration parameters
base []byte // immutable baseaddress of the table
addrs *pot.Pot // pots container for known peer addresses
conns *pot.Pot // pots container for live peer connections
depth uint8 // stores the last current depth of saturation
nDepth int // stores the last neighbourhood depth
nDepthC chan int // returned by DepthC function to signal neighbourhood depth change
addrCountC chan int // returned by AddrCountC function to signal peer count change
*KadParams // Kademlia configuration parameters
base []byte // immutable baseaddress of the table
addrs *pot.Pot // pots container for known peer addresses
conns *pot.Pot // pots container for live peer connections
depth uint8 // stores the last current depth of saturation
nDepth int // stores the last neighbourhood depth
nDepthC chan int // returned by DepthC function to signal neighbourhood depth change
addrCountC chan int // returned by AddrCountC function to signal peer count change
Pof func(pot.Val, pot.Val, int) (int, bool) // function for calculating kademlia routing distance between two addresses
}
// NewKademlia creates a Kademlia table for base address addr
......@@ -103,6 +104,7 @@ func NewKademlia(addr []byte, params *KadParams) *Kademlia {
KadParams: params,
addrs: pot.NewPot(nil, 0),
conns: pot.NewPot(nil, 0),
Pof: pof,
}
}
......@@ -289,6 +291,7 @@ func (k *Kademlia) On(p *Peer) (uint8, bool) {
// neighbourhood depth on each change.
// Not receiving from the returned channel will block On function
// when the neighbourhood depth is changed.
// TODO: Why is this exported, and if it should be; why can't we have more subscribers than one?
func (k *Kademlia) NeighbourhoodDepthC() <-chan int {
k.lock.Lock()
defer k.lock.Unlock()
......@@ -429,7 +432,12 @@ func (k *Kademlia) eachAddr(base []byte, o int, f func(*BzzAddr, int, bool) bool
// neighbourhoodDepth returns the proximity order that defines the distance of
// the nearest neighbour set with cardinality >= MinProxBinSize
// if there is altogether less than MinProxBinSize peers it returns 0
// caller must hold the lock
func (k *Kademlia) NeighbourhoodDepth() (depth int) {
k.lock.RLock()
defer k.lock.RUnlock()
return k.neighbourhoodDepth()
}
func (k *Kademlia) neighbourhoodDepth() (depth int) {
if k.conns.Size() < k.MinProxBinSize {
return 0
......
......@@ -51,7 +51,7 @@ func NewAPI(ps *Pss) *API {
//
// All incoming messages to the node matching this topic will be encapsulated in the APIMsg
// struct and sent to the subscriber
func (pssapi *API) Receive(ctx context.Context, topic Topic) (*rpc.Subscription, error) {
func (pssapi *API) Receive(ctx context.Context, topic Topic, raw bool, prox bool) (*rpc.Subscription, error) {
notifier, supported := rpc.NotifierFromContext(ctx)
if !supported {
return nil, fmt.Errorf("Subscribe not supported")
......@@ -59,7 +59,7 @@ func (pssapi *API) Receive(ctx context.Context, topic Topic) (*rpc.Subscription,
psssub := notifier.CreateSubscription()
handler := func(msg []byte, p *p2p.Peer, asymmetric bool, keyid string) error {
hndlr := NewHandler(func(msg []byte, p *p2p.Peer, asymmetric bool, keyid string) error {
apimsg := &APIMsg{
Msg: hexutil.Bytes(msg),
Asymmetric: asymmetric,
......@@ -69,9 +69,15 @@ func (pssapi *API) Receive(ctx context.Context, topic Topic) (*rpc.Subscription,
log.Warn(fmt.Sprintf("notification on pss sub topic rpc (sub %v) msg %v failed!", psssub.ID, msg))
}
return nil
})
if raw {
hndlr.caps.raw = true
}
if prox {
hndlr.caps.prox = true
}
deregf := pssapi.Register(&topic, handler)
deregf := pssapi.Register(&topic, hndlr)
go func() {
defer deregf()
select {
......
......@@ -236,7 +236,7 @@ func (c *Client) RunProtocol(ctx context.Context, proto *p2p.Protocol) error {
topichex := topicobj.String()
msgC := make(chan pss.APIMsg)
c.peerPool[topicobj] = make(map[string]*pssRPCRW)
sub, err := c.rpc.Subscribe(ctx, "pss", msgC, "receive", topichex)
sub, err := c.rpc.Subscribe(ctx, "pss", msgC, "receive", topichex, false, false)
if err != nil {
return fmt.Errorf("pss event subscription failed: %v", err)
}
......
......@@ -486,7 +486,7 @@ func (api *HandshakeAPI) Handshake(pubkeyid string, topic Topic, sync bool, flus
// Activate handshake functionality on a topic
func (api *HandshakeAPI) AddHandshake(topic Topic) error {
api.ctrl.deregisterFuncs[topic] = api.ctrl.pss.Register(&topic, api.ctrl.handler)
api.ctrl.deregisterFuncs[topic] = api.ctrl.pss.Register(&topic, NewHandler(api.ctrl.handler))
return nil
}
......
......@@ -113,7 +113,7 @@ func NewController(ps *pss.Pss) *Controller {
notifiers: make(map[string]*notifier),
subscriptions: make(map[string]*subscription),
}
ctrl.pss.Register(&controlTopic, ctrl.Handler)
ctrl.pss.Register(&controlTopic, pss.NewHandler(ctrl.Handler))
return ctrl
}
......@@ -336,7 +336,7 @@ func (c *Controller) handleNotifyWithKeyMsg(msg *Msg) error {
// \TODO keep track of and add actual address
updaterAddr := pss.PssAddress([]byte{})
c.pss.SetSymmetricKey(symkey, topic, &updaterAddr, true)
c.pss.Register(&topic, c.Handler)
c.pss.Register(&topic, pss.NewHandler(c.Handler))
return c.subscriptions[msg.namestring].handler(msg.namestring, msg.Payload[:len(msg.Payload)-symKeyLength])
}
......
......@@ -121,7 +121,7 @@ func TestStart(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), time.Second*2)
defer cancel()
rmsgC := make(chan *pss.APIMsg)
rightSub, err := rightRpc.Subscribe(ctx, "pss", rmsgC, "receive", controlTopic)
rightSub, err := rightRpc.Subscribe(ctx, "pss", rmsgC, "receive", controlTopic, false, false)
if err != nil {
t.Fatal(err)
}
......@@ -174,7 +174,7 @@ func TestStart(t *testing.T) {
t.Fatalf("expected payload length %d, have %d", len(updateMsg)+symKeyLength, len(dMsg.Payload))
}
rightSubUpdate, err := rightRpc.Subscribe(ctx, "pss", rmsgC, "receive", rsrcTopic)
rightSubUpdate, err := rightRpc.Subscribe(ctx, "pss", rmsgC, "receive", rsrcTopic, false, false)
if err != nil {
t.Fatal(err)
}
......
......@@ -92,7 +92,7 @@ func testProtocol(t *testing.T) {
lmsgC := make(chan APIMsg)
lctx, cancel := context.WithTimeout(context.Background(), time.Second*10)
defer cancel()
lsub, err := clients[0].Subscribe(lctx, "pss", lmsgC, "receive", topic)
lsub, err := clients[0].Subscribe(lctx, "pss", lmsgC, "receive", topic, false, false)
if err != nil {
t.Fatal(err)
}
......@@ -100,7 +100,7 @@ func testProtocol(t *testing.T) {
rmsgC := make(chan APIMsg)
rctx, cancel := context.WithTimeout(context.Background(), time.Second*10)
defer cancel()
rsub, err := clients[1].Subscribe(rctx, "pss", rmsgC, "receive", topic)
rsub, err := clients[1].Subscribe(rctx, "pss", rmsgC, "receive", topic, false, false)
if err != nil {
t.Fatal(err)
}
......@@ -130,6 +130,7 @@ func testProtocol(t *testing.T) {
log.Debug("lnode ok")
case cerr := <-lctx.Done():
t.Fatalf("test message timed out: %v", cerr)
return
}
select {
case <-rmsgC:
......
......@@ -23,11 +23,13 @@ import (
"crypto/rand"
"errors"
"fmt"
"hash"
"sync"
"time"
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/crypto"
"github.com/ethereum/go-ethereum/crypto/sha3"
"github.com/ethereum/go-ethereum/metrics"
"github.com/ethereum/go-ethereum/p2p"
"github.com/ethereum/go-ethereum/p2p/enode"
......@@ -136,10 +138,10 @@ type Pss struct {
symKeyDecryptCacheCapacity int // max amount of symkeys to keep.
// message handling
handlers map[Topic]map[*Handler]bool // topic and version based pss payload handlers. See pss.Handle()
handlersMu sync.RWMutex
allowRaw bool
hashPool sync.Pool
handlers map[Topic]map[*handler]bool // topic and version based pss payload handlers. See pss.Handle()
handlersMu sync.RWMutex
hashPool sync.Pool
topicHandlerCaps map[Topic]*handlerCaps // caches capabilities of each topic's handlers (see handlerCap* consts in types.go)
// process
quitC chan struct{}
......@@ -180,11 +182,12 @@ func NewPss(k *network.Kademlia, params *PssParams) (*Pss, error) {
symKeyDecryptCache: make([]*string, params.SymKeyCacheCapacity),
symKeyDecryptCacheCapacity: params.SymKeyCacheCapacity,
handlers: make(map[Topic]map[*Handler]bool),
allowRaw: params.AllowRaw,
handlers: make(map[Topic]map[*handler]bool),
topicHandlerCaps: make(map[Topic]*handlerCaps),
hashPool: sync.Pool{
New: func() interface{} {
return storage.MakeHashFunc(storage.DefaultHash)()
return sha3.NewKeccak256()
},
},
}
......@@ -313,30 +316,54 @@ func (p *Pss) PublicKey() *ecdsa.PublicKey {
//
// Returns a deregister function which needs to be called to
// deregister the handler,
func (p *Pss) Register(topic *Topic, handler Handler) func() {
func (p *Pss) Register(topic *Topic, hndlr *handler) func() {
p.handlersMu.Lock()
defer p.handlersMu.Unlock()
handlers := p.handlers[*topic]
if handlers == nil {
handlers = make(map[*Handler]bool)
handlers = make(map[*handler]bool)
p.handlers[*topic] = handlers
log.Debug("registered handler", "caps", hndlr.caps)
}
if hndlr.caps == nil {
hndlr.caps = &handlerCaps{}
}
handlers[hndlr] = true
if _, ok := p.topicHandlerCaps[*topic]; !ok {
p.topicHandlerCaps[*topic] = &handlerCaps{}
}
handlers[&handler] = true
return func() { p.deregister(topic, &handler) }
if hndlr.caps.raw {
p.topicHandlerCaps[*topic].raw = true
}
if hndlr.caps.prox {
p.topicHandlerCaps[*topic].prox = true
}
return func() { p.deregister(topic, hndlr) }
}
func (p *Pss) deregister(topic *Topic, h *Handler) {
func (p *Pss) deregister(topic *Topic, hndlr *handler) {
p.handlersMu.Lock()
defer p.handlersMu.Unlock()
handlers := p.handlers[*topic]
if len(handlers) == 1 {
if len(handlers) > 1 {
delete(p.handlers, *topic)
// topic caps might have changed now that a handler is gone
caps := &handlerCaps{}
for h := range handlers {
if h.caps.raw {
caps.raw = true
}
if h.caps.prox {
caps.prox = true
}
}
p.topicHandlerCaps[*topic] = caps
return
}
delete(handlers, h)
delete(handlers, hndlr)
}
// get all registered handlers for respective topics
func (p *Pss) getHandlers(topic Topic) map[*Handler]bool {
func (p *Pss) getHandlers(topic Topic) map[*handler]bool {
p.handlersMu.RLock()
defer p.handlersMu.RUnlock()
return p.handlers[topic]
......@@ -348,12 +375,11 @@ func (p *Pss) getHandlers(topic Topic) map[*Handler]bool {
// Only passes error to pss protocol handler if payload is not valid pssmsg
func (p *Pss) handlePssMsg(ctx context.Context, msg interface{}) error {
metrics.GetOrRegisterCounter("pss.handlepssmsg", nil).Inc(1)
pssmsg, ok := msg.(*PssMsg)
if !ok {
return fmt.Errorf("invalid message type. Expected *PssMsg, got %T ", msg)
}
log.Trace("handler", "self", label(p.Kademlia.BaseAddr()), "topic", label(pssmsg.Payload.Topic[:]))
if int64(pssmsg.Expire) < time.Now().Unix() {
metrics.GetOrRegisterCounter("pss.expire", nil).Inc(1)
log.Warn("pss filtered expired message", "from", common.ToHex(p.Kademlia.BaseAddr()), "to", common.ToHex(pssmsg.To))
......@@ -365,13 +391,34 @@ func (p *Pss) handlePssMsg(ctx context.Context, msg interface{}) error {
}
p.addFwdCache(pssmsg)
if !p.isSelfPossibleRecipient(pssmsg) {
log.Trace("pss was for someone else :'( ... forwarding", "pss", common.ToHex(p.BaseAddr()))
psstopic := Topic(pssmsg.Payload.Topic)
// raw is simplest handler contingency to check, so check that first
var isRaw bool
if pssmsg.isRaw() {
if !p.topicHandlerCaps[psstopic].raw {
log.Debug("No handler for raw message", "topic", psstopic)
return nil
}
isRaw = true
}
// check if we can be recipient:
// - no prox handler on message and partial address matches
// - prox handler on message and we are in prox regardless of partial address match
// store this result so we don't calculate again on every handler
var isProx bool
if _, ok := p.topicHandlerCaps[psstopic]; ok {
isProx = p.topicHandlerCaps[psstopic].prox
}
isRecipient := p.isSelfPossibleRecipient(pssmsg, isProx)
if !isRecipient {
log.Trace("pss was for someone else :'( ... forwarding", "pss", common.ToHex(p.BaseAddr()), "prox", isProx)
return p.enqueue(pssmsg)
}
log.Trace("pss for us, yay! ... let's process!", "pss", common.ToHex(p.BaseAddr()))
if err := p.process(pssmsg); err != nil {
log.Trace("pss for us, yay! ... let's process!", "pss", common.ToHex(p.BaseAddr()), "prox", isProx, "raw", isRaw, "topic", label(pssmsg.Payload.Topic[:]))
if err := p.process(pssmsg, isRaw, isProx); err != nil {
qerr := p.enqueue(pssmsg)
if qerr != nil {
return fmt.Errorf("process fail: processerr %v, queueerr: %v", err, qerr)
......@@ -384,7 +431,7 @@ func (p *Pss) handlePssMsg(ctx context.Context, msg interface{}) error {
// Entry point to processing a message for which the current node can be the intended recipient.
// Attempts symmetric and asymmetric decryption with stored keys.
// Dispatches message to all handlers matching the message topic
func (p *Pss) process(pssmsg *PssMsg) error {
func (p *Pss) process(pssmsg *PssMsg, raw bool, prox bool) error {
metrics.GetOrRegisterCounter("pss.process", nil).Inc(1)
var err error
......@@ -397,10 +444,8 @@ func (p *Pss) process(pssmsg *PssMsg) error {
envelope := pssmsg.Payload
psstopic := Topic(envelope.Topic)
if pssmsg.isRaw() {
if !p.allowRaw {
return errors.New("raw message support disabled")
}
if raw {
payload = pssmsg.Payload.Data
} else {
if pssmsg.isSym() {
......@@ -422,19 +467,27 @@ func (p *Pss) process(pssmsg *PssMsg) error {
return err
}
}
p.executeHandlers(psstopic, payload, from, asymmetric, keyid)
p.executeHandlers(psstopic, payload, from, raw, prox, asymmetric, keyid)
return nil
}
func (p *Pss) executeHandlers(topic Topic, payload []byte, from *PssAddress, asymmetric bool, keyid string) {
func (p *Pss) executeHandlers(topic Topic, payload []byte, from *PssAddress, raw bool, prox bool, asymmetric bool, keyid string) {
handlers := p.getHandlers(topic)
peer := p2p.NewPeer(enode.ID{}, fmt.Sprintf("%x", from), []p2p.Cap{})
for f := range handlers {
err := (*f)(payload, peer, asymmetric, keyid)
for h := range handlers {
if !h.caps.raw && raw {
log.Warn("norawhandler")
continue
}
if !h.caps.prox && prox {
log.Warn("noproxhandler")
continue
}
err := (h.f)(payload, peer, asymmetric, keyid)
if err != nil {
log.Warn("Pss handler %p failed: %v", f, err)
log.Warn("Pss handler failed", "err", err)
}
}
}
......@@ -445,9 +498,23 @@ func (p *Pss) isSelfRecipient(msg *PssMsg) bool {
}
// test match of leftmost bytes in given message to node's Kademlia address
func (p *Pss) isSelfPossibleRecipient(msg *PssMsg) bool {
func (p *Pss) isSelfPossibleRecipient(msg *PssMsg, prox bool) bool {
local := p.Kademlia.BaseAddr()
return bytes.Equal(msg.To, local[:len(msg.To)])
// if a partial address matches we are possible recipient regardless of prox
// if not and prox is not set, we are surely not
if bytes.Equal(msg.To, local[:len(msg.To)]) {
return true
} else if !prox {
return false
}
depth := p.Kademlia.NeighbourhoodDepth()
po, _ := p.Kademlia.Pof(p.Kademlia.BaseAddr(), msg.To, 0)
log.Trace("selfpossible", "po", po, "depth", depth)
return depth <= po
}
/////////////////////////////////////////////////////////////////////
......@@ -684,9 +751,6 @@ func (p *Pss) enqueue(msg *PssMsg) error {
//
// Will fail if raw messages are disallowed
func (p *Pss) SendRaw(address PssAddress, topic Topic, msg []byte) error {
if !p.allowRaw {
return errors.New("Raw messages not enabled")
}
pssMsgParams := &msgParams{
raw: true,
}
......@@ -699,7 +763,17 @@ func (p *Pss) SendRaw(address PssAddress, topic Topic, msg []byte) error {
pssMsg.Expire = uint32(time.Now().Add(p.msgTTL).Unix())
pssMsg.Payload = payload
p.addFwdCache(pssMsg)
return p.enqueue(pssMsg)
err := p.enqueue(pssMsg)
if err != nil {
return err
}
// if we have a proxhandler on this topic
// also deliver message to ourselves
if p.isSelfPossibleRecipient(pssMsg, true) && p.topicHandlerCaps[topic].prox {
return p.process(pssMsg, true, true)
}
return nil
}
// Send a message using symmetric encryption
......@@ -800,7 +874,16 @@ func (p *Pss) send(to []byte, topic Topic, msg []byte, asymmetric bool, key []by
pssMsg.To = to
pssMsg.Expire = uint32(time.Now().Add(p.msgTTL).Unix())
pssMsg.Payload = envelope
return p.enqueue(pssMsg)
err = p.enqueue(pssMsg)
if err != nil {
return err
}
if _, ok := p.topicHandlerCaps[topic]; ok {
if p.isSelfPossibleRecipient(pssMsg, true) && p.topicHandlerCaps[topic].prox {
return p.process(pssMsg, true, true)
}
}
return nil
}
// Forwards a pss message to the peer(s) closest to the to recipient address in the PssMsg struct
......@@ -895,6 +978,10 @@ func (p *Pss) cleanFwdCache() {
}
}
func label(b []byte) string {
return fmt.Sprintf("%04x", b[:2])
}
// add a message to the cache
func (p *Pss) addFwdCache(msg *PssMsg) error {
metrics.GetOrRegisterCounter("pss.addfwdcache", nil).Inc(1)
......@@ -934,10 +1021,14 @@ func (p *Pss) checkFwdCache(msg *PssMsg) bool {
// Digest of message
func (p *Pss) digest(msg *PssMsg) pssDigest {
hasher := p.hashPool.Get().(storage.SwarmHash)
return p.digestBytes(msg.serialize())
}
func (p *Pss) digestBytes(msg []byte) pssDigest {
hasher := p.hashPool.Get().(hash.Hash)
defer p.hashPool.Put(hasher)
hasher.Reset()
hasher.Write(msg.serialize())
hasher.Write(msg)
digest := pssDigest{}
key := hasher.Sum(nil)
copy(digest[:], key[:digestLength])
......
此差异已折叠。
......@@ -159,9 +159,39 @@ func (msg *PssMsg) String() string {
}
// Signature for a message handler function for a PssMsg
//
// Implementations of this type are passed to Pss.Register together with a topic,
type Handler func(msg []byte, p *p2p.Peer, asymmetric bool, keyid string) error
type HandlerFunc func(msg []byte, p *p2p.Peer, asymmetric bool, keyid string) error
type handlerCaps struct {
raw bool
prox bool
}
// Handler defines code to be executed upon reception of content.
type handler struct {
f HandlerFunc
caps *handlerCaps
}
// NewHandler returns a new message handler
func NewHandler(f HandlerFunc) *handler {
return &handler{
f: f,
caps: &handlerCaps{},
}
}
// WithRaw is a chainable method that allows raw messages to be handled.
func (h *handler) WithRaw() *handler {
h.caps.raw = true
return h
}
// WithProxBin is a chainable method that allows sending messages with full addresses to neighbourhoods using the kademlia depth as reference
func (h *handler) WithProxBin() *handler {
h.caps.prox = true
return h
}
// the stateStore handles saving and loading PSS peers and their corresponding keys
// it is currently unimplemented
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册