downloader_test.go 28.7 KB
Newer Older
1 2 3 4
package downloader

import (
	"encoding/binary"
5
	"errors"
6
	"fmt"
7
	"math/big"
8
	"sync/atomic"
9 10 11 12
	"testing"
	"time"

	"github.com/ethereum/go-ethereum/common"
13
	"github.com/ethereum/go-ethereum/core"
14
	"github.com/ethereum/go-ethereum/core/types"
O
obscuren 已提交
15
	"github.com/ethereum/go-ethereum/event"
16 17
)

18
var (
19 20 21
	knownHash   = common.Hash{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}
	unknownHash = common.Hash{2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2}
	bannedHash  = common.Hash{3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3}
22 23

	genesis = createBlock(1, common.Hash{}, knownHash)
24
)
25

26 27 28 29 30
// idCounter is used by the createHashes method the generate deterministic but unique hashes
var idCounter = int64(2) // #1 is the genesis block

// createHashes generates a batch of hashes rooted at a specific point in the chain.
func createHashes(amount int, root common.Hash) (hashes []common.Hash) {
31
	hashes = make([]common.Hash, amount+1)
32
	hashes[len(hashes)-1] = root
33

34 35 36
	for i := 0; i < len(hashes)-1; i++ {
		binary.BigEndian.PutUint64(hashes[i][:8], uint64(idCounter))
		idCounter++
37 38 39 40
	}
	return
}

41
// createBlock assembles a new block at the given chain height.
42
func createBlock(i int, parent, hash common.Hash) *types.Block {
43 44 45
	header := &types.Header{Number: big.NewInt(int64(i))}
	block := types.NewBlockWithHeader(header)
	block.HeaderHash = hash
46
	block.ParentHeaderHash = parent
47 48 49
	return block
}

50 51 52 53 54
// copyBlock makes a deep copy of a block suitable for local modifications.
func copyBlock(block *types.Block) *types.Block {
	return createBlock(int(block.Number().Int64()), block.ParentHeaderHash, block.HeaderHash)
}

55 56
func createBlocksFromHashes(hashes []common.Hash) map[common.Hash]*types.Block {
	blocks := make(map[common.Hash]*types.Block)
57 58 59 60 61 62
	for i := 0; i < len(hashes); i++ {
		parent := knownHash
		if i < len(hashes)-1 {
			parent = hashes[i+1]
		}
		blocks[hashes[i]] = createBlock(len(hashes)-i, parent, hashes[i])
63 64 65 66 67
	}
	return blocks
}

type downloadTester struct {
68 69
	downloader *Downloader

70 71 72 73
	ownHashes  []common.Hash                           // Hash chain belonging to the tester
	ownBlocks  map[common.Hash]*types.Block            // Blocks belonging to the tester
	peerHashes map[string][]common.Hash                // Hash chain belonging to different test peers
	peerBlocks map[string]map[common.Hash]*types.Block // Blocks belonging to different test peers
74

75
	maxHashFetch int // Overrides the maximum number of retrieved hashes
76 77
}

78
func newTester() *downloadTester {
79
	tester := &downloadTester{
80 81 82 83
		ownHashes:  []common.Hash{knownHash},
		ownBlocks:  map[common.Hash]*types.Block{knownHash: genesis},
		peerHashes: make(map[string][]common.Hash),
		peerBlocks: make(map[string]map[common.Hash]*types.Block),
84
	}
O
obscuren 已提交
85
	var mux event.TypeMux
86
	downloader := New(&mux, tester.hasBlock, tester.getBlock, tester.insertChain, tester.dropPeer)
87 88 89 90 91
	tester.downloader = downloader

	return tester
}

92 93
// sync starts synchronizing with a remote peer, blocking until it completes.
func (dl *downloadTester) sync(id string) error {
94 95 96 97 98
	err := dl.downloader.synchronise(id, dl.peerHashes[id][0])
	for atomic.LoadInt32(&dl.downloader.processing) == 1 {
		time.Sleep(time.Millisecond)
	}
	return err
O
obscuren 已提交
99 100
}

101
// hasBlock checks if a block is pres	ent in the testers canonical chain.
102
func (dl *downloadTester) hasBlock(hash common.Hash) bool {
103
	return dl.getBlock(hash) != nil
104 105
}

106
// getBlock retrieves a block from the testers canonical chain.
107
func (dl *downloadTester) getBlock(hash common.Hash) *types.Block {
108 109 110
	return dl.ownBlocks[hash]
}

