downloader_test.go 15.1 KB
Newer Older
1 2 3 4 5 6 7 8 9 10
package downloader

import (
	"encoding/binary"
	"math/big"
	"testing"
	"time"

	"github.com/ethereum/go-ethereum/common"
	"github.com/ethereum/go-ethereum/core/types"
O
obscuren 已提交
11
	"github.com/ethereum/go-ethereum/event"
12 13
)

14 15 16 17
var (
	knownHash   = common.Hash{1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}
	unknownHash = common.Hash{9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9}
)
18

19
func createHashes(start, amount int) (hashes []common.Hash) {
20 21 22 23 24 25 26 27 28
	hashes = make([]common.Hash, amount+1)
	hashes[len(hashes)-1] = knownHash

	for i := range hashes[:len(hashes)-1] {
		binary.BigEndian.PutUint64(hashes[i][:8], uint64(i+2))
	}
	return
}

29
func createBlock(i int, parent, hash common.Hash) *types.Block {
30 31 32
	header := &types.Header{Number: big.NewInt(int64(i))}
	block := types.NewBlockWithHeader(header)
	block.HeaderHash = hash
33
	block.ParentHeaderHash = parent
34 35 36
	return block
}

37 38
func createBlocksFromHashes(hashes []common.Hash) map[common.Hash]*types.Block {
	blocks := make(map[common.Hash]*types.Block)
39 40 41 42 43 44
	for i := 0; i < len(hashes); i++ {
		parent := knownHash
		if i < len(hashes)-1 {
			parent = hashes[i+1]
		}
		blocks[hashes[i]] = createBlock(len(hashes)-i, parent, hashes[i])
45 46 47 48 49
	}
	return blocks
}

type downloadTester struct {
50 51 52 53 54 55
	downloader *Downloader

	hashes []common.Hash                // Chain of hashes simulating
	blocks map[common.Hash]*types.Block // Blocks associated with the hashes
	chain  []common.Hash                // Block-chain being constructed

O
obscuren 已提交
56 57 58 59
	t            *testing.T
	pcount       int
	done         chan bool
	activePeerId string
60 61 62
}

func newTester(t *testing.T, hashes []common.Hash, blocks map[common.Hash]*types.Block) *downloadTester {
63 64 65 66 67 68 69 70 71
	tester := &downloadTester{
		t: t,

		hashes: hashes,
		blocks: blocks,
		chain:  []common.Hash{knownHash},

		done: make(chan bool),
	}
O
obscuren 已提交
72 73
	var mux event.TypeMux
	downloader := New(&mux, tester.hasBlock, tester.getBlock)
74 75 76 77 78
	tester.downloader = downloader

	return tester
}

79 80 81
// sync is a simple wrapper around the downloader to start synchronisation and
// block until it returns
func (dl *downloadTester) sync(peerId string, head common.Hash) error {
O
obscuren 已提交
82
	dl.activePeerId = peerId
83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112
	return dl.downloader.Synchronise(peerId, head)
}

// syncTake is starts synchronising with a remote peer, but concurrently it also
// starts fetching blocks that the downloader retrieved. IT blocks until both go
// routines terminate.
func (dl *downloadTester) syncTake(peerId string, head common.Hash) (types.Blocks, error) {
	// Start a block collector to take blocks as they become available
	done := make(chan struct{})
	took := []*types.Block{}
	go func() {
		for running := true; running; {
			select {
			case <-done:
				running = false
			default:
				time.Sleep(time.Millisecond)
			}
			// Take a batch of blocks and accumulate
			took = append(took, dl.downloader.TakeBlocks()...)
		}
		done <- struct{}{}
	}()
	// Start the downloading, sync the taker and return
	err := dl.sync(peerId, head)

	done <- struct{}{}
	<-done

	return took, err
O
obscuren 已提交
113 114
}

115 116 117 118 119 120
func (dl *downloadTester) insertBlocks(blocks types.Blocks) {
	for _, block := range blocks {
		dl.chain = append(dl.chain, block.Hash())
	}
}

