diff --git a/core/tx_list.go b/core/tx_list.go index e30fee38f506353d76eb0125f2a68aafa7d6646d..8c69331cc4ffb6ba88de4561d69d6d16b7cf32a4 100644 --- a/core/tx_list.go +++ b/core/tx_list.go @@ -52,11 +52,11 @@ func (h *nonceHeap) Pop() interface{} { type txList struct { strict bool // Whether nonces are strictly continuous or not items map[uint64]*types.Transaction // Hash map storing the transaction data - cache types.Transactions // cache of the transactions already sorted + cache types.Transactions // Cache of the transactions already sorted first uint64 // Nonce of the lowest stored transaction (strict mode) last uint64 // Nonce of the highest stored transaction (strict mode) - index *nonceHeap // Heap of nonces of all teh stored transactions (non-strict mode) + index *nonceHeap // Heap of nonces of all the stored transactions (non-strict mode) costcap *big.Int // Price of the highest costing transaction (reset only if exceeds balance) } @@ -73,8 +73,8 @@ func newTxList(strict bool) *txList { } } -// Add tries to inserts a new transaction into the list, returning whether the -// transaction was acceped, and if yes, any previous transaction it replaced. +// Add tries to insert a new transaction into the list, returning whether the +// transaction was accepted, and if yes, any previous transaction it replaced. // // In case of strict lists (contiguous nonces) the nonce boundaries are updated // appropriately with the new transaction. Otherwise (gapped nonces) the heap of @@ -146,10 +146,10 @@ func (l *txList) Forward(threshold uint64) types.Transactions { // // This method uses the cached costcap to quickly decide if there's even a point // in calculating all the costs or if the balance covers all. If the threshold is -// loewr than the costcap, the costcap will be reset to a new high after removing +// lower than the costcap, the costcap will be reset to a new high after removing // expensive the too transactions. func (l *txList) Filter(threshold *big.Int) (types.Transactions, types.Transactions) { - // If all transactions are blow the threshold, short circuit + // If all transactions are below the threshold, short circuit if l.costcap.Cmp(threshold) <= 0 { return nil, nil } @@ -195,7 +195,7 @@ func (l *txList) Filter(threshold *big.Int) (types.Transactions, types.Transacti } // Cap places a hard limit on the number of items, returning all transactions -// exceeding tht limit. +// exceeding that limit. func (l *txList) Cap(threshold int) types.Transactions { // Short circuit if the number of items is under the limit if len(l.items) < threshold { @@ -239,8 +239,9 @@ func (l *txList) Remove(tx *types.Transaction) (bool, types.Transactions) { l.cache = nil // Remove all invalidated transactions (strict mode only!) - invalids := make(types.Transactions, 0, l.last-nonce) + var invalids types.Transactions if l.strict { + invalids = make(types.Transactions, 0, l.last-nonce) for i := nonce + 1; i <= l.last; i++ { invalids = append(invalids, l.items[i]) delete(l.items, i) @@ -255,7 +256,6 @@ func (l *txList) Remove(tx *types.Transaction) (bool, types.Transactions) { } } } - // Figure out the new highest nonce return true, invalids } return false, nil @@ -265,7 +265,7 @@ func (l *txList) Remove(tx *types.Transaction) (bool, types.Transactions) { // provided nonce that is ready for processing. The returned transactions will be // removed from the list. // -// Note, all transactions with nonces lower that start will also be returned to +// Note, all transactions with nonces lower than start will also be returned to // prevent getting into and invalid state. This is not something that should ever // happen but better to be self correcting than failing! func (l *txList) Ready(start uint64) types.Transactions { diff --git a/core/tx_pool.go b/core/tx_pool.go index c4dcceba0924ea8165047ad10691f32ebcc663c6..58d304f00b2a4d6de8fa9f098106b69e5a006072 100644 --- a/core/tx_pool.go +++ b/core/tx_pool.go @@ -20,6 +20,7 @@ import ( "errors" "fmt" "math/big" + "sort" "sync" "time" @@ -44,8 +45,11 @@ var ( ErrNegativeValue = errors.New("Negative value") ) -const ( - maxQueued = 64 // max limit of queued txs per address +var ( + maxQueuedPerAccount = uint64(64) // Max limit of queued transactions per address + maxQueuedInTotal = uint64(65536) // Max limit of queued transactions from all accounts + maxQueuedLifetime = 3 * time.Hour // Max amount of time transactions from idle accounts are queued + evictionInterval = time.Minute // Time interval to check for evictable transactions ) type stateFn func() (*state.StateDB, error) @@ -71,8 +75,10 @@ type TxPool struct { pending map[common.Address]*txList // All currently processable transactions queue map[common.Address]*txList // Queued but non-processable transactions all map[common.Hash]*types.Transaction // All transactions to allow lookups + beats map[common.Address]time.Time // Last heartbeat from each known account - wg sync.WaitGroup // for shutdown sync + wg sync.WaitGroup // for shutdown sync + quit chan struct{} homestead bool } @@ -83,6 +89,7 @@ func NewTxPool(config *ChainConfig, eventMux *event.TypeMux, currentStateFn stat pending: make(map[common.Address]*txList), queue: make(map[common.Address]*txList), all: make(map[common.Hash]*types.Transaction), + beats: make(map[common.Address]time.Time), eventMux: eventMux, currentState: currentStateFn, gasLimit: gasLimitFn, @@ -90,10 +97,12 @@ func NewTxPool(config *ChainConfig, eventMux *event.TypeMux, currentStateFn stat pendingState: nil, localTx: newTxSet(), events: eventMux.Subscribe(ChainHeadEvent{}, GasPriceChanged{}, RemovedTransactionEvent{}), + quit: make(chan struct{}), } - pool.wg.Add(1) + pool.wg.Add(2) go pool.eventLoop() + go pool.expirationLoop() return pool } @@ -154,6 +163,7 @@ func (pool *TxPool) resetState() { func (pool *TxPool) Stop() { pool.events.Unsubscribe() + close(pool.quit) pool.wg.Wait() glog.V(logger.Info).Infoln("Transaction pool stopped") } @@ -290,7 +300,7 @@ func (pool *TxPool) add(tx *types.Transaction) error { if pool.all[hash] != nil { return fmt.Errorf("Known transaction: %x", hash[:4]) } - // Otherwise ensure basic validation passes nd queue it up + // Otherwise ensure basic validation passes and queue it up if err := pool.validateTx(tx); err != nil { return err } @@ -308,7 +318,7 @@ func (pool *TxPool) add(tx *types.Transaction) error { return nil } -// enqueueTx inserts a new transction into the non-executable transaction queue. +// enqueueTx inserts a new transaction into the non-executable transaction queue. // // Note, this method assumes the pool lock is held! func (pool *TxPool) enqueueTx(hash common.Hash, tx *types.Transaction) { @@ -355,6 +365,7 @@ func (pool *TxPool) promoteTx(addr common.Address, hash common.Hash, tx *types.T pool.all[hash] = tx // Failsafe to work around direct pending inserts (tests) // Set the potentially new pending nonce and notify any subsystems of the new tx + pool.beats[addr] = time.Now() pool.pendingState.SetNonce(addr, list.last+1) go pool.eventMux.Post(TxPreEvent{tx}) } @@ -412,8 +423,8 @@ func (pool *TxPool) RemoveBatch(txs types.Transactions) { } } -// removeTx iterates removes a single transaction from the queue, moving all -// subsequent transactions back to the future queue. +// removeTx removes a single transaction from the queue, moving all subsequent +// transactions back to the future queue. func (pool *TxPool) removeTx(hash common.Hash) { // Fetch the transaction we wish to delete tx, ok := pool.all[hash] @@ -431,6 +442,8 @@ func (pool *TxPool) removeTx(hash common.Hash) { // If no more transactions are left, remove the list and reset the nonce if pending.Empty() { delete(pool.pending, addr) + delete(pool.beats, addr) + pool.pendingState.SetNonce(addr, tx.Nonce()) } else { // Otherwise update the nonce and postpone any invalidated transactions @@ -465,6 +478,8 @@ func (pool *TxPool) promoteExecutables() { return } // Iterate over all accounts and promote any executable transactions + queued := uint64(0) + for addr, list := range pool.queue { // Drop all transactions that are deemed too old (low nonce) for _, tx := range list.Forward(state.GetNonce(addr)) { @@ -489,17 +504,51 @@ func (pool *TxPool) promoteExecutables() { pool.promoteTx(addr, tx.Hash(), tx) } // Drop all transactions over the allowed limit - for _, tx := range list.Cap(maxQueued) { + for _, tx := range list.Cap(int(maxQueuedPerAccount)) { if glog.V(logger.Core) { glog.Infof("Removed cap-exceeding queued transaction: %v", tx) } delete(pool.all, tx.Hash()) } + queued += uint64(list.Len()) + // Delete the entire queue entry if it became empty. if list.Empty() { delete(pool.queue, addr) } } + // If we've queued more transactions than the hard limit, drop oldest ones + if queued > maxQueuedInTotal { + // Sort all accounts with queued transactions by heartbeat + addresses := make(addresssByHeartbeat, 0, len(pool.queue)) + for addr, _ := range pool.queue { + addresses = append(addresses, addressByHeartbeat{addr, pool.beats[addr]}) + } + sort.Sort(addresses) + + // Drop transactions until the total is below the limit + for drop := queued - maxQueuedInTotal; drop > 0; { + addr := addresses[len(addresses)-1] + list := pool.queue[addr.address] + + addresses = addresses[:len(addresses)-1] + + // Drop all transactions if they are less than the overflow + if size := uint64(list.Len()); size <= drop { + for _, tx := range list.Flatten() { + pool.removeTx(tx.Hash()) + } + drop -= size + continue + } + // Otherwise drop only last few transactions + txs := list.Flatten() + for i := len(txs) - 1; i >= 0 && drop > 0; i-- { + pool.removeTx(txs[i].Hash()) + drop-- + } + } + } } // demoteUnexecutables removes invalid and processed transactions from the pools @@ -540,10 +589,51 @@ func (pool *TxPool) demoteUnexecutables() { // Delete the entire queue entry if it became empty. if list.Empty() { delete(pool.pending, addr) + delete(pool.beats, addr) } } } +// expirationLoop is a loop that periodically iterates over all accounts with +// queued transactions and drop all that have been inactive for a prolonged amount +// of time. +func (pool *TxPool) expirationLoop() { + defer pool.wg.Done() + + evict := time.NewTicker(evictionInterval) + defer evict.Stop() + + for { + select { + case <-evict.C: + pool.mu.Lock() + for addr := range pool.queue { + if time.Since(pool.beats[addr]) > maxQueuedLifetime { + for _, tx := range pool.queue[addr].Flatten() { + pool.removeTx(tx.Hash()) + } + } + } + pool.mu.Unlock() + + case <-pool.quit: + return + } + } +} + +// addressByHeartbeat is an account address tagged with its last activity timestamp. +type addressByHeartbeat struct { + address common.Address + heartbeat time.Time +} + +type addresssByHeartbeat []addressByHeartbeat + +func (a addresssByHeartbeat) Len() int { return len(a) } +func (a addresssByHeartbeat) Less(i, j int) bool { return a[i].heartbeat.Before(a[j].heartbeat) } +func (a addresssByHeartbeat) Swap(i, j int) { a[i], a[j] = a[j], a[i] } + // txSet represents a set of transaction hashes in which entries // are automatically dropped after txSetDuration time type txSet struct { diff --git a/core/tx_pool_test.go b/core/tx_pool_test.go index ec54d8c0e121b5158efdff446097b190d39a00c2..f08334fa190b8ffd7fc96a3c80ea42a1c39754e8 100644 --- a/core/tx_pool_test.go +++ b/core/tx_pool_test.go @@ -19,7 +19,9 @@ package core import ( "crypto/ecdsa" "math/big" + "math/rand" "testing" + "time" "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/core/state" @@ -38,10 +40,10 @@ func setupTxPool() (*TxPool, *ecdsa.PrivateKey) { db, _ := ethdb.NewMemDatabase() statedb, _ := state.New(common.Hash{}, db) - var m event.TypeMux key, _ := crypto.GenerateKey() - newPool := NewTxPool(testChainConfig(), &m, func() (*state.StateDB, error) { return statedb, nil }, func() *big.Int { return big.NewInt(1000000) }) + newPool := NewTxPool(testChainConfig(), new(event.TypeMux), func() (*state.StateDB, error) { return statedb, nil }, func() *big.Int { return big.NewInt(1000000) }) newPool.resetState() + return newPool, key } @@ -438,7 +440,7 @@ func TestTransactionPostponing(t *testing.T) { // Tests that if the transaction count belonging to a single account goes above // some threshold, the higher transactions are dropped to prevent DOS attacks. -func TestTransactionQueueLimiting(t *testing.T) { +func TestTransactionQueueAccountLimiting(t *testing.T) { // Create a test account and fund it pool, key := setupTxPool() account, _ := transaction(0, big.NewInt(0), key).From() @@ -447,25 +449,103 @@ func TestTransactionQueueLimiting(t *testing.T) { state.AddBalance(account, big.NewInt(1000000)) // Keep queuing up transactions and make sure all above a limit are dropped - for i := uint64(1); i <= maxQueued+5; i++ { + for i := uint64(1); i <= maxQueuedPerAccount+5; i++ { if err := pool.Add(transaction(i, big.NewInt(100000), key)); err != nil { t.Fatalf("tx %d: failed to add transaction: %v", i, err) } if len(pool.pending) != 0 { t.Errorf("tx %d: pending pool size mismatch: have %d, want %d", i, len(pool.pending), 0) } - if i <= maxQueued { + if i <= maxQueuedPerAccount { if pool.queue[account].Len() != int(i) { t.Errorf("tx %d: queue size mismatch: have %d, want %d", i, pool.queue[account].Len(), i) } } else { - if pool.queue[account].Len() != maxQueued { - t.Errorf("tx %d: queue limit mismatch: have %d, want %d", i, pool.queue[account].Len(), maxQueued) + if pool.queue[account].Len() != int(maxQueuedPerAccount) { + t.Errorf("tx %d: queue limit mismatch: have %d, want %d", i, pool.queue[account].Len(), maxQueuedPerAccount) } } } - if len(pool.all) != maxQueued { - t.Errorf("total transaction mismatch: have %d, want %d", len(pool.all), maxQueued) + if len(pool.all) != int(maxQueuedPerAccount) { + t.Errorf("total transaction mismatch: have %d, want %d", len(pool.all), maxQueuedPerAccount) + } +} + +// Tests that if the transaction count belonging to multiple accounts go above +// some threshold, the higher transactions are dropped to prevent DOS attacks. +func TestTransactionQueueGlobalLimiting(t *testing.T) { + // Reduce the queue limits to shorten test time + defer func(old uint64) { maxQueuedInTotal = old }(maxQueuedInTotal) + maxQueuedInTotal = maxQueuedPerAccount * 3 + + // Create the pool to test the limit enforcement with + db, _ := ethdb.NewMemDatabase() + statedb, _ := state.New(common.Hash{}, db) + + pool := NewTxPool(testChainConfig(), new(event.TypeMux), func() (*state.StateDB, error) { return statedb, nil }, func() *big.Int { return big.NewInt(1000000) }) + pool.resetState() + + // Create a number of test accounts and fund them + state, _ := pool.currentState() + + keys := make([]*ecdsa.PrivateKey, 5) + for i := 0; i < len(keys); i++ { + keys[i], _ = crypto.GenerateKey() + state.AddBalance(crypto.PubkeyToAddress(keys[i].PublicKey), big.NewInt(1000000)) + } + // Generate and queue a batch of transactions + nonces := make(map[common.Address]uint64) + + txs := make(types.Transactions, 0, 3*maxQueuedInTotal) + for len(txs) < cap(txs) { + key := keys[rand.Intn(len(keys))] + addr := crypto.PubkeyToAddress(key.PublicKey) + + txs = append(txs, transaction(nonces[addr]+1, big.NewInt(100000), key)) + nonces[addr]++ + } + // Import the batch and verify that limits have been enforced + pool.AddBatch(txs) + + queued := 0 + for addr, list := range pool.queue { + if list.Len() > int(maxQueuedPerAccount) { + t.Errorf("addr %x: queued accounts overflown allowance: %d > %d", addr, list.Len(), maxQueuedPerAccount) + } + queued += list.Len() + } + if queued > int(maxQueuedInTotal) { + t.Fatalf("total transactions overflow allowance: %d > %d", queued, maxQueuedInTotal) + } +} + +// Tests that if an account remains idle for a prolonged amount of time, any +// non-executable transactions queued up are dropped to prevent wasting resources +// on shuffling them around. +func TestTransactionQueueTimeLimiting(t *testing.T) { + // Reduce the queue limits to shorten test time + defer func(old time.Duration) { maxQueuedLifetime = old }(maxQueuedLifetime) + defer func(old time.Duration) { evictionInterval = old }(evictionInterval) + maxQueuedLifetime = time.Second + evictionInterval = time.Second + + // Create a test account and fund it + pool, key := setupTxPool() + account, _ := transaction(0, big.NewInt(0), key).From() + + state, _ := pool.currentState() + state.AddBalance(account, big.NewInt(1000000)) + + // Queue up a batch of transactions + for i := uint64(1); i <= maxQueuedPerAccount; i++ { + if err := pool.Add(transaction(i, big.NewInt(100000), key)); err != nil { + t.Fatalf("tx %d: failed to add transaction: %v", i, err) + } + } + // Wait until at least two expiration cycles hit and make sure the transactions are gone + time.Sleep(2 * evictionInterval) + if len(pool.queue) > 0 { + t.Fatalf("old transactions remained after eviction") } } @@ -481,7 +561,7 @@ func TestTransactionPendingLimiting(t *testing.T) { state.AddBalance(account, big.NewInt(1000000)) // Keep queuing up transactions and make sure all above a limit are dropped - for i := uint64(0); i < maxQueued+5; i++ { + for i := uint64(0); i < maxQueuedPerAccount+5; i++ { if err := pool.Add(transaction(i, big.NewInt(100000), key)); err != nil { t.Fatalf("tx %d: failed to add transaction: %v", i, err) } @@ -492,8 +572,8 @@ func TestTransactionPendingLimiting(t *testing.T) { t.Errorf("tx %d: queue size mismatch: have %d, want %d", i, pool.queue[account].Len(), 0) } } - if len(pool.all) != maxQueued+5 { - t.Errorf("total transaction mismatch: have %d, want %d", len(pool.all), maxQueued+5) + if len(pool.all) != int(maxQueuedPerAccount+5) { + t.Errorf("total transaction mismatch: have %d, want %d", len(pool.all), maxQueuedPerAccount+5) } } @@ -509,7 +589,7 @@ func testTransactionLimitingEquivalency(t *testing.T, origin uint64) { state1, _ := pool1.currentState() state1.AddBalance(account1, big.NewInt(1000000)) - for i := uint64(0); i < maxQueued+5; i++ { + for i := uint64(0); i < maxQueuedPerAccount+5; i++ { if err := pool1.Add(transaction(origin+i, big.NewInt(100000), key1)); err != nil { t.Fatalf("tx %d: failed to add transaction: %v", i, err) } @@ -521,7 +601,7 @@ func testTransactionLimitingEquivalency(t *testing.T, origin uint64) { state2.AddBalance(account2, big.NewInt(1000000)) txns := []*types.Transaction{} - for i := uint64(0); i < maxQueued+5; i++ { + for i := uint64(0); i < maxQueuedPerAccount+5; i++ { txns = append(txns, transaction(origin+i, big.NewInt(100000), key2)) } pool2.AddBatch(txns)