From 4977eb22561ab078460e3f43a4fae7f21759248e Mon Sep 17 00:00:00 2001 From: seemingwang Date: Thu, 4 Nov 2021 11:11:48 +0800 Subject: [PATCH] use cache when sampling neighbors (#36961) --- .../distributed/service/graph_brpc_server.cc | 4 +- .../distributed/table/common_graph_table.cc | 80 ++++++++++++++++++- .../distributed/table/common_graph_table.h | 39 ++++++--- .../fluid/distributed/test/graph_node_test.cc | 68 ++++++++++------ 4 files changed, 149 insertions(+), 42 deletions(-) diff --git a/paddle/fluid/distributed/service/graph_brpc_server.cc b/paddle/fluid/distributed/service/graph_brpc_server.cc index b404082f7c4..424cf281bf3 100644 --- a/paddle/fluid/distributed/service/graph_brpc_server.cc +++ b/paddle/fluid/distributed/service/graph_brpc_server.cc @@ -386,7 +386,7 @@ int32_t GraphBrpcService::graph_random_sample_neighboors( size_t node_num = request.params(0).size() / sizeof(uint64_t); uint64_t *node_data = (uint64_t *)(request.params(0).c_str()); int sample_size = *(uint64_t *)(request.params(1).c_str()); - std::vector> buffers(node_num); + std::vector> buffers(node_num); std::vector actual_sizes(node_num, 0); ((GraphTable *)table) ->random_sample_neighboors(node_data, sample_size, buffers, actual_sizes); @@ -487,7 +487,7 @@ int32_t GraphBrpcService::sample_neighboors_across_multi_servers( request2server.size() - 1; } size_t request_call_num = request2server.size(); - std::vector> local_buffers; + std::vector> local_buffers; std::vector local_actual_sizes; std::vector seq; std::vector> node_id_buckets(request_call_num); diff --git a/paddle/fluid/distributed/table/common_graph_table.cc b/paddle/fluid/distributed/table/common_graph_table.cc index 2c20e79b3b2..47b966182e6 100644 --- a/paddle/fluid/distributed/table/common_graph_table.cc +++ b/paddle/fluid/distributed/table/common_graph_table.cc @@ -394,13 +394,87 @@ int32_t GraphTable::random_sample_nodes(int sample_size, } int32_t GraphTable::random_sample_neighboors( uint64_t *node_ids, int sample_size, - std::vector> &buffers, + std::vector> &buffers, std::vector &actual_sizes) { size_t node_num = buffers.size(); + std::function char_del = [](char *c) { delete[] c; }; std::vector> tasks; + if (use_cache) { + std::vector> seq_id(shard_end - shard_start); + std::vector> id_list(shard_end - shard_start); + size_t index; + for (size_t idx = 0; idx < node_num; ++idx) { + index = get_thread_pool_index(node_ids[idx]); + seq_id[index].emplace_back(idx); + id_list[index].emplace_back(node_ids[idx], sample_size); + } + for (int i = 0; i < seq_id.size(); i++) { + if (seq_id[i].size() == 0) continue; + tasks.push_back(_shards_task_pool[i]->enqueue([&, i, this]() -> int { + uint64_t node_id; + std::vector> r; + auto response = + scaled_lru->query(i, id_list[i].data(), id_list[i].size(), r); + int index = 0; + uint32_t idx; + std::vector sample_res; + std::vector sample_keys; + auto &rng = _shards_task_rng_pool[i]; + for (size_t k = 0; k < id_list[i].size(); k++) { + if (index < r.size() && + r[index].first.node_key == id_list[i][k].node_key) { + idx = seq_id[i][k]; + actual_sizes[idx] = r[index].second.actual_size; + buffers[idx] = r[index].second.buffer; + index++; + } else { + node_id = id_list[i][k].node_key; + Node *node = find_node(node_id); + idx = seq_id[i][k]; + int &actual_size = actual_sizes[idx]; + if (node == nullptr) { + actual_size = 0; + continue; + } + std::shared_ptr &buffer = buffers[idx]; + std::vector res = node->sample_k(sample_size, rng); + actual_size = res.size() * (Node::id_size + Node::weight_size); + int offset = 0; + uint64_t id; + float weight; + char *buffer_addr = new char[actual_size]; + if (response == LRUResponse::ok) { + sample_keys.emplace_back(node_id, sample_size); + sample_res.emplace_back(actual_size, buffer_addr); + buffer = sample_res.back().buffer; + } else { + buffer.reset(buffer_addr, char_del); + } + for (int &x : res) { + id = node->get_neighbor_id(x); + weight = node->get_neighbor_weight(x); + memcpy(buffer_addr + offset, &id, Node::id_size); + offset += Node::id_size; + memcpy(buffer_addr + offset, &weight, Node::weight_size); + offset += Node::weight_size; + } + } + } + if (sample_res.size()) { + scaled_lru->insert(i, sample_keys.data(), sample_res.data(), + sample_keys.size()); + } + return 0; + })); + } + for (auto &t : tasks) { + t.get(); + } + return 0; + } for (size_t idx = 0; idx < node_num; ++idx) { uint64_t &node_id = node_ids[idx]; - std::unique_ptr &buffer = buffers[idx]; + std::shared_ptr &buffer = buffers[idx]; int &actual_size = actual_sizes[idx]; int thread_pool_index = get_thread_pool_index(node_id); @@ -419,7 +493,7 @@ int32_t GraphTable::random_sample_neighboors( uint64_t id; float weight; char *buffer_addr = new char[actual_size]; - buffer.reset(buffer_addr); + buffer.reset(buffer_addr, char_del); for (int &x : res) { id = node->get_neighbor_id(x); weight = node->get_neighbor_weight(x); diff --git a/paddle/fluid/distributed/table/common_graph_table.h b/paddle/fluid/distributed/table/common_graph_table.h index 5c226a14cd6..0e2d09effeb 100644 --- a/paddle/fluid/distributed/table/common_graph_table.h +++ b/paddle/fluid/distributed/table/common_graph_table.h @@ -80,6 +80,10 @@ enum LRUResponse { ok = 0, blocked = 1, err = 2 }; struct SampleKey { uint64_t node_key; size_t sample_size; + SampleKey(uint64_t _node_key, size_t _sample_size) + : node_key(_node_key), sample_size(_sample_size) { + // std::cerr<<"in constructor of samplekey\n"; + } bool operator==(const SampleKey &s) const { return node_key == s.node_key && sample_size == s.sample_size; } @@ -94,15 +98,13 @@ struct SampleKeyHash { class SampleResult { public: size_t actual_size; - char *buffer; - SampleResult(size_t _actual_size, char *_buffer) : actual_size(_actual_size) { - buffer = new char[actual_size]; - memcpy(buffer, _buffer, actual_size); - } - ~SampleResult() { - // std::cout<<"in SampleResult deconstructor\n"; - delete[] buffer; - } + std::shared_ptr buffer; + SampleResult(size_t _actual_size, std::shared_ptr &_buffer) + : actual_size(_actual_size), buffer(_buffer) {} + SampleResult(size_t _actual_size, char *_buffer) + : actual_size(_actual_size), + buffer(_buffer, [](char *p) { delete[] p; }) {} + ~SampleResult() {} }; template @@ -364,7 +366,7 @@ class ScaledLRU { class GraphTable : public SparseTable { public: - GraphTable() {} + GraphTable() { use_cache = false; } virtual ~GraphTable() {} virtual int32_t pull_graph_list(int start, int size, std::unique_ptr &buffer, @@ -373,7 +375,7 @@ class GraphTable : public SparseTable { virtual int32_t random_sample_neighboors( uint64_t *node_ids, int sample_size, - std::vector> &buffers, + std::vector> &buffers, std::vector &actual_sizes); int32_t random_sample_nodes(int sample_size, std::unique_ptr &buffers, @@ -431,6 +433,18 @@ class GraphTable : public SparseTable { size_t get_server_num() { return server_num; } + virtual int32_t make_neigh_sample_cache(size_t size_limit, size_t ttl) { + { + std::unique_lock lock(mutex_); + if (use_cache == false) { + scaled_lru.reset(new ScaledLRU( + shard_end - shard_start, size_limit, ttl)); + use_cache = true; + } + } + return 0; + } + protected: std::vector shards; size_t shard_start, shard_end, server_num, shard_num_per_server, shard_num; @@ -446,6 +460,9 @@ class GraphTable : public SparseTable { std::vector> _shards_task_pool; std::vector> _shards_task_rng_pool; + std::shared_ptr> scaled_lru; + bool use_cache; + mutable std::mutex mutex_; }; } // namespace distributed diff --git a/paddle/fluid/distributed/test/graph_node_test.cc b/paddle/fluid/distributed/test/graph_node_test.cc index 859478e1677..47dc7212575 100644 --- a/paddle/fluid/distributed/test/graph_node_test.cc +++ b/paddle/fluid/distributed/test/graph_node_test.cc @@ -440,6 +440,29 @@ void RunBrpcPushSparse() { 0, std::vector(1, 10240001024), 4, vs); pull_status.wait(); ASSERT_EQ(0, vs[0].size()); + paddle::distributed::GraphTable* g = + (paddle::distributed::GraphTable*)pserver_ptr_->table(0); + size_t ttl = 6; + g->make_neigh_sample_cache(4, ttl); + int round = 5; + while (round--) { + vs.clear(); + pull_status = worker_ptr_->batch_sample_neighboors( + 0, std::vector(1, 37), 1, vs); + pull_status.wait(); + + for (int i = 0; i < ttl; i++) { + std::vector>> vs1; + pull_status = worker_ptr_->batch_sample_neighboors( + 0, std::vector(1, 37), 1, vs1); + pull_status.wait(); + ASSERT_EQ(vs[0].size(), vs1[0].size()); + + for (int j = 0; j < vs[0].size(); j++) { + ASSERT_EQ(vs[0][j].first, vs1[0][j].first); + } + } + } std::vector nodes; pull_status = worker_ptr_->pull_graph_list(0, 0, 0, 1, 1, nodes); @@ -611,58 +634,51 @@ void RunBrpcPushSparse() { } void testCache() { - ::paddle::distributed::ScaledLRU< - ::paddle::distributed::SampleKey, - std::shared_ptr<::paddle::distributed::SampleResult>, - ::paddle::distributed::SampleKeyHash> + ::paddle::distributed::ScaledLRU<::paddle::distributed::SampleKey, + ::paddle::distributed::SampleResult, + ::paddle::distributed::SampleKeyHash> st(1, 2, 4); - std::shared_ptr<::paddle::distributed::SampleResult> sp; - char* str = (char*)"54321"; + char* str = new char[7]; + strcpy(str, "54321"); ::paddle::distributed::SampleResult* result = new ::paddle::distributed::SampleResult(5, str); ::paddle::distributed::SampleKey skey = {6, 1}; - sp.reset(result); std::vector>> + paddle::distributed::SampleResult>> r; st.query(0, &skey, 1, r); ASSERT_EQ((int)r.size(), 0); - st.insert(0, &skey, &sp, 1); + st.insert(0, &skey, result, 1); for (int i = 0; i < st.get_ttl(); i++) { st.query(0, &skey, 1, r); ASSERT_EQ((int)r.size(), 1); - char* p = (char*)r[0].second.get()->buffer; - for (int j = 0; j < r[0].second.get()->actual_size; j++) - ASSERT_EQ(p[j], str[j]); + char* p = (char*)r[0].second.buffer.get(); + for (int j = 0; j < r[0].second.actual_size; j++) ASSERT_EQ(p[j], str[j]); r.clear(); } st.query(0, &skey, 1, r); ASSERT_EQ((int)r.size(), 0); - str = (char*)"342cd4321"; + str = new char[10]; + strcpy(str, "54321678"); result = new ::paddle::distributed::SampleResult(strlen(str), str); - std::shared_ptr<::paddle::distributed::SampleResult> sp1; - sp1.reset(result); - st.insert(0, &skey, &sp1, 1); + st.insert(0, &skey, result, 1); for (int i = 0; i < st.get_ttl() / 2; i++) { st.query(0, &skey, 1, r); ASSERT_EQ((int)r.size(), 1); - char* p = (char*)r[0].second.get()->buffer; - for (int j = 0; j < r[0].second.get()->actual_size; j++) - ASSERT_EQ(p[j], str[j]); + char* p = (char*)r[0].second.buffer.get(); + for (int j = 0; j < r[0].second.actual_size; j++) ASSERT_EQ(p[j], str[j]); r.clear(); } - str = (char*)"343332d4321"; + str = new char[18]; + strcpy(str, "343332d4321"); result = new ::paddle::distributed::SampleResult(strlen(str), str); - std::shared_ptr<::paddle::distributed::SampleResult> sp2; - sp2.reset(result); - st.insert(0, &skey, &sp2, 1); + st.insert(0, &skey, result, 1); for (int i = 0; i < st.get_ttl(); i++) { st.query(0, &skey, 1, r); ASSERT_EQ((int)r.size(), 1); - char* p = (char*)r[0].second.get()->buffer; - for (int j = 0; j < r[0].second.get()->actual_size; j++) - ASSERT_EQ(p[j], str[j]); + char* p = (char*)r[0].second.buffer.get(); + for (int j = 0; j < r[0].second.actual_size; j++) ASSERT_EQ(p[j], str[j]); r.clear(); } st.query(0, &skey, 1, r); -- GitLab