downloader_test.go 66.2 KB
Newer Older
F
Felix Lange 已提交
1
// Copyright 2015 The go-ethereum Authors
2
// This file is part of the go-ethereum library.
F
Felix Lange 已提交
3
//
4
// The go-ethereum library is free software: you can redistribute it and/or modify
F
Felix Lange 已提交
5 6 7 8
// it under the terms of the GNU Lesser General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
9
// The go-ethereum library is distributed in the hope that it will be useful,
F
Felix Lange 已提交
10
// but WITHOUT ANY WARRANTY; without even the implied warranty of
11
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
F
Felix Lange 已提交
12 13 14
// GNU Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public License
15
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
F
Felix Lange 已提交
16

17 18 19
package downloader

import (
20
	"errors"
21
	"fmt"
22
	"math/big"
23
	"sync"
24
	"sync/atomic"
25 26 27 28
	"testing"
	"time"

	"github.com/ethereum/go-ethereum/common"
29
	"github.com/ethereum/go-ethereum/core"
30
	"github.com/ethereum/go-ethereum/core/state"
31
	"github.com/ethereum/go-ethereum/core/types"
32
	"github.com/ethereum/go-ethereum/crypto"
33
	"github.com/ethereum/go-ethereum/ethdb"
O
obscuren 已提交
34
	"github.com/ethereum/go-ethereum/event"
35
	"github.com/ethereum/go-ethereum/params"
36
	"github.com/ethereum/go-ethereum/trie"
37 38
)

39
var (
40 41 42 43
	testdb, _   = ethdb.NewMemDatabase()
	testKey, _  = crypto.HexToECDSA("b71c71a67e1177ad4e901695e1b4b9ee17ae16c6668d313eac2f96dbcda3f291")
	testAddress = crypto.PubkeyToAddress(testKey.PublicKey)
	genesis     = core.GenesisBlockForTesting(testdb, testAddress, big.NewInt(1000000000))
44
)
45

46 47 48 49
// makeChain creates a chain of n blocks starting at and including parent.
// the returned hash chain is ordered head->parent. In addition, every 3rd block
// contains a transaction and every 5th an uncle to allow testing correct block
// reassembly.
50
func makeChain(n int, seed byte, parent *types.Block, parentReceipts types.Receipts) ([]common.Hash, map[common.Hash]*types.Header, map[common.Hash]*types.Block, map[common.Hash]types.Receipts) {
51
	// Generate the block chain
52
	blocks, receipts := core.GenerateChain(parent, testdb, n, func(i int, block *core.BlockGen) {
53 54 55 56 57 58 59 60 61 62 63
		block.SetCoinbase(common.Address{seed})

		// If the block number is multiple of 3, send a bonus transaction to the miner
		if parent == genesis && i%3 == 0 {
			tx, err := types.NewTransaction(block.TxNonce(testAddress), common.Address{seed}, big.NewInt(1000), params.TxGas, nil, nil).SignECDSA(testKey)
			if err != nil {
				panic(err)
			}
			block.AddTx(tx)
		}
		// If the block number is a multiple of 5, add a bonus uncle to the block
64 65 66 67 68
		if i > 0 && i%5 == 0 {
			block.AddUncle(&types.Header{
				ParentHash: block.PrevBlock(i - 1).Hash(),
				Number:     big.NewInt(block.Number().Int64() - 1),
			})
69
		}
70
	})
71
	// Convert the block-chain into a hash-chain and header/block maps
72 73
	hashes := make([]common.Hash, n+1)
	hashes[len(hashes)-1] = parent.Hash()
74 75 76 77

	headerm := make(map[common.Hash]*types.Header, n+1)
	headerm[parent.Hash()] = parent.Header()

78 79
	blockm := make(map[common.Hash]*types.Block, n+1)
	blockm[parent.Hash()] = parent
80

81 82 83
	receiptm := make(map[common.Hash]types.Receipts, n+1)
	receiptm[parent.Hash()] = parentReceipts

84 85
	for i, b := range blocks {
		hashes[len(hashes)-i-2] = b.Hash()
86
		headerm[b.Hash()] = b.Header()
87
		blockm[b.Hash()] = b
88
		receiptm[b.Hash()] = receipts[i]
89
	}
90
	return hashes, headerm, blockm, receiptm
91 92 93 94
}

// makeChainFork creates two chains of length n, such that h1[:f] and
// h2[:f] are different but have a common suffix of length n-f.
95
func makeChainFork(n, f int, parent *types.Block, parentReceipts types.Receipts) ([]common.Hash, []common.Hash, map[common.Hash]*types.Header, map[common.Hash]*types.Header, map[common.Hash]*types.Block, map[common.Hash]*types.Block, map[common.Hash]types.Receipts, map[common.Hash]types.Receipts) {
96
	// Create the common suffix
97
	hashes, headers, blocks, receipts := makeChain(n-f, 0, parent, parentReceipts)
98 99

	// Create the forks
100
	hashes1, headers1, blocks1, receipts1 := makeChain(f, 1, blocks[hashes[0]], receipts[hashes[0]])
101 102
	hashes1 = append(hashes1, hashes[1:]...)

103
	hashes2, headers2, blocks2, receipts2 := makeChain(f, 2, blocks[hashes[0]], receipts[hashes[0]])
104 105 106 107 108 109 110 111 112 113
	hashes2 = append(hashes2, hashes[1:]...)

	for hash, header := range headers {
		headers1[hash] = header
		headers2[hash] = header
	}
	for hash, block := range blocks {
		blocks1[hash] = block
		blocks2[hash] = block
	}
114 115 116 117 118
	for hash, receipt := range receipts {
		receipts1[hash] = receipt
		receipts2[hash] = receipt
	}
	return hashes1, hashes2, headers1, headers2, blocks1, blocks2, receipts1, receipts2
119 120
}

121
// downloadTester is a test simulator for mocking out local block chain.
122
type downloadTester struct {
123
	stateDb    ethdb.Database
124 125
	downloader *Downloader

126 127 128 129 130 131 132 133 134 135 136
	ownHashes   []common.Hash                  // Hash chain belonging to the tester
	ownHeaders  map[common.Hash]*types.Header  // Headers belonging to the tester
	ownBlocks   map[common.Hash]*types.Block   // Blocks belonging to the tester
	ownReceipts map[common.Hash]types.Receipts // Receipts belonging to the tester
	ownChainTd  map[common.Hash]*big.Int       // Total difficulties of the blocks in the local chain

	peerHashes   map[string][]common.Hash                  // Hash chain belonging to different test peers
	peerHeaders  map[string]map[common.Hash]*types.Header  // Headers belonging to different test peers
	peerBlocks   map[string]map[common.Hash]*types.Block   // Blocks belonging to different test peers
	peerReceipts map[string]map[common.Hash]types.Receipts // Receipts belonging to different test peers
	peerChainTds map[string]map[common.Hash]*big.Int       // Total difficulties of the blocks in the peer chains
137 138

	lock sync.RWMutex
139 140
}

141
// newTester creates a new downloader test mocker.
142
func newTester() *downloadTester {
143
	tester := &downloadTester{
144
		ownHashes:    []common.Hash{genesis.Hash()},
145
		ownHeaders:   map[common.Hash]*types.Header{genesis.Hash(): genesis.Header()},
146
		ownBlocks:    map[common.Hash]*types.Block{genesis.Hash(): genesis},
147
		ownReceipts:  map[common.Hash]types.Receipts{genesis.Hash(): nil},
148 149
		ownChainTd:   map[common.Hash]*big.Int{genesis.Hash(): genesis.Difficulty()},
		peerHashes:   make(map[string][]common.Hash),
150
		peerHeaders:  make(map[string]map[common.Hash]*types.Header),
151
		peerBlocks:   make(map[string]map[common.Hash]*types.Block),
152
		peerReceipts: make(map[string]map[common.Hash]types.Receipts),
153
		peerChainTds: make(map[string]map[common.Hash]*big.Int),
154
	}
155
	tester.stateDb, _ = ethdb.NewMemDatabase()
156
	tester.downloader = New(tester.stateDb, new(event.TypeMux), tester.hasHeader, tester.hasBlock, tester.getHeader,
157
		tester.getBlock, tester.headHeader, tester.headBlock, tester.headFastBlock, tester.commitHeadBlock, tester.getTd,
158
		tester.insertHeaders, tester.insertBlocks, tester.insertReceipts, tester.rollback, tester.dropPeer)
159 160 161 162

	return tester
}

163
// sync starts synchronizing with a remote peer, blocking until it completes.
164
func (dl *downloadTester) sync(id string, td *big.Int, mode SyncMode) error {
165
	dl.lock.RLock()
166 167 168 169
	hash := dl.peerHashes[id][0]
	// If no particular TD was requested, load from the peer's blockchain
	if td == nil {
		td = big.NewInt(1)
170 171
		if diff, ok := dl.peerChainTds[id][hash]; ok {
			td = diff
172 173
		}
	}
174
	dl.lock.RUnlock()
175
	return dl.downloader.synchronise(id, hash, td, mode)
O
obscuren 已提交
176 177
}

178 179 180 181 182 183
// hasHeader checks if a header is present in the testers canonical chain.
func (dl *downloadTester) hasHeader(hash common.Hash) bool {
	return dl.getHeader(hash) != nil
}

// hasBlock checks if a block is present in the testers canonical chain.
184
func (dl *downloadTester) hasBlock(hash common.Hash) bool {
185 186 187 188 189
	return dl.getBlock(hash) != nil
}

// getHeader retrieves a header from the testers canonical chain.
func (dl *downloadTester) getHeader(hash common.Hash) *types.Header {
190 191 192
	dl.lock.RLock()
	defer dl.lock.RUnlock()

193
	return dl.ownHeaders[hash]
194 195
}

196
// getBlock retrieves a block from the testers canonical chain.
197
func (dl *downloadTester) getBlock(hash common.Hash) *types.Block {
198 199 200
	dl.lock.RLock()
	defer dl.lock.RUnlock()

201 202 203
	return dl.ownBlocks[hash]
}

204 205 206 207 208
// headHeader retrieves the current head header from the canonical chain.
func (dl *downloadTester) headHeader() *types.Header {
	dl.lock.RLock()
	defer dl.lock.RUnlock()

209
	for i := len(dl.ownHashes) - 1; i >= 0; i-- {
210
		if header := dl.ownHeaders[dl.ownHashes[i]]; header != nil {
211 212 213
			return header
		}
	}
214
	return genesis.Header()
215 216
}

217 218
// headBlock retrieves the current head block from the canonical chain.
func (dl *downloadTester) headBlock() *types.Block {
219 220 221
	dl.lock.RLock()
	defer dl.lock.RUnlock()

222
	for i := len(dl.ownHashes) - 1; i >= 0; i-- {
223
		if block := dl.ownBlocks[dl.ownHashes[i]]; block != nil {
224 225 226
			if _, err := dl.stateDb.Get(block.Root().Bytes()); err == nil {
				return block
			}
227 228
		}
	}
229
	return genesis
230 231
}

