downloader_test.go 23.9 KB
Newer Older
1 2 3 4
package downloader

import (
	"encoding/binary"
5
	"fmt"
6 7 8 9 10
	"math/big"
	"testing"
	"time"

	"github.com/ethereum/go-ethereum/common"
11
	"github.com/ethereum/go-ethereum/core"
12
	"github.com/ethereum/go-ethereum/core/types"
O
obscuren 已提交
13
	"github.com/ethereum/go-ethereum/event"
14 15
)

16 17 18
var (
	knownHash   = common.Hash{1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}
	unknownHash = common.Hash{9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9}
19
	bannedHash  = common.Hash{5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5}
20 21

	genesis = createBlock(1, common.Hash{}, knownHash)
22
)
23

24
func createHashes(start, amount int) (hashes []common.Hash) {
25 26 27 28
	hashes = make([]common.Hash, amount+1)
	hashes[len(hashes)-1] = knownHash

	for i := range hashes[:len(hashes)-1] {
29
		binary.BigEndian.PutUint64(hashes[i][:8], uint64(start+i+2))
30 31 32 33
	}
	return
}

34
func createBlock(i int, parent, hash common.Hash) *types.Block {
35 36 37
	header := &types.Header{Number: big.NewInt(int64(i))}
	block := types.NewBlockWithHeader(header)
	block.HeaderHash = hash
38
	block.ParentHeaderHash = parent
39 40 41
	return block
}

42 43
func createBlocksFromHashes(hashes []common.Hash) map[common.Hash]*types.Block {
	blocks := make(map[common.Hash]*types.Block)
44 45 46 47 48 49
	for i := 0; i < len(hashes); i++ {
		parent := knownHash
		if i < len(hashes)-1 {
			parent = hashes[i+1]
		}
		blocks[hashes[i]] = createBlock(len(hashes)-i, parent, hashes[i])
50 51 52 53 54
	}
	return blocks
}

type downloadTester struct {
55 56
	downloader *Downloader

57 58 59 60
	ownHashes  []common.Hash                           // Hash chain belonging to the tester
	ownBlocks  map[common.Hash]*types.Block            // Blocks belonging to the tester
	peerHashes map[string][]common.Hash                // Hash chain belonging to different test peers
	peerBlocks map[string]map[common.Hash]*types.Block // Blocks belonging to different test peers
61

62
	maxHashFetch int // Overrides the maximum number of retrieved hashes
63 64
}

65
func newTester() *downloadTester {
66
	tester := &downloadTester{
67 68 69 70
		ownHashes:  []common.Hash{knownHash},
		ownBlocks:  map[common.Hash]*types.Block{knownHash: genesis},
		peerHashes: make(map[string][]common.Hash),
		peerBlocks: make(map[string]map[common.Hash]*types.Block),
71
	}
O
obscuren 已提交
72
	var mux event.TypeMux
73
	downloader := New(&mux, tester.hasBlock, tester.getBlock, tester.dropPeer)
74 75 76 77 78
	tester.downloader = downloader

	return tester
}

79 80 81
// syncTake is starts synchronising with a remote peer, but concurrently it also
// starts fetching blocks that the downloader retrieved. IT blocks until both go
// routines terminate.
82
func (dl *downloadTester) syncTake(peerId string, head common.Hash) ([]*Block, error) {
83 84
	// Start a block collector to take blocks as they become available
	done := make(chan struct{})
85
	took := []*Block{}
86 87 88 89 90 91 92 93 94
	go func() {
		for running := true; running; {
			select {
			case <-done:
				running = false
			default:
				time.Sleep(time.Millisecond)
			}
			// Take a batch of blocks and accumulate
95 96 97 98 99 100
			blocks := dl.downloader.TakeBlocks()
			for _, block := range blocks {
				dl.ownHashes = append(dl.ownHashes, block.RawBlock.Hash())
				dl.ownBlocks[block.RawBlock.Hash()] = block.RawBlock
			}
			took = append(took, blocks...)
101 102 103 104
		}
		done <- struct{}{}
	}()
	// Start the downloading, sync the taker and return
105
	err := dl.downloader.synchronise(peerId, head)
106 107 108 109 110

	done <- struct{}{}
	<-done

	return took, err
O
obscuren 已提交
111 112
}