121
func (dl *downloadTester) hasBlock(hash common.Hash) bool {
122 123 124 125
	for _, h := range dl.chain {
		if h == hash {
			return true
		}
126 127 128 129
	}
	return false
}

130 131
func (dl *downloadTester) getBlock(hash common.Hash) *types.Block {
	return dl.blocks[knownHash]
132 133
}

134 135 136 137 138 139
// getHashes retrieves a batch of hashes for reconstructing the chain.
func (dl *downloadTester) getHashes(head common.Hash) error {
	// Gather the next batch of hashes
	hashes := make([]common.Hash, 0, maxHashFetch)
	for i, hash := range dl.hashes {
		if hash == head {
140
			i++
141 142 143 144 145 146 147 148
			for len(hashes) < cap(hashes) && i < len(dl.hashes) {
				hashes = append(hashes, dl.hashes[i])
				i++
			}
			break
		}
	}
	// Delay delivery a bit to allow attacks to unfold
149 150 151 152 153
	id := dl.activePeerId
	go func() {
		time.Sleep(time.Millisecond)
		dl.downloader.DeliverHashes(id, hashes)
	}()
154 155 156 157 158
	return nil
}

func (dl *downloadTester) getBlocks(id string) func([]common.Hash) error {
	return func(hashes []common.Hash) error {
159 160 161 162 163
		blocks := make([]*types.Block, 0, len(hashes))
		for _, hash := range hashes {
			if block, ok := dl.blocks[hash]; ok {
				blocks = append(blocks, block)
			}
164
		}
165
		go dl.downloader.DeliverBlocks(id, blocks)
166 167 168 169 170 171 172 173

		return nil
	}
}

func (dl *downloadTester) newPeer(id string, td *big.Int, hash common.Hash) {
	dl.pcount++

O
obscuren 已提交
174
	dl.downloader.RegisterPeer(id, hash, dl.getHashes, dl.getBlocks(id))
175 176 177 178 179 180
}

func (dl *downloadTester) badBlocksPeer(id string, td *big.Int, hash common.Hash) {
	dl.pcount++

	// This bad peer never returns any blocks
O
obscuren 已提交
181
	dl.downloader.RegisterPeer(id, hash, dl.getHashes, func([]common.Hash) error {
182 183 184 185 186
		return nil
	})
}

func TestDownload(t *testing.T) {
187
	minDesiredPeerCount = 4
188
	blockTTL = 1 * time.Second
189

190 191
	targetBlocks := 1000
	hashes := createHashes(0, targetBlocks)
192 193 194
	blocks := createBlocksFromHashes(hashes)
	tester := newTester(t, hashes, blocks)

195
	tester.newPeer("peer1", big.NewInt(10000), hashes[0])
196 197 198
	tester.newPeer("peer2", big.NewInt(0), common.Hash{})
	tester.badBlocksPeer("peer3", big.NewInt(0), common.Hash{})
	tester.badBlocksPeer("peer4", big.NewInt(0), common.Hash{})
O
obscuren 已提交
199
	tester.activePeerId = "peer1"
200

O
obscuren 已提交
201
	err := tester.sync("peer1", hashes[0])
202 203 204 205
	if err != nil {
		t.Error("download error", err)
	}

206
	inqueue := len(tester.downloader.queue.blockCache)
207 208
	if inqueue != targetBlocks {
		t.Error("expected", targetBlocks, "have", inqueue)
209 210
	}
}
211 212

func TestMissing(t *testing.T) {
213
	targetBlocks := 1000
214 215 216 217 218 219 220 221 222 223
	hashes := createHashes(0, 1000)
	extraHashes := createHashes(1001, 1003)
	blocks := createBlocksFromHashes(append(extraHashes, hashes...))
	tester := newTester(t, hashes, blocks)

	tester.newPeer("peer1", big.NewInt(10000), hashes[len(hashes)-1])

	hashes = append(extraHashes, hashes[:len(hashes)-1]...)
	tester.newPeer("peer2", big.NewInt(0), common.Hash{})

O
obscuren 已提交
224
	err := tester.sync("peer1", hashes[0])
225 226
	if err != nil {
		t.Error("download error", err)
227 228
	}

229
	inqueue := len(tester.downloader.queue.blockCache)
230 231 232 233 234 235 236
	if inqueue != targetBlocks {
		t.Error("expected", targetBlocks, "have", inqueue)
	}
}

