未验证 提交 9fc11db7 编写于 作者: S seemingwang 提交者: GitHub

optimize graph-engine sample api's data-transfer process (#37341)

* graph engine demo

* upload unsaved changes

* fix dependency error

* fix shard_num problem

* py client

* remove lock and graph-type

* add load direct graph

* add load direct graph

* add load direct graph

* batch random_sample

* batch_sample_k

* fix num_nodes size

* batch brpc

* batch brpc

* add test

* add test

* add load_nodes; change add_node function

* change sample return type to pair

* resolve conflict

* resolved conflict

* resolved conflict

* separate server and client

* merge pair type

* fix

* resolved conflict

* fixed segment fault; high-level VLOG for load edges and load nodes

* random_sample return 0

* rm useless loop

* test:load edge

* fix ret -1

* test: rm sample

* rm sample

* random_sample return future

* random_sample return int

* test fake node

* fixed here

* memory leak

* remove test code

* fix return problem

* add common_graph_table

* random sample node &test & change data-structure from linkedList to vector

* add common_graph_table

* sample with srand

* add node_types

* optimize nodes sample

* recover test

* random sample

* destruct weighted sampler

* GraphEdgeBlob

* WeightedGraphEdgeBlob to GraphEdgeBlob

* WeightedGraphEdgeBlob to GraphEdgeBlob

* pybind sample nodes api

* pull nodes with step

* fixed pull_graph_list bug; add test for pull_graph_list by step

* add graph table;name

* add graph table;name

* add pybind

* add pybind

* add FeatureNode

* add FeatureNode

* add FeatureNode Serialize

* add FeatureNode Serialize

* get_feat_node

* avoid local rpc

* fix get_node_feat

* fix get_node_feat

* remove log

* get_node_feat return  py:bytes

* merge develop with graph_engine

* fix threadpool.h head

* fix

* fix typo

* resolve conflict

* fix conflict

* recover lost content

* fix pybind of FeatureNode

* recover cmake

* recover tools

* resolve conflict

* resolve linking problem

* code style

* change test_server port

* fix code problems

* remove shard_num config

* remove redundent threads

* optimize start server

* remove logs

* fix code problems by reviewers' suggestions

* move graph files into a folder

* code style change

* remove graph operations from base table

* optimize get_feat function of graph engine

* fix long long count problem

* remove redandunt graph files

* remove unused shell

* recover dropout_op_pass.h

* fix potential stack overflow when request number is too large & node add & node clear & node remove

* when sample k is larger than neigbor num, return directly

* using random seed generator of paddle to speed up

* fix bug of random sample k

* fix code style

* fix code style

* add remove graph to fleet_py.cc

* fix blocking_queue problem

* fix style

* fix

* recover capacity check

* add remove graph node; add set_feature

* add remove graph node; add set_feature

* add remove graph node; add set_feature

* add remove graph node; add set_feature

* fix distributed op combining problems

* optimize

* remove logs

* fix MultiSlotDataGenerator error

* cache for graph engine

* fix type compare error

* more test&fix thread terminating problem

* remove header

* change time interval of shrink

* use cache when sample nodes

* remove unused function

* change unique_ptr to shared_ptr

* simplify cache template

* cache api on client

* fix

* reduce sample threads when cache is not used

* reduce cache memory

* cache optimization

* remove test function

* remove extra fetch function