111 112 113 114 115 116 117 118 119 120 121 122
// insertChain injects a new batch of blocks into the simulated chain.
func (dl *downloadTester) insertChain(blocks types.Blocks) (int, error) {
	for i, block := range blocks {
		if _, ok := dl.ownBlocks[block.ParentHash()]; !ok {
			return i, errors.New("unknown parent")
		}
		dl.ownHashes = append(dl.ownHashes, block.Hash())
		dl.ownBlocks[block.Hash()] = block
	}
	return len(blocks), nil
}

123 124
// newPeer registers a new block download source into the downloader.
func (dl *downloadTester) newPeer(id string, hashes []common.Hash, blocks map[common.Hash]*types.Block) error {
125 126 127 128 129 130 131 132
	return dl.newSlowPeer(id, hashes, blocks, 0)
}

// newSlowPeer registers a new block download source into the downloader, with a
// specific delay time on processing the network packets sent to it, simulating
// potentially slow network IO.
func (dl *downloadTester) newSlowPeer(id string, hashes []common.Hash, blocks map[common.Hash]*types.Block, delay time.Duration) error {
	err := dl.downloader.RegisterPeer(id, hashes[0], dl.peerGetHashesFn(id, delay), dl.peerGetBlocksFn(id, delay))
133
	if err == nil {
134 135 136 137 138 139 140 141
		// Assign the owned hashes and blocks to the peer (deep copy)
		dl.peerHashes[id] = make([]common.Hash, len(hashes))
		copy(dl.peerHashes[id], hashes)

		dl.peerBlocks[id] = make(map[common.Hash]*types.Block)
		for hash, block := range blocks {
			dl.peerBlocks[id][hash] = copyBlock(block)
		}
142 143
	}
	return err
144 145
}

146 147 148 149 150 151 152 153
// dropPeer simulates a hard peer removal from the connection pool.
func (dl *downloadTester) dropPeer(id string) {
	delete(dl.peerHashes, id)
	delete(dl.peerBlocks, id)

	dl.downloader.UnregisterPeer(id)
}

154 155 156
// peerGetBlocksFn constructs a getHashes function associated with a particular
// peer in the download tester. The returned function can be used to retrieve
// batches of hashes from the particularly requested peer.
157
func (dl *downloadTester) peerGetHashesFn(id string, delay time.Duration) func(head common.Hash) error {
158
	return func(head common.Hash) error {
159 160
		time.Sleep(delay)

161 162 163 164 165 166 167 168 169
		limit := MaxHashFetch
		if dl.maxHashFetch > 0 {
			limit = dl.maxHashFetch
		}
		// Gather the next batch of hashes
		hashes := dl.peerHashes[id]
		result := make([]common.Hash, 0, limit)
		for i, hash := range hashes {
			if hash == head {
170
				i++
171 172 173 174 175
				for len(result) < cap(result) && i < len(hashes) {
					result = append(result, hashes[i])
					i++
				}
				break
176 177
			}
		}
178 179 180 181 182 183
		// Delay delivery a bit to allow attacks to unfold
		go func() {
			time.Sleep(time.Millisecond)
			dl.downloader.DeliverHashes(id, result)
		}()
		return nil
184
	}
185 186
}

187 188 189
// peerGetBlocksFn constructs a getBlocks function associated with a particular
// peer in the download tester. The returned function can be used to retrieve
// batches of blocks from the particularly requested peer.
190
func (dl *downloadTester) peerGetBlocksFn(id string, delay time.Duration) func([]common.Hash) error {
191
	return func(hashes []common.Hash) error {
192 193
		time.Sleep(delay)

194 195
		blocks := dl.peerBlocks[id]
		result := make([]*types.Block, 0, len(hashes))
196
		for _, hash := range hashes {
197 198
			if block, ok := blocks[hash]; ok {
				result = append(result, block)
199
			}
200
		}
201
		go dl.downloader.DeliverBlocks(id, result)
202 203 204 205 206

		return nil
	}
}

207 208 209 210
// Tests that simple synchronization, without throttling from a good peer works.
func TestSynchronisation(t *testing.T) {
	// Create a small enough block chain to download and the tester
	targetBlocks := blockCacheLimit - 15
211
	hashes := createHashes(targetBlocks, knownHash)
212 213
	blocks := createBlocksFromHashes(hashes)

214 215
	tester := newTester()
	tester.newPeer("peer", hashes, blocks)
216

217
	// Synchronise with the peer and make sure all blocks were retrieved
218
	if err := tester.sync("peer"); err != nil {
219
		t.Fatalf("failed to synchronise blocks: %v", err)
220
	}
221 222
	if imported := len(tester.ownBlocks); imported != targetBlocks+1 {
		t.Fatalf("synchronised block mismatch: have %v, want %v", imported, targetBlocks+1)
223
	}
224
}
225