232 233 234 235 236 237
// headFastBlock retrieves the current head fast-sync block from the canonical chain.
func (dl *downloadTester) headFastBlock() *types.Block {
	dl.lock.RLock()
	defer dl.lock.RUnlock()

	for i := len(dl.ownHashes) - 1; i >= 0; i-- {
238
		if block := dl.ownBlocks[dl.ownHashes[i]]; block != nil {
239
			return block
240 241
		}
	}
242 243 244 245 246 247 248 249 250 251 252
	return genesis
}

// commitHeadBlock manually sets the head block to a given hash.
func (dl *downloadTester) commitHeadBlock(hash common.Hash) error {
	// For now only check that the state trie is correct
	if block := dl.getBlock(hash); block != nil {
		_, err := trie.NewSecure(block.Root(), dl.stateDb)
		return err
	}
	return fmt.Errorf("non existent block: %x", hash[:4])
253 254
}

255 256
// getTd retrieves the block's total difficulty from the canonical chain.
func (dl *downloadTester) getTd(hash common.Hash) *big.Int {
257 258 259
	dl.lock.RLock()
	defer dl.lock.RUnlock()

260 261 262
	return dl.ownChainTd[hash]
}

263
// insertHeaders injects a new batch of headers into the simulated chain.
264
func (dl *downloadTester) insertHeaders(headers []*types.Header, checkFreq int) (int, error) {
265 266 267
	dl.lock.Lock()
	defer dl.lock.Unlock()

268 269 270 271 272 273 274 275 276 277
	// Do a quick check, as the blockchain.InsertHeaderChain doesn't insert anthing in case of errors
	if _, ok := dl.ownHeaders[headers[0].ParentHash]; !ok {
		return 0, errors.New("unknown parent")
	}
	for i := 1; i < len(headers); i++ {
		if headers[i].ParentHash != headers[i-1].Hash() {
			return i, errors.New("unknown parent")
		}
	}
	// Do a full insert if pre-checks passed
278
	for i, header := range headers {
279 280 281
		if _, ok := dl.ownHeaders[header.Hash()]; ok {
			continue
		}
282 283 284 285 286
		if _, ok := dl.ownHeaders[header.ParentHash]; !ok {
			return i, errors.New("unknown parent")
		}
		dl.ownHashes = append(dl.ownHashes, header.Hash())
		dl.ownHeaders[header.Hash()] = header
287
		dl.ownChainTd[header.Hash()] = new(big.Int).Add(dl.ownChainTd[header.ParentHash], header.Difficulty)
288 289 290 291 292 293
	}
	return len(headers), nil
}

// insertBlocks injects a new batch of blocks into the simulated chain.
func (dl *downloadTester) insertBlocks(blocks types.Blocks) (int, error) {
294 295 296
	dl.lock.Lock()
	defer dl.lock.Unlock()

297 298 299 300
	for i, block := range blocks {
		if _, ok := dl.ownBlocks[block.ParentHash()]; !ok {
			return i, errors.New("unknown parent")
		}
301 302 303 304
		if _, ok := dl.ownHeaders[block.Hash()]; !ok {
			dl.ownHashes = append(dl.ownHashes, block.Hash())
			dl.ownHeaders[block.Hash()] = block.Header()
		}
305
		dl.ownBlocks[block.Hash()] = block
306 307
		dl.stateDb.Put(block.Root().Bytes(), []byte{0x00})
		dl.ownChainTd[block.Hash()] = new(big.Int).Add(dl.ownChainTd[block.ParentHash()], block.Difficulty())
308 309 310 311
	}
	return len(blocks), nil
}

312 313
// insertReceipts injects a new batch of blocks into the simulated chain.
func (dl *downloadTester) insertReceipts(blocks types.Blocks, receipts []types.Receipts) (int, error) {
314 315 316 317
	dl.lock.Lock()
	defer dl.lock.Unlock()

	for i := 0; i < len(blocks) && i < len(receipts); i++ {
318 319 320
		if _, ok := dl.ownHeaders[blocks[i].Hash()]; !ok {
			return i, errors.New("unknown owner")
		}
321 322 323 324
		if _, ok := dl.ownBlocks[blocks[i].ParentHash()]; !ok {
			return i, errors.New("unknown parent")
		}
		dl.ownBlocks[blocks[i].Hash()] = blocks[i]
325
		dl.ownReceipts[blocks[i].Hash()] = receipts[i]
326 327 328 329
	}
	return len(blocks), nil
}

330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345
// rollback removes some recently added elements from the chain.
func (dl *downloadTester) rollback(hashes []common.Hash) {
	dl.lock.Lock()
	defer dl.lock.Unlock()

	for i := len(hashes) - 1; i >= 0; i-- {
		if dl.ownHashes[len(dl.ownHashes)-1] == hashes[i] {
			dl.ownHashes = dl.ownHashes[:len(dl.ownHashes)-1]
		}
		delete(dl.ownChainTd, hashes[i])
		delete(dl.ownHeaders, hashes[i])
		delete(dl.ownReceipts, hashes[i])
		delete(dl.ownBlocks, hashes[i])
	}
}

346
// newPeer registers a new block download source into the downloader.
347 348
func (dl *downloadTester) newPeer(id string, version int, hashes []common.Hash, headers map[common.Hash]*types.Header, blocks map[common.Hash]*types.Block, receipts map[common.Hash]types.Receipts) error {
	return dl.newSlowPeer(id, version, hashes, headers, blocks, receipts, 0)
349 350 351 352 353
}

// newSlowPeer registers a new block download source into the downloader, with a
// specific delay time on processing the network packets sent to it, simulating
// potentially slow network IO.
354
func (dl *downloadTester) newSlowPeer(id string, version int, hashes []common.Hash, headers map[common.Hash]*types.Header, blocks map[common.Hash]*types.Block, receipts map[common.Hash]types.Receipts, delay time.Duration) error {
355 356 357
	dl.lock.Lock()
	defer dl.lock.Unlock()

358 359 360
	var err error
	switch version {
	case 61:
361
		err = dl.downloader.RegisterPeer(id, version, hashes[0], dl.peerGetRelHashesFn(id, delay), dl.peerGetAbsHashesFn(id, delay), dl.peerGetBlocksFn(id, delay), nil, nil, nil, nil, nil)
362
	case 62:
363
		err = dl.downloader.RegisterPeer(id, version, hashes[0], nil, nil, nil, dl.peerGetRelHeadersFn(id, delay), dl.peerGetAbsHeadersFn(id, delay), dl.peerGetBodiesFn(id, delay), nil, nil)
364
	case 63:
365
		err = dl.downloader.RegisterPeer(id, version, hashes[0], nil, nil, nil, dl.peerGetRelHeadersFn(id, delay), dl.peerGetAbsHeadersFn(id, delay), dl.peerGetBodiesFn(id, delay), dl.peerGetReceiptsFn(id, delay), dl.peerGetNodeDataFn(id, delay))
366
	case 64:
367
		err = dl.downloader.RegisterPeer(id, version, hashes[0], nil, nil, nil, dl.peerGetRelHeadersFn(id, delay), dl.peerGetAbsHeadersFn(id, delay), dl.peerGetBodiesFn(id, delay), dl.peerGetReceiptsFn(id, delay), dl.peerGetNodeDataFn(id, delay))
368
	}
369
	if err == nil {
370
		// Assign the owned hashes, headers and blocks to the peer (deep copy)
371 372
		dl.peerHashes[id] = make([]common.Hash, len(hashes))
		copy(dl.peerHashes[id], hashes)
373

374
		dl.peerHeaders[id] = make(map[common.Hash]*types.Header)
375
		dl.peerBlocks[id] = make(map[common.Hash]*types.Block)
376
		dl.peerReceipts[id] = make(map[common.Hash]types.Receipts)
377
		dl.peerChainTds[id] = make(map[common.Hash]*big.Int)
378

379 380 381 382 383 384 385 386 387 388 389 390 391
		genesis := hashes[len(hashes)-1]
		if header := headers[genesis]; header != nil {
			dl.peerHeaders[id][genesis] = header
			dl.peerChainTds[id][genesis] = header.Difficulty
		}
		if block := blocks[genesis]; block != nil {
			dl.peerBlocks[id][genesis] = block
			dl.peerChainTds[id][genesis] = block.Difficulty()
		}

		for i := len(hashes) - 2; i >= 0; i-- {
			hash := hashes[i]

392 393 394 395 396 397
			if header, ok := headers[hash]; ok {
				dl.peerHeaders[id][hash] = header
				if _, ok := dl.peerHeaders[id][header.ParentHash]; ok {
					dl.peerChainTds[id][hash] = new(big.Int).Add(header.Difficulty, dl.peerChainTds[id][header.ParentHash])
				}
			}
398 399
			if block, ok := blocks[hash]; ok {
				dl.peerBlocks[id][hash] = block
400 401
				if _, ok := dl.peerBlocks[id][block.ParentHash()]; ok {
					dl.peerChainTds[id][hash] = new(big.Int).Add(block.Difficulty(), dl.peerChainTds[id][block.ParentHash()])
402 403
				}
			}
404 405 406
			if receipt, ok := receipts[hash]; ok {
				dl.peerReceipts[id][hash] = receipt
			}
407
		}
408 409
	}
	return err
410 411
}

412 413
// dropPeer simulates a hard peer removal from the connection pool.
func (dl *downloadTester) dropPeer(id string) {
414 415 416
	dl.lock.Lock()
	defer dl.lock.Unlock()

417
	delete(dl.peerHashes, id)
418
	delete(dl.peerHeaders, id)
419
	delete(dl.peerBlocks, id)
420
	delete(dl.peerChainTds, id)
421 422 423 424

	dl.downloader.UnregisterPeer(id)
}

425
// peerGetRelHashesFn constructs a GetHashes function associated with a specific
426 427
// peer in the download tester. The returned function can be used to retrieve
// batches of hashes from the particularly requested peer.
428
func (dl *downloadTester) peerGetRelHashesFn(id string, delay time.Duration) func(head common.Hash) error {
429
	return func(head common.Hash) error {
430 431
		time.Sleep(delay)

432 433 434
		dl.lock.RLock()
		defer dl.lock.RUnlock()

435 436
		// Gather the next batch of hashes
		hashes := dl.peerHashes[id]
437
		result := make([]common.Hash, 0, MaxHashFetch)
438 439
		for i, hash := range hashes {
			if hash == head {
440
				i++
441 442 443 444 445
				for len(result) < cap(result) && i < len(hashes) {
					result = append(result, hashes[i])
					i++
				}
				break
446 447
			}
		}
448 449 450
		// Delay delivery a bit to allow attacks to unfold
		go func() {
			time.Sleep(time.Millisecond)
451
			dl.downloader.DeliverHashes(id, result)
452 453
		}()
		return nil
454
	}
455 456
}

