downloader_test.go 8.9 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12
package downloader

import (
	"encoding/binary"
	"math/big"
	"testing"
	"time"

	"github.com/ethereum/go-ethereum/common"
	"github.com/ethereum/go-ethereum/core/types"
)

13 14 15 16
var (
	knownHash   = common.Hash{1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}
	unknownHash = common.Hash{9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9}
)
17

18
func createHashes(start, amount int) (hashes []common.Hash) {
19 20 21 22 23 24 25 26 27 28
	hashes = make([]common.Hash, amount+1)
	hashes[len(hashes)-1] = knownHash

	for i := range hashes[:len(hashes)-1] {
		binary.BigEndian.PutUint64(hashes[i][:8], uint64(i+2))
	}

	return
}

29 30 31 32
func createBlock(i int, prevHash, hash common.Hash) *types.Block {
	header := &types.Header{Number: big.NewInt(int64(i))}
	block := types.NewBlockWithHeader(header)
	block.HeaderHash = hash
33
	block.ParentHeaderHash = prevHash
34 35 36
	return block
}

37 38
func createBlocksFromHashes(hashes []common.Hash) map[common.Hash]*types.Block {
	blocks := make(map[common.Hash]*types.Block)
39

40
	for i, hash := range hashes {
41
		blocks[hash] = createBlock(len(hashes)-i, knownHash, hash)
42 43 44 45 46 47
	}

	return blocks
}

type downloadTester struct {
48 49 50 51 52 53
	downloader *Downloader

	hashes []common.Hash                // Chain of hashes simulating
	blocks map[common.Hash]*types.Block // Blocks associated with the hashes
	chain  []common.Hash                // Block-chain being constructed

O
obscuren 已提交
54 55 56 57
	t            *testing.T
	pcount       int
	done         chan bool
	activePeerId string
58 59 60
}

func newTester(t *testing.T, hashes []common.Hash, blocks map[common.Hash]*types.Block) *downloadTester {
61 62 63 64 65 66 67 68 69
	tester := &downloadTester{
		t: t,

		hashes: hashes,
		blocks: blocks,
		chain:  []common.Hash{knownHash},

		done: make(chan bool),
	}
70
	downloader := New(tester.hasBlock, tester.getBlock)
71 72 73 74 75
	tester.downloader = downloader

	return tester
}

O
obscuren 已提交
76 77
func (dl *downloadTester) sync(peerId string, hash common.Hash) error {
	dl.activePeerId = peerId
78
	return dl.downloader.Synchronise(peerId, hash)
O
obscuren 已提交
79 80
}

81 82 83 84 85 86
func (dl *downloadTester) insertBlocks(blocks types.Blocks) {
	for _, block := range blocks {
		dl.chain = append(dl.chain, block.Hash())
	}
}

87
func (dl *downloadTester) hasBlock(hash common.Hash) bool {
88 89 90 91
	for _, h := range dl.chain {
		if h == hash {
			return true
		}
92 93 94 95
	}
	return false
}

96 97
func (dl *downloadTester) getBlock(hash common.Hash) *types.Block {
	return dl.blocks[knownHash]
98 99 100
}

func (dl *downloadTester) getHashes(hash common.Hash) error {
101
	dl.downloader.DeliverHashes(dl.activePeerId, dl.hashes)
102 103 104 105 106 107 108 109 110 111
	return nil
}

func (dl *downloadTester) getBlocks(id string) func([]common.Hash) error {
	return func(hashes []common.Hash) error {
		blocks := make([]*types.Block, len(hashes))
		for i, hash := range hashes {
			blocks[i] = dl.blocks[hash]
		}

112
		go dl.downloader.DeliverBlocks(id, blocks)
113 114 115 116 117 118 119 120

		return nil
	}
}

func (dl *downloadTester) newPeer(id string, td *big.Int, hash common.Hash) {
	dl.pcount++

O
obscuren 已提交
121
	dl.downloader.RegisterPeer(id, hash, dl.getHashes, dl.getBlocks(id))
122 123 124 125 126 127
}