113
// hasBlock checks if a block is present in the testers canonical chain.
114
func (dl *downloadTester) hasBlock(hash common.Hash) bool {
115
	return dl.getBlock(hash) != nil
116 117
}

118
// getBlock retrieves a block from the testers canonical chain.
119
func (dl *downloadTester) getBlock(hash common.Hash) *types.Block {
120 121 122 123 124 125 126 127 128 129 130 131
	return dl.ownBlocks[hash]
}

// newPeer registers a new block download source into the downloader.
func (dl *downloadTester) newPeer(id string, hashes []common.Hash, blocks map[common.Hash]*types.Block) error {
	err := dl.downloader.RegisterPeer(id, hashes[0], dl.peerGetHashesFn(id), dl.peerGetBlocksFn(id))
	if err == nil {
		// Assign the owned hashes and blocks to the peer
		dl.peerHashes[id] = hashes
		dl.peerBlocks[id] = blocks
	}
	return err
132 133
}

134 135 136 137 138 139 140 141
// dropPeer simulates a hard peer removal from the connection pool.
func (dl *downloadTester) dropPeer(id string) {
	delete(dl.peerHashes, id)
	delete(dl.peerBlocks, id)

	dl.downloader.UnregisterPeer(id)
}

142 143 144 145 146 147 148 149 150 151 152 153 154 155
// peerGetBlocksFn constructs a getHashes function associated with a particular
// peer in the download tester. The returned function can be used to retrieve
// batches of hashes from the particularly requested peer.
func (dl *downloadTester) peerGetHashesFn(id string) func(head common.Hash) error {
	return func(head common.Hash) error {
		limit := MaxHashFetch
		if dl.maxHashFetch > 0 {
			limit = dl.maxHashFetch
		}
		// Gather the next batch of hashes
		hashes := dl.peerHashes[id]
		result := make([]common.Hash, 0, limit)
		for i, hash := range hashes {
			if hash == head {
156
				i++
157 158 159 160 161
				for len(result) < cap(result) && i < len(hashes) {
					result = append(result, hashes[i])
					i++
				}
				break
162 163
			}
		}
164 165 166 167 168 169
		// Delay delivery a bit to allow attacks to unfold
		go func() {
			time.Sleep(time.Millisecond)
			dl.downloader.DeliverHashes(id, result)
		}()
		return nil
170
	}
171 172
}

173 174 175 176
// peerGetBlocksFn constructs a getBlocks function associated with a particular
// peer in the download tester. The returned function can be used to retrieve
// batches of blocks from the particularly requested peer.
func (dl *downloadTester) peerGetBlocksFn(id string) func([]common.Hash) error {
177
	return func(hashes []common.Hash) error {
178 179
		blocks := dl.peerBlocks[id]
		result := make([]*types.Block, 0, len(hashes))
180
		for _, hash := range hashes {
181 182
			if block, ok := blocks[hash]; ok {
				result = append(result, block)
183
			}
184
		}
185
		go dl.downloader.DeliverBlocks(id, result)
186 187 188 189 190

		return nil
	}
}

191 192 193 194
// Tests that simple synchronization, without throttling from a good peer works.
func TestSynchronisation(t *testing.T) {
	// Create a small enough block chain to download and the tester
	targetBlocks := blockCacheLimit - 15
195
	hashes := createHashes(0, targetBlocks)
196 197
	blocks := createBlocksFromHashes(hashes)

198 199
	tester := newTester()
	tester.newPeer("peer", hashes, blocks)
200

201
	// Synchronise with the peer and make sure all blocks were retrieved
202
	if err := tester.downloader.synchronise("peer", hashes[0]); err != nil {
203
		t.Fatalf("failed to synchronise blocks: %v", err)
204
	}
205
	if queued := len(tester.downloader.queue.blockPool); queued != targetBlocks {
206
		t.Fatalf("synchronised block mismatch: have %v, want %v", queued, targetBlocks)
207 208 209
	}
}