457 458 459
// peerGetAbsHashesFn constructs a GetHashesFromNumber function associated with
// a particular peer in the download tester. The returned function can be used to
// retrieve batches of hashes from the particularly requested peer.
460
func (dl *downloadTester) peerGetAbsHashesFn(id string, delay time.Duration) func(uint64, int) error {
461 462 463
	return func(head uint64, count int) error {
		time.Sleep(delay)

464 465 466
		dl.lock.RLock()
		defer dl.lock.RUnlock()

467 468
		// Gather the next batch of hashes
		hashes := dl.peerHashes[id]
469 470
		result := make([]common.Hash, 0, count)
		for i := 0; i < count && len(hashes)-int(head)-1-i >= 0; i++ {
471 472 473 474 475
			result = append(result, hashes[len(hashes)-int(head)-1-i])
		}
		// Delay delivery a bit to allow attacks to unfold
		go func() {
			time.Sleep(time.Millisecond)
476
			dl.downloader.DeliverHashes(id, result)
477 478 479 480 481
		}()
		return nil
	}
}

482 483 484
// peerGetBlocksFn constructs a getBlocks function associated with a particular
// peer in the download tester. The returned function can be used to retrieve
// batches of blocks from the particularly requested peer.
485
func (dl *downloadTester) peerGetBlocksFn(id string, delay time.Duration) func([]common.Hash) error {
486
	return func(hashes []common.Hash) error {
487
		time.Sleep(delay)
488 489 490 491

		dl.lock.RLock()
		defer dl.lock.RUnlock()

492 493
		blocks := dl.peerBlocks[id]
		result := make([]*types.Block, 0, len(hashes))
494
		for _, hash := range hashes {
495 496
			if block, ok := blocks[hash]; ok {
				result = append(result, block)
497
			}
498
		}
499
		go dl.downloader.DeliverBlocks(id, result)
500 501 502 503 504

		return nil
	}
}

505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525
// peerGetRelHeadersFn constructs a GetBlockHeaders function based on a hashed
// origin; associated with a particular peer in the download tester. The returned
// function can be used to retrieve batches of headers from the particular peer.
func (dl *downloadTester) peerGetRelHeadersFn(id string, delay time.Duration) func(common.Hash, int, int, bool) error {
	return func(origin common.Hash, amount int, skip int, reverse bool) error {
		// Find the canonical number of the hash
		dl.lock.RLock()
		number := uint64(0)
		for num, hash := range dl.peerHashes[id] {
			if hash == origin {
				number = uint64(len(dl.peerHashes[id]) - num - 1)
				break
			}
		}
		dl.lock.RUnlock()

		// Use the absolute header fetcher to satisfy the query
		return dl.peerGetAbsHeadersFn(id, delay)(number, amount, skip, reverse)
	}
}

526 527 528 529 530 531 532
// peerGetAbsHeadersFn constructs a GetBlockHeaders function based on a numbered
// origin; associated with a particular peer in the download tester. The returned
// function can be used to retrieve batches of headers from the particular peer.
func (dl *downloadTester) peerGetAbsHeadersFn(id string, delay time.Duration) func(uint64, int, int, bool) error {
	return func(origin uint64, amount int, skip int, reverse bool) error {
		time.Sleep(delay)

533 534 535
		dl.lock.RLock()
		defer dl.lock.RUnlock()

536
		// Gather the next batch of headers
537
		hashes := dl.peerHashes[id]
538
		headers := dl.peerHeaders[id]
539 540
		result := make([]*types.Header, 0, amount)
		for i := 0; i < amount && len(hashes)-int(origin)-1-i >= 0; i++ {
541 542
			if header, ok := headers[hashes[len(hashes)-int(origin)-1-i]]; ok {
				result = append(result, header)
543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559
			}
		}
		// Delay delivery a bit to allow attacks to unfold
		go func() {
			time.Sleep(time.Millisecond)
			dl.downloader.DeliverHeaders(id, result)
		}()
		return nil
	}
}

// peerGetBodiesFn constructs a getBlockBodies method associated with a particular
// peer in the download tester. The returned function can be used to retrieve
// batches of block bodies from the particularly requested peer.
func (dl *downloadTester) peerGetBodiesFn(id string, delay time.Duration) func([]common.Hash) error {
	return func(hashes []common.Hash) error {
		time.Sleep(delay)
560 561 562 563

		dl.lock.RLock()
		defer dl.lock.RUnlock()

564 565 566 567 568 569 570 571 572 573 574 575
		blocks := dl.peerBlocks[id]

		transactions := make([][]*types.Transaction, 0, len(hashes))
		uncles := make([][]*types.Header, 0, len(hashes))

		for _, hash := range hashes {
			if block, ok := blocks[hash]; ok {
				transactions = append(transactions, block.Transactions())
				uncles = append(uncles, block.Uncles())
			}
		}
		go dl.downloader.DeliverBodies(id, transactions, uncles)
576 577 578 579 580

		return nil
	}
}

581 582 583 584 585 586 587 588 589 590
// peerGetReceiptsFn constructs a getReceipts method associated with a particular
// peer in the download tester. The returned function can be used to retrieve
// batches of block receipts from the particularly requested peer.
func (dl *downloadTester) peerGetReceiptsFn(id string, delay time.Duration) func([]common.Hash) error {
	return func(hashes []common.Hash) error {
		time.Sleep(delay)

		dl.lock.RLock()
		defer dl.lock.RUnlock()

591
		receipts := dl.peerReceipts[id]
592

593
		results := make([][]*types.Receipt, 0, len(hashes))
594
		for _, hash := range hashes {
595 596
			if receipt, ok := receipts[hash]; ok {
				results = append(results, receipt)
597 598
			}
		}
599
		go dl.downloader.DeliverReceipts(id, results)
600 601 602 603 604

		return nil
	}
}

605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626
// peerGetNodeDataFn constructs a getNodeData method associated with a particular
// peer in the download tester. The returned function can be used to retrieve
// batches of node state data from the particularly requested peer.
func (dl *downloadTester) peerGetNodeDataFn(id string, delay time.Duration) func([]common.Hash) error {
	return func(hashes []common.Hash) error {
		time.Sleep(delay)

		dl.lock.RLock()
		defer dl.lock.RUnlock()

		results := make([][]byte, 0, len(hashes))
		for _, hash := range hashes {
			if data, err := testdb.Get(hash.Bytes()); err == nil {
				results = append(results, data)
			}
		}
		go dl.downloader.DeliverNodeData(id, results)

		return nil
	}
}

627 628 629
// assertOwnChain checks if the local chain contains the correct number of items
// of the various chain components.
func assertOwnChain(t *testing.T, tester *downloadTester, length int) {
630 631 632 633 634 635 636
	assertOwnForkedChain(t, tester, 1, []int{length})
}

// assertOwnForkedChain checks if the local forked chain contains the correct
// number of items of the various chain components.
func assertOwnForkedChain(t *testing.T, tester *downloadTester, common int, lengths []int) {
	// Initialize the counters for the first fork
637 638 639 640 641 642 643 644
	headers, blocks := lengths[0], lengths[0]

	minReceipts, maxReceipts := lengths[0]-fsMinFullBlocks-fsPivotInterval, lengths[0]-fsMinFullBlocks
	if minReceipts < 0 {
		minReceipts = 1
	}
	if maxReceipts < 0 {
		maxReceipts = 1
645 646 647 648 649
	}
	// Update the counters for each subsequent fork
	for _, length := range lengths[1:] {
		headers += length - common
		blocks += length - common
650 651 652

		minReceipts += length - common - fsMinFullBlocks - fsPivotInterval
		maxReceipts += length - common - fsMinFullBlocks
653
	}
654 655
	switch tester.downloader.mode {
	case FullSync:
656
		minReceipts, maxReceipts = 1, 1
657
	case LightSync:
658
		blocks, minReceipts, maxReceipts = 1, 1, 1
659 660 661 662 663 664 665
	}
	if hs := len(tester.ownHeaders); hs != headers {
		t.Fatalf("synchronised headers mismatch: have %v, want %v", hs, headers)
	}
	if bs := len(tester.ownBlocks); bs != blocks {
		t.Fatalf("synchronised blocks mismatch: have %v, want %v", bs, blocks)
	}
666 667
	if rs := len(tester.ownReceipts); rs < minReceipts || rs > maxReceipts {
		t.Fatalf("synchronised receipts mismatch: have %v, want between [%v, %v]", rs, minReceipts, maxReceipts)
668
	}
669 670
	// Verify the state trie too for fast syncs
	if tester.downloader.mode == FastSync {
671 672 673 674 675 676 677 678 679
		index := 0
		if pivot := int(tester.downloader.queue.fastSyncPivot); pivot < common {
			index = pivot
		} else {
			index = len(tester.ownHashes) - lengths[len(lengths)-1] + int(tester.downloader.queue.fastSyncPivot)
		}
		if index > 0 {
			if statedb, err := state.New(tester.ownHeaders[tester.ownHashes[index]].Root, tester.stateDb); statedb == nil || err != nil {
				t.Fatalf("state reconstruction failed: %v", err)
680 681 682
			}
		}
	}
683 684
}

685 686 687
// Tests that simple synchronization against a canonical chain works correctly.
// In this test common ancestor lookup should be short circuited and not require
// binary searching.
688 689 690 691 692 693 694 695 696
func TestCanonicalSynchronisation61(t *testing.T)      { testCanonicalSynchronisation(t, 61, FullSync) }
func TestCanonicalSynchronisation62(t *testing.T)      { testCanonicalSynchronisation(t, 62, FullSync) }
func TestCanonicalSynchronisation63Full(t *testing.T)  { testCanonicalSynchronisation(t, 63, FullSync) }
func TestCanonicalSynchronisation63Fast(t *testing.T)  { testCanonicalSynchronisation(t, 63, FastSync) }
func TestCanonicalSynchronisation64Full(t *testing.T)  { testCanonicalSynchronisation(t, 64, FullSync) }
func TestCanonicalSynchronisation64Fast(t *testing.T)  { testCanonicalSynchronisation(t, 64, FastSync) }
func TestCanonicalSynchronisation64Light(t *testing.T) { testCanonicalSynchronisation(t, 64, LightSync) }