226
// Tests that an inactive downloader will not accept incoming hashes and blocks.
227
func TestInactiveDownloader(t *testing.T) {
228
	tester := newTester()
229

230
	// Check that neither hashes nor blocks are accepted
231
	if err := tester.downloader.DeliverHashes("bad peer", []common.Hash{}); err != errNoSyncActive {
232 233
		t.Errorf("error mismatch: have %v, want %v", err, errNoSyncActive)
	}
234
	if err := tester.downloader.DeliverBlocks("bad peer", []*types.Block{}); err != errNoSyncActive {
235
		t.Errorf("error mismatch: have %v, want %v", err, errNoSyncActive)
236 237 238
	}
}

239
// Tests that a canceled download wipes all previously accumulated state.
240
func TestCancel(t *testing.T) {
241 242
	// Create a small enough block chain to download and the tester
	targetBlocks := blockCacheLimit - 15
243
	hashes := createHashes(targetBlocks, knownHash)
244 245
	blocks := createBlocksFromHashes(hashes)

246 247
	tester := newTester()
	tester.newPeer("peer", hashes, blocks)
248

249 250 251 252 253 254
	// Make sure canceling works with a pristine downloader
	tester.downloader.Cancel()
	hashCount, blockCount := tester.downloader.queue.Size()
	if hashCount > 0 || blockCount > 0 {
		t.Errorf("block or hash count mismatch: %d hashes, %d blocks, want 0", hashCount, blockCount)
	}
255
	// Synchronise with the peer, but cancel afterwards
256
	if err := tester.sync("peer"); err != nil {
257
		t.Fatalf("failed to synchronise blocks: %v", err)
258
	}
259 260
	tester.downloader.Cancel()
	hashCount, blockCount = tester.downloader.queue.Size()
261 262 263
	if hashCount > 0 || blockCount > 0 {
		t.Errorf("block or hash count mismatch: %d hashes, %d blocks, want 0", hashCount, blockCount)
	}
264 265
}

266 267
// Tests that if a large batch of blocks are being downloaded, it is throttled
// until the cached blocks are retrieved.
268
func TestThrottling(t *testing.T) {
269 270
	// Create a long block chain to download and the tester
	targetBlocks := 8 * blockCacheLimit
271
	hashes := createHashes(targetBlocks, knownHash)
272 273
	blocks := createBlocksFromHashes(hashes)

274 275
	tester := newTester()
	tester.newPeer("peer", hashes, blocks)
276

277 278 279 280 281 282 283
	// Wrap the importer to allow stepping
	done := make(chan int)
	tester.downloader.insertChain = func(blocks types.Blocks) (int, error) {
		n, err := tester.insertChain(blocks)
		done <- n
		return n, err
	}
284 285 286
	// Start a synchronisation concurrently
	errc := make(chan error)
	go func() {
287
		errc <- tester.sync("peer")
288 289
	}()
	// Iteratively take some blocks, always checking the retrieval count
290 291 292
	for len(tester.ownBlocks) < targetBlocks+1 {
		// Wait a bit for sync to throttle itself
		var cached int
293 294
		for start := time.Now(); time.Since(start) < 3*time.Second; {
			time.Sleep(25 * time.Millisecond)
295 296 297

			cached = len(tester.downloader.queue.blockPool)
			if cached == blockCacheLimit || len(tester.ownBlocks)+cached == targetBlocks+1 {
298 299 300
				break
			}
		}
301 302 303 304
		// Make sure we filled up the cache, then exhaust it
		time.Sleep(25 * time.Millisecond) // give it a chance to screw up
		if cached != blockCacheLimit && len(tester.ownBlocks)+cached < targetBlocks+1 {
			t.Fatalf("block count mismatch: have %v, want %v", cached, blockCacheLimit)
305
		}
306 307 308
		<-done // finish previous blocking import
		for cached > maxBlockProcess {
			cached -= <-done
309
		}
310 311 312 313 314 315 316
		time.Sleep(25 * time.Millisecond) // yield to the insertion
	}
	<-done // finish the last blocking import

	// Check that we haven't pulled more blocks than available
	if len(tester.ownBlocks) > targetBlocks+1 {
		t.Fatalf("target block count mismatch: have %v, want %v", len(tester.ownBlocks), targetBlocks+1)
317
	}
318 319
	if err := <-errc; err != nil {
		t.Fatalf("block synchronization failed: %v", err)
320 321
	}
}
322