210 211 212 213
// Tests that the synchronized blocks can be correctly retrieved.
func TestBlockTaking(t *testing.T) {
	// Create a small enough block chain to download and the tester
	targetBlocks := blockCacheLimit - 15
214 215 216
	hashes := createHashes(0, targetBlocks)
	blocks := createBlocksFromHashes(hashes)

217 218
	tester := newTester()
	tester.newPeer("peer", hashes, blocks)
219

220
	// Synchronise with the peer and test block retrieval
221
	if err := tester.downloader.synchronise("peer", hashes[0]); err != nil {
222
		t.Fatalf("failed to synchronise blocks: %v", err)
223
	}
224 225
	if took := tester.downloader.TakeBlocks(); len(took) != targetBlocks {
		t.Fatalf("took block mismatch: have %v, want %v", len(took), targetBlocks)
226
	}
227
}
228

229
// Tests that an inactive downloader will not accept incoming hashes and blocks.
230
func TestInactiveDownloader(t *testing.T) {
231
	tester := newTester()
232

233
	// Check that neither hashes nor blocks are accepted
234
	if err := tester.downloader.DeliverHashes("bad peer", []common.Hash{}); err != errNoSyncActive {
235 236
		t.Errorf("error mismatch: have %v, want %v", err, errNoSyncActive)
	}
237
	if err := tester.downloader.DeliverBlocks("bad peer", []*types.Block{}); err != errNoSyncActive {
238
		t.Errorf("error mismatch: have %v, want %v", err, errNoSyncActive)
239 240 241
	}
}

242
// Tests that a canceled download wipes all previously accumulated state.
243
func TestCancel(t *testing.T) {
244 245
	// Create a small enough block chain to download and the tester
	targetBlocks := blockCacheLimit - 15
246 247 248
	hashes := createHashes(0, targetBlocks)
	blocks := createBlocksFromHashes(hashes)

249 250
	tester := newTester()
	tester.newPeer("peer", hashes, blocks)
251

252
	// Synchronise with the peer, but cancel afterwards
253
	if err := tester.downloader.synchronise("peer", hashes[0]); err != nil {
254
		t.Fatalf("failed to synchronise blocks: %v", err)
255 256
	}
	if !tester.downloader.Cancel() {
257
		t.Fatalf("cancel operation failed")
258
	}
259 260 261 262 263 264 265
	// Make sure the queue reports empty and no blocks can be taken
	hashCount, blockCount := tester.downloader.queue.Size()
	if hashCount > 0 || blockCount > 0 {
		t.Errorf("block or hash count mismatch: %d hashes, %d blocks, want 0", hashCount, blockCount)
	}
	if took := tester.downloader.TakeBlocks(); len(took) != 0 {
		t.Errorf("taken blocks mismatch: have %d, want %d", len(took), 0)
266 267 268
	}
}

269 270
// Tests that if a large batch of blocks are being downloaded, it is throttled
// until the cached blocks are retrieved.
271
func TestThrottling(t *testing.T) {
272 273
	// Create a long block chain to download and the tester
	targetBlocks := 8 * blockCacheLimit
274 275 276
	hashes := createHashes(0, targetBlocks)
	blocks := createBlocksFromHashes(hashes)

277 278
	tester := newTester()
	tester.newPeer("peer", hashes, blocks)
279

280 281 282
	// Start a synchronisation concurrently
	errc := make(chan error)
	go func() {
283
		errc <- tester.downloader.synchronise("peer", hashes[0])
284 285 286
	}()
	// Iteratively take some blocks, always checking the retrieval count
	for total := 0; total < targetBlocks; {
287 288 289 290 291 292 293
		// Wait a bit for sync to complete
		for start := time.Now(); time.Since(start) < 3*time.Second; {
			time.Sleep(25 * time.Millisecond)
			if len(tester.downloader.queue.blockPool) == blockCacheLimit {
				break
			}
		}
294 295 296 297 298 299 300 301 302
		// Fetch the next batch of blocks
		took := tester.downloader.TakeBlocks()
		if len(took) != blockCacheLimit {
			t.Fatalf("block count mismatch: have %v, want %v", len(took), blockCacheLimit)
		}
		total += len(took)
		if total > targetBlocks {
			t.Fatalf("target block count mismatch: have %v, want %v", total, targetBlocks)
		}
303
	}
304 305
	if err := <-errc; err != nil {
		t.Fatalf("block synchronization failed: %v", err)
306 307
	}
}
308 309 310 311 312 313 314 315 316 317 318 319