func testCanonicalSynchronisation(t *testing.T, protocol int, mode SyncMode) {
697 698
	t.Parallel()

699 700
	// Create a small enough block chain to download
	targetBlocks := blockCacheLimit - 15
701
	hashes, headers, blocks, receipts := makeChain(targetBlocks, 0, genesis, nil)
702

703
	tester := newTester()
704
	tester.newPeer("peer", protocol, hashes, headers, blocks, receipts)
705

706
	// Synchronise with the peer and make sure all relevant data was retrieved
707
	if err := tester.sync("peer", nil, mode); err != nil {
708
		t.Fatalf("failed to synchronise blocks: %v", err)
709
	}
710
	assertOwnChain(t, tester, targetBlocks+1)
711 712
}

713 714
// Tests that if a large batch of blocks are being downloaded, it is throttled
// until the cached blocks are retrieved.
715 716 717 718 719 720 721 722
func TestThrottling61(t *testing.T)     { testThrottling(t, 61, FullSync) }
func TestThrottling62(t *testing.T)     { testThrottling(t, 62, FullSync) }
func TestThrottling63Full(t *testing.T) { testThrottling(t, 63, FullSync) }
func TestThrottling63Fast(t *testing.T) { testThrottling(t, 63, FastSync) }
func TestThrottling64Full(t *testing.T) { testThrottling(t, 64, FullSync) }
func TestThrottling64Fast(t *testing.T) { testThrottling(t, 64, FastSync) }

func testThrottling(t *testing.T, protocol int, mode SyncMode) {
723 724
	t.Parallel()

725 726
	// Create a long block chain to download and the tester
	targetBlocks := 8 * blockCacheLimit
727
	hashes, headers, blocks, receipts := makeChain(targetBlocks, 0, genesis, nil)
728

729
	tester := newTester()
730
	tester.newPeer("peer", protocol, hashes, headers, blocks, receipts)
731

732
	// Wrap the importer to allow stepping
733
	blocked, proceed := uint32(0), make(chan struct{})
734 735
	tester.downloader.chainInsertHook = func(results []*fetchResult) {
		atomic.StoreUint32(&blocked, uint32(len(results)))
736
		<-proceed
737
	}
738 739 740
	// Start a synchronisation concurrently
	errc := make(chan error)
	go func() {
741
		errc <- tester.sync("peer", nil, mode)
742 743
	}()
	// Iteratively take some blocks, always checking the retrieval count
744 745 746 747 748 749 750 751
	for {
		// Check the retrieval count synchronously (! reason for this ugly block)
		tester.lock.RLock()
		retrieved := len(tester.ownBlocks)
		tester.lock.RUnlock()
		if retrieved >= targetBlocks+1 {
			break
		}
752
		// Wait a bit for sync to throttle itself
753
		var cached, frozen int
754
		for start := time.Now(); time.Since(start) < time.Second; {
755
			time.Sleep(25 * time.Millisecond)
756

757 758
			tester.lock.Lock()
			tester.downloader.queue.lock.Lock()
759 760 761
			cached = len(tester.downloader.queue.blockDonePool)
			if mode == FastSync {
				if receipts := len(tester.downloader.queue.receiptDonePool); receipts < cached {
762 763 764
					if tester.downloader.queue.resultCache[receipts].Header.Number.Uint64() < tester.downloader.queue.fastSyncPivot {
						cached = receipts
					}
765 766
				}
			}
767 768
			frozen = int(atomic.LoadUint32(&blocked))
			retrieved = len(tester.ownBlocks)
769 770
			tester.downloader.queue.lock.Unlock()
			tester.lock.Unlock()
771

772
			if cached == blockCacheLimit || retrieved+cached+frozen == targetBlocks+1 {
773 774 775
				break
			}
		}
776 777
		// Make sure we filled up the cache, then exhaust it
		time.Sleep(25 * time.Millisecond) // give it a chance to screw up
778 779 780 781 782 783

		tester.lock.RLock()
		retrieved = len(tester.ownBlocks)
		tester.lock.RUnlock()
		if cached != blockCacheLimit && retrieved+cached+frozen != targetBlocks+1 {
			t.Fatalf("block count mismatch: have %v, want %v (owned %v, blocked %v, target %v)", cached, blockCacheLimit, retrieved, frozen, targetBlocks+1)
784
		}
785 786 787 788
		// Permit the blocked blocks to import
		if atomic.LoadUint32(&blocked) > 0 {
			atomic.StoreUint32(&blocked, uint32(0))
			proceed <- struct{}{}
789
		}
790 791
	}
	// Check that we haven't pulled more blocks than available
792
	assertOwnChain(t, tester, targetBlocks+1)
793 794
	if err := <-errc; err != nil {
		t.Fatalf("block synchronization failed: %v", err)
795 796
	}
}
797

798 799 800
// Tests that simple synchronization against a forked chain works correctly. In
// this test common ancestor lookup should *not* be short circuited, and a full
// binary search should be executed.
801 802 803 804 805 806 807 808 809
func TestForkedSynchronisation61(t *testing.T)      { testForkedSynchronisation(t, 61, FullSync) }
func TestForkedSynchronisation62(t *testing.T)      { testForkedSynchronisation(t, 62, FullSync) }
func TestForkedSynchronisation63Full(t *testing.T)  { testForkedSynchronisation(t, 63, FullSync) }
func TestForkedSynchronisation63Fast(t *testing.T)  { testForkedSynchronisation(t, 63, FastSync) }
func TestForkedSynchronisation64Full(t *testing.T)  { testForkedSynchronisation(t, 64, FullSync) }
func TestForkedSynchronisation64Fast(t *testing.T)  { testForkedSynchronisation(t, 64, FastSync) }
func TestForkedSynchronisation64Light(t *testing.T) { testForkedSynchronisation(t, 64, LightSync) }

func testForkedSynchronisation(t *testing.T, protocol int, mode SyncMode) {
810 811
	t.Parallel()

812 813
	// Create a long enough forked chain
	common, fork := MaxHashFetch, 2*MaxHashFetch
814
	hashesA, hashesB, headersA, headersB, blocksA, blocksB, receiptsA, receiptsB := makeChainFork(common+fork, fork, genesis, nil)
815

816
	tester := newTester()
817 818
	tester.newPeer("fork A", protocol, hashesA, headersA, blocksA, receiptsA)
	tester.newPeer("fork B", protocol, hashesB, headersB, blocksB, receiptsB)
819 820

	// Synchronise with the peer and make sure all blocks were retrieved
821
	if err := tester.sync("fork A", nil, mode); err != nil {
822 823
		t.Fatalf("failed to synchronise blocks: %v", err)
	}
824 825
	assertOwnChain(t, tester, common+fork+1)

826
	// Synchronise with the second peer and make sure that fork is pulled too
827
	if err := tester.sync("fork B", nil, mode); err != nil {
828 829
		t.Fatalf("failed to synchronise blocks: %v", err)
	}
830
	assertOwnForkedChain(t, tester, common+1, []int{common + fork + 1, common + fork + 1})
831 832 833
}

// Tests that an inactive downloader will not accept incoming hashes and blocks.
834
func TestInactiveDownloader61(t *testing.T) {
835
	t.Parallel()
836
	tester := newTester()
837 838

	// Check that neither hashes nor blocks are accepted
839
	if err := tester.downloader.DeliverHashes("bad peer", []common.Hash{}); err != errNoSyncActive {
840 841
		t.Errorf("error mismatch: have %v, want %v", err, errNoSyncActive)
	}
842
	if err := tester.downloader.DeliverBlocks("bad peer", []*types.Block{}); err != errNoSyncActive {
843 844 845 846
		t.Errorf("error mismatch: have %v, want %v", err, errNoSyncActive)
	}
}

847 848
// Tests that an inactive downloader will not accept incoming block headers and
// bodies.
849
func TestInactiveDownloader62(t *testing.T) {
850
	t.Parallel()
851
	tester := newTester()
852 853 854

	// Check that neither block headers nor bodies are accepted
	if err := tester.downloader.DeliverHeaders("bad peer", []*types.Header{}); err != errNoSyncActive {
855 856
		t.Errorf("error mismatch: have %v, want %v", err, errNoSyncActive)
	}
857
	if err := tester.downloader.DeliverBodies("bad peer", [][]*types.Transaction{}, [][]*types.Header{}); err != errNoSyncActive {
858 859 860 861
		t.Errorf("error mismatch: have %v, want %v", err, errNoSyncActive)
	}
}

862 863 864
// Tests that an inactive downloader will not accept incoming block headers,
// bodies and receipts.
func TestInactiveDownloader63(t *testing.T) {
865
	t.Parallel()
866
	tester := newTester()
867 868 869 870 871 872 873 874 875 876 877 878

	// Check that neither block headers nor bodies are accepted
	if err := tester.downloader.DeliverHeaders("bad peer", []*types.Header{}); err != errNoSyncActive {
		t.Errorf("error mismatch: have %v, want %v", err, errNoSyncActive)
	}
	if err := tester.downloader.DeliverBodies("bad peer", [][]*types.Transaction{}, [][]*types.Header{}); err != errNoSyncActive {
		t.Errorf("error mismatch: have %v, want %v", err, errNoSyncActive)
	}
	if err := tester.downloader.DeliverReceipts("bad peer", [][]*types.Receipt{}); err != errNoSyncActive {
		t.Errorf("error mismatch: have %v, want %v", err, errNoSyncActive)
	}
}
879

880 881 882 883 884 885 886 887 888 889
// Tests that a canceled download wipes all previously accumulated state.
func TestCancel61(t *testing.T)      { testCancel(t, 61, FullSync) }
func TestCancel62(t *testing.T)      { testCancel(t, 62, FullSync) }
func TestCancel63Full(t *testing.T)  { testCancel(t, 63, FullSync) }
func TestCancel63Fast(t *testing.T)  { testCancel(t, 63, FastSync) }
func TestCancel64Full(t *testing.T)  { testCancel(t, 64, FullSync) }
func TestCancel64Fast(t *testing.T)  { testCancel(t, 64, FastSync) }
func TestCancel64Light(t *testing.T) { testCancel(t, 64, LightSync) }

func testCancel(t *testing.T, protocol int, mode SyncMode) {
890 891
	t.Parallel()

892 893 894 895 896
	// Create a small enough block chain to download and the tester
	targetBlocks := blockCacheLimit - 15
	if targetBlocks >= MaxHashFetch {
		targetBlocks = MaxHashFetch - 15
	}
897 898 899
	if targetBlocks >= MaxHeaderFetch {
		targetBlocks = MaxHeaderFetch - 15
	}
900
	hashes, headers, blocks, receipts := makeChain(targetBlocks, 0, genesis, nil)
901

902
	tester := newTester()
903
	tester.newPeer("peer", protocol, hashes, headers, blocks, receipts)
904 905 906

	// Make sure canceling works with a pristine downloader
	tester.downloader.cancel()
907 908
	if !tester.downloader.queue.Idle() {
		t.Errorf("download queue not idle")
909 910
	}
	// Synchronise with the peer, but cancel afterwards
911
	if err := tester.sync("peer", nil, mode); err != nil {
912 913 914
		t.Fatalf("failed to synchronise blocks: %v", err)
	}
	tester.downloader.cancel()
915 916
	if !tester.downloader.queue.Idle() {
		t.Errorf("download queue not idle")
917 918 919
	}
}