323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353
// Tests that synchronisation from multiple peers works as intended (multi thread sanity test).
func TestMultiSynchronisation(t *testing.T) {
	// Create various peers with various parts of the chain
	targetPeers := 16
	targetBlocks := targetPeers*blockCacheLimit - 15

	hashes := createHashes(targetBlocks, knownHash)
	blocks := createBlocksFromHashes(hashes)

	tester := newTester()
	for i := 0; i < targetPeers; i++ {
		id := fmt.Sprintf("peer #%d", i)
		tester.newPeer(id, hashes[i*blockCacheLimit:], blocks)
	}
	// Synchronise with the middle peer and make sure half of the blocks were retrieved
	id := fmt.Sprintf("peer #%d", targetPeers/2)
	if err := tester.sync(id); err != nil {
		t.Fatalf("failed to synchronise blocks: %v", err)
	}
	if imported := len(tester.ownBlocks); imported != len(tester.peerHashes[id]) {
		t.Fatalf("synchronised block mismatch: have %v, want %v", imported, len(tester.peerHashes[id]))
	}
	// Synchronise with the best peer and make sure everything is retrieved
	if err := tester.sync("peer #0"); err != nil {
		t.Fatalf("failed to synchronise blocks: %v", err)
	}
	if imported := len(tester.ownBlocks); imported != targetBlocks+1 {
		t.Fatalf("synchronised block mismatch: have %v, want %v", imported, targetBlocks+1)
	}
}

354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384
// Tests that synchronising with a peer who's very slow at network IO does not
// stall the other peers in the system.
func TestSlowSynchronisation(t *testing.T) {
	tester := newTester()

	// Create a batch of blocks, with a slow and a full speed peer
	targetCycles := 2
	targetBlocks := targetCycles*blockCacheLimit - 15
	targetIODelay := 500 * time.Millisecond

	hashes := createHashes(targetBlocks, knownHash)
	blocks := createBlocksFromHashes(hashes)

	tester.newSlowPeer("fast", hashes, blocks, 0)
	tester.newSlowPeer("slow", hashes, blocks, targetIODelay)

	// Try to sync with the peers (pull hashes from fast)
	start := time.Now()
	if err := tester.sync("fast"); err != nil {
		t.Fatalf("failed to synchronise blocks: %v", err)
	}
	if imported := len(tester.ownBlocks); imported != targetBlocks+1 {
		t.Fatalf("synchronised block mismatch: have %v, want %v", imported, targetBlocks+1)
	}
	// Check that the slow peer got hit at most once per block-cache-size import
	limit := time.Duration(targetCycles+1) * targetIODelay
	if delay := time.Since(start); delay >= limit {
		t.Fatalf("synchronisation exceeded delay limit: have %v, want %v", delay, limit)
	}
}

385 386 387
// Tests that if a peer returns an invalid chain with a block pointing to a non-
// existing parent, it is correctly detected and handled.
func TestNonExistingParentAttack(t *testing.T) {
388 389
	tester := newTester()

390
	// Forge a single-link chain with a forged header
391
	hashes := createHashes(1, knownHash)
392
	blocks := createBlocksFromHashes(hashes)
393
	tester.newPeer("valid", hashes, blocks)
394

395 396 397 398
	hashes = createHashes(1, knownHash)
	blocks = createBlocksFromHashes(hashes)
	blocks[hashes[0]].ParentHeaderHash = unknownHash
	tester.newPeer("attack", hashes, blocks)
399 400

	// Try and sync with the malicious node and check that it fails
401 402
	if err := tester.sync("attack"); err == nil {
		t.Fatalf("block synchronization succeeded")
403
	}
404 405
	if tester.hasBlock(hashes[0]) {
		t.Fatalf("tester accepted unknown-parent block: %v", blocks[hashes[0]])
406
	}
407 408
	// Try to synchronize with the valid chain and make sure it succeeds
	if err := tester.sync("valid"); err != nil {
409 410
		t.Fatalf("failed to synchronise blocks: %v", err)
	}
411 412
	if !tester.hasBlock(tester.peerHashes["valid"][0]) {
		t.Fatalf("tester didn't accept known-parent block: %v", tester.peerBlocks["valid"][hashes[0]])
413
	}
414
}
415 416 417