// Tests that if a peer returns an invalid chain with a block pointing to a non-
// existing parent, it is correctly detected and handled.
func TestNonExistingParentAttack(t *testing.T) {
	// Forge a single-link chain with a forged header
	hashes := createHashes(0, 1)
	blocks := createBlocksFromHashes(hashes)

	forged := blocks[hashes[0]]
	forged.ParentHeaderHash = unknownHash

	// Try and sync with the malicious node and check that it fails
320 321 322
	tester := newTester()
	tester.newPeer("attack", hashes, blocks)
	if err := tester.downloader.synchronise("attack", hashes[0]); err != nil {
323 324
		t.Fatalf("failed to synchronise blocks: %v", err)
	}
325 326 327
	bs := tester.downloader.TakeBlocks()
	if len(bs) != 1 {
		t.Fatalf("retrieved block mismatch: have %v, want %v", len(bs), 1)
328
	}
329
	if tester.hasBlock(bs[0].RawBlock.ParentHash()) {
330
		t.Fatalf("tester knows about the unknown hash")
331 332 333 334 335
	}
	tester.downloader.Cancel()

	// Reconstruct a valid chain, and try to synchronize with it
	forged.ParentHeaderHash = knownHash
336 337
	tester.newPeer("valid", hashes, blocks)
	if err := tester.downloader.synchronise("valid", hashes[0]); err != nil {
338 339
		t.Fatalf("failed to synchronise blocks: %v", err)
	}
340
	bs = tester.downloader.TakeBlocks()
341
	if len(bs) != 1 {
342
		t.Fatalf("retrieved block mismatch: have %v, want %v", len(bs), 1)
343
	}
344
	if !tester.hasBlock(bs[0].RawBlock.ParentHash()) {
345 346
		t.Fatalf("tester doesn't know about the origin hash")
	}
347
}
348 349 350 351 352

// Tests that if a malicious peers keeps sending us repeating hashes, we don't
// loop indefinitely.
func TestRepeatingHashAttack(t *testing.T) {
	// Create a valid chain, but drop the last link
353
	hashes := createHashes(0, blockCacheLimit)
354
	blocks := createBlocksFromHashes(hashes)
355
	forged := hashes[:len(hashes)-1]
356 357

	// Try and sync with the malicious node
358 359
	tester := newTester()
	tester.newPeer("attack", forged, blocks)
360 361 362

	errc := make(chan error)
	go func() {
363
		errc <- tester.downloader.synchronise("attack", hashes[0])
364 365 366 367
	}()

	// Make sure that syncing returns and does so with a failure
	select {
368
	case <-time.After(time.Second):
369 370 371 372 373 374
		t.Fatalf("synchronisation blocked")
	case err := <-errc:
		if err == nil {
			t.Fatalf("synchronisation succeeded")
		}
	}
375
	// Ensure that a valid chain can still pass sync
376 377
	tester.newPeer("valid", hashes, blocks)
	if err := tester.downloader.synchronise("valid", hashes[0]); err != nil {
378 379
		t.Fatalf("failed to synchronise blocks: %v", err)
	}
380
}
381 382 383 384 385

// Tests that if a malicious peers returns a non-existent block hash, it should
// eventually time out and the sync reattempted.
func TestNonExistingBlockAttack(t *testing.T) {
	// Create a valid chain, but forge the last link
386
	hashes := createHashes(0, blockCacheLimit)
387
	blocks := createBlocksFromHashes(hashes)
388
	origin := hashes[len(hashes)/2]
389 390 391 392

	hashes[len(hashes)/2] = unknownHash

	// Try and sync with the malicious node and check that it fails
393 394 395
	tester := newTester()
	tester.newPeer("attack", hashes, blocks)
	if err := tester.downloader.synchronise("attack", hashes[0]); err != errPeersUnavailable {
396 397
		t.Fatalf("synchronisation error mismatch: have %v, want %v", err, errPeersUnavailable)
	}
398 399
	// Ensure that a valid chain can still pass sync
	hashes[len(hashes)/2] = origin
400 401
	tester.newPeer("valid", hashes, blocks)
	if err := tester.downloader.synchronise("valid", hashes[0]); err != nil {
402 403
		t.Fatalf("failed to synchronise blocks: %v", err)
	}
404
}
405 406 407 408 409 410 411 412