920
// Tests that synchronisation from multiple peers works as intended (multi thread sanity test).
921 922 923 924 925 926 927 928 929
func TestMultiSynchronisation61(t *testing.T)      { testMultiSynchronisation(t, 61, FullSync) }
func TestMultiSynchronisation62(t *testing.T)      { testMultiSynchronisation(t, 62, FullSync) }
func TestMultiSynchronisation63Full(t *testing.T)  { testMultiSynchronisation(t, 63, FullSync) }
func TestMultiSynchronisation63Fast(t *testing.T)  { testMultiSynchronisation(t, 63, FastSync) }
func TestMultiSynchronisation64Full(t *testing.T)  { testMultiSynchronisation(t, 64, FullSync) }
func TestMultiSynchronisation64Fast(t *testing.T)  { testMultiSynchronisation(t, 64, FastSync) }
func TestMultiSynchronisation64Light(t *testing.T) { testMultiSynchronisation(t, 64, LightSync) }

func testMultiSynchronisation(t *testing.T, protocol int, mode SyncMode) {
930 931
	t.Parallel()

932
	// Create various peers with various parts of the chain
933
	targetPeers := 8
934
	targetBlocks := targetPeers*blockCacheLimit - 15
935
	hashes, headers, blocks, receipts := makeChain(targetBlocks, 0, genesis, nil)
936

937
	tester := newTester()
938 939
	for i := 0; i < targetPeers; i++ {
		id := fmt.Sprintf("peer #%d", i)
940
		tester.newPeer(id, protocol, hashes[i*blockCacheLimit:], headers, blocks, receipts)
941
	}
942
	if err := tester.sync("peer #0", nil, mode); err != nil {
943 944
		t.Fatalf("failed to synchronise blocks: %v", err)
	}
945
	assertOwnChain(t, tester, targetBlocks+1)
946 947
}

948 949
// Tests that synchronisations behave well in multi-version protocol environments
// and not wreak havok on other nodes in the network.
950 951 952 953 954 955 956 957 958
func TestMultiProtoSynchronisation61(t *testing.T)      { testMultiProtoSync(t, 61, FullSync) }
func TestMultiProtoSynchronisation62(t *testing.T)      { testMultiProtoSync(t, 62, FullSync) }
func TestMultiProtoSynchronisation63Full(t *testing.T)  { testMultiProtoSync(t, 63, FullSync) }
func TestMultiProtoSynchronisation63Fast(t *testing.T)  { testMultiProtoSync(t, 63, FastSync) }
func TestMultiProtoSynchronisation64Full(t *testing.T)  { testMultiProtoSync(t, 64, FullSync) }
func TestMultiProtoSynchronisation64Fast(t *testing.T)  { testMultiProtoSync(t, 64, FastSync) }
func TestMultiProtoSynchronisation64Light(t *testing.T) { testMultiProtoSync(t, 64, LightSync) }

func testMultiProtoSync(t *testing.T, protocol int, mode SyncMode) {
959 960
	t.Parallel()

961 962
	// Create a small enough block chain to download
	targetBlocks := blockCacheLimit - 15
963
	hashes, headers, blocks, receipts := makeChain(targetBlocks, 0, genesis, nil)
964 965

	// Create peers of every type
966 967 968
	tester := newTester()
	tester.newPeer("peer 61", 61, hashes, nil, blocks, nil)
	tester.newPeer("peer 62", 62, hashes, headers, blocks, nil)
969 970
	tester.newPeer("peer 63", 63, hashes, headers, blocks, receipts)
	tester.newPeer("peer 64", 64, hashes, headers, blocks, receipts)
971

972 973
	// Synchronise with the requested peer and make sure all blocks were retrieved
	if err := tester.sync(fmt.Sprintf("peer %d", protocol), nil, mode); err != nil {
974 975
		t.Fatalf("failed to synchronise blocks: %v", err)
	}
976 977
	assertOwnChain(t, tester, targetBlocks+1)

978 979 980 981 982 983 984 985 986
	// Check that no peers have been dropped off
	for _, version := range []int{61, 62, 63, 64} {
		peer := fmt.Sprintf("peer %d", version)
		if _, ok := tester.peerHashes[peer]; !ok {
			t.Errorf("%s dropped", peer)
		}
	}
}

987
// Tests that if a block is empty (e.g. header only), no body request should be
988
// made, and instead the header should be assembled into a whole block in itself.
989 990 991 992 993 994 995 996
func TestEmptyShortCircuit62(t *testing.T)      { testEmptyShortCircuit(t, 62, FullSync) }
func TestEmptyShortCircuit63Full(t *testing.T)  { testEmptyShortCircuit(t, 63, FullSync) }
func TestEmptyShortCircuit63Fast(t *testing.T)  { testEmptyShortCircuit(t, 63, FastSync) }
func TestEmptyShortCircuit64Full(t *testing.T)  { testEmptyShortCircuit(t, 64, FullSync) }
func TestEmptyShortCircuit64Fast(t *testing.T)  { testEmptyShortCircuit(t, 64, FastSync) }
func TestEmptyShortCircuit64Light(t *testing.T) { testEmptyShortCircuit(t, 64, LightSync) }

func testEmptyShortCircuit(t *testing.T, protocol int, mode SyncMode) {
997 998
	t.Parallel()

999 1000
	// Create a block chain to download
	targetBlocks := 2*blockCacheLimit - 15
1001
	hashes, headers, blocks, receipts := makeChain(targetBlocks, 0, genesis, nil)
1002

1003
	tester := newTester()
1004
	tester.newPeer("peer", protocol, hashes, headers, blocks, receipts)
1005 1006

	// Instrument the downloader to signal body requests
1007
	bodiesHave, receiptsHave := int32(0), int32(0)
1008
	tester.downloader.bodyFetchHook = func(headers []*types.Header) {
1009
		atomic.AddInt32(&bodiesHave, int32(len(headers)))
1010 1011
	}
	tester.downloader.receiptFetchHook = func(headers []*types.Header) {
1012
		atomic.AddInt32(&receiptsHave, int32(len(headers)))
1013 1014
	}
	// Synchronise with the peer and make sure all blocks were retrieved
1015
	if err := tester.sync("peer", nil, mode); err != nil {
1016 1017
		t.Fatalf("failed to synchronise blocks: %v", err)
	}
1018 1019
	assertOwnChain(t, tester, targetBlocks+1)

1020
	// Validate the number of block bodies that should have been requested
1021
	bodiesNeeded, receiptsNeeded := 0, 0
1022
	for _, block := range blocks {
1023 1024
		if mode != LightSync && block != genesis && (len(block.Transactions()) > 0 || len(block.Uncles()) > 0) {
			bodiesNeeded++
1025
		}
1026
	}
1027
	for hash, receipt := range receipts {
1028
		if mode == FastSync && len(receipt) > 0 && headers[hash].Number.Uint64() <= tester.downloader.queue.fastSyncPivot {
1029 1030 1031
			receiptsNeeded++
		}
	}
1032 1033
	if int(bodiesHave) != bodiesNeeded {
		t.Errorf("body retrieval count mismatch: have %v, want %v", bodiesHave, bodiesNeeded)
1034
	}
1035 1036
	if int(receiptsHave) != receiptsNeeded {
		t.Errorf("receipt retrieval count mismatch: have %v, want %v", receiptsHave, receiptsNeeded)
1037 1038 1039
	}
}

1040 1041
// Tests that headers are enqueued continuously, preventing malicious nodes from
// stalling the downloader by feeding gapped header chains.
1042 1043 1044 1045 1046 1047 1048 1049
func TestMissingHeaderAttack62(t *testing.T)      { testMissingHeaderAttack(t, 62, FullSync) }
func TestMissingHeaderAttack63Full(t *testing.T)  { testMissingHeaderAttack(t, 63, FullSync) }
func TestMissingHeaderAttack63Fast(t *testing.T)  { testMissingHeaderAttack(t, 63, FastSync) }
func TestMissingHeaderAttack64Full(t *testing.T)  { testMissingHeaderAttack(t, 64, FullSync) }
func TestMissingHeaderAttack64Fast(t *testing.T)  { testMissingHeaderAttack(t, 64, FastSync) }
func TestMissingHeaderAttack64Light(t *testing.T) { testMissingHeaderAttack(t, 64, LightSync) }

func testMissingHeaderAttack(t *testing.T, protocol int, mode SyncMode) {
1050 1051
	t.Parallel()

1052 1053
	// Create a small enough block chain to download
	targetBlocks := blockCacheLimit - 15
1054
	hashes, headers, blocks, receipts := makeChain(targetBlocks, 0, genesis, nil)
1055

1056
	tester := newTester()
1057 1058

	// Attempt a full sync with an attacker feeding gapped headers
1059
	tester.newPeer("attack", protocol, hashes, headers, blocks, receipts)
1060
	missing := targetBlocks / 2
1061
	delete(tester.peerHeaders["attack"], hashes[missing])
1062

1063
	if err := tester.sync("attack", nil, mode); err == nil {
1064 1065 1066
		t.Fatalf("succeeded attacker synchronisation")
	}
	// Synchronise with the valid peer and make sure sync succeeds
1067
	tester.newPeer("valid", protocol, hashes, headers, blocks, receipts)
1068
	if err := tester.sync("valid", nil, mode); err != nil {
1069 1070
		t.Fatalf("failed to synchronise blocks: %v", err)
	}
1071
	assertOwnChain(t, tester, targetBlocks+1)
1072 1073 1074 1075
}

// Tests that if requested headers are shifted (i.e. first is missing), the queue
// detects the invalid numbering.
1076 1077 1078 1079 1080 1081 1082 1083
func TestShiftedHeaderAttack62(t *testing.T)      { testShiftedHeaderAttack(t, 62, FullSync) }
func TestShiftedHeaderAttack63Full(t *testing.T)  { testShiftedHeaderAttack(t, 63, FullSync) }
func TestShiftedHeaderAttack63Fast(t *testing.T)  { testShiftedHeaderAttack(t, 63, FastSync) }
func TestShiftedHeaderAttack64Full(t *testing.T)  { testShiftedHeaderAttack(t, 64, FullSync) }
func TestShiftedHeaderAttack64Fast(t *testing.T)  { testShiftedHeaderAttack(t, 64, FastSync) }
func TestShiftedHeaderAttack64Light(t *testing.T) { testShiftedHeaderAttack(t, 64, LightSync) }