// Tests that if a malicious peers keeps sending us repeating hashes, we don't
// loop indefinitely.
418 419 420
func TestRepeatingHashAttack(t *testing.T) { // TODO: Is this thing valid??
	tester := newTester()

421
	// Create a valid chain, but drop the last link
422
	hashes := createHashes(blockCacheLimit, knownHash)
423
	blocks := createBlocksFromHashes(hashes)
424 425
	tester.newPeer("valid", hashes, blocks)
	tester.newPeer("attack", hashes[:len(hashes)-1], blocks)
426 427 428 429

	// Try and sync with the malicious node
	errc := make(chan error)
	go func() {
430
		errc <- tester.sync("attack")
431 432 433
	}()
	// Make sure that syncing returns and does so with a failure
	select {
434
	case <-time.After(time.Second):
435 436 437 438 439 440
		t.Fatalf("synchronisation blocked")
	case err := <-errc:
		if err == nil {
			t.Fatalf("synchronisation succeeded")
		}
	}
441
	// Ensure that a valid chain can still pass sync
442
	if err := tester.sync("valid"); err != nil {
443 444
		t.Fatalf("failed to synchronise blocks: %v", err)
	}
445
}
446 447 448 449

// Tests that if a malicious peers returns a non-existent block hash, it should
// eventually time out and the sync reattempted.
func TestNonExistingBlockAttack(t *testing.T) {
450 451
	tester := newTester()

452
	// Create a valid chain, but forge the last link
453
	hashes := createHashes(blockCacheLimit, knownHash)
454
	blocks := createBlocksFromHashes(hashes)
455
	tester.newPeer("valid", hashes, blocks)
456 457

	hashes[len(hashes)/2] = unknownHash
458
	tester.newPeer("attack", hashes, blocks)
459 460

	// Try and sync with the malicious node and check that it fails
461
	if err := tester.sync("attack"); err != errPeersUnavailable {
462 463
		t.Fatalf("synchronisation error mismatch: have %v, want %v", err, errPeersUnavailable)
	}
464
	// Ensure that a valid chain can still pass sync
465
	if err := tester.sync("valid"); err != nil {
466 467
		t.Fatalf("failed to synchronise blocks: %v", err)
	}
468
}
469 470 471 472

// Tests that if a malicious peer is returning hashes in a weird order, that the
// sync throttler doesn't choke on them waiting for the valid blocks.
func TestInvalidHashOrderAttack(t *testing.T) {
473 474
	tester := newTester()

475
	// Create a valid long chain, but reverse some hashes within
476
	hashes := createHashes(4*blockCacheLimit, knownHash)
477
	blocks := createBlocksFromHashes(hashes)
478
	tester.newPeer("valid", hashes, blocks)
479

480 481 482 483 484
	chunk1 := make([]common.Hash, blockCacheLimit)
	chunk2 := make([]common.Hash, blockCacheLimit)
	copy(chunk1, hashes[blockCacheLimit:2*blockCacheLimit])
	copy(chunk2, hashes[2*blockCacheLimit:3*blockCacheLimit])

485 486 487
	copy(hashes[2*blockCacheLimit:], chunk1)
	copy(hashes[blockCacheLimit:], chunk2)
	tester.newPeer("attack", hashes, blocks)
488 489

	// Try and sync with the malicious node and check that it fails
490
	if err := tester.sync("attack"); err != errInvalidChain {
491
		t.Fatalf("synchronisation error mismatch: have %v, want %v", err, errInvalidChain)
492 493
	}
	// Ensure that a valid chain can still pass sync
494
	if err := tester.sync("valid"); err != nil {
495 496 497
		t.Fatalf("failed to synchronise blocks: %v", err)
	}
}
498 499 500 501

// Tests that if a malicious peer makes up a random hash chain and tries to push
// indefinitely, it actually gets caught with it.
func TestMadeupHashChainAttack(t *testing.T) {
502
	tester := newTester()
503
	blockSoftTTL = 100 * time.Millisecond
504 505 506
	crossCheckCycle = 25 * time.Millisecond

	// Create a long chain of hashes without backing blocks
507 508 509 510 511
	hashes := createHashes(4*blockCacheLimit, knownHash)
	blocks := createBlocksFromHashes(hashes)

	tester.newPeer("valid", hashes, blocks)
	tester.newPeer("attack", createHashes(1024*blockCacheLimit, knownHash), nil)
512 513

	// Try and sync with the malicious node and check that it fails
514
	if err := tester.sync("attack"); err != errCrossCheckFailed {
515
		t.Fatalf("synchronisation error mismatch: have %v, want %v", err, errCrossCheckFailed)
516
	}
517
	// Ensure that a valid chain can still pass sync
518
	if err := tester.sync("valid"); err != nil {
519 520
		t.Fatalf("failed to synchronise blocks: %v", err)
	}
521
}
522