func (dl *downloadTester) badBlocksPeer(id string, td *big.Int, hash common.Hash) {
	dl.pcount++

	// This bad peer never returns any blocks
O
obscuren 已提交
128
	dl.downloader.RegisterPeer(id, hash, dl.getHashes, func([]common.Hash) error {
129 130 131 132 133
		return nil
	})
}

func TestDownload(t *testing.T) {
134
	minDesiredPeerCount = 4
O
obscuren 已提交
135
	blockTtl = 1 * time.Second
136

137 138
	targetBlocks := 1000
	hashes := createHashes(0, targetBlocks)
139 140 141
	blocks := createBlocksFromHashes(hashes)
	tester := newTester(t, hashes, blocks)

142
	tester.newPeer("peer1", big.NewInt(10000), hashes[0])
143 144 145
	tester.newPeer("peer2", big.NewInt(0), common.Hash{})
	tester.badBlocksPeer("peer3", big.NewInt(0), common.Hash{})
	tester.badBlocksPeer("peer4", big.NewInt(0), common.Hash{})
O
obscuren 已提交
146
	tester.activePeerId = "peer1"
147

O
obscuren 已提交
148
	err := tester.sync("peer1", hashes[0])
149 150 151 152
	if err != nil {
		t.Error("download error", err)
	}

153
	inqueue := len(tester.downloader.queue.blockCache)
154 155
	if inqueue != targetBlocks {
		t.Error("expected", targetBlocks, "have", inqueue)
156 157
	}
}
158 159

func TestMissing(t *testing.T) {
160
	targetBlocks := 1000
161 162 163 164 165 166 167 168 169 170
	hashes := createHashes(0, 1000)
	extraHashes := createHashes(1001, 1003)
	blocks := createBlocksFromHashes(append(extraHashes, hashes...))
	tester := newTester(t, hashes, blocks)

	tester.newPeer("peer1", big.NewInt(10000), hashes[len(hashes)-1])

	hashes = append(extraHashes, hashes[:len(hashes)-1]...)
	tester.newPeer("peer2", big.NewInt(0), common.Hash{})

O
obscuren 已提交
171
	err := tester.sync("peer1", hashes[0])
172 173
	if err != nil {
		t.Error("download error", err)
174 175
	}

176
	inqueue := len(tester.downloader.queue.blockCache)
177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195
	if inqueue != targetBlocks {
		t.Error("expected", targetBlocks, "have", inqueue)
	}
}

func TestTaking(t *testing.T) {
	minDesiredPeerCount = 4
	blockTtl = 1 * time.Second

	targetBlocks := 1000
	hashes := createHashes(0, targetBlocks)
	blocks := createBlocksFromHashes(hashes)
	tester := newTester(t, hashes, blocks)

	tester.newPeer("peer1", big.NewInt(10000), hashes[0])
	tester.newPeer("peer2", big.NewInt(0), common.Hash{})
	tester.badBlocksPeer("peer3", big.NewInt(0), common.Hash{})
	tester.badBlocksPeer("peer4", big.NewInt(0), common.Hash{})

O
obscuren 已提交
196
	err := tester.sync("peer1", hashes[0])
197 198 199
	if err != nil {
		t.Error("download error", err)
	}
200
	bs := tester.downloader.TakeBlocks()
201 202
	if len(bs) != targetBlocks {
		t.Error("retrieved block mismatch: have %v, want %v", len(bs), targetBlocks)
203
	}
204
}
205

206 207 208 209 210 211
func TestInactiveDownloader(t *testing.T) {
	targetBlocks := 1000
	hashes := createHashes(0, targetBlocks)
	blocks := createBlocksFromHashSet(createHashSet(hashes))
	tester := newTester(t, hashes, nil)

212
	err := tester.downloader.DeliverHashes("bad peer 001", hashes)
213 214 215 216
	if err != errNoSyncActive {
		t.Error("expected no sync error, got", err)
	}

217
	err = tester.downloader.DeliverBlocks("bad peer 001", blocks)
218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248
	if err != errNoSyncActive {
		t.Error("expected no sync error, got", err)
	}
}