func TestTaking(t *testing.T) {
	minDesiredPeerCount = 4
237
	blockTTL = 1 * time.Second
238 239 240 241 242 243 244 245 246 247 248

	targetBlocks := 1000
	hashes := createHashes(0, targetBlocks)
	blocks := createBlocksFromHashes(hashes)
	tester := newTester(t, hashes, blocks)

	tester.newPeer("peer1", big.NewInt(10000), hashes[0])
	tester.newPeer("peer2", big.NewInt(0), common.Hash{})
	tester.badBlocksPeer("peer3", big.NewInt(0), common.Hash{})
	tester.badBlocksPeer("peer4", big.NewInt(0), common.Hash{})

O
obscuren 已提交
249
	err := tester.sync("peer1", hashes[0])
250 251 252
	if err != nil {
		t.Error("download error", err)
	}
253
	bs := tester.downloader.TakeBlocks()
254 255
	if len(bs) != targetBlocks {
		t.Error("retrieved block mismatch: have %v, want %v", len(bs), targetBlocks)
256
	}
257
}
258

259 260 261 262 263 264
func TestInactiveDownloader(t *testing.T) {
	targetBlocks := 1000
	hashes := createHashes(0, targetBlocks)
	blocks := createBlocksFromHashSet(createHashSet(hashes))
	tester := newTester(t, hashes, nil)

265
	err := tester.downloader.DeliverHashes("bad peer 001", hashes)
266 267 268 269
	if err != errNoSyncActive {
		t.Error("expected no sync error, got", err)
	}

270
	err = tester.downloader.DeliverBlocks("bad peer 001", blocks)
271 272 273 274 275 276 277
	if err != errNoSyncActive {
		t.Error("expected no sync error, got", err)
	}
}

func TestCancel(t *testing.T) {
	minDesiredPeerCount = 4
278
	blockTTL = 1 * time.Second
279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301

	targetBlocks := 1000
	hashes := createHashes(0, targetBlocks)
	blocks := createBlocksFromHashes(hashes)
	tester := newTester(t, hashes, blocks)

	tester.newPeer("peer1", big.NewInt(10000), hashes[0])

	err := tester.sync("peer1", hashes[0])
	if err != nil {
		t.Error("download error", err)
	}

	if !tester.downloader.Cancel() {
		t.Error("cancel operation unsuccessfull")
	}

	hashSize, blockSize := tester.downloader.queue.Size()
	if hashSize > 0 || blockSize > 0 {
		t.Error("block (", blockSize, ") or hash (", hashSize, ") not 0")
	}
}

302 303
func TestThrottling(t *testing.T) {
	minDesiredPeerCount = 4
304
	blockTTL = 1 * time.Second
305

306
	targetBlocks := 16 * blockCacheLimit
307 308 309 310 311 312 313 314 315 316
	hashes := createHashes(0, targetBlocks)
	blocks := createBlocksFromHashes(hashes)
	tester := newTester(t, hashes, blocks)

	tester.newPeer("peer1", big.NewInt(10000), hashes[0])
	tester.newPeer("peer2", big.NewInt(0), common.Hash{})
	tester.badBlocksPeer("peer3", big.NewInt(0), common.Hash{})
	tester.badBlocksPeer("peer4", big.NewInt(0), common.Hash{})

	// Concurrently download and take the blocks
317
	took, err := tester.syncTake("peer1", hashes[0])
318
	if err != nil {
319
		t.Fatalf("failed to synchronise blocks: %v", err)
320 321 322 323 324
	}
	if len(took) != targetBlocks {
		t.Fatalf("downloaded block mismatch: have %v, want %v", len(took), targetBlocks)
	}
}
325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341