523 524 525 526 527 528
// Tests that if a malicious peer makes up a random hash chain, and tries to push
// indefinitely, one hash at a time, it actually gets caught with it. The reason
// this is separate from the classical made up chain attack is that sending hashes
// one by one prevents reliable block/parent verification.
func TestMadeupHashChainDrippingAttack(t *testing.T) {
	// Create a random chain of hashes to drip
529
	hashes := createHashes(16*blockCacheLimit, knownHash)
530
	tester := newTester()
531 532 533

	// Try and sync with the attacker, one hash at a time
	tester.maxHashFetch = 1
534
	tester.newPeer("attack", hashes, nil)
535
	if err := tester.sync("attack"); err != errStallingPeer {
536
		t.Fatalf("synchronisation error mismatch: have %v, want %v", err, errStallingPeer)
537 538 539
	}
}

540 541 542
// Tests that if a malicious peer makes up a random block chain, and tried to
// push indefinitely, it actually gets caught with it.
func TestMadeupBlockChainAttack(t *testing.T) {
543
	defaultBlockTTL := blockSoftTTL
544 545
	defaultCrossCheckCycle := crossCheckCycle

546
	blockSoftTTL = 100 * time.Millisecond
547 548 549
	crossCheckCycle = 25 * time.Millisecond

	// Create a long chain of blocks and simulate an invalid chain by dropping every second
550
	hashes := createHashes(16*blockCacheLimit, knownHash)
551 552 553 554 555 556 557
	blocks := createBlocksFromHashes(hashes)

	gapped := make([]common.Hash, len(hashes)/2)
	for i := 0; i < len(gapped); i++ {
		gapped[i] = hashes[2*i]
	}
	// Try and sync with the malicious node and check that it fails
558 559
	tester := newTester()
	tester.newPeer("attack", gapped, blocks)
560
	if err := tester.sync("attack"); err != errCrossCheckFailed {
561
		t.Fatalf("synchronisation error mismatch: have %v, want %v", err, errCrossCheckFailed)
562 563
	}
	// Ensure that a valid chain can still pass sync
564
	blockSoftTTL = defaultBlockTTL
565 566
	crossCheckCycle = defaultCrossCheckCycle

567
	tester.newPeer("valid", hashes, blocks)
568
	if err := tester.sync("valid"); err != nil {
569 570 571
		t.Fatalf("failed to synchronise blocks: %v", err)
	}
}
572 573 574 575 576

// Advanced form of the above forged blockchain attack, where not only does the
// attacker make up a valid hashes for random blocks, but also forges the block
// parents to point to existing hashes.
func TestMadeupParentBlockChainAttack(t *testing.T) {
577 578
	tester := newTester()

579
	defaultBlockTTL := blockSoftTTL
580 581
	defaultCrossCheckCycle := crossCheckCycle

582
	blockSoftTTL = 100 * time.Millisecond
583 584 585
	crossCheckCycle = 25 * time.Millisecond

	// Create a long chain of blocks and simulate an invalid chain by dropping every second
586
	hashes := createHashes(16*blockCacheLimit, knownHash)
587
	blocks := createBlocksFromHashes(hashes)
588 589 590 591
	tester.newPeer("valid", hashes, blocks)

	for _, block := range blocks {
		block.ParentHeaderHash = knownHash // Simulate pointing to already known hash
592
	}
593 594
	tester.newPeer("attack", hashes, blocks)

595
	// Try and sync with the malicious node and check that it fails
596
	if err := tester.sync("attack"); err != errCrossCheckFailed {
597
		t.Fatalf("synchronisation error mismatch: have %v, want %v", err, errCrossCheckFailed)
598 599
	}
	// Ensure that a valid chain can still pass sync
600
	blockSoftTTL = defaultBlockTTL
601 602
	crossCheckCycle = defaultCrossCheckCycle

603
	if err := tester.sync("valid"); err != nil {
604 605 606
		t.Fatalf("failed to synchronise blocks: %v", err)
	}
}
607 608 609 610 611 612

