未验证 提交 4977eb22 编写于 作者: S seemingwang 提交者: GitHub

use cache when sampling neighbors (#36961)

上级 d33e99fe
...@@ -386,7 +386,7 @@ int32_t GraphBrpcService::graph_random_sample_neighboors( ...@@ -386,7 +386,7 @@ int32_t GraphBrpcService::graph_random_sample_neighboors(
size_t node_num = request.params(0).size() / sizeof(uint64_t); size_t node_num = request.params(0).size() / sizeof(uint64_t);
uint64_t *node_data = (uint64_t *)(request.params(0).c_str()); uint64_t *node_data = (uint64_t *)(request.params(0).c_str());
int sample_size = *(uint64_t *)(request.params(1).c_str()); int sample_size = *(uint64_t *)(request.params(1).c_str());
std::vector<std::unique_ptr<char[]>> buffers(node_num); std::vector<std::shared_ptr<char>> buffers(node_num);
std::vector<int> actual_sizes(node_num, 0); std::vector<int> actual_sizes(node_num, 0);
((GraphTable *)table) ((GraphTable *)table)
->random_sample_neighboors(node_data, sample_size, buffers, actual_sizes); ->random_sample_neighboors(node_data, sample_size, buffers, actual_sizes);
...@@ -487,7 +487,7 @@ int32_t GraphBrpcService::sample_neighboors_across_multi_servers( ...@@ -487,7 +487,7 @@ int32_t GraphBrpcService::sample_neighboors_across_multi_servers(
request2server.size() - 1; request2server.size() - 1;
} }
size_t request_call_num = request2server.size(); size_t request_call_num = request2server.size();
std::vector<std::unique_ptr<char[]>> local_buffers; std::vector<std::shared_ptr<char>> local_buffers;
std::vector<int> local_actual_sizes; std::vector<int> local_actual_sizes;
std::vector<size_t> seq; std::vector<size_t> seq;
std::vector<std::vector<uint64_t>> node_id_buckets(request_call_num); std::vector<std::vector<uint64_t>> node_id_buckets(request_call_num);
......
...@@ -394,13 +394,87 @@ int32_t GraphTable::random_sample_nodes(int sample_size, ...@@ -394,13 +394,87 @@ int32_t GraphTable::random_sample_nodes(int sample_size,
} }
int32_t GraphTable::random_sample_neighboors( int32_t GraphTable::random_sample_neighboors(
uint64_t *node_ids, int sample_size, uint64_t *node_ids, int sample_size,
std::vector<std::unique_ptr<char[]>> &buffers, std::vector<std::shared_ptr<char>> &buffers,
std::vector<int> &actual_sizes) { std::vector<int> &actual_sizes) {
size_t node_num = buffers.size(); size_t node_num = buffers.size();
std::function<void(char *)> char_del = [](char *c) { delete[] c; };
std::vector<std::future<int>> tasks; std::vector<std::future<int>> tasks;
if (use_cache) {
std::vector<std::vector<uint32_t>> seq_id(shard_end - shard_start);
std::vector<std::vector<SampleKey>> id_list(shard_end - shard_start);
size_t index;
for (size_t idx = 0; idx < node_num; ++idx) {
index = get_thread_pool_index(node_ids[idx]);
seq_id[index].emplace_back(idx);
id_list[index].emplace_back(node_ids[idx], sample_size);
}
for (int i = 0; i < seq_id.size(); i++) {
if (seq_id[i].size() == 0) continue;
tasks.push_back(_shards_task_pool[i]->enqueue([&, i, this]() -> int {
uint64_t node_id;
std::vector<std::pair<SampleKey, SampleResult>> r;
auto response =
scaled_lru->query(i, id_list[i].data(), id_list[i].size(), r);
int index = 0;
uint32_t idx;
std::vector<SampleResult> sample_res;
std::vector<SampleKey> sample_keys;
auto &rng = _shards_task_rng_pool[i];
for (size_t k = 0; k < id_list[i].size(); k++) {
if (index < r.size() &&
r[index].first.node_key == id_list[i][k].node_key) {
idx = seq_id[i][k];
actual_sizes[idx] = r[index].second.actual_size;
buffers[idx] = r[index].second.buffer;
index++;
} else {
node_id = id_list[i][k].node_key;
Node *node = find_node(node_id);
idx = seq_id[i][k];
int &actual_size = actual_sizes[idx];
if (node == nullptr) {
actual_size = 0;
continue;
}
std::shared_ptr<char> &buffer = buffers[idx];
std::vector<int> res = node->sample_k(sample_size, rng);
actual_size = res.size() * (Node::id_size + Node::weight_size);
int offset = 0;
uint64_t id;
float weight;
char *buffer_addr = new char[actual_size];
if (response == LRUResponse::ok) {
sample_keys.emplace_back(node_id, sample_size);
sample_res.emplace_back(actual_size, buffer_addr);
buffer = sample_res.back().buffer;
} else {
buffer.reset(buffer_addr, char_del);
}
for (int &x : res) {
id = node->get_neighbor_id(x);
weight = node->get_neighbor_weight(x);
memcpy(buffer_addr + offset, &id, Node::id_size);
offset += Node::id_size;
memcpy(buffer_addr + offset, &weight, Node::weight_size);
offset += Node::weight_size;
}
}
}
if (sample_res.size()) {
scaled_lru->insert(i, sample_keys.data(), sample_res.data(),
sample_keys.size());
}
return 0;
}));
}
for (auto &t : tasks) {
t.get();
}
return 0;
}
for (size_t idx = 0; idx < node_num; ++idx) { for (size_t idx = 0; idx < node_num; ++idx) {
uint64_t &node_id = node_ids[idx]; uint64_t &node_id = node_ids[idx];
std::unique_ptr<char[]> &buffer = buffers[idx]; std::shared_ptr<char> &buffer = buffers[idx];
int &actual_size = actual_sizes[idx]; int &actual_size = actual_sizes[idx];
int thread_pool_index = get_thread_pool_index(node_id); int thread_pool_index = get_thread_pool_index(node_id);
...@@ -419,7 +493,7 @@ int32_t GraphTable::random_sample_neighboors( ...@@ -419,7 +493,7 @@ int32_t GraphTable::random_sample_neighboors(
uint64_t id; uint64_t id;
float weight; float weight;
char *buffer_addr = new char[actual_size]; char *buffer_addr = new char[actual_size];
buffer.reset(buffer_addr); buffer.reset(buffer_addr, char_del);
for (int &x : res) { for (int &x : res) {
id = node->get_neighbor_id(x); id = node->get_neighbor_id(x);
weight = node->get_neighbor_weight(x); weight = node->get_neighbor_weight(x);
......
...@@ -80,6 +80,10 @@ enum LRUResponse { ok = 0, blocked = 1, err = 2 }; ...@@ -80,6 +80,10 @@ enum LRUResponse { ok = 0, blocked = 1, err = 2 };
struct SampleKey { struct SampleKey {
uint64_t node_key; uint64_t node_key;
size_t sample_size; size_t sample_size;
SampleKey(uint64_t _node_key, size_t _sample_size)
: node_key(_node_key), sample_size(_sample_size) {
// std::cerr<<"in constructor of samplekey\n";
}
bool operator==(const SampleKey &s) const { bool operator==(const SampleKey &s) const {
return node_key == s.node_key && sample_size == s.sample_size; return node_key == s.node_key && sample_size == s.sample_size;
} }
...@@ -94,15 +98,13 @@ struct SampleKeyHash { ...@@ -94,15 +98,13 @@ struct SampleKeyHash {
class SampleResult { class SampleResult {
public: public:
size_t actual_size; size_t actual_size;
char *buffer; std::shared_ptr<char> buffer;
SampleResult(size_t _actual_size, char *_buffer) : actual_size(_actual_size) { SampleResult(size_t _actual_size, std::shared_ptr<char> &_buffer)
buffer = new char[actual_size]; : actual_size(_actual_size), buffer(_buffer) {}
memcpy(buffer, _buffer, actual_size); SampleResult(size_t _actual_size, char *_buffer)
} : actual_size(_actual_size),
~SampleResult() { buffer(_buffer, [](char *p) { delete[] p; }) {}
// std::cout<<"in SampleResult deconstructor\n"; ~SampleResult() {}
delete[] buffer;
}
}; };
template <typename K, typename V> template <typename K, typename V>
...@@ -364,7 +366,7 @@ class ScaledLRU { ...@@ -364,7 +366,7 @@ class ScaledLRU {
class GraphTable : public SparseTable { class GraphTable : public SparseTable {
public: public:
GraphTable() {} GraphTable() { use_cache = false; }
virtual ~GraphTable() {} virtual ~GraphTable() {}
virtual int32_t pull_graph_list(int start, int size, virtual int32_t pull_graph_list(int start, int size,
std::unique_ptr<char[]> &buffer, std::unique_ptr<char[]> &buffer,
...@@ -373,7 +375,7 @@ class GraphTable : public SparseTable { ...@@ -373,7 +375,7 @@ class GraphTable : public SparseTable {
virtual int32_t random_sample_neighboors( virtual int32_t random_sample_neighboors(
uint64_t *node_ids, int sample_size, uint64_t *node_ids, int sample_size,
std::vector<std::unique_ptr<char[]>> &buffers, std::vector<std::shared_ptr<char>> &buffers,
std::vector<int> &actual_sizes); std::vector<int> &actual_sizes);
int32_t random_sample_nodes(int sample_size, std::unique_ptr<char[]> &buffers, int32_t random_sample_nodes(int sample_size, std::unique_ptr<char[]> &buffers,
...@@ -431,6 +433,18 @@ class GraphTable : public SparseTable { ...@@ -431,6 +433,18 @@ class GraphTable : public SparseTable {
size_t get_server_num() { return server_num; } size_t get_server_num() { return server_num; }
virtual int32_t make_neigh_sample_cache(size_t size_limit, size_t ttl) {
{
std::unique_lock<std::mutex> lock(mutex_);
if (use_cache == false) {
scaled_lru.reset(new ScaledLRU<SampleKey, SampleResult, SampleKeyHash>(
shard_end - shard_start, size_limit, ttl));
use_cache = true;
}
}
return 0;
}
protected: protected:
std::vector<GraphShard> shards; std::vector<GraphShard> shards;
size_t shard_start, shard_end, server_num, shard_num_per_server, shard_num; size_t shard_start, shard_end, server_num, shard_num_per_server, shard_num;
...@@ -446,6 +460,9 @@ class GraphTable : public SparseTable { ...@@ -446,6 +460,9 @@ class GraphTable : public SparseTable {
std::vector<std::shared_ptr<::ThreadPool>> _shards_task_pool; std::vector<std::shared_ptr<::ThreadPool>> _shards_task_pool;
std::vector<std::shared_ptr<std::mt19937_64>> _shards_task_rng_pool; std::vector<std::shared_ptr<std::mt19937_64>> _shards_task_rng_pool;
std::shared_ptr<ScaledLRU<SampleKey, SampleResult, SampleKeyHash>> scaled_lru;
bool use_cache;
mutable std::mutex mutex_;
}; };
} // namespace distributed } // namespace distributed
......
...@@ -440,6 +440,29 @@ void RunBrpcPushSparse() { ...@@ -440,6 +440,29 @@ void RunBrpcPushSparse() {
0, std::vector<uint64_t>(1, 10240001024), 4, vs); 0, std::vector<uint64_t>(1, 10240001024), 4, vs);
pull_status.wait(); pull_status.wait();
ASSERT_EQ(0, vs[0].size()); ASSERT_EQ(0, vs[0].size());
paddle::distributed::GraphTable* g =
(paddle::distributed::GraphTable*)pserver_ptr_->table(0);
size_t ttl = 6;
g->make_neigh_sample_cache(4, ttl);
int round = 5;
while (round--) {
vs.clear();
pull_status = worker_ptr_->batch_sample_neighboors(
0, std::vector<uint64_t>(1, 37), 1, vs);
pull_status.wait();
for (int i = 0; i < ttl; i++) {
std::vector<std::vector<std::pair<uint64_t, float>>> vs1;
pull_status = worker_ptr_->batch_sample_neighboors(
0, std::vector<uint64_t>(1, 37), 1, vs1);
pull_status.wait();
ASSERT_EQ(vs[0].size(), vs1[0].size());
for (int j = 0; j < vs[0].size(); j++) {
ASSERT_EQ(vs[0][j].first, vs1[0][j].first);
}
}
}
std::vector<distributed::FeatureNode> nodes; std::vector<distributed::FeatureNode> nodes;
pull_status = worker_ptr_->pull_graph_list(0, 0, 0, 1, 1, nodes); pull_status = worker_ptr_->pull_graph_list(0, 0, 0, 1, 1, nodes);
...@@ -611,58 +634,51 @@ void RunBrpcPushSparse() { ...@@ -611,58 +634,51 @@ void RunBrpcPushSparse() {
} }
void testCache() { void testCache() {
::paddle::distributed::ScaledLRU< ::paddle::distributed::ScaledLRU<::paddle::distributed::SampleKey,
::paddle::distributed::SampleKey, ::paddle::distributed::SampleResult,
std::shared_ptr<::paddle::distributed::SampleResult>,
::paddle::distributed::SampleKeyHash> ::paddle::distributed::SampleKeyHash>
st(1, 2, 4); st(1, 2, 4);
std::shared_ptr<::paddle::distributed::SampleResult> sp; char* str = new char[7];
char* str = (char*)"54321"; strcpy(str, "54321");
::paddle::distributed::SampleResult* result = ::paddle::distributed::SampleResult* result =
new ::paddle::distributed::SampleResult(5, str); new ::paddle::distributed::SampleResult(5, str);
::paddle::distributed::SampleKey skey = {6, 1}; ::paddle::distributed::SampleKey skey = {6, 1};
sp.reset(result);
std::vector<std::pair<::paddle::distributed::SampleKey, std::vector<std::pair<::paddle::distributed::SampleKey,
std::shared_ptr<::paddle::distributed::SampleResult>>> paddle::distributed::SampleResult>>
r; r;
st.query(0, &skey, 1, r); st.query(0, &skey, 1, r);
ASSERT_EQ((int)r.size(), 0); ASSERT_EQ((int)r.size(), 0);
st.insert(0, &skey, &sp, 1); st.insert(0, &skey, result, 1);
for (int i = 0; i < st.get_ttl(); i++) { for (int i = 0; i < st.get_ttl(); i++) {
st.query(0, &skey, 1, r); st.query(0, &skey, 1, r);
ASSERT_EQ((int)r.size(), 1); ASSERT_EQ((int)r.size(), 1);
char* p = (char*)r[0].second.get()->buffer; char* p = (char*)r[0].second.buffer.get();
for (int j = 0; j < r[0].second.get()->actual_size; j++) for (int j = 0; j < r[0].second.actual_size; j++) ASSERT_EQ(p[j], str[j]);
ASSERT_EQ(p[j], str[j]);
r.clear(); r.clear();
} }
st.query(0, &skey, 1, r); st.query(0, &skey, 1, r);
ASSERT_EQ((int)r.size(), 0); ASSERT_EQ((int)r.size(), 0);
str = (char*)"342cd4321"; str = new char[10];
strcpy(str, "54321678");
result = new ::paddle::distributed::SampleResult(strlen(str), str); result = new ::paddle::distributed::SampleResult(strlen(str), str);
std::shared_ptr<::paddle::distributed::SampleResult> sp1; st.insert(0, &skey, result, 1);
sp1.reset(result);
st.insert(0, &skey, &sp1, 1);
for (int i = 0; i < st.get_ttl() / 2; i++) { for (int i = 0; i < st.get_ttl() / 2; i++) {
st.query(0, &skey, 1, r); st.query(0, &skey, 1, r);
ASSERT_EQ((int)r.size(), 1); ASSERT_EQ((int)r.size(), 1);
char* p = (char*)r[0].second.get()->buffer; char* p = (char*)r[0].second.buffer.get();
for (int j = 0; j < r[0].second.get()->actual_size; j++) for (int j = 0; j < r[0].second.actual_size; j++) ASSERT_EQ(p[j], str[j]);
ASSERT_EQ(p[j], str[j]);
r.clear(); r.clear();
} }
str = (char*)"343332d4321"; str = new char[18];
strcpy(str, "343332d4321");
result = new ::paddle::distributed::SampleResult(strlen(str), str); result = new ::paddle::distributed::SampleResult(strlen(str), str);
std::shared_ptr<::paddle::distributed::SampleResult> sp2; st.insert(0, &skey, result, 1);
sp2.reset(result);
st.insert(0, &skey, &sp2, 1);
for (int i = 0; i < st.get_ttl(); i++) { for (int i = 0; i < st.get_ttl(); i++) {
st.query(0, &skey, 1, r); st.query(0, &skey, 1, r);
ASSERT_EQ((int)r.size(), 1); ASSERT_EQ((int)r.size(), 1);
char* p = (char*)r[0].second.get()->buffer; char* p = (char*)r[0].second.buffer.get();
for (int j = 0; j < r[0].second.get()->actual_size; j++) for (int j = 0; j < r[0].second.actual_size; j++) ASSERT_EQ(p[j], str[j]);
ASSERT_EQ(p[j], str[j]);
r.clear(); r.clear();
} }
st.query(0, &skey, 1, r); st.query(0, &skey, 1, r);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册