* graph-engine data transfer optimization
Co-authored-by: NHuang Zhengjie <270018958@qq.com>
Co-authored-by: NWeiyue Su <weiyue.su@gmail.com>
Co-authored-by: Nsuweiyue <suweiyue@baidu.com>
Co-authored-by: Nluobin06 <luobin06@baidu.com>
Co-authored-by: Nliweibin02 <liweibin02@baidu.com>
Co-authored-by: Ntangwei12 <tangwei12@baidu.com>
上级 c13edf66
......@@ -304,10 +304,15 @@ std::future<int32_t> GraphBrpcClient::remove_graph_node(
// char* &buffer,int &actual_size
std::future<int32_t> GraphBrpcClient::batch_sample_neighbors(
uint32_t table_id, std::vector<uint64_t> node_ids, int sample_size,
std::vector<std::vector<std::pair<uint64_t, float>>> &res,
// std::vector<std::vector<std::pair<uint64_t, float>>> &res,
std::vector<std::vector<uint64_t>> &res,
std::vector<std::vector<float>> &res_weight, bool need_weight,
int server_index) {
if (server_index != -1) {
res.resize(node_ids.size());
if (need_weight) {
res_weight.resize(node_ids.size());
}
DownpourBrpcClosure *closure = new DownpourBrpcClosure(1, [&](void *done) {
int ret = 0;
auto *closure = (DownpourBrpcClosure *)done;
......@@ -331,11 +336,14 @@ std::future<int32_t> GraphBrpcClient::batch_sample_neighbors(
int actual_size = actual_sizes[node_idx];
int start = 0;
while (start < actual_size) {
res[node_idx].push_back(
{*(uint64_t *)(node_buffer + offset + start),
*(float *)(node_buffer + offset + start +
GraphNode::id_size)});
start += GraphNode::id_size + GraphNode::weight_size;
res[node_idx].emplace_back(
*(uint64_t *)(node_buffer + offset + start));
start += GraphNode::id_size;
if (need_weight) {
res_weight[node_idx].emplace_back(
*(float *)(node_buffer + offset + start));
start += GraphNode::weight_size;
}
}
offset += actual_size;
}
......@@ -352,6 +360,7 @@ std::future<int32_t> GraphBrpcClient::batch_sample_neighbors(
closure->request(0)->add_params((char *)node_ids.data(),
sizeof(uint64_t) * node_ids.size());
closure->request(0)->add_params((char *)&sample_size, sizeof(int));
closure->request(0)->add_params((char *)&need_weight, sizeof(bool));
;
// PsService_Stub rpc_stub(get_cmd_channel(server_index));
GraphPsService_Stub rpc_stub =
......@@ -364,13 +373,18 @@ std::future<int32_t> GraphBrpcClient::batch_sample_neighbors(
std::vector<int> request2server;
std::vector<int> server2request(server_size, -1);
res.clear();
res_weight.clear();
for (int query_idx = 0; query_idx < node_ids.size(); ++query_idx) {
int server_index = get_server_index_by_id(node_ids[query_idx]);
if (server2request[server_index] == -1) {
server2request[server_index] = request2server.size();
request2server.push_back(server_index);
}
res.push_back(std::vector<std::pair<uint64_t, float>>());
// res.push_back(std::vector<std::pair<uint64_t, float>>());
res.push_back({});
if (need_weight) {
res_weight.push_back({});
}
}
size_t request_call_num = request2server.size();
std::vector<std::vector<uint64_t>> node_id_buckets(request_call_num);
......@@ -413,11 +427,14 @@ std::future<int32_t> GraphBrpcClient::batch_sample_neighbors(
int actual_size = actual_sizes[node_idx];
int start = 0;
while (start < actual_size) {
res[query_idx].push_back(
{*(uint64_t *)(node_buffer + offset + start),
*(float *)(node_buffer + offset + start +
GraphNode::id_size)});
start += GraphNode::id_size + GraphNode::weight_size;
res[query_idx].emplace_back(
*(uint64_t *)(node_buffer + offset + start));
start += GraphNode::id_size;
if (need_weight) {
res_weight[query_idx].emplace_back(
*(float *)(node_buffer + offset + start));
start += GraphNode::weight_size;
}
}
offset += actual_size;
}
......@@ -445,6 +462,8 @@ std::future<int32_t> GraphBrpcClient::batch_sample_neighbors(
sizeof(uint64_t) * node_num);
closure->request(request_idx)
->add_params((char *)&sample_size, sizeof(int));
closure->request(request_idx)
->add_params((char *)&need_weight, sizeof(bool));
// PsService_Stub rpc_stub(get_cmd_channel(server_index));
GraphPsService_Stub rpc_stub =
getServiceStub(get_cmd_channel(server_index));
......
......@@ -64,7 +64,8 @@ class GraphBrpcClient : public BrpcPsClient {
// given a batch of nodes, sample graph_neighbors for each of them
virtual std::future<int32_t> batch_sample_neighbors(
uint32_t table_id, std::vector<uint64_t> node_ids, int sample_size,
std::vector<std::vector<std::pair<uint64_t, float>>>& res,
std::vector<std::vector<uint64_t>>& res,
std::vector<std::vector<float>>& res_weight, bool need_weight,
int server_index = -1);
virtual std::future<int32_t> pull_graph_list(uint32_t table_id,
......
......@@ -378,19 +378,21 @@ int32_t GraphBrpcService::graph_random_sample_neighbors(
Table *table, const PsRequestMessage &request, PsResponseMessage &response,
brpc::Controller *cntl) {
CHECK_TABLE_EXIST(table, request, response)
if (request.params_size() < 2) {
if (request.params_size() < 3) {
set_response_code(
response, -1,
"graph_random_sample request requires at least 2 arguments");
"graph_random_sample_neighbors request requires at least 3 arguments");
return 0;
}
size_t node_num = request.params(0).size() / sizeof(uint64_t);
uint64_t *node_data = (uint64_t *)(request.params(0).c_str());
int sample_size = *(uint64_t *)(request.params(1).c_str());
bool need_weight = *(bool *)(request.params(2).c_str());
std::vector<std::shared_ptr<char>> buffers(node_num);
std::vector<int> actual_sizes(node_num, 0);
((GraphTable *)table)
->random_sample_neighbors(node_data, sample_size, buffers, actual_sizes);
->random_sample_neighbors(node_data, sample_size, buffers, actual_sizes,
need_weight);
cntl->response_attachment().append(&node_num, sizeof(size_t));
cntl->response_attachment().append(actual_sizes.data(),
......@@ -454,16 +456,17 @@ int32_t GraphBrpcService::sample_neighbors_across_multi_servers(
brpc::Controller *cntl) {
// sleep(5);
CHECK_TABLE_EXIST(table, request, response)
if (request.params_size() < 2) {
set_response_code(
response, -1,
"graph_random_neighbors_sample request requires at least 2 arguments");
if (request.params_size() < 3) {
set_response_code(response, -1,
"sample_neighbors_across_multi_servers request requires "
"at least 3 arguments");
return 0;
}
size_t node_num = request.params(0).size() / sizeof(uint64_t),
size_of_size_t = sizeof(size_t);
uint64_t *node_data = (uint64_t *)(request.params(0).c_str());
int sample_size = *(uint64_t *)(request.params(1).c_str());
bool need_weight = *(uint64_t *)(request.params(2).c_str());
// std::vector<uint64_t> res = ((GraphTable
// *)table).filter_out_non_exist_nodes(node_data, sample_size);
std::vector<int> request2server;
......@@ -581,6 +584,8 @@ int32_t GraphBrpcService::sample_neighbors_across_multi_servers(
sizeof(uint64_t) * node_num);
closure->request(request_idx)
->add_params((char *)&sample_size, sizeof(int));
closure->request(request_idx)
->add_params((char *)&need_weight, sizeof(bool));
PsService_Stub rpc_stub(
((GraphBrpcServer *)get_server())->get_cmd_channel(server_index));
// GraphPsService_Stub rpc_stub =
......@@ -592,7 +597,8 @@ int32_t GraphBrpcService::sample_neighbors_across_multi_servers(
if (server2request[rank] != -1) {
((GraphTable *)table)
->random_sample_neighbors(node_id_buckets.back().data(), sample_size,
local_buffers, local_actual_sizes);
local_buffers, local_actual_sizes,
need_weight);
}
local_promise.get()->set_value(0);
if (remote_call_num == 0) func(closure);
......
......@@ -295,11 +295,13 @@ GraphPyClient::batch_sample_neighbors(std::string name,
std::vector<uint64_t> node_ids,
int sample_size, bool return_weight,
bool return_edges) {
std::vector<std::vector<std::pair<uint64_t, float>>> v;
// std::vector<std::vector<std::pair<uint64_t, float>>> v;
std::vector<std::vector<uint64_t>> v;
std::vector<std::vector<float>> v1;
if (this->table_id_map.count(name)) {
uint32_t table_id = this->table_id_map[name];
auto status =
worker_ptr->batch_sample_neighbors(table_id, node_ids, sample_size, v);
auto status = worker_ptr->batch_sample_neighbors(
table_id, node_ids, sample_size, v, v1, return_weight);
status.wait();
}
......@@ -313,9 +315,10 @@ GraphPyClient::batch_sample_neighbors(std::string name,
if (return_edges) res.first.push_back({});
for (size_t i = 0; i < v.size(); i++) {
for (size_t j = 0; j < v[i].size(); j++) {
res.first[0].push_back(v[i][j].first);
// res.first[0].push_back(v[i][j].first);
res.first[0].push_back(v[i][j]);
if (return_edges) res.first[2].push_back(node_ids[i]);
if (return_weight) res.second.push_back(v[i][j].second);
if (return_weight) res.second.push_back(v1[i][j]);
}
if (i == v.size() - 1) break;
......
......@@ -396,8 +396,8 @@ int32_t GraphTable::random_sample_nodes(int sample_size,
}
int32_t GraphTable::random_sample_neighbors(
uint64_t *node_ids, int sample_size,
std::vector<std::shared_ptr<char>> &buffers,
std::vector<int> &actual_sizes) {
std::vector<std::shared_ptr<char>> &buffers, std::vector<int> &actual_sizes,
bool need_weight) {
size_t node_num = buffers.size();
std::function<void(char *)> char_del = [](char *c) { delete[] c; };
std::vector<std::future<int>> tasks;
......@@ -407,7 +407,7 @@ int32_t GraphTable::random_sample_neighbors(
for (size_t idx = 0; idx < node_num; ++idx) {
index = get_thread_pool_index(node_ids[idx]);
seq_id[index].emplace_back(idx);
id_list[index].emplace_back(node_ids[idx], sample_size);
id_list[index].emplace_back(node_ids[idx], sample_size, need_weight);
}
for (int i = 0; i < seq_id.size(); i++) {
if (seq_id[i].size() == 0) continue;
......@@ -442,13 +442,15 @@ int32_t GraphTable::random_sample_neighbors(
}
std::shared_ptr<char> &buffer = buffers[idx];
std::vector<int> res = node->sample_k(sample_size, rng);
actual_size = res.size() * (Node::id_size + Node::weight_size);
actual_size =
res.size() * (need_weight ? (Node::id_size + Node::weight_size)
: Node::id_size);
int offset = 0;
uint64_t id;
float weight;
char *buffer_addr = new char[actual_size];
if (response == LRUResponse::ok) {
sample_keys.emplace_back(node_id, sample_size);
sample_keys.emplace_back(node_id, sample_size, need_weight);
sample_res.emplace_back(actual_size, buffer_addr);
buffer = sample_res.back().buffer;
} else {
......@@ -456,14 +458,16 @@ int32_t GraphTable::random_sample_neighbors(
}
for (int &x : res) {
id = node->get_neighbor_id(x);
weight = node->get_neighbor_weight(x);
memcpy(buffer_addr + offset, &id, Node::id_size);
offset += Node::id_size;
if (need_weight) {
weight = node->get_neighbor_weight(x);
memcpy(buffer_addr + offset, &weight, Node::weight_size);
offset += Node::weight_size;
}
}
}
}
if (sample_res.size()) {
scaled_lru->insert(i, sample_keys.data(), sample_res.data(),
sample_keys.size());
......
......@@ -80,10 +80,14 @@ enum LRUResponse { ok = 0, blocked = 1, err = 2 };
struct SampleKey {
uint64_t node_key;
size_t sample_size;
SampleKey(uint64_t _node_key, size_t _sample_size)
: node_key(_node_key), sample_size(_sample_size) {}
bool is_weighted;
SampleKey(uint64_t _node_key, size_t _sample_size, bool _is_weighted)
: node_key(_node_key),
sample_size(_sample_size),
is_weighted(_is_weighted) {}
bool operator==(const SampleKey &s) const {
return node_key == s.node_key && sample_size == s.sample_size;
return node_key == s.node_key && sample_size == s.sample_size &&
is_weighted == s.is_weighted;
}
};
......@@ -360,7 +364,7 @@ class GraphTable : public SparseTable {
virtual int32_t random_sample_neighbors(
uint64_t *node_ids, int sample_size,
std::vector<std::shared_ptr<char>> &buffers,
std::vector<int> &actual_sizes);
std::vector<int> &actual_sizes, bool need_weight);
int32_t random_sample_nodes(int sample_size, std::unique_ptr<char[]> &buffers,
int &actual_sizes);
......
......@@ -51,6 +51,7 @@ class Node {
protected:
uint64_t id;
bool is_weighted;
};
class GraphNode : public Node {
......
......@@ -107,15 +107,16 @@ void testFeatureNodeSerializeFloat64() {
void testSingleSampleNeighboor(
std::shared_ptr<paddle::distributed::GraphBrpcClient>& worker_ptr_) {
std::vector<std::vector<std::pair<uint64_t, float>>> vs;
std::vector<std::vector<uint64_t>> vs;
std::vector<std::vector<float>> vs1;
auto pull_status = worker_ptr_->batch_sample_neighbors(
0, std::vector<uint64_t>(1, 37), 4, vs);
0, std::vector<uint64_t>(1, 37), 4, vs, vs1, true);
pull_status.wait();
std::unordered_set<uint64_t> s;
std::unordered_set<uint64_t> s1 = {112, 45, 145};
for (auto g : vs[0]) {
s.insert(g.first);
s.insert(g);
}
ASSERT_EQ(s.size(), 3);
for (auto g : s) {
......@@ -124,19 +125,21 @@ void testSingleSampleNeighboor(
s.clear();
s1.clear();
vs.clear();
vs1.clear();
pull_status = worker_ptr_->batch_sample_neighbors(
0, std::vector<uint64_t>(1, 96), 4, vs);
0, std::vector<uint64_t>(1, 96), 4, vs, vs1, true);
pull_status.wait();
s1 = {111, 48, 247};
for (auto g : vs[0]) {
s.insert(g.first);
s.insert(g);
}
ASSERT_EQ(s.size(), 3);
for (auto g : s) {
ASSERT_EQ(true, s1.find(g) != s1.end());
}
vs.clear();
pull_status = worker_ptr_->batch_sample_neighbors(0, {96, 37}, 4, vs, 0);
pull_status =
worker_ptr_->batch_sample_neighbors(0, {96, 37}, 4, vs, vs1, true, 0);
pull_status.wait();
ASSERT_EQ(vs.size(), 2);
}
......@@ -194,14 +197,16 @@ void testAddNode(
}
void testBatchSampleNeighboor(
std::shared_ptr<paddle::distributed::GraphBrpcClient>& worker_ptr_) {
std::vector<std::vector<std::pair<uint64_t, float>>> vs;
std::vector<std::vector<uint64_t>> vs;
std::vector<std::vector<float>> vs1;
std::vector<std::uint64_t> v = {37, 96};
auto pull_status = worker_ptr_->batch_sample_neighbors(0, v, 4, vs);
auto pull_status =
worker_ptr_->batch_sample_neighbors(0, v, 4, vs, vs1, false);
pull_status.wait();
std::unordered_set<uint64_t> s;
std::unordered_set<uint64_t> s1 = {112, 45, 145};
for (auto g : vs[0]) {
s.insert(g.first);
s.insert(g);
}
ASSERT_EQ(s.size(), 3);
for (auto g : s) {
......@@ -211,7 +216,7 @@ void testBatchSampleNeighboor(
s1.clear();
s1 = {111, 48, 247};
for (auto g : vs[1]) {
s.insert(g.first);
s.insert(g);
}
ASSERT_EQ(s.size(), 3);
for (auto g : s) {
......@@ -221,10 +226,6 @@ void testBatchSampleNeighboor(
void testCache();
void testGraphToBuffer();
// std::string nodes[] = {std::string("37\taa\t45;0.34\t145;0.31\t112;0.21"),
// std::string("96\tfeature\t48;1.4\t247;0.31\t111;1.21"),
// std::string("59\ttreat\t45;0.34\t145;0.31\t112;0.21"),
// std::string("97\tfood\t48;1.4\t247;0.31\t111;1.21")};
std::string edges[] = {
std::string("37\t45\t0.34"), std::string("37\t145\t0.31"),
......@@ -427,15 +428,16 @@ void RunBrpcPushSparse() {
worker_ptr_->load(0, std::string(edge_file_name), std::string("e>"));
srand(time(0));
pull_status.wait();
std::vector<std::vector<std::pair<uint64_t, float>>> vs;
std::vector<std::vector<uint64_t>> _vs;
std::vector<std::vector<float>> vs;
testSampleNodes(worker_ptr_);
sleep(5);
testSingleSampleNeighboor(worker_ptr_);
testBatchSampleNeighboor(worker_ptr_);
pull_status = worker_ptr_->batch_sample_neighbors(
0, std::vector<uint64_t>(1, 10240001024), 4, vs);
0, std::vector<uint64_t>(1, 10240001024), 4, _vs, vs, true);
pull_status.wait();
ASSERT_EQ(0, vs[0].size());
ASSERT_EQ(0, _vs[0].size());
paddle::distributed::GraphTable* g =
(paddle::distributed::GraphTable*)pserver_ptr_->table(0);
size_t ttl = 6;
......@@ -444,18 +446,19 @@ void RunBrpcPushSparse() {
while (round--) {
vs.clear();
pull_status = worker_ptr_->batch_sample_neighbors(
0, std::vector<uint64_t>(1, 37), 1, vs);
0, std::vector<uint64_t>(1, 37), 1, _vs, vs, false);
pull_status.wait();
for (int i = 0; i < ttl; i++) {
std::vector<std::vector<std::pair<uint64_t, float>>> vs1;
std::vector<std::vector<uint64_t>> vs1;
std::vector<std::vector<float>> vs2;
pull_status = worker_ptr_->batch_sample_neighbors(
0, std::vector<uint64_t>(1, 37), 1, vs1);
0, std::vector<uint64_t>(1, 37), 1, vs1, vs2, false);
pull_status.wait();
ASSERT_EQ(vs[0].size(), vs1[0].size());
ASSERT_EQ(_vs[0].size(), vs1[0].size());
for (int j = 0; j < vs[0].size(); j++) {
ASSERT_EQ(vs[0][j].first, vs1[0][j].first);
for (int j = 0; j < _vs[0].size(); j++) {
ASSERT_EQ(_vs[0][j], vs1[0][j]);
}
}
}
......@@ -639,7 +642,7 @@ void testCache() {
strcpy(str, "54321");
::paddle::distributed::SampleResult* result =
new ::paddle::distributed::SampleResult(5, str);
::paddle::distributed::SampleKey skey = {6, 1};
::paddle::distributed::SampleKey skey = {6, 1, false};
std::vector<std::pair<::paddle::distributed::SampleKey,
paddle::distributed::SampleResult>>
r;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册