// Tests that if one/multiple malicious peers try to feed a banned blockchain to
// the downloader, it will not keep refetching the same chain indefinitely, but
// gradually block pieces of it, until it's head is also blocked.
func TestBannedChainStarvationAttack(t *testing.T) {
	// Create the tester and ban the selected hash
613
	tester := newTester()
614 615
	tester.downloader.banned.Add(bannedHash)

616 617 618 619 620 621 622 623 624 625
	// Construct a valid chain, for it and ban the fork
	hashes := createHashes(8*blockCacheLimit, knownHash)
	blocks := createBlocksFromHashes(hashes)
	tester.newPeer("valid", hashes, blocks)

	fork := len(hashes)/2 - 23
	hashes = append(createHashes(4*blockCacheLimit, bannedHash), hashes[fork:]...)
	blocks = createBlocksFromHashes(hashes)
	tester.newPeer("attack", hashes, blocks)

626 627 628 629
	// Iteratively try to sync, and verify that the banned hash list grows until
	// the head of the invalid chain is blocked too.
	for banned := tester.downloader.banned.Size(); ; {
		// Try to sync with the attacker, check hash chain failure
630
		if err := tester.sync("attack"); err != errInvalidChain {
631 632 633
			if tester.downloader.banned.Has(hashes[0]) && err == errBannedHead {
				break
			}
634
			t.Fatalf("synchronisation error mismatch: have %v, want %v", err, errInvalidChain)
635 636 637 638 639 640 641 642
		}
		// Check that the ban list grew with at least 1 new item, or all banned
		bans := tester.downloader.banned.Size()
		if bans < banned+1 {
			t.Fatalf("ban count mismatch: have %v, want %v+", bans, banned+1)
		}
		banned = bans
	}
643
	// Check that after banning an entire chain, bad peers get dropped
644
	if err := tester.newPeer("new attacker", hashes, blocks); err != errBannedHead {
645 646
		t.Fatalf("peer registration mismatch: have %v, want %v", err, errBannedHead)
	}
647
	if peer := tester.downloader.peers.Peer("new attacker"); peer != nil {
648 649
		t.Fatalf("banned attacker registered: %v", peer)
	}
650
	// Ensure that a valid chain can still pass sync
651
	if err := tester.sync("valid"); err != nil {
652 653
		t.Fatalf("failed to synchronise blocks: %v", err)
	}
654
}
655 656 657 658 659

// Tests that if a peer sends excessively many/large invalid chains that are
// gradually banned, it will have an upper limit on the consumed memory and also
// the origin bad hashes will not be evacuated.
func TestBannedChainMemoryExhaustionAttack(t *testing.T) {
660 661 662 663
	// Create the tester and ban the selected hash
	tester := newTester()
	tester.downloader.banned.Add(bannedHash)

664
	// Reduce the test size a bit
665 666 667
	defaultMaxBlockFetch := MaxBlockFetch
	defaultMaxBannedHashes := maxBannedHashes

668 669 670 671
	MaxBlockFetch = 4
	maxBannedHashes = 256

	// Construct a banned chain with more chunks than the ban limit
672
	hashes := createHashes(8*blockCacheLimit, knownHash)
673
	blocks := createBlocksFromHashes(hashes)
674
	tester.newPeer("valid", hashes, blocks)
675

676 677 678 679
	fork := len(hashes)/2 - 23
	hashes = append(createHashes(maxBannedHashes*MaxBlockFetch, bannedHash), hashes[fork:]...)
	blocks = createBlocksFromHashes(hashes)
	tester.newPeer("attack", hashes, blocks)
680 681 682 683 684

	// Iteratively try to sync, and verify that the banned hash list grows until
	// the head of the invalid chain is blocked too.
	for {
		// Try to sync with the attacker, check hash chain failure
685
		if err := tester.sync("attack"); err != errInvalidChain {
686
			t.Fatalf("synchronisation error mismatch: have %v, want %v", err, errInvalidChain)
687 688 689 690 691 692 693 694 695 696 697 698 699 700 701
		}
		// Short circuit if the entire chain was banned
		if tester.downloader.banned.Has(hashes[0]) {
			break
		}
		// Otherwise ensure we never exceed the memory allowance and the hard coded bans are untouched
		if bans := tester.downloader.banned.Size(); bans > maxBannedHashes {
			t.Fatalf("ban cap exceeded: have %v, want max %v", bans, maxBannedHashes)
		}
		for hash, _ := range core.BadHashes {
			if !tester.downloader.banned.Has(hash) {
				t.Fatalf("hard coded ban evacuated: %x", hash)
			}
		}
	}
702 703 704 705
	// Ensure that a valid chain can still pass sync
	MaxBlockFetch = defaultMaxBlockFetch
	maxBannedHashes = defaultMaxBannedHashes

706
	if err := tester.sync("valid"); err != nil {
707 708
		t.Fatalf("failed to synchronise blocks: %v", err)
	}
709
}
710 711

