diff --git a/paddle/fluid/distributed/table/common_graph_table.h b/paddle/fluid/distributed/table/common_graph_table.h index 2eb93a15f74360aae02e7d67b5e1766ebf92347b..ef5235eab1034c9c313aa8042054df7bb1970d0d 100644 --- a/paddle/fluid/distributed/table/common_graph_table.h +++ b/paddle/fluid/distributed/table/common_graph_table.h @@ -123,6 +123,8 @@ class RandomSampleLRU { node_size = 0; node_head = node_end = NULL; global_ttl = father->ttl; + extra_penalty = 0; + size_limit = (father->size_limit / father->shard_num + 1); } ~RandomSampleLRU() { @@ -138,16 +140,16 @@ class RandomSampleLRU { return LRUResponse::blocked; int init_node_size = node_size; try { + // pthread_rwlock_rdlock(&father->rwlock); for (size_t i = 0; i < length; i++) { auto iter = key_map.find(keys[i]); if (iter != key_map.end()) { res.emplace_back(keys[i], iter->second->data); iter->second->ttl--; if (iter->second->ttl == 0) { - remove(iter->second, true); - } else { remove(iter->second); - add_to_tail(iter->second); + } else { + move_to_tail(iter->second); } } } @@ -168,14 +170,12 @@ class RandomSampleLRU { for (size_t i = 0; i < length; i++) { auto iter = key_map.find(keys[i]); if (iter != key_map.end()) { + move_to_tail(iter->second); iter->second->ttl = global_ttl; - remove(iter->second); - add_to_tail(iter->second); iter->second->data = data[i]; } else { LRUNode *temp = new LRUNode(keys[i], data[i], global_ttl); - add_to_tail(temp); - key_map[keys[i]] = temp; + add_new(temp); } } } catch (...) { @@ -187,25 +187,34 @@ class RandomSampleLRU { father->handle_size_diff(node_size - init_node_size); return LRUResponse::ok; } - void remove(LRUNode *node, bool del = false) { - if (node->pre) { - node->pre->next = node->next; - } else { - node_head = node->next; - } - if (node->next) { - node->next->pre = node->pre; - } else { - node_end = node->pre; - } + void remove(LRUNode *node) { + fetch(node); node_size--; - if (del) { - delete node; - key_map.erase(node->key); + key_map.erase(node->key); + delete node; + if (node_size >= size_limit) { + extra_penalty -= 1.0; } } - void add_to_tail(LRUNode *node) { + void move_to_tail(LRUNode *node) { + fetch(node); + place_at_tail(node); + } + + void add_new(LRUNode *node) { + node->ttl = global_ttl; + place_at_tail(node); + node_size++; + key_map[node->key] = node; + if (node_size > size_limit) { + extra_penalty += penalty_inc; + if (extra_penalty >= 1.0) { + remove(node_head); + } + } + } + void place_at_tail(LRUNode *node) { if (node_end == NULL) { node_head = node_end = node; node->next = node->pre = NULL; @@ -215,25 +224,40 @@ class RandomSampleLRU { node->next = NULL; node_end = node; } - node_size++; node->ms = std::chrono::duration_cast( std::chrono::system_clock::now().time_since_epoch()); } + void fetch(LRUNode *node) { + if (node->pre) { + node->pre->next = node->next; + } else { + node_head = node->next; + } + if (node->next) { + node->next->pre = node->pre; + } else { + node_end = node->pre; + } + } + private: std::unordered_map *> key_map; ScaledLRU *father; - size_t global_ttl; + size_t global_ttl, size_limit; int node_size; LRUNode *node_head, *node_end; friend class ScaledLRU; + float extra_penalty; + const float penalty_inc = 0.75; }; template class ScaledLRU { public: - ScaledLRU(size_t shard_num, size_t size_limit, size_t _ttl) + ScaledLRU(size_t _shard_num, size_t size_limit, size_t _ttl) : size_limit(size_limit), ttl(_ttl) { + shard_num = _shard_num; pthread_rwlock_init(&rwlock, NULL); stop = false; thread_pool.reset(new ::ThreadPool(1)); @@ -244,12 +268,11 @@ class ScaledLRU { while (true) { { std::unique_lock lock(mutex_); - cv_.wait_for(lock, std::chrono::milliseconds(3000)); + cv_.wait_for(lock, std::chrono::milliseconds(20000)); if (stop) { return; } } - auto status = thread_pool->enqueue([this]() -> int { return shrink(); }); status.wait(); @@ -271,12 +294,11 @@ class ScaledLRU { } int shrink() { int node_size = 0; - std::string t = ""; for (size_t i = 0; i < lru_pool.size(); i++) { node_size += lru_pool[i].node_size; } - if (node_size <= size_limit) return 0; + if (node_size <= 1.2 * size_limit) return 0; if (pthread_rwlock_wrlock(&rwlock) == 0) { try { global_count = 0; @@ -301,14 +323,16 @@ class ScaledLRU { q.push({next, remove_node.lru_pointer}); } global_count--; - remove_node.lru_pointer->key_map.erase(remove_node.node->key); - remove_node.lru_pointer->remove(remove_node.node, true); + remove_node.lru_pointer->remove(remove_node.node); } - // VLOG(0)<<"after shrinking cache, cached nodes count = - // "< int(1.5 * size_limit)) { + if (global_count > int(1.25 * size_limit)) { // VLOG(0)<<"global_count too large "<enqueue([this]() -> int { return shrink(); }); @@ -332,6 +356,7 @@ class ScaledLRU { private: pthread_rwlock_t rwlock; + size_t shard_num; int global_count; size_t size_limit; size_t ttl;