// Tests that if a malicious peer is returning hashes in a weird order, that the
// sync throttler doesn't choke on them waiting for the valid blocks.
func TestInvalidHashOrderAttack(t *testing.T) {
	// Create a valid long chain, but reverse some hashes within
	hashes := createHashes(0, 4*blockCacheLimit)
	blocks := createBlocksFromHashes(hashes)

413 414 415 416 417
	chunk1 := make([]common.Hash, blockCacheLimit)
	chunk2 := make([]common.Hash, blockCacheLimit)
	copy(chunk1, hashes[blockCacheLimit:2*blockCacheLimit])
	copy(chunk2, hashes[2*blockCacheLimit:3*blockCacheLimit])

418 419
	reverse := make([]common.Hash, len(hashes))
	copy(reverse, hashes)
420 421
	copy(reverse[2*blockCacheLimit:], chunk1)
	copy(reverse[blockCacheLimit:], chunk2)
422 423

	// Try and sync with the malicious node and check that it fails
424 425
	tester := newTester()
	tester.newPeer("attack", reverse, blocks)
426 427
	if _, err := tester.syncTake("attack", reverse[0]); err != errInvalidChain {
		t.Fatalf("synchronisation error mismatch: have %v, want %v", err, errInvalidChain)
428 429
	}
	// Ensure that a valid chain can still pass sync
430
	tester.newPeer("valid", hashes, blocks)
431 432 433 434
	if _, err := tester.syncTake("valid", hashes[0]); err != nil {
		t.Fatalf("failed to synchronise blocks: %v", err)
	}
}
435 436 437 438

// Tests that if a malicious peer makes up a random hash chain and tries to push
// indefinitely, it actually gets caught with it.
func TestMadeupHashChainAttack(t *testing.T) {
439
	blockSoftTTL = 100 * time.Millisecond
440 441 442 443 444 445
	crossCheckCycle = 25 * time.Millisecond

	// Create a long chain of hashes without backing blocks
	hashes := createHashes(0, 1024*blockCacheLimit)

	// Try and sync with the malicious node and check that it fails
446 447
	tester := newTester()
	tester.newPeer("attack", hashes, nil)
448 449
	if _, err := tester.syncTake("attack", hashes[0]); err != errCrossCheckFailed {
		t.Fatalf("synchronisation error mismatch: have %v, want %v", err, errCrossCheckFailed)
450 451
	}
}
452

453 454 455 456 457 458 459
// Tests that if a malicious peer makes up a random hash chain, and tries to push
// indefinitely, one hash at a time, it actually gets caught with it. The reason
// this is separate from the classical made up chain attack is that sending hashes
// one by one prevents reliable block/parent verification.
func TestMadeupHashChainDrippingAttack(t *testing.T) {
	// Create a random chain of hashes to drip
	hashes := createHashes(0, 16*blockCacheLimit)
460
	tester := newTester()
461 462 463

	// Try and sync with the attacker, one hash at a time
	tester.maxHashFetch = 1
464
	tester.newPeer("attack", hashes, nil)
465 466
	if _, err := tester.syncTake("attack", hashes[0]); err != errStallingPeer {
		t.Fatalf("synchronisation error mismatch: have %v, want %v", err, errStallingPeer)
467 468 469
	}
}

470 471 472
// Tests that if a malicious peer makes up a random block chain, and tried to
// push indefinitely, it actually gets caught with it.
func TestMadeupBlockChainAttack(t *testing.T) {
473
	defaultBlockTTL := blockSoftTTL
474 475
	defaultCrossCheckCycle := crossCheckCycle

476
	blockSoftTTL = 100 * time.Millisecond
477 478 479
	crossCheckCycle = 25 * time.Millisecond

	// Create a long chain of blocks and simulate an invalid chain by dropping every second
480
	hashes := createHashes(0, 16*blockCacheLimit)
481 482 483 484 485 486 487
	blocks := createBlocksFromHashes(hashes)

	gapped := make([]common.Hash, len(hashes)/2)
	for i := 0; i < len(gapped); i++ {
		gapped[i] = hashes[2*i]
	}
	// Try and sync with the malicious node and check that it fails
488 489
	tester := newTester()
	tester.newPeer("attack", gapped, blocks)
490 491
	if _, err := tester.syncTake("attack", gapped[0]); err != errCrossCheckFailed {
		t.Fatalf("synchronisation error mismatch: have %v, want %v", err, errCrossCheckFailed)
492 493
	}
	// Ensure that a valid chain can still pass sync
494
	blockSoftTTL = defaultBlockTTL
495 496
	crossCheckCycle = defaultCrossCheckCycle

497
	tester.newPeer("valid", hashes, blocks)
498 499 500 501
	if _, err := tester.syncTake("valid", hashes[0]); err != nil {
		t.Fatalf("failed to synchronise blocks: %v", err)
	}
}
502 503 504 505 506