func TestCancel(t *testing.T) {
	minDesiredPeerCount = 4
	blockTtl = 1 * time.Second

	targetBlocks := 1000
	hashes := createHashes(0, targetBlocks)
	blocks := createBlocksFromHashes(hashes)
	tester := newTester(t, hashes, blocks)

	tester.newPeer("peer1", big.NewInt(10000), hashes[0])

	err := tester.sync("peer1", hashes[0])
	if err != nil {
		t.Error("download error", err)
	}

	if !tester.downloader.Cancel() {
		t.Error("cancel operation unsuccessfull")
	}

	hashSize, blockSize := tester.downloader.queue.Size()
	if hashSize > 0 || blockSize > 0 {
		t.Error("block (", blockSize, ") or hash (", hashSize, ") not 0")
	}
}

249 250 251 252
func TestThrottling(t *testing.T) {
	minDesiredPeerCount = 4
	blockTtl = 1 * time.Second

253
	targetBlocks := 16 * blockCacheLimit
254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271
	hashes := createHashes(0, targetBlocks)
	blocks := createBlocksFromHashes(hashes)
	tester := newTester(t, hashes, blocks)

	tester.newPeer("peer1", big.NewInt(10000), hashes[0])
	tester.newPeer("peer2", big.NewInt(0), common.Hash{})
	tester.badBlocksPeer("peer3", big.NewInt(0), common.Hash{})
	tester.badBlocksPeer("peer4", big.NewInt(0), common.Hash{})

	// Concurrently download and take the blocks
	errc := make(chan error, 1)
	go func() {
		errc <- tester.sync("peer1", hashes[0])
	}()

	done := make(chan struct{})
	took := []*types.Block{}
	go func() {
272
		for running := true; running; {
273 274
			select {
			case <-done:
275
				running = false
276
			default:
277
				time.Sleep(time.Millisecond)
278
			}
279
			// Take a batch of blocks and accumulate
280
			took = append(took, tester.downloader.TakeBlocks()...)
281
		}
282
		done <- struct{}{}
283 284
	}()

285
	// Synchronise the two threads and verify
286 287 288 289 290
	err := <-errc
	done <- struct{}{}
	<-done

	if err != nil {
291
		t.Fatalf("failed to synchronise blocks: %v", err)
292 293 294 295 296
	}
	if len(took) != targetBlocks {
		t.Fatalf("downloaded block mismatch: have %v, want %v", len(took), targetBlocks)
	}
}
297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313

// Tests that if a peer returns an invalid chain with a block pointing to a non-
// existing parent, it is correctly detected and handled.
func TestNonExistingParentAttack(t *testing.T) {
	// Forge a single-link chain with a forged header
	hashes := createHashes(0, 1)
	blocks := createBlocksFromHashes(hashes)

	forged := blocks[hashes[0]]
	forged.ParentHeaderHash = unknownHash

	// Try and sync with the malicious node and check that it fails
	tester := newTester(t, hashes, blocks)
	tester.newPeer("attack", big.NewInt(10000), hashes[0])
	if err := tester.sync("attack", hashes[0]); err != nil {
		t.Fatalf("failed to synchronise blocks: %v", err)
	}
314 315 316
	bs := tester.downloader.TakeBlocks()
	if len(bs) != 1 {
		t.Fatalf("retrieved block mismatch: have %v, want %v", len(bs), 1)
317
	}
318 319
	if tester.hasBlock(bs[0].ParentHash()) {
		t.Fatalf("tester knows about the unknown hash")
320 321 322 323 324 325 326 327 328
	}
	tester.downloader.Cancel()

	// Reconstruct a valid chain, and try to synchronize with it
	forged.ParentHeaderHash = knownHash
	tester.newPeer("valid", big.NewInt(20000), hashes[0])
	if err := tester.sync("valid", hashes[0]); err != nil {
		t.Fatalf("failed to synchronise blocks: %v", err)
	}
329
	bs = tester.downloader.TakeBlocks()
330
	if len(bs) != 1 {
331
		t.Fatalf("retrieved block mismatch: have %v, want %v", len(bs), 1)
332
	}
333 334 335
	if !tester.hasBlock(bs[0].ParentHash()) {
		t.Fatalf("tester doesn't know about the origin hash")
	}
336
}