// Tests that if a peer returns an invalid chain with a block pointing to a non-
// existing parent, it is correctly detected and handled.
func TestNonExistingParentAttack(t *testing.T) {
	// Forge a single-link chain with a forged header
	hashes := createHashes(0, 1)
	blocks := createBlocksFromHashes(hashes)

	forged := blocks[hashes[0]]
	forged.ParentHeaderHash = unknownHash

	// Try and sync with the malicious node and check that it fails
	tester := newTester(t, hashes, blocks)
	tester.newPeer("attack", big.NewInt(10000), hashes[0])
	if err := tester.sync("attack", hashes[0]); err != nil {
		t.Fatalf("failed to synchronise blocks: %v", err)
	}
342 343 344
	bs := tester.downloader.TakeBlocks()
	if len(bs) != 1 {
		t.Fatalf("retrieved block mismatch: have %v, want %v", len(bs), 1)
345
	}
346 347
	if tester.hasBlock(bs[0].ParentHash()) {
		t.Fatalf("tester knows about the unknown hash")
348 349 350 351 352 353 354 355 356
	}
	tester.downloader.Cancel()

	// Reconstruct a valid chain, and try to synchronize with it
	forged.ParentHeaderHash = knownHash
	tester.newPeer("valid", big.NewInt(20000), hashes[0])
	if err := tester.sync("valid", hashes[0]); err != nil {
		t.Fatalf("failed to synchronise blocks: %v", err)
	}
357
	bs = tester.downloader.TakeBlocks()
358
	if len(bs) != 1 {
359
		t.Fatalf("retrieved block mismatch: have %v, want %v", len(bs), 1)
360
	}
361 362 363
	if !tester.hasBlock(bs[0].ParentHash()) {
		t.Fatalf("tester doesn't know about the origin hash")
	}
364
}
365 366 367 368 369

// Tests that if a malicious peers keeps sending us repeating hashes, we don't
// loop indefinitely.
func TestRepeatingHashAttack(t *testing.T) {
	// Create a valid chain, but drop the last link
370
	hashes := createHashes(0, blockCacheLimit)
371
	blocks := createBlocksFromHashes(hashes)
372
	forged := hashes[:len(hashes)-1]
373 374

	// Try and sync with the malicious node
375 376
	tester := newTester(t, forged, blocks)
	tester.newPeer("attack", big.NewInt(10000), forged[0])
377 378 379 380 381 382 383 384 385 386 387 388 389 390 391

	errc := make(chan error)
	go func() {
		errc <- tester.sync("attack", hashes[0])
	}()

	// Make sure that syncing returns and does so with a failure
	select {
	case <-time.After(100 * time.Millisecond):
		t.Fatalf("synchronisation blocked")
	case err := <-errc:
		if err == nil {
			t.Fatalf("synchronisation succeeded")
		}
	}
392 393 394 395 396 397
	// Ensure that a valid chain can still pass sync
	tester.hashes = hashes
	tester.newPeer("valid", big.NewInt(20000), hashes[0])
	if err := tester.sync("valid", hashes[0]); err != nil {
		t.Fatalf("failed to synchronise blocks: %v", err)
	}
398
}
399 400 401 402 403

// Tests that if a malicious peers returns a non-existent block hash, it should
// eventually time out and the sync reattempted.
func TestNonExistingBlockAttack(t *testing.T) {
	// Create a valid chain, but forge the last link
404
	hashes := createHashes(0, blockCacheLimit)
405
	blocks := createBlocksFromHashes(hashes)
406
	origin := hashes[len(hashes)/2]
407 408 409 410 411 412 413 414 415

	hashes[len(hashes)/2] = unknownHash

	// Try and sync with the malicious node and check that it fails
	tester := newTester(t, hashes, blocks)
	tester.newPeer("attack", big.NewInt(10000), hashes[0])
	if err := tester.sync("attack", hashes[0]); err != errPeersUnavailable {
		t.Fatalf("synchronisation error mismatch: have %v, want %v", err, errPeersUnavailable)
	}
416 417 418 419 420 421
	// Ensure that a valid chain can still pass sync
	hashes[len(hashes)/2] = origin
	tester.newPeer("valid", big.NewInt(20000), hashes[0])
	if err := tester.sync("valid", hashes[0]); err != nil {
		t.Fatalf("failed to synchronise blocks: %v", err)
	}
422
}
423 424 425 426 427 428 429 430

