diff --git a/paddle/fluid/distributed/table/common_graph_table.h b/paddle/fluid/distributed/table/common_graph_table.h index d681262c664807943bd3dda9bce4512495a441ed..5c226a14cd656a31aa66e98096530645d249a0c6 100644 --- a/paddle/fluid/distributed/table/common_graph_table.h +++ b/paddle/fluid/distributed/table/common_graph_table.h @@ -17,11 +17,23 @@ #include #include #include +#include +#include +#include +#include +#include +#include #include +#include #include #include // NOLINT +#include +#include +#include #include +#include #include +#include #include #include #include "paddle/fluid/distributed/table/accessor.h" @@ -62,6 +74,294 @@ class GraphShard { int shard_num; std::vector bucket; }; + +enum LRUResponse { ok = 0, blocked = 1, err = 2 }; + +struct SampleKey { + uint64_t node_key; + size_t sample_size; + bool operator==(const SampleKey &s) const { + return node_key == s.node_key && sample_size == s.sample_size; + } +}; + +struct SampleKeyHash { + size_t operator()(const SampleKey &s) const { + return s.node_key ^ s.sample_size; + } +}; + +class SampleResult { + public: + size_t actual_size; + char *buffer; + SampleResult(size_t _actual_size, char *_buffer) : actual_size(_actual_size) { + buffer = new char[actual_size]; + memcpy(buffer, _buffer, actual_size); + } + ~SampleResult() { + // std::cout<<"in SampleResult deconstructor\n"; + delete[] buffer; + } +}; + +template +class LRUNode { + public: + LRUNode(K _key, V _data, size_t _ttl) : key(_key), data(_data), ttl(_ttl) { + next = pre = NULL; + } + std::chrono::milliseconds ms; + // the last hit time + K key; + V data; + size_t ttl; + // time to live + LRUNode *pre, *next; +}; +template > +class ScaledLRU; + +template > +class RandomSampleLRU { + public: + RandomSampleLRU(ScaledLRU *_father) : father(_father) { + node_size = 0; + node_head = node_end = NULL; + global_ttl = father->ttl; + } + + ~RandomSampleLRU() { + LRUNode *p; + while (node_head != NULL) { + p = node_head->next; + delete node_head; + node_head = p; + } + } + LRUResponse query(K *keys, size_t length, std::vector> &res) { + if (pthread_rwlock_tryrdlock(&father->rwlock) != 0) + return LRUResponse::blocked; + int init_node_size = node_size; + try { + for (size_t i = 0; i < length; i++) { + auto iter = key_map.find(keys[i]); + if (iter != key_map.end()) { + res.push_back({keys[i], iter->second->data}); + iter->second->ttl--; + if (iter->second->ttl == 0) { + remove(iter->second, true); + } else { + remove(iter->second); + add_to_tail(iter->second); + } + } + } + } catch (...) { + pthread_rwlock_unlock(&father->rwlock); + father->handle_size_diff(node_size - init_node_size); + return LRUResponse::err; + } + pthread_rwlock_unlock(&father->rwlock); + father->handle_size_diff(node_size - init_node_size); + return LRUResponse::ok; + } + LRUResponse insert(K *keys, V *data, size_t length) { + if (pthread_rwlock_tryrdlock(&father->rwlock) != 0) + return LRUResponse::blocked; + int init_node_size = node_size; + try { + for (size_t i = 0; i < length; i++) { + auto iter = key_map.find(keys[i]); + if (iter != key_map.end()) { + iter->second->ttl = global_ttl; + remove(iter->second); + add_to_tail(iter->second); + iter->second->data = data[i]; + } else { + LRUNode *temp = new LRUNode(keys[i], data[i], global_ttl); + add_to_tail(temp); + key_map[keys[i]] = temp; + } + } + } catch (...) { + pthread_rwlock_unlock(&father->rwlock); + father->handle_size_diff(node_size - init_node_size); + return LRUResponse::err; + } + pthread_rwlock_unlock(&father->rwlock); + father->handle_size_diff(node_size - init_node_size); + return LRUResponse::ok; + } + void remove(LRUNode *node, bool del = false) { + if (node->pre) { + node->pre->next = node->next; + } else { + node_head = node->next; + } + if (node->next) { + node->next->pre = node->pre; + } else { + node_end = node->pre; + } + node_size--; + if (del) { + delete node; + key_map.erase(node->key); + } + } + + void add_to_tail(LRUNode *node) { + if (node_end == NULL) { + node_head = node_end = node; + node->next = node->pre = NULL; + } else { + node_end->next = node; + node->pre = node_end; + node->next = NULL; + node_end = node; + } + node_size++; + node->ms = std::chrono::duration_cast( + std::chrono::system_clock::now().time_since_epoch()); + } + + private: + std::unordered_map *, Hash> key_map; + ScaledLRU *father; + size_t global_ttl; + int node_size; + LRUNode *node_head, *node_end; + friend class ScaledLRU; +}; + +template +class ScaledLRU { + public: + ScaledLRU(size_t shard_num, size_t size_limit, size_t _ttl) + : size_limit(size_limit), ttl(_ttl) { + pthread_rwlock_init(&rwlock, NULL); + stop = false; + thread_pool.reset(new ::ThreadPool(1)); + global_count = 0; + lru_pool = std::vector>( + shard_num, RandomSampleLRU(this)); + shrink_job = std::thread([this]() -> void { + while (true) { + { + std::unique_lock lock(mutex_); + cv_.wait_for(lock, std::chrono::milliseconds(3000)); + if (stop) { + return; + } + } + + // shrink(); + // std::cerr<<"shrink job in queue\n"; + auto status = + thread_pool->enqueue([this]() -> int { return shrink(); }); + status.wait(); + } + }); + shrink_job.detach(); + } + ~ScaledLRU() { + std::unique_lock lock(mutex_); + // std::cerr<<"cancel shrink job\n"; + stop = true; + cv_.notify_one(); + // pthread_cancel(shrink_job.native_handle()); + } + LRUResponse query(size_t index, K *keys, size_t length, + std::vector> &res) { + return lru_pool[index].query(keys, length, res); + } + LRUResponse insert(size_t index, K *keys, V *data, size_t length) { + return lru_pool[index].insert(keys, data, length); + } + int shrink() { + int node_size = 0; + std::string t = ""; + for (size_t i = 0; i < lru_pool.size(); i++) { + node_size += lru_pool[i].node_size; + // t += std::to_string(i) + "->" + std::to_string(lru_pool[i].node_size) + + // " "; + } + // std::cout<, + std::greater> + q; + for (size_t i = 0; i < lru_pool.size(); i++) { + if (lru_pool[i].node_size > 0) { + global_count += lru_pool[i].node_size; + q.push({lru_pool[i].node_head, &lru_pool[i]}); + } + } + if (global_count > size_limit) { + // std::cout<<"before shrinking cache, cached nodes count = + // "<next; + if (next) { + q.push({next, remove_node.lru_pointer}); + } + global_count--; + remove_node.lru_pointer->key_map.erase(remove_node.node->key); + remove_node.lru_pointer->remove(remove_node.node, true); + } + // std::cout<<"after shrinking cache, cached nodes count = + // "< int(1.5 * size_limit)) { + // std::cout<<"global_count too large "<enqueue([this]() -> int { return shrink(); }); + } + } + } + + size_t get_ttl() { return ttl; } + + private: + pthread_rwlock_t rwlock; + int global_count; + size_t size_limit; + size_t ttl; + bool stop; + std::thread shrink_job; + std::vector> lru_pool; + mutable std::mutex mutex_; + std::condition_variable cv_; + struct RemovedNode { + LRUNode *node; + RandomSampleLRU *lru_pointer; + bool operator>(const RemovedNode &a) const { return node->ms > a.node->ms; } + }; + std::shared_ptr<::ThreadPool> thread_pool; + friend class RandomSampleLRU; +}; + class GraphTable : public SparseTable { public: GraphTable() {} diff --git a/paddle/fluid/distributed/test/graph_node_test.cc b/paddle/fluid/distributed/test/graph_node_test.cc index 613770220f9d7995242da79f3b5fd70142c119f0..859478e167771456d245ded975970a633379f4bf 100644 --- a/paddle/fluid/distributed/test/graph_node_test.cc +++ b/paddle/fluid/distributed/test/graph_node_test.cc @@ -222,6 +222,7 @@ void testBatchSampleNeighboor( } } +void testCache(); void testGraphToBuffer(); // std::string nodes[] = {std::string("37\taa\t45;0.34\t145;0.31\t112;0.21"), // std::string("96\tfeature\t48;1.4\t247;0.31\t111;1.21"), @@ -400,6 +401,8 @@ void RunClient( } void RunBrpcPushSparse() { + std::cout << "in test cache"; + testCache(); setenv("http_proxy", "", 1); setenv("https_proxy", "", 1); prepare_file(edge_file_name, 1); @@ -607,6 +610,64 @@ void RunBrpcPushSparse() { client1.stop_server(); } +void testCache() { + ::paddle::distributed::ScaledLRU< + ::paddle::distributed::SampleKey, + std::shared_ptr<::paddle::distributed::SampleResult>, + ::paddle::distributed::SampleKeyHash> + st(1, 2, 4); + std::shared_ptr<::paddle::distributed::SampleResult> sp; + char* str = (char*)"54321"; + ::paddle::distributed::SampleResult* result = + new ::paddle::distributed::SampleResult(5, str); + ::paddle::distributed::SampleKey skey = {6, 1}; + sp.reset(result); + std::vector>> + r; + st.query(0, &skey, 1, r); + ASSERT_EQ((int)r.size(), 0); + + st.insert(0, &skey, &sp, 1); + for (int i = 0; i < st.get_ttl(); i++) { + st.query(0, &skey, 1, r); + ASSERT_EQ((int)r.size(), 1); + char* p = (char*)r[0].second.get()->buffer; + for (int j = 0; j < r[0].second.get()->actual_size; j++) + ASSERT_EQ(p[j], str[j]); + r.clear(); + } + st.query(0, &skey, 1, r); + ASSERT_EQ((int)r.size(), 0); + str = (char*)"342cd4321"; + result = new ::paddle::distributed::SampleResult(strlen(str), str); + std::shared_ptr<::paddle::distributed::SampleResult> sp1; + sp1.reset(result); + st.insert(0, &skey, &sp1, 1); + for (int i = 0; i < st.get_ttl() / 2; i++) { + st.query(0, &skey, 1, r); + ASSERT_EQ((int)r.size(), 1); + char* p = (char*)r[0].second.get()->buffer; + for (int j = 0; j < r[0].second.get()->actual_size; j++) + ASSERT_EQ(p[j], str[j]); + r.clear(); + } + str = (char*)"343332d4321"; + result = new ::paddle::distributed::SampleResult(strlen(str), str); + std::shared_ptr<::paddle::distributed::SampleResult> sp2; + sp2.reset(result); + st.insert(0, &skey, &sp2, 1); + for (int i = 0; i < st.get_ttl(); i++) { + st.query(0, &skey, 1, r); + ASSERT_EQ((int)r.size(), 1); + char* p = (char*)r[0].second.get()->buffer; + for (int j = 0; j < r[0].second.get()->actual_size; j++) + ASSERT_EQ(p[j], str[j]); + r.clear(); + } + st.query(0, &skey, 1, r); + ASSERT_EQ((int)r.size(), 0); +} void testGraphToBuffer() { ::paddle::distributed::GraphNode s, s1; s.set_feature_size(1);