func testShiftedHeaderAttack(t *testing.T, protocol int, mode SyncMode) {
1084 1085
	// Create a small enough block chain to download
	targetBlocks := blockCacheLimit - 15
1086
	hashes, headers, blocks, receipts := makeChain(targetBlocks, 0, genesis, nil)
1087

1088
	tester := newTester()
1089 1090

	// Attempt a full sync with an attacker feeding shifted headers
1091
	tester.newPeer("attack", protocol, hashes, headers, blocks, receipts)
1092
	delete(tester.peerHeaders["attack"], hashes[len(hashes)-2])
1093
	delete(tester.peerBlocks["attack"], hashes[len(hashes)-2])
1094
	delete(tester.peerReceipts["attack"], hashes[len(hashes)-2])
1095

1096
	if err := tester.sync("attack", nil, mode); err == nil {
1097 1098 1099
		t.Fatalf("succeeded attacker synchronisation")
	}
	// Synchronise with the valid peer and make sure sync succeeds
1100
	tester.newPeer("valid", protocol, hashes, headers, blocks, receipts)
1101
	if err := tester.sync("valid", nil, mode); err != nil {
1102 1103 1104 1105 1106 1107 1108 1109 1110 1111 1112 1113
		t.Fatalf("failed to synchronise blocks: %v", err)
	}
	assertOwnChain(t, tester, targetBlocks+1)
}

// Tests that upon detecting an invalid header, the recent ones are rolled back
func TestInvalidHeaderRollback63Fast(t *testing.T)  { testInvalidHeaderRollback(t, 63, FastSync) }
func TestInvalidHeaderRollback64Fast(t *testing.T)  { testInvalidHeaderRollback(t, 64, FastSync) }
func TestInvalidHeaderRollback64Light(t *testing.T) { testInvalidHeaderRollback(t, 64, LightSync) }

func testInvalidHeaderRollback(t *testing.T, protocol int, mode SyncMode) {
	// Create a small enough block chain to download
1114
	targetBlocks := 3*fsHeaderSafetyNet + fsMinFullBlocks
1115 1116
	hashes, headers, blocks, receipts := makeChain(targetBlocks, 0, genesis, nil)

1117
	tester := newTester()
1118

1119 1120
	// Attempt to sync with an attacker that feeds junk during the fast sync phase.
	// This should result in the last fsHeaderSafetyNet headers being rolled back.
1121
	tester.newPeer("fast-attack", protocol, hashes, headers, blocks, receipts)
1122
	missing := fsHeaderSafetyNet + MaxHeaderFetch + 1
1123 1124
	delete(tester.peerHeaders["fast-attack"], hashes[len(hashes)-missing])

1125
	if err := tester.sync("fast-attack", nil, mode); err == nil {
1126 1127 1128
		t.Fatalf("succeeded fast attacker synchronisation")
	}
	if head := tester.headHeader().Number.Int64(); int(head) > MaxHeaderFetch {
1129
		t.Errorf("rollback head mismatch: have %v, want at most %v", head, MaxHeaderFetch)
1130
	}
1131 1132 1133
	// Attempt to sync with an attacker that feeds junk during the block import phase.
	// This should result in both the last fsHeaderSafetyNet number of headers being
	// rolled back, and also the pivot point being reverted to a non-block status.
1134
	tester.newPeer("block-attack", protocol, hashes, headers, blocks, receipts)
1135
	missing = 3*fsHeaderSafetyNet + MaxHeaderFetch + 1
1136 1137
	delete(tester.peerHeaders["block-attack"], hashes[len(hashes)-missing])

1138
	if err := tester.sync("block-attack", nil, mode); err == nil {
1139 1140
		t.Fatalf("succeeded block attacker synchronisation")
	}
1141 1142 1143
	if head := tester.headHeader().Number.Int64(); int(head) > 2*fsHeaderSafetyNet+MaxHeaderFetch {
		t.Errorf("rollback head mismatch: have %v, want at most %v", head, 2*fsHeaderSafetyNet+MaxHeaderFetch)
	}
1144
	if mode == FastSync {
1145 1146
		if head := tester.headBlock().NumberU64(); head != 0 {
			t.Errorf("fast sync pivot block #%d not rolled back", head)
1147
		}
1148
	}
1149 1150 1151 1152 1153
	// Attempt to sync with an attacker that withholds promised blocks after the
	// fast sync pivot point. This could be a trial to leave the node with a bad
	// but already imported pivot block.
	tester.newPeer("withhold-attack", protocol, hashes, headers, blocks, receipts)
	missing = 3*fsHeaderSafetyNet + MaxHeaderFetch + 1
1154

1155 1156 1157 1158 1159 1160
	tester.downloader.noFast = false
	tester.downloader.syncInitHook = func(uint64, uint64) {
		for i := missing; i <= len(hashes); i++ {
			delete(tester.peerHeaders["withhold-attack"], hashes[len(hashes)-i])
		}
		tester.downloader.syncInitHook = nil
1161
	}
1162

1163 1164 1165 1166 1167
	if err := tester.sync("withhold-attack", nil, mode); err == nil {
		t.Fatalf("succeeded withholding attacker synchronisation")
	}
	if head := tester.headHeader().Number.Int64(); int(head) > 2*fsHeaderSafetyNet+MaxHeaderFetch {
		t.Errorf("rollback head mismatch: have %v, want at most %v", head, 2*fsHeaderSafetyNet+MaxHeaderFetch)
1168 1169
	}
	if mode == FastSync {
1170 1171 1172
		if head := tester.headBlock().NumberU64(); head != 0 {
			t.Errorf("fast sync pivot block #%d not rolled back", head)
		}
1173
	}
1174 1175 1176 1177 1178 1179
	// Synchronise with the valid peer and make sure sync succeeds. Since the last
	// rollback should also disable fast syncing for this process, verify that we
	// did a fresh full sync. Note, we can't assert anything about the receipts
	// since we won't purge the database of them, hence we can't use asserOwnChain.
	tester.newPeer("valid", protocol, hashes, headers, blocks, receipts)
	if err := tester.sync("valid", nil, mode); err != nil {
1180 1181
		t.Fatalf("failed to synchronise blocks: %v", err)
	}
1182 1183
	if hs := len(tester.ownHeaders); hs != len(headers) {
		t.Fatalf("synchronised headers mismatch: have %v, want %v", hs, len(headers))
1184
	}
1185 1186 1187 1188
	if mode != LightSync {
		if bs := len(tester.ownBlocks); bs != len(blocks) {
			t.Fatalf("synchronised blocks mismatch: have %v, want %v", bs, len(blocks))
		}
1189
	}
1190 1191
}

1192 1193
// Tests that a peer advertising an high TD doesn't get to stall the downloader
// afterwards by not sending any useful hashes.
1194 1195 1196 1197 1198 1199 1200 1201 1202
func TestHighTDStarvationAttack61(t *testing.T)      { testHighTDStarvationAttack(t, 61, FullSync) }
func TestHighTDStarvationAttack62(t *testing.T)      { testHighTDStarvationAttack(t, 62, FullSync) }
func TestHighTDStarvationAttack63Full(t *testing.T)  { testHighTDStarvationAttack(t, 63, FullSync) }
func TestHighTDStarvationAttack63Fast(t *testing.T)  { testHighTDStarvationAttack(t, 63, FastSync) }
func TestHighTDStarvationAttack64Full(t *testing.T)  { testHighTDStarvationAttack(t, 64, FullSync) }
func TestHighTDStarvationAttack64Fast(t *testing.T)  { testHighTDStarvationAttack(t, 64, FastSync) }
func TestHighTDStarvationAttack64Light(t *testing.T) { testHighTDStarvationAttack(t, 64, LightSync) }

func testHighTDStarvationAttack(t *testing.T, protocol int, mode SyncMode) {
1203 1204
	t.Parallel()

1205
	tester := newTester()
1206
	hashes, headers, blocks, receipts := makeChain(0, 0, genesis, nil)
1207

1208
	tester.newPeer("attack", protocol, []common.Hash{hashes[0]}, headers, blocks, receipts)
1209
	if err := tester.sync("attack", big.NewInt(1000000), mode); err != errStallingPeer {
1210 1211 1212 1213
		t.Fatalf("synchronisation error mismatch: have %v, want %v", err, errStallingPeer)
	}
}

1214
// Tests that misbehaving peers are disconnected, whilst behaving ones are not.
1215 1216 1217 1218 1219 1220
func TestBlockHeaderAttackerDropping61(t *testing.T) { testBlockHeaderAttackerDropping(t, 61) }
func TestBlockHeaderAttackerDropping62(t *testing.T) { testBlockHeaderAttackerDropping(t, 62) }
func TestBlockHeaderAttackerDropping63(t *testing.T) { testBlockHeaderAttackerDropping(t, 63) }
func TestBlockHeaderAttackerDropping64(t *testing.T) { testBlockHeaderAttackerDropping(t, 64) }

func testBlockHeaderAttackerDropping(t *testing.T, protocol int) {
1221
	// Define the disconnection requirement for individual hash fetch errors
1222 1223 1224 1225
	tests := []struct {
		result error
		drop   bool
	}{
1226 1227 1228 1229 1230 1231 1232 1233 1234 1235 1236 1237 1238 1239 1240 1241 1242 1243 1244 1245
		{nil, false},                   // Sync succeeded, all is well
		{errBusy, false},               // Sync is already in progress, no problem
		{errUnknownPeer, false},        // Peer is unknown, was already dropped, don't double drop
		{errBadPeer, true},             // Peer was deemed bad for some reason, drop it
		{errStallingPeer, true},        // Peer was detected to be stalling, drop it
		{errNoPeers, false},            // No peers to download from, soft race, no issue
		{errTimeout, true},             // No hashes received in due time, drop the peer
		{errEmptyHashSet, true},        // No hashes were returned as a response, drop as it's a dead end
		{errEmptyHeaderSet, true},      // No headers were returned as a response, drop as it's a dead end
		{errPeersUnavailable, true},    // Nobody had the advertised blocks, drop the advertiser
		{errInvalidChain, true},        // Hash chain was detected as invalid, definitely drop
		{errInvalidBlock, false},       // A bad peer was detected, but not the sync origin
		{errInvalidBody, false},        // A bad peer was detected, but not the sync origin
		{errInvalidReceipt, false},     // A bad peer was detected, but not the sync origin
		{errCancelHashFetch, false},    // Synchronisation was canceled, origin may be innocent, don't drop
		{errCancelBlockFetch, false},   // Synchronisation was canceled, origin may be innocent, don't drop
		{errCancelHeaderFetch, false},  // Synchronisation was canceled, origin may be innocent, don't drop
		{errCancelBodyFetch, false},    // Synchronisation was canceled, origin may be innocent, don't drop
		{errCancelReceiptFetch, false}, // Synchronisation was canceled, origin may be innocent, don't drop
		{errCancelProcessing, false},   // Synchronisation was canceled, origin may be innocent, don't drop
1246 1247
	}
	// Run the tests and check disconnection status
1248
	tester := newTester()
1249 1250 1251
	for i, tt := range tests {
		// Register a new peer and ensure it's presence
		id := fmt.Sprintf("test %d", i)
1252
		if err := tester.newPeer(id, protocol, []common.Hash{genesis.Hash()}, nil, nil, nil); err != nil {
1253 1254 1255 1256 1257 1258 1259 1260
			t.Fatalf("test %d: failed to register new peer: %v", i, err)
		}
		if _, ok := tester.peerHashes[id]; !ok {
			t.Fatalf("test %d: registered peer not found", i)
		}
		// Simulate a synchronisation and check the required result
		tester.downloader.synchroniseMock = func(string, common.Hash) error { return tt.result }

1261
		tester.downloader.Synchronise(id, genesis.Hash(), big.NewInt(1000), FullSync)
1262 1263 1264 1265 1266
		if _, ok := tester.peerHashes[id]; !ok != tt.drop {
			t.Errorf("test %d: peer drop mismatch for %v: have %v, want %v", i, tt.result, !ok, tt.drop)
		}
	}
}
1267