// Advanced form of the above forged blockchain attack, where not only does the
// attacker make up a valid hashes for random blocks, but also forges the block
// parents to point to existing hashes.
func TestMadeupParentBlockChainAttack(t *testing.T) {
507
	defaultBlockTTL := blockSoftTTL
508 509
	defaultCrossCheckCycle := crossCheckCycle

510
	blockSoftTTL = 100 * time.Millisecond
511 512 513 514 515 516 517 518 519 520
	crossCheckCycle = 25 * time.Millisecond

	// Create a long chain of blocks and simulate an invalid chain by dropping every second
	hashes := createHashes(0, 16*blockCacheLimit)
	blocks := createBlocksFromHashes(hashes)
	forges := createBlocksFromHashes(hashes)
	for hash, block := range forges {
		block.ParentHeaderHash = hash // Simulate pointing to already known hash
	}
	// Try and sync with the malicious node and check that it fails
521 522
	tester := newTester()
	tester.newPeer("attack", hashes, forges)
523 524
	if _, err := tester.syncTake("attack", hashes[0]); err != errCrossCheckFailed {
		t.Fatalf("synchronisation error mismatch: have %v, want %v", err, errCrossCheckFailed)
525 526
	}
	// Ensure that a valid chain can still pass sync
527
	blockSoftTTL = defaultBlockTTL
528 529
	crossCheckCycle = defaultCrossCheckCycle

530
	tester.newPeer("valid", hashes, blocks)
531 532 533 534
	if _, err := tester.syncTake("valid", hashes[0]); err != nil {
		t.Fatalf("failed to synchronise blocks: %v", err)
	}
}
535 536 537 538 539 540 541 542 543 544 545 546

// Tests that if one/multiple malicious peers try to feed a banned blockchain to
// the downloader, it will not keep refetching the same chain indefinitely, but
// gradually block pieces of it, until it's head is also blocked.
func TestBannedChainStarvationAttack(t *testing.T) {
	// Construct a valid chain, but ban one of the hashes in it
	hashes := createHashes(0, 8*blockCacheLimit)
	hashes[len(hashes)/2+23] = bannedHash // weird index to have non multiple of ban chunk size

	blocks := createBlocksFromHashes(hashes)

	// Create the tester and ban the selected hash
547
	tester := newTester()
548 549 550 551
	tester.downloader.banned.Add(bannedHash)

	// Iteratively try to sync, and verify that the banned hash list grows until
	// the head of the invalid chain is blocked too.
552
	tester.newPeer("attack", hashes, blocks)
553 554
	for banned := tester.downloader.banned.Size(); ; {
		// Try to sync with the attacker, check hash chain failure
555
		if _, err := tester.syncTake("attack", hashes[0]); err != errInvalidChain {
556 557 558
			if tester.downloader.banned.Has(hashes[0]) && err == errBannedHead {
				break
			}
559
			t.Fatalf("synchronisation error mismatch: have %v, want %v", err, errInvalidChain)
560 561 562 563 564 565 566 567
		}
		// Check that the ban list grew with at least 1 new item, or all banned
		bans := tester.downloader.banned.Size()
		if bans < banned+1 {
			t.Fatalf("ban count mismatch: have %v, want %v+", bans, banned+1)
		}
		banned = bans
	}
568
	// Check that after banning an entire chain, bad peers get dropped
569
	if err := tester.newPeer("new attacker", hashes, blocks); err != errBannedHead {
570 571 572 573 574
		t.Fatalf("peer registration mismatch: have %v, want %v", err, errBannedHead)
	}
	if peer := tester.downloader.peers.Peer("net attacker"); peer != nil {
		t.Fatalf("banned attacker registered: %v", peer)
	}
575
}
576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591

