未验证 提交 9fc11db7 编写于 作者: S seemingwang 提交者: GitHub

optimize graph-engine sample api's data-transfer process (#37341)

* graph engine demo

* upload unsaved changes

* fix dependency error

* fix shard_num problem

* py client

* remove lock and graph-type

* add load direct graph

* add load direct graph

* add load direct graph

* batch random_sample

* batch_sample_k

* fix num_nodes size

* batch brpc

* batch brpc

* add test

* add test

* add load_nodes; change add_node function

* change sample return type to pair

* resolve conflict

* resolved conflict

* resolved conflict

* separate server and client

* merge pair type

* fix

* resolved conflict

* fixed segment fault; high-level VLOG for load edges and load nodes

* random_sample return 0

* rm useless loop

* test:load edge

* fix ret -1

* test: rm sample

* rm sample

* random_sample return future

* random_sample return int

* test fake node

* fixed here

* memory leak

* remove test code

* fix return problem

* add common_graph_table

* random sample node &test & change data-structure from linkedList to vector

* add common_graph_table

* sample with srand

* add node_types

* optimize nodes sample

* recover test

* random sample

* destruct weighted sampler

* GraphEdgeBlob

* WeightedGraphEdgeBlob to GraphEdgeBlob

* WeightedGraphEdgeBlob to GraphEdgeBlob

* pybind sample nodes api

* pull nodes with step

* fixed pull_graph_list bug; add test for pull_graph_list by step

* add graph table;name

* add graph table;name

* add pybind

* add pybind

* add FeatureNode

* add FeatureNode

* add FeatureNode Serialize

* add FeatureNode Serialize

* get_feat_node

* avoid local rpc

* fix get_node_feat

* fix get_node_feat

* remove log

* get_node_feat return  py:bytes

* merge develop with graph_engine

* fix threadpool.h head

* fix

* fix typo

* resolve conflict

* fix conflict

* recover lost content

* fix pybind of FeatureNode

* recover cmake

* recover tools

* resolve conflict

* resolve linking problem

* code style

* change test_server port

* fix code problems

* remove shard_num config

* remove redundent threads

* optimize start server

* remove logs

* fix code problems by reviewers' suggestions

* move graph files into a folder

* code style change

* remove graph operations from base table

* optimize get_feat function of graph engine

* fix long long count problem

* remove redandunt graph files

* remove unused shell

* recover dropout_op_pass.h

* fix potential stack overflow when request number is too large & node add & node clear & node remove

* when sample k is larger than neigbor num, return directly

* using random seed generator of paddle to speed up

* fix bug of random sample k

* fix code style

* fix code style

* add remove graph to fleet_py.cc

* fix blocking_queue problem

* fix style

* fix

* recover capacity check

* add remove graph node; add set_feature

* add remove graph node; add set_feature

* add remove graph node; add set_feature

* add remove graph node; add set_feature

* fix distributed op combining problems

* optimize

* remove logs

* fix MultiSlotDataGenerator error

* cache for graph engine

* fix type compare error

* more test&fix thread terminating problem

* remove header

* change time interval of shrink

* use cache when sample nodes

* remove unused function

* change unique_ptr to shared_ptr

* simplify cache template

* cache api on client

* fix

* reduce sample threads when cache is not used

* reduce cache memory

* cache optimization

* remove test function

* remove extra fetch function