1268 1269 1270 1271 1272 1273 1274 1275 1276 1277 1278
// Tests that synchronisation progress (origin block number, current block number
// and highest block number) is tracked and updated correctly.
func TestSyncProgress61(t *testing.T)      { testSyncProgress(t, 61, FullSync) }
func TestSyncProgress62(t *testing.T)      { testSyncProgress(t, 62, FullSync) }
func TestSyncProgress63Full(t *testing.T)  { testSyncProgress(t, 63, FullSync) }
func TestSyncProgress63Fast(t *testing.T)  { testSyncProgress(t, 63, FastSync) }
func TestSyncProgress64Full(t *testing.T)  { testSyncProgress(t, 64, FullSync) }
func TestSyncProgress64Fast(t *testing.T)  { testSyncProgress(t, 64, FastSync) }
func TestSyncProgress64Light(t *testing.T) { testSyncProgress(t, 64, LightSync) }

func testSyncProgress(t *testing.T, protocol int, mode SyncMode) {
1279 1280
	t.Parallel()

1281 1282
	// Create a small enough block chain to download
	targetBlocks := blockCacheLimit - 15
1283
	hashes, headers, blocks, receipts := makeChain(targetBlocks, 0, genesis, nil)
1284

1285
	// Set a sync init hook to catch progress changes
1286 1287 1288
	starting := make(chan struct{})
	progress := make(chan struct{})

1289
	tester := newTester()
1290 1291 1292 1293
	tester.downloader.syncInitHook = func(origin, latest uint64) {
		starting <- struct{}{}
		<-progress
	}
1294 1295 1296
	// Retrieve the sync progress and ensure they are zero (pristine sync)
	if origin, current, latest := tester.downloader.Progress(); origin != 0 || current != 0 || latest != 0 {
		t.Fatalf("Pristine progress mismatch: have %v/%v/%v, want %v/%v/%v", origin, current, latest, 0, 0, 0)
1297
	}
1298
	// Synchronise half the blocks and check initial progress
1299
	tester.newPeer("peer-half", protocol, hashes[targetBlocks/2:], headers, blocks, receipts)
1300 1301 1302 1303 1304
	pending := new(sync.WaitGroup)
	pending.Add(1)

	go func() {
		defer pending.Done()
1305
		if err := tester.sync("peer-half", nil, mode); err != nil {
1306 1307 1308 1309
			t.Fatalf("failed to synchronise blocks: %v", err)
		}
	}()
	<-starting
1310 1311
	if origin, current, latest := tester.downloader.Progress(); origin != 0 || current != 0 || latest != uint64(targetBlocks/2+1) {
		t.Fatalf("Initial progress mismatch: have %v/%v/%v, want %v/%v/%v", origin, current, latest, 0, 0, targetBlocks/2+1)
1312 1313 1314 1315
	}
	progress <- struct{}{}
	pending.Wait()

1316
	// Synchronise all the blocks and check continuation progress
1317
	tester.newPeer("peer-full", protocol, hashes, headers, blocks, receipts)
1318 1319 1320 1321
	pending.Add(1)

	go func() {
		defer pending.Done()
1322
		if err := tester.sync("peer-full", nil, mode); err != nil {
1323 1324 1325 1326
			t.Fatalf("failed to synchronise blocks: %v", err)
		}
	}()
	<-starting
1327 1328
	if origin, current, latest := tester.downloader.Progress(); origin != uint64(targetBlocks/2+1) || current != uint64(targetBlocks/2+1) || latest != uint64(targetBlocks) {
		t.Fatalf("Completing progress mismatch: have %v/%v/%v, want %v/%v/%v", origin, current, latest, targetBlocks/2+1, targetBlocks/2+1, targetBlocks)
1329 1330 1331
	}
	progress <- struct{}{}
	pending.Wait()
1332 1333 1334 1335 1336

	// Check final progress after successful sync
	if origin, current, latest := tester.downloader.Progress(); origin != uint64(targetBlocks/2+1) || current != uint64(targetBlocks) || latest != uint64(targetBlocks) {
		t.Fatalf("Final progress mismatch: have %v/%v/%v, want %v/%v/%v", origin, current, latest, targetBlocks/2+1, targetBlocks, targetBlocks)
	}
1337 1338
}

1339
// Tests that synchronisation progress (origin block number and highest block
1340 1341
// number) is tracked and updated correctly in case of a fork (or manual head
// revertal).
1342 1343 1344 1345 1346 1347 1348 1349 1350
func TestForkedSyncProgress61(t *testing.T)      { testForkedSyncProgress(t, 61, FullSync) }
func TestForkedSyncProgress62(t *testing.T)      { testForkedSyncProgress(t, 62, FullSync) }
func TestForkedSyncProgress63Full(t *testing.T)  { testForkedSyncProgress(t, 63, FullSync) }
func TestForkedSyncProgress63Fast(t *testing.T)  { testForkedSyncProgress(t, 63, FastSync) }
func TestForkedSyncProgress64Full(t *testing.T)  { testForkedSyncProgress(t, 64, FullSync) }
func TestForkedSyncProgress64Fast(t *testing.T)  { testForkedSyncProgress(t, 64, FastSync) }
func TestForkedSyncProgress64Light(t *testing.T) { testForkedSyncProgress(t, 64, LightSync) }

func testForkedSyncProgress(t *testing.T, protocol int, mode SyncMode) {
1351 1352
	t.Parallel()

1353 1354
	// Create a forked chain to simulate origin revertal
	common, fork := MaxHashFetch, 2*MaxHashFetch
1355
	hashesA, hashesB, headersA, headersB, blocksA, blocksB, receiptsA, receiptsB := makeChainFork(common+fork, fork, genesis, nil)
1356

1357
	// Set a sync init hook to catch progress changes
1358 1359 1360
	starting := make(chan struct{})
	progress := make(chan struct{})

1361
	tester := newTester()
1362 1363 1364 1365
	tester.downloader.syncInitHook = func(origin, latest uint64) {
		starting <- struct{}{}
		<-progress
	}
1366 1367 1368
	// Retrieve the sync progress and ensure they are zero (pristine sync)
	if origin, current, latest := tester.downloader.Progress(); origin != 0 || current != 0 || latest != 0 {
		t.Fatalf("Pristine progress mismatch: have %v/%v/%v, want %v/%v/%v", origin, current, latest, 0, 0, 0)
1369
	}
1370
	// Synchronise with one of the forks and check progress
1371
	tester.newPeer("fork A", protocol, hashesA, headersA, blocksA, receiptsA)
1372 1373 1374 1375 1376
	pending := new(sync.WaitGroup)
	pending.Add(1)

	go func() {
		defer pending.Done()
1377
		if err := tester.sync("fork A", nil, mode); err != nil {
1378 1379 1380 1381
			t.Fatalf("failed to synchronise blocks: %v", err)
		}
	}()
	<-starting
1382 1383
	if origin, current, latest := tester.downloader.Progress(); origin != 0 || current != 0 || latest != uint64(len(hashesA)-1) {
		t.Fatalf("Initial progress mismatch: have %v/%v/%v, want %v/%v/%v", origin, current, latest, 0, 0, len(hashesA)-1)
1384 1385 1386 1387 1388
	}
	progress <- struct{}{}
	pending.Wait()

	// Simulate a successful sync above the fork
1389
	tester.downloader.syncStatsChainOrigin = tester.downloader.syncStatsChainHeight
1390

1391
	// Synchronise with the second fork and check progress resets
1392
	tester.newPeer("fork B", protocol, hashesB, headersB, blocksB, receiptsB)
1393 1394 1395 1396
	pending.Add(1)

	go func() {
		defer pending.Done()
1397
		if err := tester.sync("fork B", nil, mode); err != nil {
1398 1399 1400 1401
			t.Fatalf("failed to synchronise blocks: %v", err)
		}
	}()
	<-starting
1402 1403
	if origin, current, latest := tester.downloader.Progress(); origin != uint64(common) || current != uint64(len(hashesA)-1) || latest != uint64(len(hashesB)-1) {
		t.Fatalf("Forking progress mismatch: have %v/%v/%v, want %v/%v/%v", origin, current, latest, common, len(hashesA)-1, len(hashesB)-1)
1404 1405 1406
	}
	progress <- struct{}{}
	pending.Wait()
1407 1408 1409 1410 1411

	// Check final progress after successful sync
	if origin, current, latest := tester.downloader.Progress(); origin != uint64(common) || current != uint64(len(hashesB)-1) || latest != uint64(len(hashesB)-1) {
		t.Fatalf("Final progress mismatch: have %v/%v/%v, want %v/%v/%v", origin, current, latest, common, len(hashesB)-1, len(hashesB)-1)
	}
1412 1413
}

1414
// Tests that if synchronisation is aborted due to some failure, then the progress
1415 1416
// origin is not updated in the next sync cycle, as it should be considered the
// continuation of the previous sync and not a new instance.
1417 1418 1419 1420 1421 1422 1423 1424 1425
func TestFailedSyncProgress61(t *testing.T)      { testFailedSyncProgress(t, 61, FullSync) }
func TestFailedSyncProgress62(t *testing.T)      { testFailedSyncProgress(t, 62, FullSync) }
func TestFailedSyncProgress63Full(t *testing.T)  { testFailedSyncProgress(t, 63, FullSync) }
func TestFailedSyncProgress63Fast(t *testing.T)  { testFailedSyncProgress(t, 63, FastSync) }
func TestFailedSyncProgress64Full(t *testing.T)  { testFailedSyncProgress(t, 64, FullSync) }
func TestFailedSyncProgress64Fast(t *testing.T)  { testFailedSyncProgress(t, 64, FastSync) }
func TestFailedSyncProgress64Light(t *testing.T) { testFailedSyncProgress(t, 64, LightSync) }