// Tests that if a malicious peer is returning hashes in a weird order, that the
// sync throttler doesn't choke on them waiting for the valid blocks.
func TestInvalidHashOrderAttack(t *testing.T) {
	// Create a valid long chain, but reverse some hashes within
	hashes := createHashes(0, 4*blockCacheLimit)
	blocks := createBlocksFromHashes(hashes)

431 432 433 434 435
	chunk1 := make([]common.Hash, blockCacheLimit)
	chunk2 := make([]common.Hash, blockCacheLimit)
	copy(chunk1, hashes[blockCacheLimit:2*blockCacheLimit])
	copy(chunk2, hashes[2*blockCacheLimit:3*blockCacheLimit])

436 437
	reverse := make([]common.Hash, len(hashes))
	copy(reverse, hashes)
438 439
	copy(reverse[2*blockCacheLimit:], chunk1)
	copy(reverse[blockCacheLimit:], chunk2)
440 441 442 443 444 445 446 447 448 449 450 451 452 453

	// Try and sync with the malicious node and check that it fails
	tester := newTester(t, reverse, blocks)
	tester.newPeer("attack", big.NewInt(10000), reverse[0])
	if _, err := tester.syncTake("attack", reverse[0]); err != ErrInvalidChain {
		t.Fatalf("synchronisation error mismatch: have %v, want %v", err, ErrInvalidChain)
	}
	// Ensure that a valid chain can still pass sync
	tester.hashes = hashes
	tester.newPeer("valid", big.NewInt(20000), hashes[0])
	if _, err := tester.syncTake("valid", hashes[0]); err != nil {
		t.Fatalf("failed to synchronise blocks: %v", err)
	}
}
454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470

// Tests that if a malicious peer makes up a random hash chain and tries to push
// indefinitely, it actually gets caught with it.
func TestMadeupHashChainAttack(t *testing.T) {
	blockTTL = 100 * time.Millisecond
	crossCheckCycle = 25 * time.Millisecond

	// Create a long chain of hashes without backing blocks
	hashes := createHashes(0, 1024*blockCacheLimit)

	// Try and sync with the malicious node and check that it fails
	tester := newTester(t, hashes, nil)
	tester.newPeer("attack", big.NewInt(10000), hashes[0])
	if _, err := tester.syncTake("attack", hashes[0]); err != ErrCrossCheckFailed {
		t.Fatalf("synchronisation error mismatch: have %v, want %v", err, ErrCrossCheckFailed)
	}
}
471 472 473 474

// Tests that if a malicious peer makes up a random block chain, and tried to
// push indefinitely, it actually gets caught with it.
func TestMadeupBlockChainAttack(t *testing.T) {
475 476 477
	defaultBlockTTL := blockTTL
	defaultCrossCheckCycle := crossCheckCycle

478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495
	blockTTL = 100 * time.Millisecond
	crossCheckCycle = 25 * time.Millisecond

	// Create a long chain of blocks and simulate an invalid chain by dropping every second
	hashes := createHashes(0, 32*blockCacheLimit)
	blocks := createBlocksFromHashes(hashes)

	gapped := make([]common.Hash, len(hashes)/2)
	for i := 0; i < len(gapped); i++ {
		gapped[i] = hashes[2*i]
	}
	// Try and sync with the malicious node and check that it fails
	tester := newTester(t, gapped, blocks)
	tester.newPeer("attack", big.NewInt(10000), gapped[0])
	if _, err := tester.syncTake("attack", gapped[0]); err != ErrCrossCheckFailed {
		t.Fatalf("synchronisation error mismatch: have %v, want %v", err, ErrCrossCheckFailed)
	}
	// Ensure that a valid chain can still pass sync
496 497 498
	blockTTL = defaultBlockTTL
	crossCheckCycle = defaultCrossCheckCycle

499 500 501 502 503 504
	tester.hashes = hashes
	tester.newPeer("valid", big.NewInt(20000), hashes[0])
	if _, err := tester.syncTake("valid", hashes[0]); err != nil {
		t.Fatalf("failed to synchronise blocks: %v", err)
	}
}