* graph-engine data transfer optimization
Co-authored-by: NHuang Zhengjie <270018958@qq.com>
Co-authored-by: NWeiyue Su <weiyue.su@gmail.com>
Co-authored-by: Nsuweiyue <suweiyue@baidu.com>
Co-authored-by: Nluobin06 <luobin06@baidu.com>
Co-authored-by: Nliweibin02 <liweibin02@baidu.com>
Co-authored-by: Ntangwei12 <tangwei12@baidu.com>
上级 c13edf66
...@@ -304,10 +304,15 @@ std::future<int32_t> GraphBrpcClient::remove_graph_node( ...@@ -304,10 +304,15 @@ std::future<int32_t> GraphBrpcClient::remove_graph_node(
// char* &buffer,int &actual_size // char* &buffer,int &actual_size
std::future<int32_t> GraphBrpcClient::batch_sample_neighbors( std::future<int32_t> GraphBrpcClient::batch_sample_neighbors(
uint32_t table_id, std::vector<uint64_t> node_ids, int sample_size, uint32_t table_id, std::vector<uint64_t> node_ids, int sample_size,
std::vector<std::vector<std::pair<uint64_t, float>>> &res, // std::vector<std::vector<std::pair<uint64_t, float>>> &res,
std::vector<std::vector<uint64_t>> &res,
std::vector<std::vector<float>> &res_weight, bool need_weight,
int server_index) { int server_index) {
if (server_index != -1) { if (server_index != -1) {
res.resize(node_ids.size()); res.resize(node_ids.size());
if (need_weight) {
res_weight.resize(node_ids.size());
}
DownpourBrpcClosure *closure = new DownpourBrpcClosure(1, [&](void *done) { DownpourBrpcClosure *closure = new DownpourBrpcClosure(1, [&](void *done) {
int ret = 0; int ret = 0;
auto *closure = (DownpourBrpcClosure *)done; auto *closure = (DownpourBrpcClosure *)done;
...@@ -331,11 +336,14 @@ std::future<int32_t> GraphBrpcClient::batch_sample_neighbors( ...@@ -331,11 +336,14 @@ std::future<int32_t> GraphBrpcClient::batch_sample_neighbors(
int actual_size = actual_sizes[node_idx]; int actual_size = actual_sizes[node_idx];
int start = 0; int start = 0;
while (start < actual_size) { while (start < actual_size) {
res[node_idx].push_back( res[node_idx].emplace_back(
{*(uint64_t *)(node_buffer + offset + start), *(uint64_t *)(node_buffer + offset + start));
*(float *)(node_buffer + offset + start + start += GraphNode::id_size;
GraphNode::id_size)}); if (need_weight) {
start += GraphNode::id_size + GraphNode::weight_size; res_weight[node_idx].emplace_back(
*(float *)(node_buffer + offset + start));
start += GraphNode::weight_size;
}
} }
offset += actual_size; offset += actual_size;
} }
...@@ -352,6 +360,7 @@ std::future<int32_t> GraphBrpcClient::batch_sample_neighbors( ...@@ -352,6 +360,7 @@ std::future<int32_t> GraphBrpcClient::batch_sample_neighbors(
closure->request(0)->add_params((char *)node_ids.data(), closure->request(0)->add_params((char *)node_ids.data(),
sizeof(uint64_t) * node_ids.size()); sizeof(uint64_t) * node_ids.size());
closure->request(0)->add_params((char *)&sample_size, sizeof(int)); closure->request(0)->add_params((char *)&sample_size, sizeof(int));
closure->request(0)->add_params((char *)&need_weight, sizeof(bool));
; ;
// PsService_Stub rpc_stub(get_cmd_channel(server_index)); // PsService_Stub rpc_stub(get_cmd_channel(server_index));
GraphPsService_Stub rpc_stub = GraphPsService_Stub rpc_stub =
...@@ -364,13 +373,18 @@ std::future<int32_t> GraphBrpcClient::batch_sample_neighbors( ...@@ -364,13 +373,18 @@ std::future<int32_t> GraphBrpcClient::batch_sample_neighbors(
std::vector<int> request2server; std::vector<int> request2server;
std::vector<int> server2request(server_size, -1); std::vector<int> server2request(server_size, -1);
res.clear(); res.clear();
res_weight.clear();
for (int query_idx = 0; query_idx < node_ids.size(); ++query_idx) { for (int query_idx = 0; query_idx < node_ids.size(); ++query_idx) {
int server_index = get_server_index_by_id(node_ids[query_idx]); int server_index = get_server_index_by_id(node_ids[query_idx]);
if (server2request[server_index] == -1) { if (server2request[server_index] == -1) {
server2request[server_index] = request2server.size(); server2request[server_index] = request2server.size();
request2server.push_back(server_index); request2server.push_back(server_index);
} }
res.push_back(std::vector<std::pair<uint64_t, float>>()); // res.push_back(std::vector<std::pair<uint64_t, float>>());
res.push_back({});
if (need_weight) {
res_weight.push_back({});
}
} }
size_t request_call_num = request2server.size(); size_t request_call_num = request2server.size();
std::vector<std::vector<uint64_t>> node_id_buckets(request_call_num); std::vector<std::vector<uint64_t>> node_id_buckets(request_call_num);
...@@ -413,11 +427,14 @@ std::future<int32_t> GraphBrpcClient::batch_sample_neighbors( ...@@ -413,11 +427,14 @@ std::future<int32_t> GraphBrpcClient::batch_sample_neighbors(
int actual_size = actual_sizes[node_idx]; int actual_size = actual_sizes[node_idx];
int start = 0; int start = 0;
while (start < actual_size) { while (start < actual_size) {
res[query_idx].push_back( res[query_idx].emplace_back(
{*(uint64_t *)(node_buffer + offset + start), *(uint64_t *)(node_buffer + offset + start));
*(float *)(node_buffer + offset + start + start += GraphNode::id_size;
GraphNode::id_size)}); if (need_weight) {
start += GraphNode::id_size + GraphNode::weight_size; res_weight[query_idx].emplace_back(
*(float *)(node_buffer + offset + start));
start += GraphNode::weight_size;
}
} }
offset += actual_size; offset += actual_size;
} }
...@@ -445,6 +462,8 @@ std::future<int32_t> GraphBrpcClient::batch_sample_neighbors( ...@@ -445,6 +462,8 @@ std::future<int32_t> GraphBrpcClient::batch_sample_neighbors(
sizeof(uint64_t) * node_num); sizeof(uint64_t) * node_num);
closure->request(request_idx) closure->request(request_idx)
->add_params((char *)&sample_size, sizeof(int)); ->add_params((char *)&sample_size, sizeof(int));
closure->request(request_idx)
->add_params((char *)&need_weight, sizeof(bool));
// PsService_Stub rpc_stub(get_cmd_channel(server_index)); // PsService_Stub rpc_stub(get_cmd_channel(server_index));
GraphPsService_Stub rpc_stub = GraphPsService_Stub rpc_stub =
getServiceStub(get_cmd_channel(server_index)); getServiceStub(get_cmd_channel(server_index));
......
...@@ -64,7 +64,8 @@ class GraphBrpcClient : public BrpcPsClient { ...@@ -64,7 +64,8 @@ class GraphBrpcClient : public BrpcPsClient {
// given a batch of nodes, sample graph_neighbors for each of them // given a batch of nodes, sample graph_neighbors for each of them
virtual std::future<int32_t> batch_sample_neighbors( virtual std::future<int32_t> batch_sample_neighbors(
uint32_t table_id, std::vector<uint64_t> node_ids, int sample_size, uint32_t table_id, std::vector<uint64_t> node_ids, int sample_size,
std::vector<std::vector<std::pair<uint64_t, float>>>& res, std::vector<std::vector<uint64_t>>& res,
std::vector<std::vector<float>>& res_weight, bool need_weight,
int server_index = -1); int server_index = -1);
virtual std::future<int32_t> pull_graph_list(uint32_t table_id, virtual std::future<int32_t> pull_graph_list(uint32_t table_id,
......
...@@ -378,19 +378,21 @@ int32_t GraphBrpcService::graph_random_sample_neighbors( ...@@ -378,19 +378,21 @@ int32_t GraphBrpcService::graph_random_sample_neighbors(
Table *table, const PsRequestMessage &request, PsResponseMessage &response, Table *table, const PsRequestMessage &request, PsResponseMessage &response,
brpc::Controller *cntl) { brpc::Controller *cntl) {
CHECK_TABLE_EXIST(table, request, response) CHECK_TABLE_EXIST(table, request, response)
if (request.params_size() < 2) { if (request.params_size() < 3) {
set_response_code( set_response_code(
response, -1, response, -1,
"graph_random_sample request requires at least 2 arguments"); "graph_random_sample_neighbors request requires at least 3 arguments");
return 0; return 0;
} }
size_t node_num = request.params(0).size() / sizeof(uint64_t); size_t node_num = request.params(0).size() / sizeof(uint64_t);
uint64_t *node_data = (uint64_t *)(request.params(0).c_str()); uint64_t *node_data = (uint64_t *)(request.params(0).c_str());
int sample_size = *(uint64_t *)(request.params(1).c_str()); int sample_size = *(uint64_t *)(request.params(1).c_str());
bool need_weight = *(bool *)(request.params(2).c_str());
std::vector<std::shared_ptr<char>> buffers(node_num); std::vector<std::shared_ptr<char>> buffers(node_num);
std::vector<int> actual_sizes(node_num, 0); std::vector<int> actual_sizes(node_num, 0);
((GraphTable *)table) ((GraphTable *)table)
->random_sample_neighbors(node_data, sample_size, buffers, actual_sizes); ->random_sample_neighbors(node_data, sample_size, buffers, actual_sizes,
need_weight);
cntl->response_attachment().append(&node_num, sizeof(size_t)); cntl->response_attachment().append(&node_num, sizeof(size_t));
cntl->response_attachment().append(actual_sizes.data(), cntl->response_attachment().append(actual_sizes.data(),
...@@ -454,16 +456,17 @@ int32_t GraphBrpcService::sample_neighbors_across_multi_servers( ...@@ -454,16 +456,17 @@ int32_t GraphBrpcService::sample_neighbors_across_multi_servers(
brpc::Controller *cntl) { brpc::Controller *cntl) {
// sleep(5); // sleep(5);
CHECK_TABLE_EXIST(table, request, response) CHECK_TABLE_EXIST(table, request, response)
if (request.params_size() < 2) { if (request.params_size() < 3) {
set_response_code( set_response_code(response, -1,
response, -1, "sample_neighbors_across_multi_servers request requires "
"graph_random_neighbors_sample request requires at least 2 arguments"); "at least 3 arguments");
return 0; return 0;
} }
size_t node_num = request.params(0).size() / sizeof(uint64_t), size_t node_num = request.params(0).size() / sizeof(uint64_t),
size_of_size_t = sizeof(size_t); size_of_size_t = sizeof(size_t);
uint64_t *node_data = (uint64_t *)(request.params(0).c_str()); uint64_t *node_data = (uint64_t *)(request.params(0).c_str());
int sample_size = *(uint64_t *)(request.params(1).c_str()); int sample_size = *(uint64_t *)(request.params(1).c_str());
bool need_weight = *(uint64_t *)(request.params(2).c_str());
// std::vector<uint64_t> res = ((GraphTable // std::vector<uint64_t> res = ((GraphTable
// *)table).filter_out_non_exist_nodes(node_data, sample_size); // *)table).filter_out_non_exist_nodes(node_data, sample_size);
std::vector<int> request2server; std::vector<int> request2server;
...@@ -581,6 +584,8 @@ int32_t GraphBrpcService::sample_neighbors_across_multi_servers( ...@@ -581,6 +584,8 @@ int32_t GraphBrpcService::sample_neighbors_across_multi_servers(
sizeof(uint64_t) * node_num); sizeof(uint64_t) * node_num);
closure->request(request_idx) closure->request(request_idx)
->add_params((char *)&sample_size, sizeof(int)); ->add_params((char *)&sample_size, sizeof(int));
closure->request(request_idx)
->add_params((char *)&need_weight, sizeof(bool));
PsService_Stub rpc_stub( PsService_Stub rpc_stub(
((GraphBrpcServer *)get_server())->get_cmd_channel(server_index)); ((GraphBrpcServer *)get_server())->get_cmd_channel(server_index));
// GraphPsService_Stub rpc_stub = // GraphPsService_Stub rpc_stub =
...@@ -592,7 +597,8 @@ int32_t GraphBrpcService::sample_neighbors_across_multi_servers( ...@@ -592,7 +597,8 @@ int32_t GraphBrpcService::sample_neighbors_across_multi_servers(
if (server2request[rank] != -1) { if (server2request[rank] != -1) {
((GraphTable *)table) ((GraphTable *)table)
->random_sample_neighbors(node_id_buckets.back().data(), sample_size, ->random_sample_neighbors(node_id_buckets.back().data(), sample_size,
local_buffers, local_actual_sizes); local_buffers, local_actual_sizes,
need_weight);
} }
local_promise.get()->set_value(0); local_promise.get()->set_value(0);
if (remote_call_num == 0) func(closure); if (remote_call_num == 0) func(closure);
......
...@@ -295,11 +295,13 @@ GraphPyClient::batch_sample_neighbors(std::string name, ...@@ -295,11 +295,13 @@ GraphPyClient::batch_sample_neighbors(std::string name,
std::vector<uint64_t> node_ids, std::vector<uint64_t> node_ids,
int sample_size, bool return_weight, int sample_size, bool return_weight,
bool return_edges) { bool return_edges) {
std::vector<std::vector<std::pair<uint64_t, float>>> v; // std::vector<std::vector<std::pair<uint64_t, float>>> v;
std::vector<std::vector<uint64_t>> v;
std::vector<std::vector<float>> v1;
if (this->table_id_map.count(name)) { if (this->table_id_map.count(name)) {
uint32_t table_id = this->table_id_map[name]; uint32_t table_id = this->table_id_map[name];
auto status = auto status = worker_ptr->batch_sample_neighbors(
worker_ptr->batch_sample_neighbors(table_id, node_ids, sample_size, v); table_id, node_ids, sample_size, v, v1, return_weight);
status.wait(); status.wait();
} }
...@@ -313,9 +315,10 @@ GraphPyClient::batch_sample_neighbors(std::string name, ...@@ -313,9 +315,10 @@ GraphPyClient::batch_sample_neighbors(std::string name,
if (return_edges) res.first.push_back({}); if (return_edges) res.first.push_back({});
for (size_t i = 0; i < v.size(); i++) { for (size_t i = 0; i < v.size(); i++) {
for (size_t j = 0; j < v[i].size(); j++) { for (size_t j = 0; j < v[i].size(); j++) {
res.first[0].push_back(v[i][j].first); // res.first[0].push_back(v[i][j].first);
res.first[0].push_back(v[i][j]);
if (return_edges) res.first[2].push_back(node_ids[i]); if (return_edges) res.first[2].push_back(node_ids[i]);
if (return_weight) res.second.push_back(v[i][j].second); if (return_weight) res.second.push_back(v1[i][j]);
} }
if (i == v.size() - 1) break; if (i == v.size() - 1) break;
......
...@@ -396,8 +396,8 @@ int32_t GraphTable::random_sample_nodes(int sample_size, ...@@ -396,8 +396,8 @@ int32_t GraphTable::random_sample_nodes(int sample_size,
} }
int32_t GraphTable::random_sample_neighbors( int32_t GraphTable::random_sample_neighbors(
uint64_t *node_ids, int sample_size, uint64_t *node_ids, int sample_size,
std::vector<std::shared_ptr<char>> &buffers, std::vector<std::shared_ptr<char>> &buffers, std::vector<int> &actual_sizes,
std::vector<int> &actual_sizes) { bool need_weight) {
size_t node_num = buffers.size(); size_t node_num = buffers.size();
std::function<void(char *)> char_del = [](char *c) { delete[] c; }; std::function<void(char *)> char_del = [](char *c) { delete[] c; };
std::vector<std::future<int>> tasks; std::vector<std::future<int>> tasks;
...@@ -407,7 +407,7 @@ int32_t GraphTable::random_sample_neighbors( ...@@ -407,7 +407,7 @@ int32_t GraphTable::random_sample_neighbors(
for (size_t idx = 0; idx < node_num; ++idx) { for (size_t idx = 0; idx < node_num; ++idx) {
index = get_thread_pool_index(node_ids[idx]); index = get_thread_pool_index(node_ids[idx]);
seq_id[index].emplace_back(idx); seq_id[index].emplace_back(idx);
id_list[index].emplace_back(node_ids[idx], sample_size); id_list[index].emplace_back(node_ids[idx], sample_size, need_weight);
} }
for (int i = 0; i < seq_id.size(); i++) { for (int i = 0; i < seq_id.size(); i++) {
if (seq_id[i].size() == 0) continue; if (seq_id[i].size() == 0) continue;
...@@ -442,13 +442,15 @@ int32_t GraphTable::random_sample_neighbors( ...@@ -442,13 +442,15 @@ int32_t GraphTable::random_sample_neighbors(
} }
std::shared_ptr<char> &buffer = buffers[idx]; std::shared_ptr<char> &buffer = buffers[idx];
std::vector<int> res = node->sample_k(sample_size, rng); std::vector<int> res = node->sample_k(sample_size, rng);
actual_size = res.size() * (Node::id_size + Node::weight_size); actual_size =
res.size() * (need_weight ? (Node::id_size + Node::weight_size)
: Node::id_size);
int offset = 0; int offset = 0;
uint64_t id; uint64_t id;
float weight; float weight;
char *buffer_addr = new char[actual_size]; char *buffer_addr = new char[actual_size];
if (response == LRUResponse::ok) { if (response == LRUResponse::ok) {
sample_keys.emplace_back(node_id, sample_size); sample_keys.emplace_back(node_id, sample_size, need_weight);
sample_res.emplace_back(actual_size, buffer_addr); sample_res.emplace_back(actual_size, buffer_addr);
buffer = sample_res.back().buffer; buffer = sample_res.back().buffer;
} else { } else {
...@@ -456,11 +458,13 @@ int32_t GraphTable::random_sample_neighbors( ...@@ -456,11 +458,13 @@ int32_t GraphTable::random_sample_neighbors(
} }
for (int &x : res) { for (int &x : res) {
id = node->get_neighbor_id(x); id = node->get_neighbor_id(x);
weight = node->get_neighbor_weight(x);
memcpy(buffer_addr + offset, &id, Node::id_size); memcpy(buffer_addr + offset, &id, Node::id_size);
offset += Node::id_size; offset += Node::id_size;
memcpy(buffer_addr + offset, &weight, Node::weight_size); if (need_weight) {
offset += Node::weight_size; weight = node->get_neighbor_weight(x);
memcpy(buffer_addr + offset, &weight, Node::weight_size);
offset += Node::weight_size;
}
} }
} }
} }
......
...@@ -80,10 +80,14 @@ enum LRUResponse { ok = 0, blocked = 1, err = 2 }; ...@@ -80,10 +80,14 @@ enum LRUResponse { ok = 0, blocked = 1, err = 2 };
struct SampleKey { struct SampleKey {
uint64_t node_key; uint64_t node_key;
size_t sample_size; size_t sample_size;
SampleKey(uint64_t _node_key, size_t _sample_size) bool is_weighted;
: node_key(_node_key), sample_size(_sample_size) {} SampleKey(uint64_t _node_key, size_t _sample_size, bool _is_weighted)
: node_key(_node_key),
sample_size(_sample_size),
is_weighted(_is_weighted) {}
bool operator==(const SampleKey &s) const { bool operator==(const SampleKey &s) const {
return node_key == s.node_key && sample_size == s.sample_size; return node_key == s.node_key && sample_size == s.sample_size &&
is_weighted == s.is_weighted;
} }
}; };
...@@ -360,7 +364,7 @@ class GraphTable : public SparseTable { ...@@ -360,7 +364,7 @@ class GraphTable : public SparseTable {
virtual int32_t random_sample_neighbors( virtual int32_t random_sample_neighbors(
uint64_t *node_ids, int sample_size, uint64_t *node_ids, int sample_size,
std::vector<std::shared_ptr<char>> &buffers, std::vector<std::shared_ptr<char>> &buffers,
std::vector<int> &actual_sizes); std::vector<int> &actual_sizes, bool need_weight);
int32_t random_sample_nodes(int sample_size, std::unique_ptr<char[]> &buffers, int32_t random_sample_nodes(int sample_size, std::unique_ptr<char[]> &buffers,
int &actual_sizes); int &actual_sizes);
......
...@@ -51,6 +51,7 @@ class Node { ...@@ -51,6 +51,7 @@ class Node {
protected: protected:
uint64_t id; uint64_t id;
bool is_weighted;
}; };
class GraphNode : public Node { class GraphNode : public Node {
......
...@@ -107,15 +107,16 @@ void testFeatureNodeSerializeFloat64() { ...@@ -107,15 +107,16 @@ void testFeatureNodeSerializeFloat64() {
void testSingleSampleNeighboor( void testSingleSampleNeighboor(
std::shared_ptr<paddle::distributed::GraphBrpcClient>& worker_ptr_) { std::shared_ptr<paddle::distributed::GraphBrpcClient>& worker_ptr_) {
std::vector<std::vector<std::pair<uint64_t, float>>> vs; std::vector<std::vector<uint64_t>> vs;
std::vector<std::vector<float>> vs1;
auto pull_status = worker_ptr_->batch_sample_neighbors( auto pull_status = worker_ptr_->batch_sample_neighbors(
0, std::vector<uint64_t>(1, 37), 4, vs); 0, std::vector<uint64_t>(1, 37), 4, vs, vs1, true);
pull_status.wait(); pull_status.wait();
std::unordered_set<uint64_t> s; std::unordered_set<uint64_t> s;
std::unordered_set<uint64_t> s1 = {112, 45, 145}; std::unordered_set<uint64_t> s1 = {112, 45, 145};
for (auto g : vs[0]) { for (auto g : vs[0]) {
s.insert(g.first); s.insert(g);
} }
ASSERT_EQ(s.size(), 3); ASSERT_EQ(s.size(), 3);
for (auto g : s) { for (auto g : s) {
...@@ -124,19 +125,21 @@ void testSingleSampleNeighboor( ...@@ -124,19 +125,21 @@ void testSingleSampleNeighboor(
s.clear(); s.clear();
s1.clear(); s1.clear();
vs.clear(); vs.clear();
vs1.clear();
pull_status = worker_ptr_->batch_sample_neighbors( pull_status = worker_ptr_->batch_sample_neighbors(
0, std::vector<uint64_t>(1, 96), 4, vs); 0, std::vector<uint64_t>(1, 96), 4, vs, vs1, true);
pull_status.wait(); pull_status.wait();
s1 = {111, 48, 247}; s1 = {111, 48, 247};
for (auto g : vs[0]) { for (auto g : vs[0]) {
s.insert(g.first); s.insert(g);
} }
ASSERT_EQ(s.size(), 3); ASSERT_EQ(s.size(), 3);
for (auto g : s) { for (auto g : s) {
ASSERT_EQ(true, s1.find(g) != s1.end()); ASSERT_EQ(true, s1.find(g) != s1.end());
} }
vs.clear(); vs.clear();
pull_status = worker_ptr_->batch_sample_neighbors(0, {96, 37}, 4, vs, 0); pull_status =
worker_ptr_->batch_sample_neighbors(0, {96, 37}, 4, vs, vs1, true, 0);
pull_status.wait(); pull_status.wait();
ASSERT_EQ(vs.size(), 2); ASSERT_EQ(vs.size(), 2);
} }
...@@ -194,14 +197,16 @@ void testAddNode( ...@@ -194,14 +197,16 @@ void testAddNode(
} }
void testBatchSampleNeighboor( void testBatchSampleNeighboor(
std::shared_ptr<paddle::distributed::GraphBrpcClient>& worker_ptr_) { std::shared_ptr<paddle::distributed::GraphBrpcClient>& worker_ptr_) {
std::vector<std::vector<std::pair<uint64_t, float>>> vs; std::vector<std::vector<uint64_t>> vs;
std::vector<std::vector<float>> vs1;
std::vector<std::uint64_t> v = {37, 96}; std::vector<std::uint64_t> v = {37, 96};
auto pull_status = worker_ptr_->batch_sample_neighbors(0, v, 4, vs); auto pull_status =
worker_ptr_->batch_sample_neighbors(0, v, 4, vs, vs1, false);
pull_status.wait(); pull_status.wait();
std::unordered_set<uint64_t> s; std::unordered_set<uint64_t> s;
std::unordered_set<uint64_t> s1 = {112, 45, 145}; std::unordered_set<uint64_t> s1 = {112, 45, 145};
for (auto g : vs[0]) { for (auto g : vs[0]) {
s.insert(g.first); s.insert(g);
} }
ASSERT_EQ(s.size(), 3); ASSERT_EQ(s.size(), 3);
for (auto g : s) { for (auto g : s) {
...@@ -211,7 +216,7 @@ void testBatchSampleNeighboor( ...@@ -211,7 +216,7 @@ void testBatchSampleNeighboor(
s1.clear(); s1.clear();
s1 = {111, 48, 247}; s1 = {111, 48, 247};
for (auto g : vs[1]) { for (auto g : vs[1]) {
s.insert(g.first); s.insert(g);
} }
ASSERT_EQ(s.size(), 3); ASSERT_EQ(s.size(), 3);
for (auto g : s) { for (auto g : s) {
...@@ -221,10 +226,6 @@ void testBatchSampleNeighboor( ...@@ -221,10 +226,6 @@ void testBatchSampleNeighboor(
void testCache(); void testCache();
void testGraphToBuffer(); void testGraphToBuffer();
// std::string nodes[] = {std::string("37\taa\t45;0.34\t145;0.31\t112;0.21"),
// std::string("96\tfeature\t48;1.4\t247;0.31\t111;1.21"),
// std::string("59\ttreat\t45;0.34\t145;0.31\t112;0.21"),
// std::string("97\tfood\t48;1.4\t247;0.31\t111;1.21")};
std::string edges[] = { std::string edges[] = {
std::string("37\t45\t0.34"), std::string("37\t145\t0.31"), std::string("37\t45\t0.34"), std::string("37\t145\t0.31"),
...@@ -427,15 +428,16 @@ void RunBrpcPushSparse() { ...@@ -427,15 +428,16 @@ void RunBrpcPushSparse() {
worker_ptr_->load(0, std::string(edge_file_name), std::string("e>")); worker_ptr_->load(0, std::string(edge_file_name), std::string("e>"));
srand(time(0)); srand(time(0));
pull_status.wait(); pull_status.wait();
std::vector<std::vector<std::pair<uint64_t, float>>> vs; std::vector<std::vector<uint64_t>> _vs;
std::vector<std::vector<float>> vs;
testSampleNodes(worker_ptr_); testSampleNodes(worker_ptr_);
sleep(5); sleep(5);
testSingleSampleNeighboor(worker_ptr_); testSingleSampleNeighboor(worker_ptr_);
testBatchSampleNeighboor(worker_ptr_); testBatchSampleNeighboor(worker_ptr_);
pull_status = worker_ptr_->batch_sample_neighbors( pull_status = worker_ptr_->batch_sample_neighbors(
0, std::vector<uint64_t>(1, 10240001024), 4, vs); 0, std::vector<uint64_t>(1, 10240001024), 4, _vs, vs, true);
pull_status.wait(); pull_status.wait();
ASSERT_EQ(0, vs[0].size()); ASSERT_EQ(0, _vs[0].size());
paddle::distributed::GraphTable* g = paddle::distributed::GraphTable* g =
(paddle::distributed::GraphTable*)pserver_ptr_->table(0); (paddle::distributed::GraphTable*)pserver_ptr_->table(0);
size_t ttl = 6; size_t ttl = 6;
...@@ -444,18 +446,19 @@ void RunBrpcPushSparse() { ...@@ -444,18 +446,19 @@ void RunBrpcPushSparse() {
while (round--) { while (round--) {
vs.clear(); vs.clear();
pull_status = worker_ptr_->batch_sample_neighbors( pull_status = worker_ptr_->batch_sample_neighbors(
0, std::vector<uint64_t>(1, 37), 1, vs); 0, std::vector<uint64_t>(1, 37), 1, _vs, vs, false);
pull_status.wait(); pull_status.wait();
for (int i = 0; i < ttl; i++) { for (int i = 0; i < ttl; i++) {
std::vector<std::vector<std::pair<uint64_t, float>>> vs1; std::vector<std::vector<uint64_t>> vs1;
std::vector<std::vector<float>> vs2;
pull_status = worker_ptr_->batch_sample_neighbors( pull_status = worker_ptr_->batch_sample_neighbors(
0, std::vector<uint64_t>(1, 37), 1, vs1); 0, std::vector<uint64_t>(1, 37), 1, vs1, vs2, false);
pull_status.wait(); pull_status.wait();
ASSERT_EQ(vs[0].size(), vs1[0].size()); ASSERT_EQ(_vs[0].size(), vs1[0].size());
for (int j = 0; j < vs[0].size(); j++) { for (int j = 0; j < _vs[0].size(); j++) {
ASSERT_EQ(vs[0][j].first, vs1[0][j].first); ASSERT_EQ(_vs[0][j], vs1[0][j]);
} }
} }
} }
...@@ -639,7 +642,7 @@ void testCache() { ...@@ -639,7 +642,7 @@ void testCache() {
strcpy(str, "54321"); strcpy(str, "54321");
::paddle::distributed::SampleResult* result = ::paddle::distributed::SampleResult* result =
new ::paddle::distributed::SampleResult(5, str); new ::paddle::distributed::SampleResult(5, str);
::paddle::distributed::SampleKey skey = {6, 1}; ::paddle::distributed::SampleKey skey = {6, 1, false};
std::vector<std::pair<::paddle::distributed::SampleKey, std::vector<std::pair<::paddle::distributed::SampleKey,
paddle::distributed::SampleResult>> paddle::distributed::SampleResult>>
r; r;
...@@ -695,4 +698,4 @@ void testGraphToBuffer() { ...@@ -695,4 +698,4 @@ void testGraphToBuffer() {
VLOG(0) << s1.get_feature(0); VLOG(0) << s1.get_feature(0);
} }
TEST(RunBrpcPushSparse, Run) { RunBrpcPushSparse(); } TEST(RunBrpcPushSparse, Run) { RunBrpcPushSparse(); }
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册