func testFailedSyncProgress(t *testing.T, protocol int, mode SyncMode) {
1426 1427
	t.Parallel()

1428 1429
	// Create a small enough block chain to download
	targetBlocks := blockCacheLimit - 15
1430
	hashes, headers, blocks, receipts := makeChain(targetBlocks, 0, genesis, nil)
1431

1432
	// Set a sync init hook to catch progress changes
1433 1434 1435
	starting := make(chan struct{})
	progress := make(chan struct{})

1436
	tester := newTester()
1437 1438 1439 1440
	tester.downloader.syncInitHook = func(origin, latest uint64) {
		starting <- struct{}{}
		<-progress
	}
1441 1442 1443
	// Retrieve the sync progress and ensure they are zero (pristine sync)
	if origin, current, latest := tester.downloader.Progress(); origin != 0 || current != 0 || latest != 0 {
		t.Fatalf("Pristine progress mismatch: have %v/%v/%v, want %v/%v/%v", origin, current, latest, 0, 0, 0)
1444 1445
	}
	// Attempt a full sync with a faulty peer
1446
	tester.newPeer("faulty", protocol, hashes, headers, blocks, receipts)
1447
	missing := targetBlocks / 2
1448
	delete(tester.peerHeaders["faulty"], hashes[missing])
1449
	delete(tester.peerBlocks["faulty"], hashes[missing])
1450
	delete(tester.peerReceipts["faulty"], hashes[missing])
1451 1452 1453 1454 1455 1456

	pending := new(sync.WaitGroup)
	pending.Add(1)

	go func() {
		defer pending.Done()
1457
		if err := tester.sync("faulty", nil, mode); err == nil {
1458 1459 1460 1461
			t.Fatalf("succeeded faulty synchronisation")
		}
	}()
	<-starting
1462 1463
	if origin, current, latest := tester.downloader.Progress(); origin != 0 || current != 0 || latest != uint64(targetBlocks) {
		t.Fatalf("Initial progress mismatch: have %v/%v/%v, want %v/%v/%v", origin, current, latest, 0, 0, targetBlocks)
1464 1465 1466 1467
	}
	progress <- struct{}{}
	pending.Wait()

1468
	// Synchronise with a good peer and check that the progress origin remind the same after a failure
1469
	tester.newPeer("valid", protocol, hashes, headers, blocks, receipts)
1470 1471 1472 1473
	pending.Add(1)

	go func() {
		defer pending.Done()
1474
		if err := tester.sync("valid", nil, mode); err != nil {
1475 1476 1477 1478
			t.Fatalf("failed to synchronise blocks: %v", err)
		}
	}()
	<-starting
1479 1480
	if origin, current, latest := tester.downloader.Progress(); origin != 0 || current > uint64(targetBlocks/2) || latest != uint64(targetBlocks) {
		t.Fatalf("Completing progress mismatch: have %v/%v/%v, want %v/0-%v/%v", origin, current, latest, 0, targetBlocks/2, targetBlocks)
1481 1482 1483
	}
	progress <- struct{}{}
	pending.Wait()
1484 1485 1486 1487 1488

	// Check final progress after successful sync
	if origin, current, latest := tester.downloader.Progress(); origin > uint64(targetBlocks/2) || current != uint64(targetBlocks) || latest != uint64(targetBlocks) {
		t.Fatalf("Final progress mismatch: have %v/%v/%v, want 0-%v/%v/%v", origin, current, latest, targetBlocks/2, targetBlocks, targetBlocks)
	}
1489 1490 1491
}

// Tests that if an attacker fakes a chain height, after the attack is detected,
1492 1493 1494 1495 1496 1497 1498 1499 1500 1501
// the progress height is successfully reduced at the next sync invocation.
func TestFakedSyncProgress61(t *testing.T)      { testFakedSyncProgress(t, 61, FullSync) }
func TestFakedSyncProgress62(t *testing.T)      { testFakedSyncProgress(t, 62, FullSync) }
func TestFakedSyncProgress63Full(t *testing.T)  { testFakedSyncProgress(t, 63, FullSync) }
func TestFakedSyncProgress63Fast(t *testing.T)  { testFakedSyncProgress(t, 63, FastSync) }
func TestFakedSyncProgress64Full(t *testing.T)  { testFakedSyncProgress(t, 64, FullSync) }
func TestFakedSyncProgress64Fast(t *testing.T)  { testFakedSyncProgress(t, 64, FastSync) }
func TestFakedSyncProgress64Light(t *testing.T) { testFakedSyncProgress(t, 64, LightSync) }

func testFakedSyncProgress(t *testing.T, protocol int, mode SyncMode) {
1502 1503
	t.Parallel()

1504 1505
	// Create a small block chain
	targetBlocks := blockCacheLimit - 15
1506
	hashes, headers, blocks, receipts := makeChain(targetBlocks+3, 0, genesis, nil)
1507

1508
	// Set a sync init hook to catch progress changes
1509 1510 1511
	starting := make(chan struct{})
	progress := make(chan struct{})

1512
	tester := newTester()
1513 1514 1515 1516
	tester.downloader.syncInitHook = func(origin, latest uint64) {
		starting <- struct{}{}
		<-progress
	}
1517 1518 1519
	// Retrieve the sync progress and ensure they are zero (pristine sync)
	if origin, current, latest := tester.downloader.Progress(); origin != 0 || current != 0 || latest != 0 {
		t.Fatalf("Pristine progress mismatch: have %v/%v/%v, want %v/%v/%v", origin, current, latest, 0, 0, 0)
1520 1521
	}
	//  Create and sync with an attacker that promises a higher chain than available
1522
	tester.newPeer("attack", protocol, hashes, headers, blocks, receipts)
1523
	for i := 1; i < 3; i++ {
1524
		delete(tester.peerHeaders["attack"], hashes[i])
1525
		delete(tester.peerBlocks["attack"], hashes[i])
1526
		delete(tester.peerReceipts["attack"], hashes[i])
1527 1528 1529 1530 1531 1532 1533
	}

	pending := new(sync.WaitGroup)
	pending.Add(1)

	go func() {
		defer pending.Done()
1534
		if err := tester.sync("attack", nil, mode); err == nil {
1535 1536 1537 1538
			t.Fatalf("succeeded attacker synchronisation")
		}
	}()
	<-starting
1539 1540
	if origin, current, latest := tester.downloader.Progress(); origin != 0 || current != 0 || latest != uint64(targetBlocks+3) {
		t.Fatalf("Initial progress mismatch: have %v/%v/%v, want %v/%v/%v", origin, current, latest, 0, 0, targetBlocks+3)
1541 1542 1543 1544
	}
	progress <- struct{}{}
	pending.Wait()

1545
	// Synchronise with a good peer and check that the progress height has been reduced to the true value
1546
	tester.newPeer("valid", protocol, hashes[3:], headers, blocks, receipts)
1547 1548 1549 1550
	pending.Add(1)

	go func() {
		defer pending.Done()
1551
		if err := tester.sync("valid", nil, mode); err != nil {
1552 1553 1554 1555
			t.Fatalf("failed to synchronise blocks: %v", err)
		}
	}()
	<-starting
1556 1557
	if origin, current, latest := tester.downloader.Progress(); origin != 0 || current > uint64(targetBlocks) || latest != uint64(targetBlocks) {
		t.Fatalf("Completing progress mismatch: have %v/%v/%v, want %v/0-%v/%v", origin, current, latest, 0, targetBlocks, targetBlocks)
1558 1559 1560
	}
	progress <- struct{}{}
	pending.Wait()
1561 1562 1563 1564 1565

	// Check final progress after successful sync
	if origin, current, latest := tester.downloader.Progress(); origin > uint64(targetBlocks) || current != uint64(targetBlocks) || latest != uint64(targetBlocks) {
		t.Fatalf("Final progress mismatch: have %v/%v/%v, want 0-%v/%v/%v", origin, current, latest, targetBlocks, targetBlocks, targetBlocks)
	}
1566
}
1567 1568 1569 1570 1571 1572 1573 1574 1575 1576 1577 1578 1579 1580 1581 1582 1583 1584 1585 1586 1587 1588 1589 1590 1591 1592 1593 1594 1595 1596 1597 1598 1599 1600 1601 1602 1603 1604 1605 1606 1607 1608 1609 1610 1611 1612 1613

// This test reproduces an issue where unexpected deliveries would
// block indefinitely if they arrived at the right time.
func TestDeliverHeadersHang62(t *testing.T)      { testDeliverHeadersHang(t, 62, FullSync) }
func TestDeliverHeadersHang63Full(t *testing.T)  { testDeliverHeadersHang(t, 63, FullSync) }
func TestDeliverHeadersHang63Fast(t *testing.T)  { testDeliverHeadersHang(t, 63, FastSync) }
func TestDeliverHeadersHang64Full(t *testing.T)  { testDeliverHeadersHang(t, 64, FullSync) }
func TestDeliverHeadersHang64Fast(t *testing.T)  { testDeliverHeadersHang(t, 64, FastSync) }
func TestDeliverHeadersHang64Light(t *testing.T) { testDeliverHeadersHang(t, 64, LightSync) }

func testDeliverHeadersHang(t *testing.T, protocol int, mode SyncMode) {
	t.Parallel()
	hashes, headers, blocks, receipts := makeChain(5, 0, genesis, nil)
	fakeHeads := []*types.Header{{}, {}, {}, {}}
	for i := 0; i < 200; i++ {
		tester := newTester()
		tester.newPeer("peer", protocol, hashes, headers, blocks, receipts)
		// Whenever the downloader requests headers, flood it with
		// a lot of unrequested header deliveries.
		tester.downloader.peers.peers["peer"].getAbsHeaders = func(from uint64, count, skip int, reverse bool) error {
			deliveriesDone := make(chan struct{}, 500)
			for i := 0; i < cap(deliveriesDone); i++ {
				peer := fmt.Sprintf("fake-peer%d", i)
				go func() {
					tester.downloader.DeliverHeaders(peer, fakeHeads)
					deliveriesDone <- struct{}{}
				}()
			}
			// Deliver the actual requested headers.
			impl := tester.peerGetAbsHeadersFn("peer", 0)
			go impl(from, count, skip, reverse)
			// None of the extra deliveries should block.
			timeout := time.After(5 * time.Second)
			for i := 0; i < cap(deliveriesDone); i++ {
				select {
				case <-deliveriesDone:
				case <-timeout:
					panic("blocked")
				}
			}
			return nil
		}
		if err := tester.sync("peer", nil, mode); err != nil {
			t.Errorf("sync failed: %v", err)
		}
	}
}