diff --git a/p2p/dial.go b/p2p/dial.go new file mode 100644 index 0000000000000000000000000000000000000000..71065c5eedc3d95c516fd378defe46f1833d96cf --- /dev/null +++ b/p2p/dial.go @@ -0,0 +1,276 @@ +package p2p + +import ( + "container/heap" + "crypto/rand" + "fmt" + "net" + "time" + + "github.com/ethereum/go-ethereum/logger" + "github.com/ethereum/go-ethereum/logger/glog" + "github.com/ethereum/go-ethereum/p2p/discover" +) + +const ( + // This is the amount of time spent waiting in between + // redialing a certain node. + dialHistoryExpiration = 30 * time.Second + + // Discovery lookup tasks will wait for this long when + // no results are returned. This can happen if the table + // becomes empty (i.e. not often). + emptyLookupDelay = 10 * time.Second +) + +// dialstate schedules dials and discovery lookups. +// it get's a chance to compute new tasks on every iteration +// of the main loop in Server.run. +type dialstate struct { + maxDynDials int + ntab discoverTable + + lookupRunning bool + bootstrapped bool + + dialing map[discover.NodeID]connFlag + lookupBuf []*discover.Node // current discovery lookup results + randomNodes []*discover.Node // filled from Table + static map[discover.NodeID]*discover.Node + hist *dialHistory +} + +type discoverTable interface { + Self() *discover.Node + Close() + Bootstrap([]*discover.Node) + Lookup(target discover.NodeID) []*discover.Node + ReadRandomNodes([]*discover.Node) int +} + +// the dial history remembers recent dials. +type dialHistory []pastDial + +// pastDial is an entry in the dial history. +type pastDial struct { + id discover.NodeID + exp time.Time +} + +type task interface { + Do(*Server) +} + +// A dialTask is generated for each node that is dialed. +type dialTask struct { + flags connFlag + dest *discover.Node +} + +// discoverTask runs discovery table operations. +// Only one discoverTask is active at any time. +// +// If bootstrap is true, the task runs Table.Bootstrap, +// otherwise it performs a random lookup and leaves the +// results in the task. +type discoverTask struct { + bootstrap bool + results []*discover.Node +} + +// A waitExpireTask is generated if there are no other tasks +// to keep the loop in Server.run ticking. +type waitExpireTask struct { + time.Duration +} + +func newDialState(static []*discover.Node, ntab discoverTable, maxdyn int) *dialstate { + s := &dialstate{ + maxDynDials: maxdyn, + ntab: ntab, + static: make(map[discover.NodeID]*discover.Node), + dialing: make(map[discover.NodeID]connFlag), + randomNodes: make([]*discover.Node, maxdyn/2), + hist: new(dialHistory), + } + for _, n := range static { + s.static[n.ID] = n + } + return s +} + +func (s *dialstate) addStatic(n *discover.Node) { + s.static[n.ID] = n +} + +func (s *dialstate) newTasks(nRunning int, peers map[discover.NodeID]*Peer, now time.Time) []task { + var newtasks []task + addDial := func(flag connFlag, n *discover.Node) bool { + _, dialing := s.dialing[n.ID] + if dialing || peers[n.ID] != nil || s.hist.contains(n.ID) { + return false + } + s.dialing[n.ID] = flag + newtasks = append(newtasks, &dialTask{flags: flag, dest: n}) + return true + } + + // Compute number of dynamic dials necessary at this point. + needDynDials := s.maxDynDials + for _, p := range peers { + if p.rw.is(dynDialedConn) { + needDynDials-- + } + } + for _, flag := range s.dialing { + if flag&dynDialedConn != 0 { + needDynDials-- + } + } + + // Expire the dial history on every invocation. + s.hist.expire(now) + + // Create dials for static nodes if they are not connected. + for _, n := range s.static { + addDial(staticDialedConn, n) + } + + // Use random nodes from the table for half of the necessary + // dynamic dials. + randomCandidates := needDynDials / 2 + if randomCandidates > 0 && s.bootstrapped { + n := s.ntab.ReadRandomNodes(s.randomNodes) + for i := 0; i < randomCandidates && i < n; i++ { + if addDial(dynDialedConn, s.randomNodes[i]) { + needDynDials-- + } + } + } + // Create dynamic dials from random lookup results, removing tried + // items from the result buffer. + i := 0 + for ; i < len(s.lookupBuf) && needDynDials > 0; i++ { + if addDial(dynDialedConn, s.lookupBuf[i]) { + needDynDials-- + } + } + s.lookupBuf = s.lookupBuf[:copy(s.lookupBuf, s.lookupBuf[i:])] + // Launch a discovery lookup if more candidates are needed. The + // first discoverTask bootstraps the table and won't return any + // results. + if len(s.lookupBuf) < needDynDials && !s.lookupRunning { + s.lookupRunning = true + newtasks = append(newtasks, &discoverTask{bootstrap: !s.bootstrapped}) + } + + // Launch a timer to wait for the next node to expire if all + // candidates have been tried and no task is currently active. + // This should prevent cases where the dialer logic is not ticked + // because there are no pending events. + if nRunning == 0 && len(newtasks) == 0 && s.hist.Len() > 0 { + t := &waitExpireTask{s.hist.min().exp.Sub(now)} + newtasks = append(newtasks, t) + } + return newtasks +} + +func (s *dialstate) taskDone(t task, now time.Time) { + switch t := t.(type) { + case *dialTask: + s.hist.add(t.dest.ID, now.Add(dialHistoryExpiration)) + delete(s.dialing, t.dest.ID) + case *discoverTask: + if t.bootstrap { + s.bootstrapped = true + } + s.lookupRunning = false + s.lookupBuf = append(s.lookupBuf, t.results...) + } +} + +func (t *dialTask) Do(srv *Server) { + addr := &net.TCPAddr{IP: t.dest.IP, Port: int(t.dest.TCP)} + glog.V(logger.Debug).Infof("dialing %v\n", t.dest) + fd, err := srv.Dialer.Dial("tcp", addr.String()) + if err != nil { + glog.V(logger.Detail).Infof("dial error: %v", err) + return + } + srv.setupConn(fd, t.flags, t.dest) +} +func (t *dialTask) String() string { + return fmt.Sprintf("%v %x %v:%d", t.flags, t.dest.ID[:8], t.dest.IP, t.dest.TCP) +} + +func (t *discoverTask) Do(srv *Server) { + if t.bootstrap { + srv.ntab.Bootstrap(srv.BootstrapNodes) + } else { + var target discover.NodeID + rand.Read(target[:]) + t.results = srv.ntab.Lookup(target) + // newTasks generates a lookup task whenever dynamic dials are + // necessary. Lookups need to take some time, otherwise the + // event loop spins too fast. An empty result can only be + // returned if the table is empty. + if len(t.results) == 0 { + time.Sleep(emptyLookupDelay) + } + } +} + +func (t *discoverTask) String() (s string) { + if t.bootstrap { + s = "discovery bootstrap" + } else { + s = "discovery lookup" + } + if len(t.results) > 0 { + s += fmt.Sprintf(" (%d results)", len(t.results)) + } + return s +} + +func (t waitExpireTask) Do(*Server) { + time.Sleep(t.Duration) +} +func (t waitExpireTask) String() string { + return fmt.Sprintf("wait for dial hist expire (%v)", t.Duration) +} + +// Use only these methods to access or modify dialHistory. +func (h dialHistory) min() pastDial { + return h[0] +} +func (h *dialHistory) add(id discover.NodeID, exp time.Time) { + heap.Push(h, pastDial{id, exp}) +} +func (h dialHistory) contains(id discover.NodeID) bool { + for _, v := range h { + if v.id == id { + return true + } + } + return false +} +func (h *dialHistory) expire(now time.Time) { + for h.Len() > 0 && h.min().exp.Before(now) { + heap.Pop(h) + } +} + +// heap.Interface boilerplate +func (h dialHistory) Len() int { return len(h) } +func (h dialHistory) Less(i, j int) bool { return h[i].exp.Before(h[j].exp) } +func (h dialHistory) Swap(i, j int) { h[i], h[j] = h[j], h[i] } +func (h *dialHistory) Push(x interface{}) { + *h = append(*h, x.(pastDial)) +} +func (h *dialHistory) Pop() interface{} { + old := *h + n := len(old) + x := old[n-1] + *h = old[0 : n-1] + return x +} diff --git a/p2p/dial_test.go b/p2p/dial_test.go new file mode 100644 index 0000000000000000000000000000000000000000..78568c5edd9441257a4e3e818d005cf6e83503e4 --- /dev/null +++ b/p2p/dial_test.go @@ -0,0 +1,482 @@ +package p2p + +import ( + "encoding/binary" + "reflect" + "testing" + "time" + + "github.com/davecgh/go-spew/spew" + "github.com/ethereum/go-ethereum/p2p/discover" +) + +func init() { + spew.Config.Indent = "\t" +} + +type dialtest struct { + init *dialstate // state before and after the test. + rounds []round +} + +type round struct { + peers []*Peer // current peer set + done []task // tasks that got done this round + new []task // the result must match this one +} + +func runDialTest(t *testing.T, test dialtest) { + var ( + vtime time.Time + running int + ) + pm := func(ps []*Peer) map[discover.NodeID]*Peer { + m := make(map[discover.NodeID]*Peer) + for _, p := range ps { + m[p.rw.id] = p + } + return m + } + for i, round := range test.rounds { + for _, task := range round.done { + running-- + if running < 0 { + panic("running task counter underflow") + } + test.init.taskDone(task, vtime) + } + + new := test.init.newTasks(running, pm(round.peers), vtime) + if !sametasks(new, round.new) { + t.Errorf("round %d: new tasks mismatch:\ngot %v\nwant %v\nstate: %v\nrunning: %v\n", + i, spew.Sdump(new), spew.Sdump(round.new), spew.Sdump(test.init), spew.Sdump(running)) + } + + // Time advances by 16 seconds on every round. + vtime = vtime.Add(16 * time.Second) + running += len(new) + } +} + +type fakeTable []*discover.Node + +func (t fakeTable) Self() *discover.Node { return new(discover.Node) } +func (t fakeTable) Close() {} +func (t fakeTable) Bootstrap([]*discover.Node) {} +func (t fakeTable) Lookup(target discover.NodeID) []*discover.Node { + return nil +} +func (t fakeTable) ReadRandomNodes(buf []*discover.Node) int { + return copy(buf, t) +} + +// This test checks that dynamic dials are launched from discovery results. +func TestDialStateDynDial(t *testing.T) { + runDialTest(t, dialtest{ + init: newDialState(nil, fakeTable{}, 5), + rounds: []round{ + // A discovery query is launched. + { + peers: []*Peer{ + {rw: &conn{flags: staticDialedConn, id: uintID(0)}}, + {rw: &conn{flags: dynDialedConn, id: uintID(1)}}, + {rw: &conn{flags: dynDialedConn, id: uintID(2)}}, + }, + new: []task{&discoverTask{bootstrap: true}}, + }, + // Dynamic dials are launched when it completes. + { + peers: []*Peer{ + {rw: &conn{flags: staticDialedConn, id: uintID(0)}}, + {rw: &conn{flags: dynDialedConn, id: uintID(1)}}, + {rw: &conn{flags: dynDialedConn, id: uintID(2)}}, + }, + done: []task{ + &discoverTask{bootstrap: true, results: []*discover.Node{ + {ID: uintID(2)}, // this one is already connected and not dialed. + {ID: uintID(3)}, + {ID: uintID(4)}, + {ID: uintID(5)}, + {ID: uintID(6)}, // these are not tried because max dyn dials is 5 + {ID: uintID(7)}, // ... + }}, + }, + new: []task{ + &dialTask{dynDialedConn, &discover.Node{ID: uintID(3)}}, + &dialTask{dynDialedConn, &discover.Node{ID: uintID(4)}}, + &dialTask{dynDialedConn, &discover.Node{ID: uintID(5)}}, + }, + }, + // Some of the dials complete but no new ones are launched yet because + // the sum of active dial count and dynamic peer count is == maxDynDials. + { + peers: []*Peer{ + {rw: &conn{flags: staticDialedConn, id: uintID(0)}}, + {rw: &conn{flags: dynDialedConn, id: uintID(1)}}, + {rw: &conn{flags: dynDialedConn, id: uintID(2)}}, + {rw: &conn{flags: dynDialedConn, id: uintID(3)}}, + {rw: &conn{flags: dynDialedConn, id: uintID(4)}}, + }, + done: []task{ + &dialTask{dynDialedConn, &discover.Node{ID: uintID(3)}}, + &dialTask{dynDialedConn, &discover.Node{ID: uintID(4)}}, + }, + }, + // No new dial tasks are launched in the this round because + // maxDynDials has been reached. + { + peers: []*Peer{ + {rw: &conn{flags: staticDialedConn, id: uintID(0)}}, + {rw: &conn{flags: dynDialedConn, id: uintID(1)}}, + {rw: &conn{flags: dynDialedConn, id: uintID(2)}}, + {rw: &conn{flags: dynDialedConn, id: uintID(3)}}, + {rw: &conn{flags: dynDialedConn, id: uintID(4)}}, + {rw: &conn{flags: dynDialedConn, id: uintID(5)}}, + }, + done: []task{ + &dialTask{dynDialedConn, &discover.Node{ID: uintID(5)}}, + }, + new: []task{ + &waitExpireTask{Duration: 14 * time.Second}, + }, + }, + // In this round, the peer with id 2 drops off. The query + // results from last discovery lookup are reused. + { + peers: []*Peer{ + {rw: &conn{flags: staticDialedConn, id: uintID(0)}}, + {rw: &conn{flags: dynDialedConn, id: uintID(1)}}, + {rw: &conn{flags: dynDialedConn, id: uintID(3)}}, + {rw: &conn{flags: dynDialedConn, id: uintID(4)}}, + {rw: &conn{flags: dynDialedConn, id: uintID(5)}}, + }, + new: []task{ + &dialTask{dynDialedConn, &discover.Node{ID: uintID(6)}}, + }, + }, + // More peers (3,4) drop off and dial for ID 6 completes. + // The last query result from the discovery lookup is reused + // and a new one is spawned because more candidates are needed. + { + peers: []*Peer{ + {rw: &conn{flags: staticDialedConn, id: uintID(0)}}, + {rw: &conn{flags: dynDialedConn, id: uintID(1)}}, + {rw: &conn{flags: dynDialedConn, id: uintID(5)}}, + }, + done: []task{ + &dialTask{dynDialedConn, &discover.Node{ID: uintID(6)}}, + }, + new: []task{ + &dialTask{dynDialedConn, &discover.Node{ID: uintID(7)}}, + &discoverTask{}, + }, + }, + // Peer 7 is connected, but there still aren't enough dynamic peers + // (4 out of 5). However, a discovery is already running, so ensure + // no new is started. + { + peers: []*Peer{ + {rw: &conn{flags: staticDialedConn, id: uintID(0)}}, + {rw: &conn{flags: dynDialedConn, id: uintID(1)}}, + {rw: &conn{flags: dynDialedConn, id: uintID(5)}}, + {rw: &conn{flags: dynDialedConn, id: uintID(7)}}, + }, + done: []task{ + &dialTask{dynDialedConn, &discover.Node{ID: uintID(7)}}, + }, + }, + // Finish the running node discovery with an empty set. A new lookup + // should be immediately requested. + { + peers: []*Peer{ + {rw: &conn{flags: staticDialedConn, id: uintID(0)}}, + {rw: &conn{flags: dynDialedConn, id: uintID(1)}}, + {rw: &conn{flags: dynDialedConn, id: uintID(5)}}, + {rw: &conn{flags: dynDialedConn, id: uintID(7)}}, + }, + done: []task{ + &discoverTask{}, + }, + new: []task{ + &discoverTask{}, + }, + }, + }, + }) +} + +func TestDialStateDynDialFromTable(t *testing.T) { + // This table always returns the same random nodes + // in the order given below. + table := fakeTable{ + {ID: uintID(1)}, + {ID: uintID(2)}, + {ID: uintID(3)}, + {ID: uintID(4)}, + {ID: uintID(5)}, + {ID: uintID(6)}, + {ID: uintID(7)}, + {ID: uintID(8)}, + } + + runDialTest(t, dialtest{ + init: newDialState(nil, table, 10), + rounds: []round{ + // Discovery bootstrap is launched. + { + new: []task{&discoverTask{bootstrap: true}}, + }, + // 5 out of 8 of the nodes returned by ReadRandomNodes are dialed. + { + done: []task{ + &discoverTask{bootstrap: true}, + }, + new: []task{ + &dialTask{dynDialedConn, &discover.Node{ID: uintID(1)}}, + &dialTask{dynDialedConn, &discover.Node{ID: uintID(2)}}, + &dialTask{dynDialedConn, &discover.Node{ID: uintID(3)}}, + &dialTask{dynDialedConn, &discover.Node{ID: uintID(4)}}, + &dialTask{dynDialedConn, &discover.Node{ID: uintID(5)}}, + &discoverTask{bootstrap: false}, + }, + }, + // Dialing nodes 1,2 succeeds. Dials from the lookup are launched. + { + peers: []*Peer{ + {rw: &conn{flags: dynDialedConn, id: uintID(1)}}, + {rw: &conn{flags: dynDialedConn, id: uintID(2)}}, + }, + done: []task{ + &dialTask{dynDialedConn, &discover.Node{ID: uintID(1)}}, + &dialTask{dynDialedConn, &discover.Node{ID: uintID(2)}}, + &discoverTask{results: []*discover.Node{ + {ID: uintID(10)}, + {ID: uintID(11)}, + {ID: uintID(12)}, + }}, + }, + new: []task{ + &dialTask{dynDialedConn, &discover.Node{ID: uintID(10)}}, + &dialTask{dynDialedConn, &discover.Node{ID: uintID(11)}}, + &dialTask{dynDialedConn, &discover.Node{ID: uintID(12)}}, + &discoverTask{bootstrap: false}, + }, + }, + // Dialing nodes 3,4,5 fails. The dials from the lookup succeed. + { + peers: []*Peer{ + {rw: &conn{flags: dynDialedConn, id: uintID(1)}}, + {rw: &conn{flags: dynDialedConn, id: uintID(2)}}, + {rw: &conn{flags: dynDialedConn, id: uintID(10)}}, + {rw: &conn{flags: dynDialedConn, id: uintID(11)}}, + {rw: &conn{flags: dynDialedConn, id: uintID(12)}}, + }, + done: []task{ + &dialTask{dynDialedConn, &discover.Node{ID: uintID(3)}}, + &dialTask{dynDialedConn, &discover.Node{ID: uintID(4)}}, + &dialTask{dynDialedConn, &discover.Node{ID: uintID(5)}}, + &dialTask{dynDialedConn, &discover.Node{ID: uintID(10)}}, + &dialTask{dynDialedConn, &discover.Node{ID: uintID(11)}}, + &dialTask{dynDialedConn, &discover.Node{ID: uintID(12)}}, + }, + }, + // Waiting for expiry. No waitExpireTask is launched because the + // discovery query is still running. + { + peers: []*Peer{ + {rw: &conn{flags: dynDialedConn, id: uintID(1)}}, + {rw: &conn{flags: dynDialedConn, id: uintID(2)}}, + {rw: &conn{flags: dynDialedConn, id: uintID(10)}}, + {rw: &conn{flags: dynDialedConn, id: uintID(11)}}, + {rw: &conn{flags: dynDialedConn, id: uintID(12)}}, + }, + }, + // Nodes 3,4 are not tried again because only the first two + // returned random nodes (nodes 1,2) are tried and they're + // already connected. + { + peers: []*Peer{ + {rw: &conn{flags: dynDialedConn, id: uintID(1)}}, + {rw: &conn{flags: dynDialedConn, id: uintID(2)}}, + {rw: &conn{flags: dynDialedConn, id: uintID(10)}}, + {rw: &conn{flags: dynDialedConn, id: uintID(11)}}, + {rw: &conn{flags: dynDialedConn, id: uintID(12)}}, + }, + }, + }, + }) +} + +// This test checks that static dials are launched. +func TestDialStateStaticDial(t *testing.T) { + wantStatic := []*discover.Node{ + {ID: uintID(1)}, + {ID: uintID(2)}, + {ID: uintID(3)}, + {ID: uintID(4)}, + {ID: uintID(5)}, + } + + runDialTest(t, dialtest{ + init: newDialState(wantStatic, fakeTable{}, 0), + rounds: []round{ + // Static dials are launched for the nodes that + // aren't yet connected. + { + peers: []*Peer{ + {rw: &conn{flags: dynDialedConn, id: uintID(1)}}, + {rw: &conn{flags: dynDialedConn, id: uintID(2)}}, + }, + new: []task{ + &dialTask{staticDialedConn, &discover.Node{ID: uintID(3)}}, + &dialTask{staticDialedConn, &discover.Node{ID: uintID(4)}}, + &dialTask{staticDialedConn, &discover.Node{ID: uintID(5)}}, + }, + }, + // No new tasks are launched in this round because all static + // nodes are either connected or still being dialed. + { + peers: []*Peer{ + {rw: &conn{flags: dynDialedConn, id: uintID(1)}}, + {rw: &conn{flags: dynDialedConn, id: uintID(2)}}, + {rw: &conn{flags: staticDialedConn, id: uintID(3)}}, + }, + done: []task{ + &dialTask{staticDialedConn, &discover.Node{ID: uintID(3)}}, + }, + }, + // No new dial tasks are launched because all static + // nodes are now connected. + { + peers: []*Peer{ + {rw: &conn{flags: dynDialedConn, id: uintID(1)}}, + {rw: &conn{flags: dynDialedConn, id: uintID(2)}}, + {rw: &conn{flags: staticDialedConn, id: uintID(3)}}, + {rw: &conn{flags: staticDialedConn, id: uintID(4)}}, + {rw: &conn{flags: staticDialedConn, id: uintID(5)}}, + }, + done: []task{ + &dialTask{staticDialedConn, &discover.Node{ID: uintID(4)}}, + &dialTask{staticDialedConn, &discover.Node{ID: uintID(5)}}, + }, + new: []task{ + &waitExpireTask{Duration: 14 * time.Second}, + }, + }, + // Wait a round for dial history to expire, no new tasks should spawn. + { + peers: []*Peer{ + {rw: &conn{flags: dynDialedConn, id: uintID(1)}}, + {rw: &conn{flags: dynDialedConn, id: uintID(2)}}, + {rw: &conn{flags: staticDialedConn, id: uintID(3)}}, + {rw: &conn{flags: staticDialedConn, id: uintID(4)}}, + {rw: &conn{flags: staticDialedConn, id: uintID(5)}}, + }, + }, + // If a static node is dropped, it should be immediately redialed, + // irrespective whether it was originally static or dynamic. + { + peers: []*Peer{ + {rw: &conn{flags: dynDialedConn, id: uintID(1)}}, + {rw: &conn{flags: staticDialedConn, id: uintID(3)}}, + {rw: &conn{flags: staticDialedConn, id: uintID(5)}}, + }, + new: []task{ + &dialTask{staticDialedConn, &discover.Node{ID: uintID(2)}}, + &dialTask{staticDialedConn, &discover.Node{ID: uintID(4)}}, + }, + }, + }, + }) +} + +// This test checks that past dials are not retried for some time. +func TestDialStateCache(t *testing.T) { + wantStatic := []*discover.Node{ + {ID: uintID(1)}, + {ID: uintID(2)}, + {ID: uintID(3)}, + } + + runDialTest(t, dialtest{ + init: newDialState(wantStatic, fakeTable{}, 0), + rounds: []round{ + // Static dials are launched for the nodes that + // aren't yet connected. + { + peers: nil, + new: []task{ + &dialTask{staticDialedConn, &discover.Node{ID: uintID(1)}}, + &dialTask{staticDialedConn, &discover.Node{ID: uintID(2)}}, + &dialTask{staticDialedConn, &discover.Node{ID: uintID(3)}}, + }, + }, + // No new tasks are launched in this round because all static + // nodes are either connected or still being dialed. + { + peers: []*Peer{ + {rw: &conn{flags: staticDialedConn, id: uintID(1)}}, + {rw: &conn{flags: staticDialedConn, id: uintID(2)}}, + }, + done: []task{ + &dialTask{staticDialedConn, &discover.Node{ID: uintID(1)}}, + &dialTask{staticDialedConn, &discover.Node{ID: uintID(2)}}, + }, + }, + // A salvage task is launched to wait for node 3's history + // entry to expire. + { + peers: []*Peer{ + {rw: &conn{flags: dynDialedConn, id: uintID(1)}}, + {rw: &conn{flags: dynDialedConn, id: uintID(2)}}, + }, + done: []task{ + &dialTask{staticDialedConn, &discover.Node{ID: uintID(3)}}, + }, + new: []task{ + &waitExpireTask{Duration: 14 * time.Second}, + }, + }, + // Still waiting for node 3's entry to expire in the cache. + { + peers: []*Peer{ + {rw: &conn{flags: dynDialedConn, id: uintID(1)}}, + {rw: &conn{flags: dynDialedConn, id: uintID(2)}}, + }, + }, + // The cache entry for node 3 has expired and is retried. + { + peers: []*Peer{ + {rw: &conn{flags: dynDialedConn, id: uintID(1)}}, + {rw: &conn{flags: dynDialedConn, id: uintID(2)}}, + }, + new: []task{ + &dialTask{staticDialedConn, &discover.Node{ID: uintID(3)}}, + }, + }, + }, + }) +} + +// compares task lists but doesn't care about the order. +func sametasks(a, b []task) bool { + if len(a) != len(b) { + return false + } +next: + for _, ta := range a { + for _, tb := range b { + if reflect.DeepEqual(ta, tb) { + continue next + } + } + return false + } + return true +} + +func uintID(i uint32) discover.NodeID { + var id discover.NodeID + binary.BigEndian.PutUint32(id[:], i) + return id +} diff --git a/p2p/handshake.go b/p2p/handshake.go deleted file mode 100644 index 4cdcee6d4dc49d29fbcde3ff3e9845d70961731d..0000000000000000000000000000000000000000 --- a/p2p/handshake.go +++ /dev/null @@ -1,448 +0,0 @@ -package p2p - -import ( - "crypto/ecdsa" - "crypto/elliptic" - "crypto/rand" - "errors" - "fmt" - "hash" - "io" - "net" - - "github.com/ethereum/go-ethereum/crypto" - "github.com/ethereum/go-ethereum/crypto/ecies" - "github.com/ethereum/go-ethereum/crypto/secp256k1" - "github.com/ethereum/go-ethereum/crypto/sha3" - "github.com/ethereum/go-ethereum/p2p/discover" - "github.com/ethereum/go-ethereum/rlp" -) - -const ( - sskLen = 16 // ecies.MaxSharedKeyLength(pubKey) / 2 - sigLen = 65 // elliptic S256 - pubLen = 64 // 512 bit pubkey in uncompressed representation without format byte - shaLen = 32 // hash length (for nonce etc) - - authMsgLen = sigLen + shaLen + pubLen + shaLen + 1 - authRespLen = pubLen + shaLen + 1 - - eciesBytes = 65 + 16 + 32 - encAuthMsgLen = authMsgLen + eciesBytes // size of the final ECIES payload sent as initiator's handshake - encAuthRespLen = authRespLen + eciesBytes // size of the final ECIES payload sent as receiver's handshake -) - -// conn represents a remote connection after encryption handshake -// and protocol handshake have completed. -// -// The MsgReadWriter is usually layered as follows: -// -// netWrapper (I/O timeouts, thread-safe ReadMsg, WriteMsg) -// rlpxFrameRW (message encoding, encryption, authentication) -// bufio.ReadWriter (buffering) -// net.Conn (network I/O) -// -type conn struct { - MsgReadWriter - *protoHandshake -} - -// secrets represents the connection secrets -// which are negotiated during the encryption handshake. -type secrets struct { - RemoteID discover.NodeID - AES, MAC []byte - EgressMAC, IngressMAC hash.Hash - Token []byte -} - -// protoHandshake is the RLP structure of the protocol handshake. -type protoHandshake struct { - Version uint64 - Name string - Caps []Cap - ListenPort uint64 - ID discover.NodeID -} - -// setupConn starts a protocol session on the given connection. It -// runs the encryption handshake and the protocol handshake. If dial -// is non-nil, the connection the local node is the initiator. If -// keepconn returns false, the connection will be disconnected with -// DiscTooManyPeers after the key exchange. -func setupConn(fd net.Conn, prv *ecdsa.PrivateKey, our *protoHandshake, dial *discover.Node, keepconn func(discover.NodeID) bool) (*conn, error) { - if dial == nil { - return setupInboundConn(fd, prv, our, keepconn) - } else { - return setupOutboundConn(fd, prv, our, dial, keepconn) - } -} - -func setupInboundConn(fd net.Conn, prv *ecdsa.PrivateKey, our *protoHandshake, keepconn func(discover.NodeID) bool) (*conn, error) { - secrets, err := receiverEncHandshake(fd, prv, nil) - if err != nil { - return nil, fmt.Errorf("encryption handshake failed: %v", err) - } - rw := newRlpxFrameRW(fd, secrets) - if !keepconn(secrets.RemoteID) { - SendItems(rw, discMsg, DiscTooManyPeers) - return nil, errors.New("we have too many peers") - } - // Run the protocol handshake using authenticated messages. - rhs, err := readProtocolHandshake(rw, secrets.RemoteID, our) - if err != nil { - return nil, err - } - if err := Send(rw, handshakeMsg, our); err != nil { - return nil, fmt.Errorf("protocol handshake write error: %v", err) - } - return &conn{rw, rhs}, nil -} - -func setupOutboundConn(fd net.Conn, prv *ecdsa.PrivateKey, our *protoHandshake, dial *discover.Node, keepconn func(discover.NodeID) bool) (*conn, error) { - secrets, err := initiatorEncHandshake(fd, prv, dial.ID, nil) - if err != nil { - return nil, fmt.Errorf("encryption handshake failed: %v", err) - } - rw := newRlpxFrameRW(fd, secrets) - if !keepconn(secrets.RemoteID) { - SendItems(rw, discMsg, DiscTooManyPeers) - return nil, errors.New("we have too many peers") - } - // Run the protocol handshake using authenticated messages. - // - // Note that even though writing the handshake is first, we prefer - // returning the handshake read error. If the remote side - // disconnects us early with a valid reason, we should return it - // as the error so it can be tracked elsewhere. - werr := make(chan error, 1) - go func() { werr <- Send(rw, handshakeMsg, our) }() - rhs, err := readProtocolHandshake(rw, secrets.RemoteID, our) - if err != nil { - return nil, err - } - if err := <-werr; err != nil { - return nil, fmt.Errorf("protocol handshake write error: %v", err) - } - if rhs.ID != dial.ID { - return nil, errors.New("dialed node id mismatch") - } - return &conn{rw, rhs}, nil -} - -// encHandshake contains the state of the encryption handshake. -type encHandshake struct { - initiator bool - remoteID discover.NodeID - - remotePub *ecies.PublicKey // remote-pubk - initNonce, respNonce []byte // nonce - randomPrivKey *ecies.PrivateKey // ecdhe-random - remoteRandomPub *ecies.PublicKey // ecdhe-random-pubk -} - -// secrets is called after the handshake is completed. -// It extracts the connection secrets from the handshake values. -func (h *encHandshake) secrets(auth, authResp []byte) (secrets, error) { - ecdheSecret, err := h.randomPrivKey.GenerateShared(h.remoteRandomPub, sskLen, sskLen) - if err != nil { - return secrets{}, err - } - - // derive base secrets from ephemeral key agreement - sharedSecret := crypto.Sha3(ecdheSecret, crypto.Sha3(h.respNonce, h.initNonce)) - aesSecret := crypto.Sha3(ecdheSecret, sharedSecret) - s := secrets{ - RemoteID: h.remoteID, - AES: aesSecret, - MAC: crypto.Sha3(ecdheSecret, aesSecret), - Token: crypto.Sha3(sharedSecret), - } - - // setup sha3 instances for the MACs - mac1 := sha3.NewKeccak256() - mac1.Write(xor(s.MAC, h.respNonce)) - mac1.Write(auth) - mac2 := sha3.NewKeccak256() - mac2.Write(xor(s.MAC, h.initNonce)) - mac2.Write(authResp) - if h.initiator { - s.EgressMAC, s.IngressMAC = mac1, mac2 - } else { - s.EgressMAC, s.IngressMAC = mac2, mac1 - } - - return s, nil -} - -func (h *encHandshake) ecdhShared(prv *ecdsa.PrivateKey) ([]byte, error) { - return ecies.ImportECDSA(prv).GenerateShared(h.remotePub, sskLen, sskLen) -} - -// initiatorEncHandshake negotiates a session token on conn. -// it should be called on the dialing side of the connection. -// -// prv is the local client's private key. -// token is the token from a previous session with this node. -func initiatorEncHandshake(conn io.ReadWriter, prv *ecdsa.PrivateKey, remoteID discover.NodeID, token []byte) (s secrets, err error) { - h, err := newInitiatorHandshake(remoteID) - if err != nil { - return s, err - } - auth, err := h.authMsg(prv, token) - if err != nil { - return s, err - } - if _, err = conn.Write(auth); err != nil { - return s, err - } - - response := make([]byte, encAuthRespLen) - if _, err = io.ReadFull(conn, response); err != nil { - return s, err - } - if err := h.decodeAuthResp(response, prv); err != nil { - return s, err - } - return h.secrets(auth, response) -} - -func newInitiatorHandshake(remoteID discover.NodeID) (*encHandshake, error) { - // generate random initiator nonce - n := make([]byte, shaLen) - if _, err := rand.Read(n); err != nil { - return nil, err - } - // generate random keypair to use for signing - randpriv, err := ecies.GenerateKey(rand.Reader, crypto.S256(), nil) - if err != nil { - return nil, err - } - rpub, err := remoteID.Pubkey() - if err != nil { - return nil, fmt.Errorf("bad remoteID: %v", err) - } - h := &encHandshake{ - initiator: true, - remoteID: remoteID, - remotePub: ecies.ImportECDSAPublic(rpub), - initNonce: n, - randomPrivKey: randpriv, - } - return h, nil -} - -// authMsg creates an encrypted initiator handshake message. -func (h *encHandshake) authMsg(prv *ecdsa.PrivateKey, token []byte) ([]byte, error) { - var tokenFlag byte - if token == nil { - // no session token found means we need to generate shared secret. - // ecies shared secret is used as initial session token for new peers - // generate shared key from prv and remote pubkey - var err error - if token, err = h.ecdhShared(prv); err != nil { - return nil, err - } - } else { - // for known peers, we use stored token from the previous session - tokenFlag = 0x01 - } - - // sign known message: - // ecdh-shared-secret^nonce for new peers - // token^nonce for old peers - signed := xor(token, h.initNonce) - signature, err := crypto.Sign(signed, h.randomPrivKey.ExportECDSA()) - if err != nil { - return nil, err - } - - // encode auth message - // signature || sha3(ecdhe-random-pubk) || pubk || nonce || token-flag - msg := make([]byte, authMsgLen) - n := copy(msg, signature) - n += copy(msg[n:], crypto.Sha3(exportPubkey(&h.randomPrivKey.PublicKey))) - n += copy(msg[n:], crypto.FromECDSAPub(&prv.PublicKey)[1:]) - n += copy(msg[n:], h.initNonce) - msg[n] = tokenFlag - - // encrypt auth message using remote-pubk - return ecies.Encrypt(rand.Reader, h.remotePub, msg, nil, nil) -} - -// decodeAuthResp decode an encrypted authentication response message. -func (h *encHandshake) decodeAuthResp(auth []byte, prv *ecdsa.PrivateKey) error { - msg, err := crypto.Decrypt(prv, auth) - if err != nil { - return fmt.Errorf("could not decrypt auth response (%v)", err) - } - h.respNonce = msg[pubLen : pubLen+shaLen] - h.remoteRandomPub, err = importPublicKey(msg[:pubLen]) - if err != nil { - return err - } - // ignore token flag for now - return nil -} - -// receiverEncHandshake negotiates a session token on conn. -// it should be called on the listening side of the connection. -// -// prv is the local client's private key. -// token is the token from a previous session with this node. -func receiverEncHandshake(conn io.ReadWriter, prv *ecdsa.PrivateKey, token []byte) (s secrets, err error) { - // read remote auth sent by initiator. - auth := make([]byte, encAuthMsgLen) - if _, err := io.ReadFull(conn, auth); err != nil { - return s, err - } - h, err := decodeAuthMsg(prv, token, auth) - if err != nil { - return s, err - } - - // send auth response - resp, err := h.authResp(prv, token) - if err != nil { - return s, err - } - if _, err = conn.Write(resp); err != nil { - return s, err - } - - return h.secrets(auth, resp) -} - -func decodeAuthMsg(prv *ecdsa.PrivateKey, token []byte, auth []byte) (*encHandshake, error) { - var err error - h := new(encHandshake) - // generate random keypair for session - h.randomPrivKey, err = ecies.GenerateKey(rand.Reader, crypto.S256(), nil) - if err != nil { - return nil, err - } - // generate random nonce - h.respNonce = make([]byte, shaLen) - if _, err = rand.Read(h.respNonce); err != nil { - return nil, err - } - - msg, err := crypto.Decrypt(prv, auth) - if err != nil { - return nil, fmt.Errorf("could not decrypt auth message (%v)", err) - } - - // decode message parameters - // signature || sha3(ecdhe-random-pubk) || pubk || nonce || token-flag - h.initNonce = msg[authMsgLen-shaLen-1 : authMsgLen-1] - copy(h.remoteID[:], msg[sigLen+shaLen:sigLen+shaLen+pubLen]) - rpub, err := h.remoteID.Pubkey() - if err != nil { - return nil, fmt.Errorf("bad remoteID: %#v", err) - } - h.remotePub = ecies.ImportECDSAPublic(rpub) - - // recover remote random pubkey from signed message. - if token == nil { - // TODO: it is an error if the initiator has a token and we don't. check that. - - // no session token means we need to generate shared secret. - // ecies shared secret is used as initial session token for new peers. - // generate shared key from prv and remote pubkey. - if token, err = h.ecdhShared(prv); err != nil { - return nil, err - } - } - signedMsg := xor(token, h.initNonce) - remoteRandomPub, err := secp256k1.RecoverPubkey(signedMsg, msg[:sigLen]) - if err != nil { - return nil, err - } - h.remoteRandomPub, _ = importPublicKey(remoteRandomPub) - return h, nil -} - -// authResp generates the encrypted authentication response message. -func (h *encHandshake) authResp(prv *ecdsa.PrivateKey, token []byte) ([]byte, error) { - // responder auth message - // E(remote-pubk, ecdhe-random-pubk || nonce || 0x0) - resp := make([]byte, authRespLen) - n := copy(resp, exportPubkey(&h.randomPrivKey.PublicKey)) - n += copy(resp[n:], h.respNonce) - if token == nil { - resp[n] = 0 - } else { - resp[n] = 1 - } - // encrypt using remote-pubk - return ecies.Encrypt(rand.Reader, h.remotePub, resp, nil, nil) -} - -// importPublicKey unmarshals 512 bit public keys. -func importPublicKey(pubKey []byte) (*ecies.PublicKey, error) { - var pubKey65 []byte - switch len(pubKey) { - case 64: - // add 'uncompressed key' flag - pubKey65 = append([]byte{0x04}, pubKey...) - case 65: - pubKey65 = pubKey - default: - return nil, fmt.Errorf("invalid public key length %v (expect 64/65)", len(pubKey)) - } - // TODO: fewer pointless conversions - return ecies.ImportECDSAPublic(crypto.ToECDSAPub(pubKey65)), nil -} - -func exportPubkey(pub *ecies.PublicKey) []byte { - if pub == nil { - panic("nil pubkey") - } - return elliptic.Marshal(pub.Curve, pub.X, pub.Y)[1:] -} - -func xor(one, other []byte) (xor []byte) { - xor = make([]byte, len(one)) - for i := 0; i < len(one); i++ { - xor[i] = one[i] ^ other[i] - } - return xor -} - -func readProtocolHandshake(rw MsgReadWriter, wantID discover.NodeID, our *protoHandshake) (*protoHandshake, error) { - msg, err := rw.ReadMsg() - if err != nil { - return nil, err - } - if msg.Code == discMsg { - // disconnect before protocol handshake is valid according to the - // spec and we send it ourself if Server.addPeer fails. - var reason [1]DiscReason - rlp.Decode(msg.Payload, &reason) - return nil, reason[0] - } - if msg.Code != handshakeMsg { - return nil, fmt.Errorf("expected handshake, got %x", msg.Code) - } - if msg.Size > baseProtocolMaxMsgSize { - return nil, fmt.Errorf("message too big (%d > %d)", msg.Size, baseProtocolMaxMsgSize) - } - var hs protoHandshake - if err := msg.Decode(&hs); err != nil { - return nil, err - } - // validate handshake info - if hs.Version != our.Version { - SendItems(rw, discMsg, DiscIncompatibleVersion) - return nil, fmt.Errorf("required version %d, received %d\n", baseProtocolVersion, hs.Version) - } - if (hs.ID == discover.NodeID{}) { - SendItems(rw, discMsg, DiscInvalidIdentity) - return nil, errors.New("invalid public key in handshake") - } - if hs.ID != wantID { - SendItems(rw, discMsg, DiscUnexpectedIdentity) - return nil, errors.New("handshake node ID does not match encryption handshake") - } - return &hs, nil -} diff --git a/p2p/handshake_test.go b/p2p/handshake_test.go deleted file mode 100644 index ab75921a366150c309d0588b340d3a3d7ac4dceb..0000000000000000000000000000000000000000 --- a/p2p/handshake_test.go +++ /dev/null @@ -1,172 +0,0 @@ -package p2p - -import ( - "bytes" - "crypto/rand" - "fmt" - "net" - "reflect" - "testing" - "time" - - "github.com/ethereum/go-ethereum/crypto" - "github.com/ethereum/go-ethereum/crypto/ecies" - "github.com/ethereum/go-ethereum/p2p/discover" -) - -func TestSharedSecret(t *testing.T) { - prv0, _ := crypto.GenerateKey() // = ecdsa.GenerateKey(crypto.S256(), rand.Reader) - pub0 := &prv0.PublicKey - prv1, _ := crypto.GenerateKey() - pub1 := &prv1.PublicKey - - ss0, err := ecies.ImportECDSA(prv0).GenerateShared(ecies.ImportECDSAPublic(pub1), sskLen, sskLen) - if err != nil { - return - } - ss1, err := ecies.ImportECDSA(prv1).GenerateShared(ecies.ImportECDSAPublic(pub0), sskLen, sskLen) - if err != nil { - return - } - t.Logf("Secret:\n%v %x\n%v %x", len(ss0), ss0, len(ss0), ss1) - if !bytes.Equal(ss0, ss1) { - t.Errorf("dont match :(") - } -} - -func TestEncHandshake(t *testing.T) { - for i := 0; i < 20; i++ { - start := time.Now() - if err := testEncHandshake(nil); err != nil { - t.Fatalf("i=%d %v", i, err) - } - t.Logf("(without token) %d %v\n", i+1, time.Since(start)) - } - - for i := 0; i < 20; i++ { - tok := make([]byte, shaLen) - rand.Reader.Read(tok) - start := time.Now() - if err := testEncHandshake(tok); err != nil { - t.Fatalf("i=%d %v", i, err) - } - t.Logf("(with token) %d %v\n", i+1, time.Since(start)) - } -} - -func testEncHandshake(token []byte) error { - type result struct { - side string - s secrets - err error - } - var ( - prv0, _ = crypto.GenerateKey() - prv1, _ = crypto.GenerateKey() - rw0, rw1 = net.Pipe() - output = make(chan result) - ) - - go func() { - r := result{side: "initiator"} - defer func() { output <- r }() - - pub1s := discover.PubkeyID(&prv1.PublicKey) - r.s, r.err = initiatorEncHandshake(rw0, prv0, pub1s, token) - if r.err != nil { - return - } - id1 := discover.PubkeyID(&prv1.PublicKey) - if r.s.RemoteID != id1 { - r.err = fmt.Errorf("remote ID mismatch: got %v, want: %v", r.s.RemoteID, id1) - } - }() - go func() { - r := result{side: "receiver"} - defer func() { output <- r }() - - r.s, r.err = receiverEncHandshake(rw1, prv1, token) - if r.err != nil { - return - } - id0 := discover.PubkeyID(&prv0.PublicKey) - if r.s.RemoteID != id0 { - r.err = fmt.Errorf("remote ID mismatch: got %v, want: %v", r.s.RemoteID, id0) - } - }() - - // wait for results from both sides - r1, r2 := <-output, <-output - - if r1.err != nil { - return fmt.Errorf("%s side error: %v", r1.side, r1.err) - } - if r2.err != nil { - return fmt.Errorf("%s side error: %v", r2.side, r2.err) - } - - // don't compare remote node IDs - r1.s.RemoteID, r2.s.RemoteID = discover.NodeID{}, discover.NodeID{} - // flip MACs on one of them so they compare equal - r1.s.EgressMAC, r1.s.IngressMAC = r1.s.IngressMAC, r1.s.EgressMAC - if !reflect.DeepEqual(r1.s, r2.s) { - return fmt.Errorf("secrets mismatch:\n t1: %#v\n t2: %#v", r1.s, r2.s) - } - return nil -} - -func TestSetupConn(t *testing.T) { - prv0, _ := crypto.GenerateKey() - prv1, _ := crypto.GenerateKey() - node0 := &discover.Node{ - ID: discover.PubkeyID(&prv0.PublicKey), - IP: net.IP{1, 2, 3, 4}, - TCP: 33, - } - node1 := &discover.Node{ - ID: discover.PubkeyID(&prv1.PublicKey), - IP: net.IP{5, 6, 7, 8}, - TCP: 44, - } - hs0 := &protoHandshake{ - Version: baseProtocolVersion, - ID: node0.ID, - Caps: []Cap{{"a", 0}, {"b", 2}}, - } - hs1 := &protoHandshake{ - Version: baseProtocolVersion, - ID: node1.ID, - Caps: []Cap{{"c", 1}, {"d", 3}}, - } - fd0, fd1 := net.Pipe() - - done := make(chan struct{}) - keepalways := func(discover.NodeID) bool { return true } - go func() { - defer close(done) - conn0, err := setupConn(fd0, prv0, hs0, node1, keepalways) - if err != nil { - t.Errorf("outbound side error: %v", err) - return - } - if conn0.ID != node1.ID { - t.Errorf("outbound conn id mismatch: got %v, want %v", conn0.ID, node1.ID) - } - if !reflect.DeepEqual(conn0.Caps, hs1.Caps) { - t.Errorf("outbound caps mismatch: got %v, want %v", conn0.Caps, hs1.Caps) - } - }() - - conn1, err := setupConn(fd1, prv1, hs1, nil, keepalways) - if err != nil { - t.Fatalf("inbound side error: %v", err) - } - if conn1.ID != node0.ID { - t.Errorf("inbound conn id mismatch: got %v, want %v", conn1.ID, node0.ID) - } - if !reflect.DeepEqual(conn1.Caps, hs0.Caps) { - t.Errorf("inbound caps mismatch: got %v, want %v", conn1.Caps, hs0.Caps) - } - - <-done -} diff --git a/p2p/peer.go b/p2p/peer.go index 87a91d406df85ddec2d7e7961567e08a7655a123..cbe5ccc84aa448bee0f60aa070ff233977fe2a9c 100644 --- a/p2p/peer.go +++ b/p2p/peer.go @@ -33,9 +33,17 @@ const ( peersMsg = 0x05 ) +// protoHandshake is the RLP structure of the protocol handshake. +type protoHandshake struct { + Version uint64 + Name string + Caps []Cap + ListenPort uint64 + ID discover.NodeID +} + // Peer represents a connected remote node. type Peer struct { - conn net.Conn rw *conn running map[string]*protoRW @@ -48,37 +56,36 @@ type Peer struct { // NewPeer returns a peer for testing purposes. func NewPeer(id discover.NodeID, name string, caps []Cap) *Peer { pipe, _ := net.Pipe() - msgpipe, _ := MsgPipe() - conn := &conn{msgpipe, &protoHandshake{ID: id, Name: name, Caps: caps}} - peer := newPeer(pipe, conn, nil) + conn := &conn{fd: pipe, transport: nil, id: id, caps: caps, name: name} + peer := newPeer(conn, nil) close(peer.closed) // ensures Disconnect doesn't block return peer } // ID returns the node's public key. func (p *Peer) ID() discover.NodeID { - return p.rw.ID + return p.rw.id } // Name returns the node name that the remote node advertised. func (p *Peer) Name() string { - return p.rw.Name + return p.rw.name } // Caps returns the capabilities (supported subprotocols) of the remote peer. func (p *Peer) Caps() []Cap { // TODO: maybe return copy - return p.rw.Caps + return p.rw.caps } // RemoteAddr returns the remote address of the network connection. func (p *Peer) RemoteAddr() net.Addr { - return p.conn.RemoteAddr() + return p.rw.fd.RemoteAddr() } // LocalAddr returns the local address of the network connection. func (p *Peer) LocalAddr() net.Addr { - return p.conn.LocalAddr() + return p.rw.fd.LocalAddr() } // Disconnect terminates the peer connection with the given reason. @@ -92,13 +99,12 @@ func (p *Peer) Disconnect(reason DiscReason) { // String implements fmt.Stringer. func (p *Peer) String() string { - return fmt.Sprintf("Peer %.8x %v", p.rw.ID[:], p.RemoteAddr()) + return fmt.Sprintf("Peer %x %v", p.rw.id[:8], p.RemoteAddr()) } -func newPeer(fd net.Conn, conn *conn, protocols []Protocol) *Peer { - protomap := matchProtocols(protocols, conn.Caps, conn) +func newPeer(conn *conn, protocols []Protocol) *Peer { + protomap := matchProtocols(protocols, conn.caps, conn) p := &Peer{ - conn: fd, rw: conn, running: protomap, disc: make(chan DiscReason), @@ -117,7 +123,10 @@ func (p *Peer) run() DiscReason { p.startProtocols() // Wait for an error or disconnect. - var reason DiscReason + var ( + reason DiscReason + requested bool + ) select { case err := <-readErr: if r, ok := err.(DiscReason); ok { @@ -131,21 +140,17 @@ func (p *Peer) run() DiscReason { case err := <-p.protoErr: reason = discReasonForError(err) case reason = <-p.disc: - p.politeDisconnect(reason) - reason = DiscRequested + requested = true } - close(p.closed) + p.rw.close(reason) p.wg.Wait() - glog.V(logger.Debug).Infof("%v: Disconnected: %v\n", p, reason) - return reason -} -func (p *Peer) politeDisconnect(reason DiscReason) { - if reason != DiscNetworkError { - SendItems(p.rw, discMsg, uint(reason)) + if requested { + reason = DiscRequested } - p.conn.Close() + glog.V(logger.Debug).Infof("%v: Disconnected: %v\n", p, reason) + return reason } func (p *Peer) pingLoop() { @@ -254,7 +259,7 @@ func (p *Peer) startProtocols() { glog.V(logger.Detail).Infof("%v: Protocol %s/%d returned\n", p, proto.Name, proto.Version) err = errors.New("protocol returned") } else if err != io.EOF { - glog.V(logger.Detail).Infof("%v: Protocol %s/%d error: \n", p, proto.Name, proto.Version, err) + glog.V(logger.Detail).Infof("%v: Protocol %s/%d error: %v\n", p, proto.Name, proto.Version, err) } p.protoErr <- err p.wg.Done() diff --git a/p2p/peer_error.go b/p2p/peer_error.go index a912f60644e90e8c93633195131565d219926065..6938a9801fc1f3a285817927d4a43bd3faaa7575 100644 --- a/p2p/peer_error.go +++ b/p2p/peer_error.go @@ -5,39 +5,17 @@ import ( ) const ( - errMagicTokenMismatch = iota - errRead - errWrite - errMisc - errInvalidMsgCode + errInvalidMsgCode = iota errInvalidMsg - errP2PVersionMismatch - errPubkeyInvalid - errPubkeyForbidden - errProtocolBreach - errPingTimeout - errInvalidNetworkId - errInvalidProtocolVersion ) var errorToString = map[int]string{ - errMagicTokenMismatch: "magic token mismatch", - errRead: "read error", - errWrite: "write error", - errMisc: "misc error", - errInvalidMsgCode: "invalid message code", - errInvalidMsg: "invalid message", - errP2PVersionMismatch: "P2P Version Mismatch", - errPubkeyInvalid: "public key invalid", - errPubkeyForbidden: "public key forbidden", - errProtocolBreach: "protocol Breach", - errPingTimeout: "ping timeout", - errInvalidNetworkId: "invalid network id", - errInvalidProtocolVersion: "invalid protocol version", + errInvalidMsgCode: "invalid message code", + errInvalidMsg: "invalid message", } type peerError struct { - Code int + code int message string } @@ -107,23 +85,13 @@ func discReasonForError(err error) DiscReason { return reason } peerError, ok := err.(*peerError) - if !ok { - return DiscSubprotocolError - } - switch peerError.Code { - case errP2PVersionMismatch: - return DiscIncompatibleVersion - case errPubkeyInvalid: - return DiscInvalidIdentity - case errPubkeyForbidden: - return DiscUselessPeer - case errInvalidMsgCode, errMagicTokenMismatch, errProtocolBreach: - return DiscProtocolError - case errPingTimeout: - return DiscReadTimeout - case errRead, errWrite: - return DiscNetworkError - default: - return DiscSubprotocolError + if ok { + switch peerError.code { + case errInvalidMsgCode, errInvalidMsg: + return DiscProtocolError + default: + return DiscSubprotocolError + } } + return DiscSubprotocolError } diff --git a/p2p/peer_test.go b/p2p/peer_test.go index 0ac032ab7a702a5b84152a6f4dbc8d311d0ecaca..7b772e1988652de467d422829e7820f27ee5789d 100644 --- a/p2p/peer_test.go +++ b/p2p/peer_test.go @@ -28,24 +28,20 @@ var discard = Protocol{ } func testPeer(protos []Protocol) (func(), *conn, *Peer, <-chan DiscReason) { - fd1, _ := net.Pipe() - hs1 := &protoHandshake{ID: randomID(), Version: baseProtocolVersion} - hs2 := &protoHandshake{ID: randomID(), Version: baseProtocolVersion} + fd1, fd2 := net.Pipe() + c1 := &conn{fd: fd1, transport: newTestTransport(randomID(), fd1)} + c2 := &conn{fd: fd2, transport: newTestTransport(randomID(), fd2)} for _, p := range protos { - hs1.Caps = append(hs1.Caps, p.cap()) - hs2.Caps = append(hs2.Caps, p.cap()) + c1.caps = append(c1.caps, p.cap()) + c2.caps = append(c2.caps, p.cap()) } - p1, p2 := MsgPipe() - peer := newPeer(fd1, &conn{p1, hs1}, protos) + peer := newPeer(c1, protos) errc := make(chan DiscReason, 1) go func() { errc <- peer.run() }() - closer := func() { - p1.Close() - fd1.Close() - } - return closer, &conn{p2, hs2}, peer, errc + closer := func() { c2.close(errors.New("close func called")) } + return closer, c2, peer, errc } func TestPeerProtoReadMsg(t *testing.T) { diff --git a/p2p/rlpx.go b/p2p/rlpx.go index 6b533e2751b6c0c1466c59ca8e57dc469524b104..e1cb13aae6f3a9298b6ae67ad4f90adc2ea711d8 100644 --- a/p2p/rlpx.go +++ b/p2p/rlpx.go @@ -4,23 +4,459 @@ import ( "bytes" "crypto/aes" "crypto/cipher" + "crypto/ecdsa" + "crypto/elliptic" "crypto/hmac" + "crypto/rand" "errors" + "fmt" "hash" "io" + "net" + "sync" + "time" + "github.com/ethereum/go-ethereum/crypto" + "github.com/ethereum/go-ethereum/crypto/ecies" + "github.com/ethereum/go-ethereum/crypto/secp256k1" + "github.com/ethereum/go-ethereum/crypto/sha3" + "github.com/ethereum/go-ethereum/p2p/discover" "github.com/ethereum/go-ethereum/rlp" ) +const ( + maxUint24 = ^uint32(0) >> 8 + + sskLen = 16 // ecies.MaxSharedKeyLength(pubKey) / 2 + sigLen = 65 // elliptic S256 + pubLen = 64 // 512 bit pubkey in uncompressed representation without format byte + shaLen = 32 // hash length (for nonce etc) + + authMsgLen = sigLen + shaLen + pubLen + shaLen + 1 + authRespLen = pubLen + shaLen + 1 + + eciesBytes = 65 + 16 + 32 + encAuthMsgLen = authMsgLen + eciesBytes // size of the final ECIES payload sent as initiator's handshake + encAuthRespLen = authRespLen + eciesBytes // size of the final ECIES payload sent as receiver's handshake + + // total timeout for encryption handshake and protocol + // handshake in both directions. + handshakeTimeout = 5 * time.Second + + // This is the timeout for sending the disconnect reason. + // This is shorter than the usual timeout because we don't want + // to wait if the connection is known to be bad anyway. + discWriteTimeout = 1 * time.Second +) + +// rlpx is the transport protocol used by actual (non-test) connections. +// It wraps the frame encoder with locks and read/write deadlines. +type rlpx struct { + fd net.Conn + + rmu, wmu sync.Mutex + rw *rlpxFrameRW +} + +func newRLPX(fd net.Conn) transport { + fd.SetDeadline(time.Now().Add(handshakeTimeout)) + return &rlpx{fd: fd} +} + +func (t *rlpx) ReadMsg() (Msg, error) { + t.rmu.Lock() + defer t.rmu.Unlock() + t.fd.SetReadDeadline(time.Now().Add(frameReadTimeout)) + return t.rw.ReadMsg() +} + +func (t *rlpx) WriteMsg(msg Msg) error { + t.wmu.Lock() + defer t.wmu.Unlock() + t.fd.SetWriteDeadline(time.Now().Add(frameWriteTimeout)) + return t.rw.WriteMsg(msg) +} + +func (t *rlpx) close(err error) { + t.wmu.Lock() + defer t.wmu.Unlock() + // Tell the remote end why we're disconnecting if possible. + if t.rw != nil { + if r, ok := err.(DiscReason); ok && r != DiscNetworkError { + t.fd.SetWriteDeadline(time.Now().Add(discWriteTimeout)) + SendItems(t.rw, discMsg, r) + } + } + t.fd.Close() +} + +// doEncHandshake runs the protocol handshake using authenticated +// messages. the protocol handshake is the first authenticated message +// and also verifies whether the encryption handshake 'worked' and the +// remote side actually provided the right public key. +func (t *rlpx) doProtoHandshake(our *protoHandshake) (their *protoHandshake, err error) { + // Writing our handshake happens concurrently, we prefer + // returning the handshake read error. If the remote side + // disconnects us early with a valid reason, we should return it + // as the error so it can be tracked elsewhere. + werr := make(chan error, 1) + go func() { werr <- Send(t.rw, handshakeMsg, our) }() + if their, err = readProtocolHandshake(t.rw, our); err != nil { + return nil, err + } + if err := <-werr; err != nil { + return nil, fmt.Errorf("write error: %v", err) + } + return their, nil +} + +func readProtocolHandshake(rw MsgReader, our *protoHandshake) (*protoHandshake, error) { + msg, err := rw.ReadMsg() + if err != nil { + return nil, err + } + if msg.Size > baseProtocolMaxMsgSize { + return nil, fmt.Errorf("message too big") + } + if msg.Code == discMsg { + // Disconnect before protocol handshake is valid according to the + // spec and we send it ourself if the posthanshake checks fail. + // We can't return the reason directly, though, because it is echoed + // back otherwise. Wrap it in a string instead. + var reason [1]DiscReason + rlp.Decode(msg.Payload, &reason) + return nil, reason[0] + } + if msg.Code != handshakeMsg { + return nil, fmt.Errorf("expected handshake, got %x", msg.Code) + } + var hs protoHandshake + if err := msg.Decode(&hs); err != nil { + return nil, err + } + // validate handshake info + if hs.Version != our.Version { + return nil, DiscIncompatibleVersion + } + if (hs.ID == discover.NodeID{}) { + return nil, DiscInvalidIdentity + } + return &hs, nil +} + +func (t *rlpx) doEncHandshake(prv *ecdsa.PrivateKey, dial *discover.Node) (discover.NodeID, error) { + var ( + sec secrets + err error + ) + if dial == nil { + sec, err = receiverEncHandshake(t.fd, prv, nil) + } else { + sec, err = initiatorEncHandshake(t.fd, prv, dial.ID, nil) + } + if err != nil { + return discover.NodeID{}, err + } + t.wmu.Lock() + t.rw = newRLPXFrameRW(t.fd, sec) + t.wmu.Unlock() + return sec.RemoteID, nil +} + +// encHandshake contains the state of the encryption handshake. +type encHandshake struct { + initiator bool + remoteID discover.NodeID + + remotePub *ecies.PublicKey // remote-pubk + initNonce, respNonce []byte // nonce + randomPrivKey *ecies.PrivateKey // ecdhe-random + remoteRandomPub *ecies.PublicKey // ecdhe-random-pubk +} + +// secrets represents the connection secrets +// which are negotiated during the encryption handshake. +type secrets struct { + RemoteID discover.NodeID + AES, MAC []byte + EgressMAC, IngressMAC hash.Hash + Token []byte +} + +// secrets is called after the handshake is completed. +// It extracts the connection secrets from the handshake values. +func (h *encHandshake) secrets(auth, authResp []byte) (secrets, error) { + ecdheSecret, err := h.randomPrivKey.GenerateShared(h.remoteRandomPub, sskLen, sskLen) + if err != nil { + return secrets{}, err + } + + // derive base secrets from ephemeral key agreement + sharedSecret := crypto.Sha3(ecdheSecret, crypto.Sha3(h.respNonce, h.initNonce)) + aesSecret := crypto.Sha3(ecdheSecret, sharedSecret) + s := secrets{ + RemoteID: h.remoteID, + AES: aesSecret, + MAC: crypto.Sha3(ecdheSecret, aesSecret), + Token: crypto.Sha3(sharedSecret), + } + + // setup sha3 instances for the MACs + mac1 := sha3.NewKeccak256() + mac1.Write(xor(s.MAC, h.respNonce)) + mac1.Write(auth) + mac2 := sha3.NewKeccak256() + mac2.Write(xor(s.MAC, h.initNonce)) + mac2.Write(authResp) + if h.initiator { + s.EgressMAC, s.IngressMAC = mac1, mac2 + } else { + s.EgressMAC, s.IngressMAC = mac2, mac1 + } + + return s, nil +} + +func (h *encHandshake) ecdhShared(prv *ecdsa.PrivateKey) ([]byte, error) { + return ecies.ImportECDSA(prv).GenerateShared(h.remotePub, sskLen, sskLen) +} + +// initiatorEncHandshake negotiates a session token on conn. +// it should be called on the dialing side of the connection. +// +// prv is the local client's private key. +// token is the token from a previous session with this node. +func initiatorEncHandshake(conn io.ReadWriter, prv *ecdsa.PrivateKey, remoteID discover.NodeID, token []byte) (s secrets, err error) { + h, err := newInitiatorHandshake(remoteID) + if err != nil { + return s, err + } + auth, err := h.authMsg(prv, token) + if err != nil { + return s, err + } + if _, err = conn.Write(auth); err != nil { + return s, err + } + + response := make([]byte, encAuthRespLen) + if _, err = io.ReadFull(conn, response); err != nil { + return s, err + } + if err := h.decodeAuthResp(response, prv); err != nil { + return s, err + } + return h.secrets(auth, response) +} + +func newInitiatorHandshake(remoteID discover.NodeID) (*encHandshake, error) { + // generate random initiator nonce + n := make([]byte, shaLen) + if _, err := rand.Read(n); err != nil { + return nil, err + } + // generate random keypair to use for signing + randpriv, err := ecies.GenerateKey(rand.Reader, crypto.S256(), nil) + if err != nil { + return nil, err + } + rpub, err := remoteID.Pubkey() + if err != nil { + return nil, fmt.Errorf("bad remoteID: %v", err) + } + h := &encHandshake{ + initiator: true, + remoteID: remoteID, + remotePub: ecies.ImportECDSAPublic(rpub), + initNonce: n, + randomPrivKey: randpriv, + } + return h, nil +} + +// authMsg creates an encrypted initiator handshake message. +func (h *encHandshake) authMsg(prv *ecdsa.PrivateKey, token []byte) ([]byte, error) { + var tokenFlag byte + if token == nil { + // no session token found means we need to generate shared secret. + // ecies shared secret is used as initial session token for new peers + // generate shared key from prv and remote pubkey + var err error + if token, err = h.ecdhShared(prv); err != nil { + return nil, err + } + } else { + // for known peers, we use stored token from the previous session + tokenFlag = 0x01 + } + + // sign known message: + // ecdh-shared-secret^nonce for new peers + // token^nonce for old peers + signed := xor(token, h.initNonce) + signature, err := crypto.Sign(signed, h.randomPrivKey.ExportECDSA()) + if err != nil { + return nil, err + } + + // encode auth message + // signature || sha3(ecdhe-random-pubk) || pubk || nonce || token-flag + msg := make([]byte, authMsgLen) + n := copy(msg, signature) + n += copy(msg[n:], crypto.Sha3(exportPubkey(&h.randomPrivKey.PublicKey))) + n += copy(msg[n:], crypto.FromECDSAPub(&prv.PublicKey)[1:]) + n += copy(msg[n:], h.initNonce) + msg[n] = tokenFlag + + // encrypt auth message using remote-pubk + return ecies.Encrypt(rand.Reader, h.remotePub, msg, nil, nil) +} + +// decodeAuthResp decode an encrypted authentication response message. +func (h *encHandshake) decodeAuthResp(auth []byte, prv *ecdsa.PrivateKey) error { + msg, err := crypto.Decrypt(prv, auth) + if err != nil { + return fmt.Errorf("could not decrypt auth response (%v)", err) + } + h.respNonce = msg[pubLen : pubLen+shaLen] + h.remoteRandomPub, err = importPublicKey(msg[:pubLen]) + if err != nil { + return err + } + // ignore token flag for now + return nil +} + +// receiverEncHandshake negotiates a session token on conn. +// it should be called on the listening side of the connection. +// +// prv is the local client's private key. +// token is the token from a previous session with this node. +func receiverEncHandshake(conn io.ReadWriter, prv *ecdsa.PrivateKey, token []byte) (s secrets, err error) { + // read remote auth sent by initiator. + auth := make([]byte, encAuthMsgLen) + if _, err := io.ReadFull(conn, auth); err != nil { + return s, err + } + h, err := decodeAuthMsg(prv, token, auth) + if err != nil { + return s, err + } + + // send auth response + resp, err := h.authResp(prv, token) + if err != nil { + return s, err + } + if _, err = conn.Write(resp); err != nil { + return s, err + } + + return h.secrets(auth, resp) +} + +func decodeAuthMsg(prv *ecdsa.PrivateKey, token []byte, auth []byte) (*encHandshake, error) { + var err error + h := new(encHandshake) + // generate random keypair for session + h.randomPrivKey, err = ecies.GenerateKey(rand.Reader, crypto.S256(), nil) + if err != nil { + return nil, err + } + // generate random nonce + h.respNonce = make([]byte, shaLen) + if _, err = rand.Read(h.respNonce); err != nil { + return nil, err + } + + msg, err := crypto.Decrypt(prv, auth) + if err != nil { + return nil, fmt.Errorf("could not decrypt auth message (%v)", err) + } + + // decode message parameters + // signature || sha3(ecdhe-random-pubk) || pubk || nonce || token-flag + h.initNonce = msg[authMsgLen-shaLen-1 : authMsgLen-1] + copy(h.remoteID[:], msg[sigLen+shaLen:sigLen+shaLen+pubLen]) + rpub, err := h.remoteID.Pubkey() + if err != nil { + return nil, fmt.Errorf("bad remoteID: %#v", err) + } + h.remotePub = ecies.ImportECDSAPublic(rpub) + + // recover remote random pubkey from signed message. + if token == nil { + // TODO: it is an error if the initiator has a token and we don't. check that. + + // no session token means we need to generate shared secret. + // ecies shared secret is used as initial session token for new peers. + // generate shared key from prv and remote pubkey. + if token, err = h.ecdhShared(prv); err != nil { + return nil, err + } + } + signedMsg := xor(token, h.initNonce) + remoteRandomPub, err := secp256k1.RecoverPubkey(signedMsg, msg[:sigLen]) + if err != nil { + return nil, err + } + h.remoteRandomPub, _ = importPublicKey(remoteRandomPub) + return h, nil +} + +// authResp generates the encrypted authentication response message. +func (h *encHandshake) authResp(prv *ecdsa.PrivateKey, token []byte) ([]byte, error) { + // responder auth message + // E(remote-pubk, ecdhe-random-pubk || nonce || 0x0) + resp := make([]byte, authRespLen) + n := copy(resp, exportPubkey(&h.randomPrivKey.PublicKey)) + n += copy(resp[n:], h.respNonce) + if token == nil { + resp[n] = 0 + } else { + resp[n] = 1 + } + // encrypt using remote-pubk + return ecies.Encrypt(rand.Reader, h.remotePub, resp, nil, nil) +} + +// importPublicKey unmarshals 512 bit public keys. +func importPublicKey(pubKey []byte) (*ecies.PublicKey, error) { + var pubKey65 []byte + switch len(pubKey) { + case 64: + // add 'uncompressed key' flag + pubKey65 = append([]byte{0x04}, pubKey...) + case 65: + pubKey65 = pubKey + default: + return nil, fmt.Errorf("invalid public key length %v (expect 64/65)", len(pubKey)) + } + // TODO: fewer pointless conversions + return ecies.ImportECDSAPublic(crypto.ToECDSAPub(pubKey65)), nil +} + +func exportPubkey(pub *ecies.PublicKey) []byte { + if pub == nil { + panic("nil pubkey") + } + return elliptic.Marshal(pub.Curve, pub.X, pub.Y)[1:] +} + +func xor(one, other []byte) (xor []byte) { + xor = make([]byte, len(one)) + for i := 0; i < len(one); i++ { + xor[i] = one[i] ^ other[i] + } + return xor +} + var ( // this is used in place of actual frame header data. // TODO: replace this when Msg contains the protocol type code. zeroHeader = []byte{0xC2, 0x80, 0x80} - // sixteen zero bytes zero16 = make([]byte, 16) - - maxUint24 = ^uint32(0) >> 8 ) // rlpxFrameRW implements a simplified version of RLPx framing. @@ -38,7 +474,7 @@ type rlpxFrameRW struct { ingressMAC hash.Hash } -func newRlpxFrameRW(conn io.ReadWriter, s secrets) *rlpxFrameRW { +func newRLPXFrameRW(conn io.ReadWriter, s secrets) *rlpxFrameRW { macc, err := aes.NewCipher(s.MAC) if err != nil { panic("invalid MAC secret: " + err.Error()) diff --git a/p2p/rlpx_test.go b/p2p/rlpx_test.go index d98f1c2cd7ea34bf72eedca43f51861f2265e93b..44be46a99dee21439ca13dc25bade1b4597acde5 100644 --- a/p2p/rlpx_test.go +++ b/p2p/rlpx_test.go @@ -3,19 +3,253 @@ package p2p import ( "bytes" "crypto/rand" + "errors" + "fmt" "io/ioutil" + "net" + "reflect" "strings" + "sync" "testing" + "time" + "github.com/davecgh/go-spew/spew" "github.com/ethereum/go-ethereum/crypto" + "github.com/ethereum/go-ethereum/crypto/ecies" "github.com/ethereum/go-ethereum/crypto/sha3" + "github.com/ethereum/go-ethereum/p2p/discover" "github.com/ethereum/go-ethereum/rlp" ) -func TestRlpxFrameFake(t *testing.T) { +func TestSharedSecret(t *testing.T) { + prv0, _ := crypto.GenerateKey() // = ecdsa.GenerateKey(crypto.S256(), rand.Reader) + pub0 := &prv0.PublicKey + prv1, _ := crypto.GenerateKey() + pub1 := &prv1.PublicKey + + ss0, err := ecies.ImportECDSA(prv0).GenerateShared(ecies.ImportECDSAPublic(pub1), sskLen, sskLen) + if err != nil { + return + } + ss1, err := ecies.ImportECDSA(prv1).GenerateShared(ecies.ImportECDSAPublic(pub0), sskLen, sskLen) + if err != nil { + return + } + t.Logf("Secret:\n%v %x\n%v %x", len(ss0), ss0, len(ss0), ss1) + if !bytes.Equal(ss0, ss1) { + t.Errorf("dont match :(") + } +} + +func TestEncHandshake(t *testing.T) { + for i := 0; i < 10; i++ { + start := time.Now() + if err := testEncHandshake(nil); err != nil { + t.Fatalf("i=%d %v", i, err) + } + t.Logf("(without token) %d %v\n", i+1, time.Since(start)) + } + for i := 0; i < 10; i++ { + tok := make([]byte, shaLen) + rand.Reader.Read(tok) + start := time.Now() + if err := testEncHandshake(tok); err != nil { + t.Fatalf("i=%d %v", i, err) + } + t.Logf("(with token) %d %v\n", i+1, time.Since(start)) + } +} + +func testEncHandshake(token []byte) error { + type result struct { + side string + id discover.NodeID + err error + } + var ( + prv0, _ = crypto.GenerateKey() + prv1, _ = crypto.GenerateKey() + fd0, fd1 = net.Pipe() + c0, c1 = newRLPX(fd0).(*rlpx), newRLPX(fd1).(*rlpx) + output = make(chan result) + ) + + go func() { + r := result{side: "initiator"} + defer func() { output <- r }() + + dest := &discover.Node{ID: discover.PubkeyID(&prv1.PublicKey)} + r.id, r.err = c0.doEncHandshake(prv0, dest) + if r.err != nil { + return + } + id1 := discover.PubkeyID(&prv1.PublicKey) + if r.id != id1 { + r.err = fmt.Errorf("remote ID mismatch: got %v, want: %v", r.id, id1) + } + }() + go func() { + r := result{side: "receiver"} + defer func() { output <- r }() + + r.id, r.err = c1.doEncHandshake(prv1, nil) + if r.err != nil { + return + } + id0 := discover.PubkeyID(&prv0.PublicKey) + if r.id != id0 { + r.err = fmt.Errorf("remote ID mismatch: got %v, want: %v", r.id, id0) + } + }() + + // wait for results from both sides + r1, r2 := <-output, <-output + if r1.err != nil { + return fmt.Errorf("%s side error: %v", r1.side, r1.err) + } + if r2.err != nil { + return fmt.Errorf("%s side error: %v", r2.side, r2.err) + } + + // compare derived secrets + if !reflect.DeepEqual(c0.rw.egressMAC, c1.rw.ingressMAC) { + return fmt.Errorf("egress mac mismatch:\n c0.rw: %#v\n c1.rw: %#v", c0.rw.egressMAC, c1.rw.ingressMAC) + } + if !reflect.DeepEqual(c0.rw.ingressMAC, c1.rw.egressMAC) { + return fmt.Errorf("ingress mac mismatch:\n c0.rw: %#v\n c1.rw: %#v", c0.rw.ingressMAC, c1.rw.egressMAC) + } + if !reflect.DeepEqual(c0.rw.enc, c1.rw.enc) { + return fmt.Errorf("enc cipher mismatch:\n c0.rw: %#v\n c1.rw: %#v", c0.rw.enc, c1.rw.enc) + } + if !reflect.DeepEqual(c0.rw.dec, c1.rw.dec) { + return fmt.Errorf("dec cipher mismatch:\n c0.rw: %#v\n c1.rw: %#v", c0.rw.dec, c1.rw.dec) + } + return nil +} + +func TestProtocolHandshake(t *testing.T) { + var ( + prv0, _ = crypto.GenerateKey() + node0 = &discover.Node{ID: discover.PubkeyID(&prv0.PublicKey), IP: net.IP{1, 2, 3, 4}, TCP: 33} + hs0 = &protoHandshake{Version: 3, ID: node0.ID, Caps: []Cap{{"a", 0}, {"b", 2}}} + + prv1, _ = crypto.GenerateKey() + node1 = &discover.Node{ID: discover.PubkeyID(&prv1.PublicKey), IP: net.IP{5, 6, 7, 8}, TCP: 44} + hs1 = &protoHandshake{Version: 3, ID: node1.ID, Caps: []Cap{{"c", 1}, {"d", 3}}} + + fd0, fd1 = net.Pipe() + wg sync.WaitGroup + ) + + wg.Add(2) + go func() { + defer wg.Done() + rlpx := newRLPX(fd0) + remid, err := rlpx.doEncHandshake(prv0, node1) + if err != nil { + t.Errorf("dial side enc handshake failed: %v", err) + return + } + if remid != node1.ID { + t.Errorf("dial side remote id mismatch: got %v, want %v", remid, node1.ID) + return + } + + phs, err := rlpx.doProtoHandshake(hs0) + if err != nil { + t.Errorf("dial side proto handshake error: %v", err) + return + } + if !reflect.DeepEqual(phs, hs1) { + t.Errorf("dial side proto handshake mismatch:\ngot: %s\nwant: %s\n", spew.Sdump(phs), spew.Sdump(hs1)) + return + } + rlpx.close(DiscQuitting) + }() + go func() { + defer wg.Done() + rlpx := newRLPX(fd1) + remid, err := rlpx.doEncHandshake(prv1, nil) + if err != nil { + t.Errorf("listen side enc handshake failed: %v", err) + return + } + if remid != node0.ID { + t.Errorf("listen side remote id mismatch: got %v, want %v", remid, node0.ID) + return + } + + phs, err := rlpx.doProtoHandshake(hs1) + if err != nil { + t.Errorf("listen side proto handshake error: %v", err) + return + } + if !reflect.DeepEqual(phs, hs0) { + t.Errorf("listen side proto handshake mismatch:\ngot: %s\nwant: %s\n", spew.Sdump(phs), spew.Sdump(hs0)) + return + } + + if err := ExpectMsg(rlpx, discMsg, []DiscReason{DiscQuitting}); err != nil { + t.Errorf("error receiving disconnect: %v", err) + } + }() + wg.Wait() +} + +func TestProtocolHandshakeErrors(t *testing.T) { + our := &protoHandshake{Version: 3, Caps: []Cap{{"foo", 2}, {"bar", 3}}, Name: "quux"} + id := randomID() + tests := []struct { + code uint64 + msg interface{} + err error + }{ + { + code: discMsg, + msg: []DiscReason{DiscQuitting}, + err: DiscQuitting, + }, + { + code: 0x989898, + msg: []byte{1}, + err: errors.New("expected handshake, got 989898"), + }, + { + code: handshakeMsg, + msg: make([]byte, baseProtocolMaxMsgSize+2), + err: errors.New("message too big"), + }, + { + code: handshakeMsg, + msg: []byte{1, 2, 3}, + err: newPeerError(errInvalidMsg, "(code 0) (size 4) rlp: expected input list for p2p.protoHandshake"), + }, + { + code: handshakeMsg, + msg: &protoHandshake{Version: 9944, ID: id}, + err: DiscIncompatibleVersion, + }, + { + code: handshakeMsg, + msg: &protoHandshake{Version: 3}, + err: DiscInvalidIdentity, + }, + } + + for i, test := range tests { + p1, p2 := MsgPipe() + go Send(p1, test.code, test.msg) + _, err := readProtocolHandshake(p2, our) + if !reflect.DeepEqual(err, test.err) { + t.Errorf("test %d: error mismatch: got %q, want %q", i, err, test.err) + } + } +} + +func TestRLPXFrameFake(t *testing.T) { buf := new(bytes.Buffer) hash := fakeHash([]byte{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}) - rw := newRlpxFrameRW(buf, secrets{ + rw := newRLPXFrameRW(buf, secrets{ AES: crypto.Sha3(), MAC: crypto.Sha3(), IngressMAC: hash, @@ -66,7 +300,7 @@ func (fakeHash) BlockSize() int { return 0 } func (h fakeHash) Size() int { return len(h) } func (h fakeHash) Sum(b []byte) []byte { return append(b, h...) } -func TestRlpxFrameRW(t *testing.T) { +func TestRLPXFrameRW(t *testing.T) { var ( aesSecret = make([]byte, 16) macSecret = make([]byte, 16) @@ -86,7 +320,7 @@ func TestRlpxFrameRW(t *testing.T) { } s1.EgressMAC.Write(egressMACinit) s1.IngressMAC.Write(ingressMACinit) - rw1 := newRlpxFrameRW(conn, s1) + rw1 := newRLPXFrameRW(conn, s1) s2 := secrets{ AES: aesSecret, @@ -96,7 +330,7 @@ func TestRlpxFrameRW(t *testing.T) { } s2.EgressMAC.Write(ingressMACinit) s2.IngressMAC.Write(egressMACinit) - rw2 := newRlpxFrameRW(conn, s2) + rw2 := newRLPXFrameRW(conn, s2) // send some messages for i := 0; i < 10; i++ { diff --git a/p2p/server.go b/p2p/server.go index 529fedbca4c2f508da9e2492c74ca1ebc93e783a..27e617610c03de0618d2917716cef9301a8ea181 100644 --- a/p2p/server.go +++ b/p2p/server.go @@ -2,7 +2,6 @@ package p2p import ( "crypto/ecdsa" - "crypto/rand" "errors" "fmt" "net" @@ -24,11 +23,8 @@ const ( maxAcceptConns = 50 // Maximum number of concurrently dialing outbound connections. - maxDialingConns = 10 + maxActiveDialTasks = 16 - // total timeout for encryption handshake and protocol - // handshake in both directions. - handshakeTimeout = 5 * time.Second // maximum time allowed for reading a complete message. // this is effectively the amount of time a connection can be idle. frameReadTimeout = 1 * time.Minute @@ -36,6 +32,8 @@ const ( frameWriteTimeout = 5 * time.Second ) +var errServerStopped = errors.New("server stopped") + var srvjslog = logger.NewJsonLogger() // Server manages all peer connections. @@ -103,68 +101,173 @@ type Server struct { // Hooks for testing. These are useful because we can inhibit // the whole protocol stack. - setupFunc - newPeerHook + newTransport func(net.Conn) transport + newPeerHook func(*Peer) + + lock sync.Mutex // protects running + running bool + ntab discoverTable + listener net.Listener ourHandshake *protoHandshake - lock sync.RWMutex // protects running, peers and the trust fields - running bool - peers map[discover.NodeID]*Peer - staticNodes map[discover.NodeID]*discover.Node // Map of currently maintained static remote nodes - staticDial chan *discover.Node // Dial request channel reserved for the static nodes - staticCycle time.Duration // Overrides staticPeerCheckInterval, used for testing - trustedNodes map[discover.NodeID]bool // Set of currently trusted remote nodes + // These are for Peers, PeerCount (and nothing else). + peerOp chan peerOpFunc + peerOpDone chan struct{} + + quit chan struct{} + addstatic chan *discover.Node + posthandshake chan *conn + addpeer chan *conn + delpeer chan *Peer + loopWG sync.WaitGroup // loop, listenLoop +} + +type peerOpFunc func(map[discover.NodeID]*Peer) + +type connFlag int - ntab *discover.Table - listener net.Listener +const ( + dynDialedConn connFlag = 1 << iota + staticDialedConn + inboundConn + trustedConn +) + +// conn wraps a network connection with information gathered +// during the two handshakes. +type conn struct { + fd net.Conn + transport + flags connFlag + cont chan error // The run loop uses cont to signal errors to setupConn. + id discover.NodeID // valid after the encryption handshake + caps []Cap // valid after the protocol handshake + name string // valid after the protocol handshake +} - quit chan struct{} - loopWG sync.WaitGroup // {dial,listen,nat}Loop - peerWG sync.WaitGroup // active peer goroutines +type transport interface { + // The two handshakes. + doEncHandshake(prv *ecdsa.PrivateKey, dialDest *discover.Node) (discover.NodeID, error) + doProtoHandshake(our *protoHandshake) (*protoHandshake, error) + // The MsgReadWriter can only be used after the encryption + // handshake has completed. The code uses conn.id to track this + // by setting it to a non-nil value after the encryption handshake. + MsgReadWriter + // transports must provide Close because we use MsgPipe in some of + // the tests. Closing the actual network connection doesn't do + // anything in those tests because NsgPipe doesn't use it. + close(err error) } -type setupFunc func(net.Conn, *ecdsa.PrivateKey, *protoHandshake, *discover.Node, func(discover.NodeID) bool) (*conn, error) -type newPeerHook func(*Peer) +func (c *conn) String() string { + s := c.flags.String() + " conn" + if (c.id != discover.NodeID{}) { + s += fmt.Sprintf(" %x", c.id[:8]) + } + s += " " + c.fd.RemoteAddr().String() + return s +} + +func (f connFlag) String() string { + s := "" + if f&trustedConn != 0 { + s += " trusted" + } + if f&dynDialedConn != 0 { + s += " dyn dial" + } + if f&staticDialedConn != 0 { + s += " static dial" + } + if f&inboundConn != 0 { + s += " inbound" + } + if s != "" { + s = s[1:] + } + return s +} + +func (c *conn) is(f connFlag) bool { + return c.flags&f != 0 +} // Peers returns all connected peers. -func (srv *Server) Peers() (peers []*Peer) { - srv.lock.RLock() - defer srv.lock.RUnlock() - for _, peer := range srv.peers { - if peer != nil { - peers = append(peers, peer) +func (srv *Server) Peers() []*Peer { + var ps []*Peer + select { + // Note: We'd love to put this function into a variable but + // that seems to cause a weird compiler error in some + // environments. + case srv.peerOp <- func(peers map[discover.NodeID]*Peer) { + for _, p := range peers { + ps = append(ps, p) } + }: + <-srv.peerOpDone + case <-srv.quit: } - return + return ps } // PeerCount returns the number of connected peers. func (srv *Server) PeerCount() int { - srv.lock.RLock() - n := len(srv.peers) - srv.lock.RUnlock() - return n + var count int + select { + case srv.peerOp <- func(ps map[discover.NodeID]*Peer) { count = len(ps) }: + <-srv.peerOpDone + case <-srv.quit: + } + return count } // AddPeer connects to the given node and maintains the connection until the // server is shut down. If the connection fails for any reason, the server will // attempt to reconnect the peer. func (srv *Server) AddPeer(node *discover.Node) { + select { + case srv.addstatic <- node: + case <-srv.quit: + } +} + +// Self returns the local node's endpoint information. +func (srv *Server) Self() *discover.Node { srv.lock.Lock() defer srv.lock.Unlock() + if !srv.running { + return &discover.Node{IP: net.ParseIP("0.0.0.0")} + } + return srv.ntab.Self() +} - srv.staticNodes[node.ID] = node +// Stop terminates the server and all active peer connections. +// It blocks until all active connections have been closed. +func (srv *Server) Stop() { + srv.lock.Lock() + defer srv.lock.Unlock() + if !srv.running { + return + } + srv.running = false + if srv.listener != nil { + // this unblocks listener Accept + srv.listener.Close() + } + close(srv.quit) + srv.loopWG.Wait() } // Start starts running the server. -// Servers can be re-used and started again after stopping. +// Servers can not be re-used after stopping. func (srv *Server) Start() (err error) { srv.lock.Lock() defer srv.lock.Unlock() if srv.running { return errors.New("server already running") } + srv.running = true glog.V(logger.Info).Infoln("Starting Server") // static fields @@ -174,23 +277,19 @@ func (srv *Server) Start() (err error) { if srv.MaxPeers <= 0 { return fmt.Errorf("Server.MaxPeers must be > 0") } - srv.quit = make(chan struct{}) - srv.peers = make(map[discover.NodeID]*Peer) - - // Create the current trust maps, and the associated dialing channel - srv.trustedNodes = make(map[discover.NodeID]bool) - for _, node := range srv.TrustedNodes { - srv.trustedNodes[node.ID] = true - } - srv.staticNodes = make(map[discover.NodeID]*discover.Node) - for _, node := range srv.StaticNodes { - srv.staticNodes[node.ID] = node + if srv.newTransport == nil { + srv.newTransport = newRLPX } - srv.staticDial = make(chan *discover.Node) - - if srv.setupFunc == nil { - srv.setupFunc = setupConn + if srv.Dialer == nil { + srv.Dialer = &net.Dialer{Timeout: defaultDialTimeout} } + srv.quit = make(chan struct{}) + srv.addpeer = make(chan *conn) + srv.delpeer = make(chan *Peer) + srv.posthandshake = make(chan *conn) + srv.addstatic = make(chan *discover.Node) + srv.peerOp = make(chan peerOpFunc) + srv.peerOpDone = make(chan struct{}) // node table ntab, err := discover.ListenUDP(srv.PrivateKey, srv.ListenAddr, srv.NAT, srv.NodeDatabase) @@ -198,37 +297,31 @@ func (srv *Server) Start() (err error) { return err } srv.ntab = ntab + dialer := newDialState(srv.StaticNodes, srv.ntab, srv.MaxPeers/2) // handshake srv.ourHandshake = &protoHandshake{Version: baseProtocolVersion, Name: srv.Name, ID: ntab.Self().ID} for _, p := range srv.Protocols { srv.ourHandshake.Caps = append(srv.ourHandshake.Caps, p.cap()) } - // listen/dial if srv.ListenAddr != "" { if err := srv.startListening(); err != nil { return err } } - if srv.Dialer == nil { - srv.Dialer = &net.Dialer{Timeout: defaultDialTimeout} - } - if !srv.NoDial { - srv.loopWG.Add(1) - go srv.dialLoop() - } if srv.NoDial && srv.ListenAddr == "" { glog.V(logger.Warn).Infoln("I will be kind-of useless, neither dialing nor listening.") } - // maintain the static peers - go srv.staticNodesLoop() + srv.loopWG.Add(1) + go srv.run(dialer) srv.running = true return nil } func (srv *Server) startListening() error { + // Launch the TCP listener. listener, err := net.Listen("tcp", srv.ListenAddr) if err != nil { return err @@ -238,6 +331,7 @@ func (srv *Server) startListening() error { srv.listener = listener srv.loopWG.Add(1) go srv.listenLoop() + // Map the TCP listening port if NAT is configured. if !laddr.IP.IsLoopback() && srv.NAT != nil { srv.loopWG.Add(1) go func() { @@ -248,50 +342,164 @@ func (srv *Server) startListening() error { return nil } -// Stop terminates the server and all active peer connections. -// It blocks until all active connections have been closed. -func (srv *Server) Stop() { - srv.lock.Lock() - if !srv.running { - srv.lock.Unlock() - return +type dialer interface { + newTasks(running int, peers map[discover.NodeID]*Peer, now time.Time) []task + taskDone(task, time.Time) + addStatic(*discover.Node) +} + +func (srv *Server) run(dialstate dialer) { + defer srv.loopWG.Done() + var ( + peers = make(map[discover.NodeID]*Peer) + trusted = make(map[discover.NodeID]bool, len(srv.TrustedNodes)) + + tasks []task + pendingTasks []task + taskdone = make(chan task, maxActiveDialTasks) + ) + // Put trusted nodes into a map to speed up checks. + // Trusted peers are loaded on startup and cannot be + // modified while the server is running. + for _, n := range srv.TrustedNodes { + trusted[n.ID] = true + } + + // Some task list helpers. + delTask := func(t task) { + for i := range tasks { + if tasks[i] == t { + tasks = append(tasks[:i], tasks[i+1:]...) + break + } + } } - srv.running = false - srv.lock.Unlock() + scheduleTasks := func(new []task) { + pt := append(pendingTasks, new...) + start := maxActiveDialTasks - len(tasks) + if len(pt) < start { + start = len(pt) + } + if start > 0 { + tasks = append(tasks, pt[:start]...) + for _, t := range pt[:start] { + t := t + glog.V(logger.Detail).Infoln("new task:", t) + go func() { t.Do(srv); taskdone <- t }() + } + copy(pt, pt[start:]) + pendingTasks = pt[:len(pt)-start] + } + } + +running: + for { + // Query the dialer for new tasks and launch them. + now := time.Now() + nt := dialstate.newTasks(len(pendingTasks)+len(tasks), peers, now) + scheduleTasks(nt) - glog.V(logger.Info).Infoln("Stopping Server") + select { + case <-srv.quit: + // The server was stopped. Run the cleanup logic. + glog.V(logger.Detail).Infoln("<-quit: spinning down") + break running + case n := <-srv.addstatic: + // This channel is used by AddPeer to add to the + // ephemeral static peer list. Add it to the dialer, + // it will keep the node connected. + glog.V(logger.Detail).Infoln("<-addstatic:", n) + dialstate.addStatic(n) + case op := <-srv.peerOp: + // This channel is used by Peers and PeerCount. + op(peers) + srv.peerOpDone <- struct{}{} + case t := <-taskdone: + // A task got done. Tell dialstate about it so it + // can update its state and remove it from the active + // tasks list. + glog.V(logger.Detail).Infoln("<-taskdone:", t) + dialstate.taskDone(t, now) + delTask(t) + case c := <-srv.posthandshake: + // A connection has passed the encryption handshake so + // the remote identity is known (but hasn't been verified yet). + if trusted[c.id] { + // Ensure that the trusted flag is set before checking against MaxPeers. + c.flags |= trustedConn + } + glog.V(logger.Detail).Infoln("<-posthandshake:", c) + // TODO: track in-progress inbound node IDs (pre-Peer) to avoid dialing them. + c.cont <- srv.encHandshakeChecks(peers, c) + case c := <-srv.addpeer: + // At this point the connection is past the protocol handshake. + // Its capabilities are known and the remote identity is verified. + glog.V(logger.Detail).Infoln("<-addpeer:", c) + err := srv.protoHandshakeChecks(peers, c) + if err != nil { + glog.V(logger.Detail).Infof("Not adding %v as peer: %v", c, err) + } else { + // The handshakes are done and it passed all checks. + p := newPeer(c, srv.Protocols) + peers[c.id] = p + go srv.runPeer(p) + } + // The dialer logic relies on the assumption that + // dial tasks complete after the peer has been added or + // discarded. Unblock the task last. + c.cont <- err + case p := <-srv.delpeer: + // A peer disconnected. + glog.V(logger.Detail).Infoln("<-delpeer:", p) + delete(peers, p.ID()) + } + } + + // Terminate discovery. If there is a running lookup it will terminate soon. srv.ntab.Close() - if srv.listener != nil { - // this unblocks listener Accept - srv.listener.Close() + // Disconnect all peers. + for _, p := range peers { + p.Disconnect(DiscQuitting) + } + // Wait for peers to shut down. Pending connections and tasks are + // not handled here and will terminate soon-ish because srv.quit + // is closed. + glog.V(logger.Detail).Infof("ignoring %d pending tasks at spindown", len(tasks)) + for len(peers) > 0 { + p := <-srv.delpeer + glog.V(logger.Detail).Infoln("<-delpeer (spindown):", p) + delete(peers, p.ID()) } - close(srv.quit) - srv.loopWG.Wait() +} - // No new peers can be added at this point because dialLoop and - // listenLoop are down. It is safe to call peerWG.Wait because - // peerWG.Add is not called outside of those loops. - srv.lock.Lock() - for _, peer := range srv.peers { - peer.Disconnect(DiscQuitting) +func (srv *Server) protoHandshakeChecks(peers map[discover.NodeID]*Peer, c *conn) error { + // Drop connections with no matching protocols. + if len(srv.Protocols) > 0 && countMatchingProtocols(srv.Protocols, c.caps) == 0 { + return DiscUselessPeer } - srv.lock.Unlock() - srv.peerWG.Wait() + // Repeat the encryption handshake checks because the + // peer set might have changed between the handshakes. + return srv.encHandshakeChecks(peers, c) } -// Self returns the local node's endpoint information. -func (srv *Server) Self() *discover.Node { - srv.lock.RLock() - defer srv.lock.RUnlock() - if !srv.running { - return &discover.Node{IP: net.ParseIP("0.0.0.0")} +func (srv *Server) encHandshakeChecks(peers map[discover.NodeID]*Peer, c *conn) error { + switch { + case !c.is(trustedConn|staticDialedConn) && len(peers) >= srv.MaxPeers: + return DiscTooManyPeers + case peers[c.id] != nil: + return DiscAlreadyConnected + case c.id == srv.ntab.Self().ID: + return DiscSelf + default: + return nil } - return srv.ntab.Self() } -// main loop for adding connections via listening +// listenLoop runs in its own goroutine and accepts +// inbound connections. func (srv *Server) listenLoop() { defer srv.loopWG.Done() + glog.V(logger.Info).Infoln("Listening on", srv.listener.Addr()) // This channel acts as a semaphore limiting // active inbound connections that are lingering pre-handshake. @@ -305,204 +513,92 @@ func (srv *Server) listenLoop() { slots <- struct{}{} } - glog.V(logger.Info).Infoln("Listening on", srv.listener.Addr()) for { <-slots - conn, err := srv.listener.Accept() + fd, err := srv.listener.Accept() if err != nil { return } - glog.V(logger.Debug).Infof("Accepted conn %v\n", conn.RemoteAddr()) - srv.peerWG.Add(1) + glog.V(logger.Debug).Infof("Accepted conn %v\n", fd.RemoteAddr()) go func() { - srv.startPeer(conn, nil) + srv.setupConn(fd, inboundConn, nil) slots <- struct{}{} }() } } -// staticNodesLoop is responsible for periodically checking that static -// connections are actually live, and requests dialing if not. -func (srv *Server) staticNodesLoop() { - // Create a default maintenance ticker, but override it requested - cycle := staticPeerCheckInterval - if srv.staticCycle != 0 { - cycle = srv.staticCycle - } - tick := time.NewTicker(cycle) - - for { - select { - case <-srv.quit: - return - - case <-tick.C: - // Collect all the non-connected static nodes - needed := []*discover.Node{} - srv.lock.RLock() - for id, node := range srv.staticNodes { - if _, ok := srv.peers[id]; !ok { - needed = append(needed, node) - } - } - srv.lock.RUnlock() - - // Try to dial each of them (don't hang if server terminates) - for _, node := range needed { - glog.V(logger.Debug).Infof("Dialing static peer %v", node) - select { - case srv.staticDial <- node: - case <-srv.quit: - return - } - } - } - } -} - -func (srv *Server) dialLoop() { - var ( - dialed = make(chan *discover.Node) - dialing = make(map[discover.NodeID]bool) - findresults = make(chan []*discover.Node) - refresh = time.NewTimer(0) - ) - defer srv.loopWG.Done() - defer refresh.Stop() - - // Limit the number of concurrent dials - tokens := maxDialingConns - if srv.MaxPendingPeers > 0 { - tokens = srv.MaxPendingPeers - } - slots := make(chan struct{}, tokens) - for i := 0; i < tokens; i++ { - slots <- struct{}{} +// setupConn runs the handshakes and attempts to add the connection +// as a peer. It returns when the connection has been added as a peer +// or the handshakes have failed. +func (srv *Server) setupConn(fd net.Conn, flags connFlag, dialDest *discover.Node) { + // Prevent leftover pending conns from entering the handshake. + srv.lock.Lock() + running := srv.running + srv.lock.Unlock() + c := &conn{fd: fd, transport: srv.newTransport(fd), flags: flags, cont: make(chan error)} + if !running { + c.close(errServerStopped) + return } - dial := func(dest *discover.Node) { - // Don't dial nodes that would fail the checks in addPeer. - // This is important because the connection handshake is a lot - // of work and we'd rather avoid doing that work for peers - // that can't be added. - srv.lock.RLock() - ok, _ := srv.checkPeer(dest.ID) - srv.lock.RUnlock() - if !ok || dialing[dest.ID] { - return - } - // Request a dial slot to prevent CPU exhaustion - <-slots - - dialing[dest.ID] = true - srv.peerWG.Add(1) - go func() { - srv.dialNode(dest) - slots <- struct{}{} - dialed <- dest - }() + // Run the encryption handshake. + var err error + if c.id, err = c.doEncHandshake(srv.PrivateKey, dialDest); err != nil { + glog.V(logger.Debug).Infof("%v faild enc handshake: %v", c, err) + c.close(err) + return } - - srv.ntab.Bootstrap(srv.BootstrapNodes) - for { - select { - case <-refresh.C: - // Grab some nodes to connect to if we're not at capacity. - srv.lock.RLock() - needpeers := len(srv.peers) < srv.MaxPeers/2 - srv.lock.RUnlock() - if needpeers { - go func() { - var target discover.NodeID - rand.Read(target[:]) - findresults <- srv.ntab.Lookup(target) - }() - } else { - // Make sure we check again if the peer count falls - // below MaxPeers. - refresh.Reset(refreshPeersInterval) - } - case dest := <-srv.staticDial: - dial(dest) - case dests := <-findresults: - for _, dest := range dests { - dial(dest) - } - refresh.Reset(refreshPeersInterval) - case dest := <-dialed: - delete(dialing, dest.ID) - if len(dialing) == 0 { - // Check again immediately after dialing all current candidates. - refresh.Reset(0) - } - case <-srv.quit: - // TODO: maybe wait for active dials - return - } + // For dialed connections, check that the remote public key matches. + if dialDest != nil && c.id != dialDest.ID { + c.close(DiscUnexpectedIdentity) + glog.V(logger.Debug).Infof("%v dialed identity mismatch, want %x", c, dialDest.ID[:8]) + return } -} - -func (srv *Server) dialNode(dest *discover.Node) { - addr := &net.TCPAddr{IP: dest.IP, Port: int(dest.TCP)} - glog.V(logger.Debug).Infof("Dialing %v\n", dest) - conn, err := srv.Dialer.Dial("tcp", addr.String()) - if err != nil { - // dialLoop adds to the wait group counter when launching - // dialNode, so we need to count it down again. startPeer also - // does that when an error occurs. - srv.peerWG.Done() - glog.V(logger.Detail).Infof("dial error: %v", err) + if err := srv.checkpoint(c, srv.posthandshake); err != nil { + glog.V(logger.Debug).Infof("%v failed checkpoint posthandshake: %v", c, err) + c.close(err) return } - srv.startPeer(conn, dest) -} - -func (srv *Server) startPeer(fd net.Conn, dest *discover.Node) { - // TODO: handle/store session token - - // Run setupFunc, which should create an authenticated connection - // and run the capability exchange. Note that any early error - // returns during that exchange need to call peerWG.Done because - // the callers of startPeer added the peer to the wait group already. - fd.SetDeadline(time.Now().Add(handshakeTimeout)) - - conn, err := srv.setupFunc(fd, srv.PrivateKey, srv.ourHandshake, dest, srv.keepconn) + // Run the protocol handshake + phs, err := c.doProtoHandshake(srv.ourHandshake) if err != nil { - fd.Close() - glog.V(logger.Debug).Infof("Handshake with %v failed: %v", fd.RemoteAddr(), err) - srv.peerWG.Done() + glog.V(logger.Debug).Infof("%v failed proto handshake: %v", c, err) + c.close(err) return } - conn.MsgReadWriter = &netWrapper{ - wrapped: conn.MsgReadWriter, - conn: fd, rtimeout: frameReadTimeout, wtimeout: frameWriteTimeout, + if phs.ID != c.id { + glog.V(logger.Debug).Infof("%v wrong proto handshake identity: %x", c, phs.ID[:8]) + c.close(DiscUnexpectedIdentity) + return } - p := newPeer(fd, conn, srv.Protocols) - if ok, reason := srv.addPeer(conn, p); !ok { - glog.V(logger.Detail).Infof("Not adding %v (%v)\n", p, reason) - p.politeDisconnect(reason) - srv.peerWG.Done() + c.caps, c.name = phs.Caps, phs.Name + if err := srv.checkpoint(c, srv.addpeer); err != nil { + glog.V(logger.Debug).Infof("%v failed checkpoint addpeer: %v", c, err) + c.close(err) return } - // The handshakes are done and it passed all checks. - // Spawn the Peer loops. - go srv.runPeer(p) + // If the checks completed successfully, runPeer has now been + // launched by run. } -// preflight checks whether a connection should be kept. it runs -// after the encryption handshake, as soon as the remote identity is -// known. -func (srv *Server) keepconn(id discover.NodeID) bool { - srv.lock.RLock() - defer srv.lock.RUnlock() - if _, ok := srv.staticNodes[id]; ok { - return true // static nodes are always allowed +// checkpoint sends the conn to run, which performs the +// post-handshake checks for the stage (posthandshake, addpeer). +func (srv *Server) checkpoint(c *conn, stage chan<- *conn) error { + select { + case stage <- c: + case <-srv.quit: + return errServerStopped } - if _, ok := srv.trustedNodes[id]; ok { - return true // trusted nodes are always allowed + select { + case err := <-c.cont: + return err + case <-srv.quit: + return errServerStopped } - return len(srv.peers) < srv.MaxPeers } +// runPeer runs in its own goroutine for each peer. +// it waits until the Peer logic returns and removes +// the peer. func (srv *Server) runPeer(p *Peer) { glog.V(logger.Debug).Infof("Added %v\n", p) srvjslog.LogJson(&logger.P2PConnected{ @@ -511,58 +607,18 @@ func (srv *Server) runPeer(p *Peer) { RemoteVersionString: p.Name(), NumConnections: srv.PeerCount(), }) + if srv.newPeerHook != nil { srv.newPeerHook(p) } discreason := p.run() - srv.removePeer(p) + // Note: run waits for existing peers to be sent on srv.delpeer + // before returning, so this send should not select on srv.quit. + srv.delpeer <- p + glog.V(logger.Debug).Infof("Removed %v (%v)\n", p, discreason) srvjslog.LogJson(&logger.P2PDisconnected{ RemoteId: p.ID().String(), NumConnections: srv.PeerCount(), }) } - -func (srv *Server) addPeer(conn *conn, p *Peer) (bool, DiscReason) { - // drop connections with no matching protocols. - if len(srv.Protocols) > 0 && countMatchingProtocols(srv.Protocols, conn.protoHandshake.Caps) == 0 { - return false, DiscUselessPeer - } - // add the peer if it passes the other checks. - srv.lock.Lock() - defer srv.lock.Unlock() - if ok, reason := srv.checkPeer(conn.ID); !ok { - return false, reason - } - srv.peers[conn.ID] = p - return true, 0 -} - -// checkPeer verifies whether a peer looks promising and should be allowed/kept -// in the pool, or if it's of no use. -func (srv *Server) checkPeer(id discover.NodeID) (bool, DiscReason) { - // First up, figure out if the peer is static or trusted - _, static := srv.staticNodes[id] - trusted := srv.trustedNodes[id] - - // Make sure the peer passes all required checks - switch { - case !srv.running: - return false, DiscQuitting - case !static && !trusted && len(srv.peers) >= srv.MaxPeers: - return false, DiscTooManyPeers - case srv.peers[id] != nil: - return false, DiscAlreadyConnected - case id == srv.ntab.Self().ID: - return false, DiscSelf - default: - return true, 0 - } -} - -func (srv *Server) removePeer(p *Peer) { - srv.lock.Lock() - delete(srv.peers, p.ID()) - srv.lock.Unlock() - srv.peerWG.Done() -} diff --git a/p2p/server_test.go b/p2p/server_test.go index 6f7aaf8e1fa31e2cffae8cc5b9d945dfd11f2757..01448cc7bb5559ad6efa7d93550d011a65cdca27 100644 --- a/p2p/server_test.go +++ b/p2p/server_test.go @@ -2,8 +2,10 @@ package p2p import ( "crypto/ecdsa" + "errors" "math/rand" "net" + "reflect" "testing" "time" @@ -12,29 +14,50 @@ import ( "github.com/ethereum/go-ethereum/p2p/discover" ) -func startTestServer(t *testing.T, pf newPeerHook) *Server { +func init() { + // glog.SetV(6) + // glog.SetToStderr(true) +} + +type testTransport struct { + id discover.NodeID + *rlpx + + closeErr error +} + +func newTestTransport(id discover.NodeID, fd net.Conn) transport { + wrapped := newRLPX(fd).(*rlpx) + wrapped.rw = newRLPXFrameRW(fd, secrets{ + MAC: zero16, + AES: zero16, + IngressMAC: sha3.NewKeccak256(), + EgressMAC: sha3.NewKeccak256(), + }) + return &testTransport{id: id, rlpx: wrapped} +} + +func (c *testTransport) doEncHandshake(prv *ecdsa.PrivateKey, dialDest *discover.Node) (discover.NodeID, error) { + return c.id, nil +} + +func (c *testTransport) doProtoHandshake(our *protoHandshake) (*protoHandshake, error) { + return &protoHandshake{ID: c.id, Name: "test"}, nil +} + +func (c *testTransport) close(err error) { + c.rlpx.fd.Close() + c.closeErr = err +} + +func startTestServer(t *testing.T, id discover.NodeID, pf func(*Peer)) *Server { server := &Server{ - Name: "test", - MaxPeers: 10, - ListenAddr: "127.0.0.1:0", - PrivateKey: newkey(), - newPeerHook: pf, - setupFunc: func(fd net.Conn, prv *ecdsa.PrivateKey, our *protoHandshake, dial *discover.Node, keepconn func(discover.NodeID) bool) (*conn, error) { - id := randomID() - if !keepconn(id) { - return nil, DiscAlreadyConnected - } - rw := newRlpxFrameRW(fd, secrets{ - MAC: zero16, - AES: zero16, - IngressMAC: sha3.NewKeccak256(), - EgressMAC: sha3.NewKeccak256(), - }) - return &conn{ - MsgReadWriter: rw, - protoHandshake: &protoHandshake{ID: id, Version: baseProtocolVersion}, - }, nil - }, + Name: "test", + MaxPeers: 10, + ListenAddr: "127.0.0.1:0", + PrivateKey: newkey(), + newPeerHook: pf, + newTransport: func(fd net.Conn) transport { return newTestTransport(id, fd) }, } if err := server.Start(); err != nil { t.Fatalf("Could not start server: %v", err) @@ -45,7 +68,11 @@ func startTestServer(t *testing.T, pf newPeerHook) *Server { func TestServerListen(t *testing.T) { // start the test server connected := make(chan *Peer) - srv := startTestServer(t, func(p *Peer) { + remid := randomID() + srv := startTestServer(t, remid, func(p *Peer) { + if p.ID() != remid { + t.Error("peer func called with wrong node id") + } if p == nil { t.Error("peer func called with nil conn") } @@ -67,6 +94,10 @@ func TestServerListen(t *testing.T) { t.Errorf("peer started with wrong conn: got %v, want %v", peer.LocalAddr(), conn.RemoteAddr()) } + peers := srv.Peers() + if !reflect.DeepEqual(peers, []*Peer{peer}) { + t.Errorf("Peers mismatch: got %v, want %v", peers, []*Peer{peer}) + } case <-time.After(1 * time.Second): t.Error("server did not accept within one second") } @@ -92,23 +123,33 @@ func TestServerDial(t *testing.T) { // start the server connected := make(chan *Peer) - srv := startTestServer(t, func(p *Peer) { connected <- p }) + remid := randomID() + srv := startTestServer(t, remid, func(p *Peer) { connected <- p }) defer close(connected) defer srv.Stop() // tell the server to connect tcpAddr := listener.Addr().(*net.TCPAddr) - srv.staticDial <- &discover.Node{IP: tcpAddr.IP, TCP: uint16(tcpAddr.Port)} + srv.AddPeer(&discover.Node{ID: remid, IP: tcpAddr.IP, TCP: uint16(tcpAddr.Port)}) select { case conn := <-accepted: select { case peer := <-connected: + if peer.ID() != remid { + t.Errorf("peer has wrong id") + } + if peer.Name() != "test" { + t.Errorf("peer has wrong name") + } if peer.RemoteAddr().String() != conn.LocalAddr().String() { t.Errorf("peer started with wrong conn: got %v, want %v", peer.RemoteAddr(), conn.LocalAddr()) } - // TODO: validate more fields + peers := srv.Peers() + if !reflect.DeepEqual(peers, []*Peer{peer}) { + t.Errorf("Peers mismatch: got %v, want %v", peers, []*Peer{peer}) + } case <-time.After(1 * time.Second): t.Error("server did not launch peer within one second") } @@ -118,331 +159,250 @@ func TestServerDial(t *testing.T) { } } -// This test checks that connections are disconnected -// just after the encryption handshake when the server is -// at capacity. -// -// It also serves as a light-weight integration test. -func TestServerDisconnectAtCap(t *testing.T) { - started := make(chan *Peer) - srv := &Server{ - ListenAddr: "127.0.0.1:0", - PrivateKey: newkey(), - MaxPeers: 10, - NoDial: true, - // This hook signals that the peer was actually started. We - // need to wait for the peer to be started before dialing the - // next connection to get a deterministic peer count. - newPeerHook: func(p *Peer) { started <- p }, - } - if err := srv.Start(); err != nil { - t.Fatal(err) - } - defer srv.Stop() - - nconns := srv.MaxPeers + 1 - dialer := &net.Dialer{Deadline: time.Now().Add(3 * time.Second)} - for i := 0; i < nconns; i++ { - conn, err := dialer.Dial("tcp", srv.ListenAddr) - if err != nil { - t.Fatalf("conn %d: dial error: %v", i, err) - } - // Close the connection when the test ends, before - // shutting down the server. - defer conn.Close() - // Run the handshakes just like a real peer would. - key := newkey() - hs := &protoHandshake{Version: baseProtocolVersion, ID: discover.PubkeyID(&key.PublicKey)} - _, err = setupConn(conn, key, hs, srv.Self(), keepalways) - if i == nconns-1 { - // When handling the last connection, the server should - // disconnect immediately instead of running the protocol - // handshake. - if err != DiscTooManyPeers { - t.Errorf("conn %d: got error %q, expected %q", i, err, DiscTooManyPeers) - } - } else { - // For all earlier connections, the handshake should go through. - if err != nil { - t.Fatalf("conn %d: unexpected error: %v", i, err) - } - // Wait for runPeer to be started. - <-started +// This test checks that tasks generated by dialstate are +// actually executed and taskdone is called for them. +func TestServerTaskScheduling(t *testing.T) { + var ( + done = make(chan *testTask) + quit, returned = make(chan struct{}), make(chan struct{}) + tc = 0 + tg = taskgen{ + newFunc: func(running int, peers map[discover.NodeID]*Peer) []task { + tc++ + return []task{&testTask{index: tc - 1}} + }, + doneFunc: func(t task) { + select { + case done <- t.(*testTask): + case <-quit: + } + }, } - } -} + ) -// Tests that static peers are (re)connected, and done so even above max peers. -func TestServerStaticPeers(t *testing.T) { - // Create a test server with limited connection slots - started := make(chan *Peer) - server := &Server{ - ListenAddr: "127.0.0.1:0", - PrivateKey: newkey(), - MaxPeers: 3, - newPeerHook: func(p *Peer) { started <- p }, - staticCycle: time.Second, - } - if err := server.Start(); err != nil { - t.Fatal(err) + // The Server in this test isn't actually running + // because we're only interested in what run does. + srv := &Server{ + MaxPeers: 10, + quit: make(chan struct{}), + ntab: fakeTable{}, + running: true, } - defer server.Stop() - - // Fill up all the slots on the server - dialer := &net.Dialer{Deadline: time.Now().Add(3 * time.Second)} - for i := 0; i < server.MaxPeers; i++ { - // Establish a new connection - conn, err := dialer.Dial("tcp", server.ListenAddr) - if err != nil { - t.Fatalf("conn %d: dial error: %v", i, err) - } - defer conn.Close() + srv.loopWG.Add(1) + go func() { + srv.run(tg) + close(returned) + }() - // Run the handshakes just like a real peer would, and wait for completion - key := newkey() - shake := &protoHandshake{Version: baseProtocolVersion, ID: discover.PubkeyID(&key.PublicKey)} - if _, err = setupConn(conn, key, shake, server.Self(), keepalways); err != nil { - t.Fatalf("conn %d: unexpected error: %v", i, err) - } - <-started + var gotdone []*testTask + for i := 0; i < 100; i++ { + gotdone = append(gotdone, <-done) } - // Open a TCP listener to accept static connections - listener, err := net.Listen("tcp", "127.0.0.1:0") - if err != nil { - t.Fatalf("failed to setup listener: %v", err) - } - defer listener.Close() - - connected := make(chan net.Conn) - go func() { - for i := 0; i < 3; i++ { - conn, err := listener.Accept() - if err == nil { - connected <- conn - } + for i, task := range gotdone { + if task.index != i { + t.Errorf("task %d has wrong index, got %d", i, task.index) + break + } + if !task.called { + t.Errorf("task %d was not called", i) + break } - }() - // Inject a static node and wait for a remote dial, then redial, then nothing - addr := listener.Addr().(*net.TCPAddr) - static := &discover.Node{ - ID: discover.PubkeyID(&newkey().PublicKey), - IP: addr.IP, - TCP: uint16(addr.Port), } - server.AddPeer(static) + close(quit) + srv.Stop() select { - case conn := <-connected: - // Close the first connection, expect redial - conn.Close() - - case <-time.After(2 * server.staticCycle): - t.Fatalf("remote dial timeout") + case <-returned: + case <-time.After(500 * time.Millisecond): + t.Error("Server.run did not return within 500ms") } +} - select { - case conn := <-connected: - // Keep the second connection, don't expect redial - defer conn.Close() - - case <-time.After(2 * server.staticCycle): - t.Fatalf("remote re-dial timeout") - } +type taskgen struct { + newFunc func(running int, peers map[discover.NodeID]*Peer) []task + doneFunc func(task) +} - select { - case <-time.After(2 * server.staticCycle): - // Timeout as no dial occurred +func (tg taskgen) newTasks(running int, peers map[discover.NodeID]*Peer, now time.Time) []task { + return tg.newFunc(running, peers) +} +func (tg taskgen) taskDone(t task, now time.Time) { + tg.doneFunc(t) +} +func (tg taskgen) addStatic(*discover.Node) { +} - case <-connected: - t.Fatalf("connected node dialed") - } +type testTask struct { + index int + called bool } -// Tests that trusted peers and can connect above max peer caps. -func TestServerTrustedPeers(t *testing.T) { +func (t *testTask) Do(srv *Server) { + t.called = true +} - // Create a trusted peer to accept connections from - key := newkey() - trusted := &discover.Node{ - ID: discover.PubkeyID(&key.PublicKey), - } - // Create a test server with limited connection slots - started := make(chan *Peer) - server := &Server{ - ListenAddr: "127.0.0.1:0", +// This test checks that connections are disconnected +// just after the encryption handshake when the server is +// at capacity. Trusted connections should still be accepted. +func TestServerAtCap(t *testing.T) { + trustedID := randomID() + srv := &Server{ PrivateKey: newkey(), - MaxPeers: 3, + MaxPeers: 10, NoDial: true, - TrustedNodes: []*discover.Node{trusted}, - newPeerHook: func(p *Peer) { started <- p }, + TrustedNodes: []*discover.Node{{ID: trustedID}}, } - if err := server.Start(); err != nil { - t.Fatal(err) + if err := srv.Start(); err != nil { + t.Fatalf("could not start: %v", err) } - defer server.Stop() + defer srv.Stop() - // Fill up all the slots on the server - dialer := &net.Dialer{Deadline: time.Now().Add(3 * time.Second)} - for i := 0; i < server.MaxPeers; i++ { - // Establish a new connection - conn, err := dialer.Dial("tcp", server.ListenAddr) - if err != nil { - t.Fatalf("conn %d: dial error: %v", i, err) - } - defer conn.Close() + newconn := func(id discover.NodeID) *conn { + fd, _ := net.Pipe() + tx := newTestTransport(id, fd) + return &conn{fd: fd, transport: tx, flags: inboundConn, id: id, cont: make(chan error)} + } - // Run the handshakes just like a real peer would, and wait for completion - key := newkey() - shake := &protoHandshake{Version: baseProtocolVersion, ID: discover.PubkeyID(&key.PublicKey)} - if _, err = setupConn(conn, key, shake, server.Self(), keepalways); err != nil { - t.Fatalf("conn %d: unexpected error: %v", i, err) + // Inject a few connections to fill up the peer set. + for i := 0; i < 10; i++ { + c := newconn(randomID()) + if err := srv.checkpoint(c, srv.addpeer); err != nil { + t.Fatalf("could not add conn %d: %v", i, err) } - <-started } - // Dial from the trusted peer, ensure connection is accepted - conn, err := dialer.Dial("tcp", server.ListenAddr) - if err != nil { - t.Fatalf("trusted node: dial error: %v", err) + // Try inserting a non-trusted connection. + c := newconn(randomID()) + if err := srv.checkpoint(c, srv.posthandshake); err != DiscTooManyPeers { + t.Error("wrong error for insert:", err) } - defer conn.Close() - - shake := &protoHandshake{Version: baseProtocolVersion, ID: trusted.ID} - if _, err = setupConn(conn, key, shake, server.Self(), keepalways); err != nil { - t.Fatalf("trusted node: unexpected error: %v", err) + // Try inserting a trusted connection. + c = newconn(trustedID) + if err := srv.checkpoint(c, srv.posthandshake); err != nil { + t.Error("unexpected error for trusted conn @posthandshake:", err) } - select { - case <-started: - // Ok, trusted peer accepted - - case <-time.After(100 * time.Millisecond): - t.Fatalf("trusted node timeout") + if !c.is(trustedConn) { + t.Error("Server did not set trusted flag") } + } -// Tests that a failed dial will temporarily throttle a peer. -func TestServerMaxPendingDials(t *testing.T) { - // Start a simple test server - server := &Server{ - ListenAddr: "127.0.0.1:0", - PrivateKey: newkey(), - MaxPeers: 10, - MaxPendingPeers: 1, - } - if err := server.Start(); err != nil { - t.Fatal("failed to start test server: %v", err) +func TestServerSetupConn(t *testing.T) { + id := randomID() + srvkey := newkey() + srvid := discover.PubkeyID(&srvkey.PublicKey) + tests := []struct { + dontstart bool + tt *setupTransport + flags connFlag + dialDest *discover.Node + + wantCloseErr error + wantCalls string + }{ + { + dontstart: true, + tt: &setupTransport{id: id}, + wantCalls: "close,", + wantCloseErr: errServerStopped, + }, + { + tt: &setupTransport{id: id, encHandshakeErr: errors.New("read error")}, + flags: inboundConn, + wantCalls: "doEncHandshake,close,", + wantCloseErr: errors.New("read error"), + }, + { + tt: &setupTransport{id: id}, + dialDest: &discover.Node{ID: randomID()}, + flags: dynDialedConn, + wantCalls: "doEncHandshake,close,", + wantCloseErr: DiscUnexpectedIdentity, + }, + { + tt: &setupTransport{id: id, phs: &protoHandshake{ID: randomID()}}, + dialDest: &discover.Node{ID: id}, + flags: dynDialedConn, + wantCalls: "doEncHandshake,doProtoHandshake,close,", + wantCloseErr: DiscUnexpectedIdentity, + }, + { + tt: &setupTransport{id: id, protoHandshakeErr: errors.New("foo")}, + dialDest: &discover.Node{ID: id}, + flags: dynDialedConn, + wantCalls: "doEncHandshake,doProtoHandshake,close,", + wantCloseErr: errors.New("foo"), + }, + { + tt: &setupTransport{id: srvid, phs: &protoHandshake{ID: srvid}}, + flags: inboundConn, + wantCalls: "doEncHandshake,close,", + wantCloseErr: DiscSelf, + }, + { + tt: &setupTransport{id: id, phs: &protoHandshake{ID: id}}, + flags: inboundConn, + wantCalls: "doEncHandshake,doProtoHandshake,close,", + wantCloseErr: DiscUselessPeer, + }, } - defer server.Stop() - // Simulate two separate remote peers - peers := make(chan *discover.Node, 2) - conns := make(chan net.Conn, 2) - for i := 0; i < 2; i++ { - listener, err := net.Listen("tcp", "127.0.0.1:0") - if err != nil { - t.Fatalf("listener %d: failed to setup: %v", i, err) - } - defer listener.Close() - - addr := listener.Addr().(*net.TCPAddr) - peers <- &discover.Node{ - ID: discover.PubkeyID(&newkey().PublicKey), - IP: addr.IP, - TCP: uint16(addr.Port), + for i, test := range tests { + srv := &Server{ + PrivateKey: srvkey, + MaxPeers: 10, + NoDial: true, + Protocols: []Protocol{discard}, + newTransport: func(fd net.Conn) transport { return test.tt }, } - go func() { - conn, err := listener.Accept() - if err == nil { - conns <- conn + if !test.dontstart { + if err := srv.Start(); err != nil { + t.Fatalf("couldn't start server: %v", err) } - }() - } - // Request a dial for both peers - go func() { - for i := 0; i < 2; i++ { - server.staticDial <- <-peers // hack piggybacking the static implementation } - }() - - // Make sure only one outbound connection goes through - var conn net.Conn - - select { - case conn = <-conns: - case <-time.After(100 * time.Millisecond): - t.Fatalf("first dial timeout") - } - select { - case conn = <-conns: - t.Fatalf("second dial completed prematurely") - case <-time.After(100 * time.Millisecond): - } - // Finish the first dial, check the second - conn.Close() - select { - case conn = <-conns: - conn.Close() - - case <-time.After(100 * time.Millisecond): - t.Fatalf("second dial timeout") + p1, _ := net.Pipe() + srv.setupConn(p1, test.flags, test.dialDest) + if !reflect.DeepEqual(test.tt.closeErr, test.wantCloseErr) { + t.Errorf("test %d: close error mismatch: got %q, want %q", i, test.tt.closeErr, test.wantCloseErr) + } + if test.tt.calls != test.wantCalls { + t.Errorf("test %d: calls mismatch: got %q, want %q", i, test.tt.calls, test.wantCalls) + } } } -func TestServerMaxPendingAccepts(t *testing.T) { - // Start a test server and a peer sink for synchronization - started := make(chan *Peer) - server := &Server{ - ListenAddr: "127.0.0.1:0", - PrivateKey: newkey(), - MaxPeers: 10, - MaxPendingPeers: 1, - NoDial: true, - newPeerHook: func(p *Peer) { started <- p }, - } - if err := server.Start(); err != nil { - t.Fatal("failed to start test server: %v", err) - } - defer server.Stop() +type setupTransport struct { + id discover.NodeID + encHandshakeErr error - // Try and connect to the server on multiple threads concurrently - conns := make([]net.Conn, 2) - for i := 0; i < 2; i++ { - dialer := &net.Dialer{Deadline: time.Now().Add(3 * time.Second)} + phs *protoHandshake + protoHandshakeErr error - conn, err := dialer.Dial("tcp", server.ListenAddr) - if err != nil { - t.Fatalf("failed to dial server: %v", err) - } - conns[i] = conn - } - // Check that a handshake on the second doesn't pass - go func() { - key := newkey() - shake := &protoHandshake{Version: baseProtocolVersion, ID: discover.PubkeyID(&key.PublicKey)} - if _, err := setupConn(conns[1], key, shake, server.Self(), keepalways); err != nil { - t.Fatalf("failed to run handshake: %v", err) - } - }() - select { - case <-started: - t.Fatalf("handshake on second connection accepted") + calls string + closeErr error +} - case <-time.After(time.Second): - } - // Shake on first, check that both go through - go func() { - key := newkey() - shake := &protoHandshake{Version: baseProtocolVersion, ID: discover.PubkeyID(&key.PublicKey)} - if _, err := setupConn(conns[0], key, shake, server.Self(), keepalways); err != nil { - t.Fatalf("failed to run handshake: %v", err) - } - }() - for i := 0; i < 2; i++ { - select { - case <-started: - case <-time.After(time.Second): - t.Fatalf("peer %d: handshake timeout", i) - } +func (c *setupTransport) doEncHandshake(prv *ecdsa.PrivateKey, dialDest *discover.Node) (discover.NodeID, error) { + c.calls += "doEncHandshake," + return c.id, c.encHandshakeErr +} +func (c *setupTransport) doProtoHandshake(our *protoHandshake) (*protoHandshake, error) { + c.calls += "doProtoHandshake," + if c.protoHandshakeErr != nil { + return nil, c.protoHandshakeErr } + return c.phs, nil +} +func (c *setupTransport) close(err error) { + c.calls += "close," + c.closeErr = err +} + +// setupConn shouldn't write to/read from the connection. +func (c *setupTransport) WriteMsg(Msg) error { + panic("WriteMsg called on setupTransport") +} +func (c *setupTransport) ReadMsg() (Msg, error) { + panic("ReadMsg called on setupTransport") } func newkey() *ecdsa.PrivateKey { @@ -459,7 +419,3 @@ func randomID() (id discover.NodeID) { } return id } - -func keepalways(id discover.NodeID) bool { - return true -}