// Tests that misbehaving peers are disconnected, whilst behaving ones are not.
712 713
func TestHashAttackerDropping(t *testing.T) {
	// Define the disconnection requirement for individual hash fetch errors
714 715 716 717
	tests := []struct {
		result error
		drop   bool
	}{
718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733
		{nil, false},                  // Sync succeeded, all is well
		{errBusy, false},              // Sync is already in progress, no problem
		{errUnknownPeer, false},       // Peer is unknown, was already dropped, don't double drop
		{errBadPeer, true},            // Peer was deemed bad for some reason, drop it
		{errStallingPeer, true},       // Peer was detected to be stalling, drop it
		{errBannedHead, true},         // Peer's head hash is a known bad hash, drop it
		{errNoPeers, false},           // No peers to download from, soft race, no issue
		{errPendingQueue, false},      // There are blocks still cached, wait to exhaust, no issue
		{errTimeout, true},            // No hashes received in due time, drop the peer
		{errEmptyHashSet, true},       // No hashes were returned as a response, drop as it's a dead end
		{errPeersUnavailable, true},   // Nobody had the advertised blocks, drop the advertiser
		{errInvalidChain, true},       // Hash chain was detected as invalid, definitely drop
		{errCrossCheckFailed, true},   // Hash-origin failed to pass a block cross check, drop
		{errCancelHashFetch, false},   // Synchronisation was canceled, origin may be innocent, don't drop
		{errCancelBlockFetch, false},  // Synchronisation was canceled, origin may be innocent, don't drop
		{errCancelChainImport, false}, // Synchronisation was canceled, origin may be innocent, don't drop
734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754
	}
	// Run the tests and check disconnection status
	tester := newTester()
	for i, tt := range tests {
		// Register a new peer and ensure it's presence
		id := fmt.Sprintf("test %d", i)
		if err := tester.newPeer(id, []common.Hash{knownHash}, nil); err != nil {
			t.Fatalf("test %d: failed to register new peer: %v", i, err)
		}
		if _, ok := tester.peerHashes[id]; !ok {
			t.Fatalf("test %d: registered peer not found", i)
		}
		// Simulate a synchronisation and check the required result
		tester.downloader.synchroniseMock = func(string, common.Hash) error { return tt.result }

		tester.downloader.Synchronise(id, knownHash)
		if _, ok := tester.peerHashes[id]; !ok != tt.drop {
			t.Errorf("test %d: peer drop mismatch for %v: have %v, want %v", i, tt.result, !ok, tt.drop)
		}
	}
}
755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789

// Tests that feeding bad blocks will result in a peer drop.
func TestBlockAttackerDropping(t *testing.T) {
	// Define the disconnection requirement for individual block import errors
	tests := []struct {
		failure bool
		drop    bool
	}{{true, true}, {false, false}}

	// Run the tests and check disconnection status
	tester := newTester()
	for i, tt := range tests {
		// Register a new peer and ensure it's presence
		id := fmt.Sprintf("test %d", i)
		if err := tester.newPeer(id, []common.Hash{common.Hash{}}, nil); err != nil {
			t.Fatalf("test %d: failed to register new peer: %v", i, err)
		}
		if _, ok := tester.peerHashes[id]; !ok {
			t.Fatalf("test %d: registered peer not found", i)
		}
		// Assemble a good or bad block, depending of the test
		raw := createBlock(1, knownHash, common.Hash{})
		if tt.failure {
			raw = createBlock(1, unknownHash, common.Hash{})
		}
		block := &Block{OriginPeer: id, RawBlock: raw}

		// Simulate block processing and check the result
		tester.downloader.queue.blockCache[0] = block
		tester.downloader.process()
		if _, ok := tester.peerHashes[id]; !ok != tt.drop {
			t.Errorf("test %d: peer drop mismatch for %v: have %v, want %v", i, tt.failure, !ok, tt.drop)
		}
	}
}