// Tests that if a peer sends excessively many/large invalid chains that are
// gradually banned, it will have an upper limit on the consumed memory and also
// the origin bad hashes will not be evacuated.
func TestBannedChainMemoryExhaustionAttack(t *testing.T) {
	// Reduce the test size a bit
	MaxBlockFetch = 4
	maxBannedHashes = 256

	// Construct a banned chain with more chunks than the ban limit
	hashes := createHashes(0, maxBannedHashes*MaxBlockFetch)
	hashes[len(hashes)-1] = bannedHash // weird index to have non multiple of ban chunk size

	blocks := createBlocksFromHashes(hashes)

	// Create the tester and ban the selected hash
592
	tester := newTester()
593 594 595 596
	tester.downloader.banned.Add(bannedHash)

	// Iteratively try to sync, and verify that the banned hash list grows until
	// the head of the invalid chain is blocked too.
597
	tester.newPeer("attack", hashes, blocks)
598 599
	for {
		// Try to sync with the attacker, check hash chain failure
600 601
		if _, err := tester.syncTake("attack", hashes[0]); err != errInvalidChain {
			t.Fatalf("synchronisation error mismatch: have %v, want %v", err, errInvalidChain)
602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617
		}
		// Short circuit if the entire chain was banned
		if tester.downloader.banned.Has(hashes[0]) {
			break
		}
		// Otherwise ensure we never exceed the memory allowance and the hard coded bans are untouched
		if bans := tester.downloader.banned.Size(); bans > maxBannedHashes {
			t.Fatalf("ban cap exceeded: have %v, want max %v", bans, maxBannedHashes)
		}
		for hash, _ := range core.BadHashes {
			if !tester.downloader.banned.Has(hash) {
				t.Fatalf("hard coded ban evacuated: %x", hash)
			}
		}
	}
}
618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661

// Tests that misbehaving peers are disconnected, whilst behaving ones are not.
func TestAttackerDropping(t *testing.T) {
	// Define the disconnection requirement for individual errors
	tests := []struct {
		result error
		drop   bool
	}{
		{nil, false},                 // Sync succeeded, all is well
		{errBusy, false},             // Sync is already in progress, no problem
		{errUnknownPeer, false},      // Peer is unknown, was already dropped, don't double drop
		{errBadPeer, true},           // Peer was deemed bad for some reason, drop it
		{errStallingPeer, true},      // Peer was detected to be stalling, drop it
		{errBannedHead, true},        // Peer's head hash is a known bad hash, drop it
		{errNoPeers, false},          // No peers to download from, soft race, no issue
		{errPendingQueue, false},     // There are blocks still cached, wait to exhaust, no issue
		{errTimeout, true},           // No hashes received in due time, drop the peer
		{errEmptyHashSet, true},      // No hashes were returned as a response, drop as it's a dead end
		{errPeersUnavailable, true},  // Nobody had the advertised blocks, drop the advertiser
		{errInvalidChain, true},      // Hash chain was detected as invalid, definitely drop
		{errCrossCheckFailed, true},  // Hash-origin failed to pass a block cross check, drop
		{errCancelHashFetch, false},  // Synchronisation was canceled, origin may be innocent, don't drop
		{errCancelBlockFetch, false}, // Synchronisation was canceled, origin may be innocent, don't drop
	}
	// Run the tests and check disconnection status
	tester := newTester()
	for i, tt := range tests {
		// Register a new peer and ensure it's presence
		id := fmt.Sprintf("test %d", i)
		if err := tester.newPeer(id, []common.Hash{knownHash}, nil); err != nil {
			t.Fatalf("test %d: failed to register new peer: %v", i, err)
		}
		if _, ok := tester.peerHashes[id]; !ok {
			t.Fatalf("test %d: registered peer not found", i)
		}
		// Simulate a synchronisation and check the required result
		tester.downloader.synchroniseMock = func(string, common.Hash) error { return tt.result }

		tester.downloader.Synchronise(id, knownHash)
		if _, ok := tester.peerHashes[id]; !ok != tt.drop {
			t.Errorf("test %d: peer drop mismatch for %v: have %v, want %v", i, tt.result, !ok, tt.drop)
		}
	}
}