未验证 提交 60ac1602 编写于 作者: S seemingwang 提交者: GitHub

fix potential overflow problem & node add & node remove & node clear (#33055)

* graph engine demo

* upload unsaved changes

* fix dependency error

* fix shard_num problem

* py client

* remove lock and graph-type

* add load direct graph

* add load direct graph

* add load direct graph

* batch random_sample

* batch_sample_k

* fix num_nodes size

* batch brpc

* batch brpc

* add test

* add test

* add load_nodes; change add_node function

* change sample return type to pair

* resolve conflict

* resolved conflict

* resolved conflict

* separate server and client

* merge pair type

* fix

* resolved conflict

* fixed segment fault; high-level VLOG for load edges and load nodes

* random_sample return 0

* rm useless loop

* test:load edge

* fix ret -1

* test: rm sample

* rm sample

* random_sample return future

* random_sample return int

* test fake node

* fixed here

* memory leak

* remove test code

* fix return problem

* add common_graph_table

* random sample node &test & change data-structure from linkedList to vector

* add common_graph_table

* sample with srand

* add node_types

* optimize nodes sample

* recover test

* random sample

* destruct weighted sampler

* GraphEdgeBlob

* WeightedGraphEdgeBlob to GraphEdgeBlob

* WeightedGraphEdgeBlob to GraphEdgeBlob

* pybind sample nodes api

* pull nodes with step

* fixed pull_graph_list bug; add test for pull_graph_list by step

* add graph table;name

* add graph table;name

* add pybind

* add pybind

* add FeatureNode

* add FeatureNode

* add FeatureNode Serialize

* add FeatureNode Serialize

* get_feat_node

* avoid local rpc

* fix get_node_feat

* fix get_node_feat

* remove log

* get_node_feat return  py:bytes

* merge develop with graph_engine

* fix threadpool.h head

* fix

* fix typo

* resolve conflict

* fix conflict

* recover lost content

* fix pybind of FeatureNode

* recover cmake

* recover tools

* resolve conflict

* resolve linking problem

* code style

* change test_server port

* fix code problems

* remove shard_num config

* remove redundent threads

* optimize start server

* remove logs

* fix code problems by reviewers' suggestions

* move graph files into a folder

* code style change

* remove graph operations from base table

* optimize get_feat function of graph engine

* fix long long count problem

* remove redandunt graph files

* remove unused shell

* recover dropout_op_pass.h

* fix potential stack overflow when request number is too large & node add & node clear & node remove
Co-authored-by: NHuang Zhengjie <270018958@qq.com>
Co-authored-by: NWeiyue Su <weiyue.su@gmail.com>
Co-authored-by: Nsuweiyue <suweiyue@baidu.com>
Co-authored-by: Nluobin06 <luobin06@baidu.com>
Co-authored-by: Nliweibin02 <liweibin02@baidu.com>
Co-authored-by: Ntangwei12 <tangwei12@baidu.com>
上级 d6aea4ac
......@@ -80,11 +80,11 @@ std::future<int32_t> GraphBrpcClient::get_node_feat(
[&, node_id_buckets, query_idx_buckets, request_call_num](void *done) {
int ret = 0;
auto *closure = (DownpourBrpcClosure *)done;
int fail_num = 0;
size_t fail_num = 0;
for (int request_idx = 0; request_idx < request_call_num;
++request_idx) {
if (closure->check_response(request_idx,
PS_GRAPH_SAMPLE_NEIGHBOORS) != 0) {
if (closure->check_response(request_idx, PS_GRAPH_GET_NODE_FEAT) !=
0) {
++fail_num;
} else {
auto &res_io_buffer =
......@@ -144,6 +144,163 @@ std::future<int32_t> GraphBrpcClient::get_node_feat(
return fut;
}
std::future<int32_t> GraphBrpcClient::clear_nodes(uint32_t table_id) {
DownpourBrpcClosure *closure = new DownpourBrpcClosure(
server_size, [&, server_size = this->server_size ](void *done) {
int ret = 0;
auto *closure = (DownpourBrpcClosure *)done;
size_t fail_num = 0;
for (size_t request_idx = 0; request_idx < server_size; ++request_idx) {
if (closure->check_response(request_idx, PS_GRAPH_CLEAR) != 0) {
++fail_num;
break;
}
}
ret = fail_num == 0 ? 0 : -1;
closure->set_promise_value(ret);
});
auto promise = std::make_shared<std::promise<int32_t>>();
closure->add_promise(promise);
std::future<int> fut = promise->get_future();
for (size_t i = 0; i < server_size; i++) {
int server_index = i;
closure->request(server_index)->set_cmd_id(PS_GRAPH_CLEAR);
closure->request(server_index)->set_table_id(table_id);
closure->request(server_index)->set_client_id(_client_id);
GraphPsService_Stub rpc_stub =
getServiceStub(get_cmd_channel(server_index));
closure->cntl(server_index)->set_log_id(butil::gettimeofday_ms());
rpc_stub.service(closure->cntl(server_index),
closure->request(server_index),
closure->response(server_index), closure);
}
return fut;
}
std::future<int32_t> GraphBrpcClient::add_graph_node(
uint32_t table_id, std::vector<uint64_t> &node_id_list,
std::vector<bool> &is_weighted_list) {
std::vector<std::vector<uint64_t>> request_bucket;
std::vector<std::vector<bool>> is_weighted_bucket;
bool add_weight = is_weighted_list.size() > 0;
std::vector<int> server_index_arr;
std::vector<int> index_mapping(server_size, -1);
for (size_t query_idx = 0; query_idx < node_id_list.size(); ++query_idx) {
int server_index = get_server_index_by_id(node_id_list[query_idx]);
if (index_mapping[server_index] == -1) {
index_mapping[server_index] = request_bucket.size();
server_index_arr.push_back(server_index);
request_bucket.push_back(std::vector<uint64_t>());
if (add_weight) is_weighted_bucket.push_back(std::vector<bool>());
}
request_bucket[index_mapping[server_index]].push_back(
node_id_list[query_idx]);
if (add_weight)
is_weighted_bucket[index_mapping[server_index]].push_back(
query_idx < is_weighted_list.size() ? is_weighted_list[query_idx]
: false);
}
size_t request_call_num = request_bucket.size();
DownpourBrpcClosure *closure = new DownpourBrpcClosure(
request_call_num, [&, request_call_num](void *done) {
int ret = 0;
auto *closure = (DownpourBrpcClosure *)done;
size_t fail_num = 0;
for (size_t request_idx = 0; request_idx < request_call_num;
++request_idx) {
if (closure->check_response(request_idx, PS_GRAPH_ADD_GRAPH_NODE) !=
0) {
++fail_num;
}
}
ret = fail_num == request_call_num ? -1 : 0;
closure->set_promise_value(ret);
});
auto promise = std::make_shared<std::promise<int32_t>>();
closure->add_promise(promise);
std::future<int> fut = promise->get_future();
for (size_t request_idx = 0; request_idx < request_call_num; ++request_idx) {
int server_index = server_index_arr[request_idx];
closure->request(request_idx)->set_cmd_id(PS_GRAPH_ADD_GRAPH_NODE);
closure->request(request_idx)->set_table_id(table_id);
closure->request(request_idx)->set_client_id(_client_id);
size_t node_num = request_bucket[request_idx].size();
closure->request(request_idx)
->add_params((char *)request_bucket[request_idx].data(),
sizeof(uint64_t) * node_num);
if (add_weight) {
bool weighted[is_weighted_bucket[request_idx].size() + 1];
for (size_t j = 0; j < is_weighted_bucket[request_idx].size(); j++)
weighted[j] = is_weighted_bucket[request_idx][j];
closure->request(request_idx)
->add_params((char *)weighted,
sizeof(bool) * is_weighted_bucket[request_idx].size());
}
// PsService_Stub rpc_stub(get_cmd_channel(server_index));
GraphPsService_Stub rpc_stub =
getServiceStub(get_cmd_channel(server_index));
closure->cntl(request_idx)->set_log_id(butil::gettimeofday_ms());
rpc_stub.service(closure->cntl(request_idx), closure->request(request_idx),
closure->response(request_idx), closure);
}
return fut;
}
std::future<int32_t> GraphBrpcClient::remove_graph_node(
uint32_t table_id, std::vector<uint64_t> &node_id_list) {
std::vector<std::vector<uint64_t>> request_bucket;
std::vector<int> server_index_arr;
std::vector<int> index_mapping(server_size, -1);
for (size_t query_idx = 0; query_idx < node_id_list.size(); ++query_idx) {
int server_index = get_server_index_by_id(node_id_list[query_idx]);
if (index_mapping[server_index] == -1) {
index_mapping[server_index] = request_bucket.size();
server_index_arr.push_back(server_index);
request_bucket.push_back(std::vector<uint64_t>());
}
request_bucket[index_mapping[server_index]].push_back(
node_id_list[query_idx]);
}
size_t request_call_num = request_bucket.size();
DownpourBrpcClosure *closure = new DownpourBrpcClosure(
request_call_num, [&, request_call_num](void *done) {
int ret = 0;
auto *closure = (DownpourBrpcClosure *)done;
int fail_num = 0;
for (size_t request_idx = 0; request_idx < request_call_num;
++request_idx) {
if (closure->check_response(request_idx,
PS_GRAPH_REMOVE_GRAPH_NODE) != 0) {
++fail_num;
}
}
ret = fail_num == request_call_num ? -1 : 0;
closure->set_promise_value(ret);
});
auto promise = std::make_shared<std::promise<int32_t>>();
closure->add_promise(promise);
std::future<int> fut = promise->get_future();
for (size_t request_idx = 0; request_idx < request_call_num; ++request_idx) {
int server_index = server_index_arr[request_idx];
closure->request(request_idx)->set_cmd_id(PS_GRAPH_REMOVE_GRAPH_NODE);
closure->request(request_idx)->set_table_id(table_id);
closure->request(request_idx)->set_client_id(_client_id);
size_t node_num = request_bucket[request_idx].size();
closure->request(request_idx)
->add_params((char *)request_bucket[request_idx].data(),
sizeof(uint64_t) * node_num);
// PsService_Stub rpc_stub(get_cmd_channel(server_index));
GraphPsService_Stub rpc_stub =
getServiceStub(get_cmd_channel(server_index));
closure->cntl(request_idx)->set_log_id(butil::gettimeofday_ms());
rpc_stub.service(closure->cntl(request_idx), closure->request(request_idx),
closure->response(request_idx), closure);
}
return fut;
}
// char* &buffer,int &actual_size
std::future<int32_t> GraphBrpcClient::batch_sample_neighboors(
uint32_t table_id, std::vector<uint64_t> node_ids, int sample_size,
......@@ -174,8 +331,8 @@ std::future<int32_t> GraphBrpcClient::batch_sample_neighboors(
[&, node_id_buckets, query_idx_buckets, request_call_num](void *done) {
int ret = 0;
auto *closure = (DownpourBrpcClosure *)done;
int fail_num = 0;
for (int request_idx = 0; request_idx < request_call_num;
size_t fail_num = 0;
for (size_t request_idx = 0; request_idx < request_call_num;
++request_idx) {
if (closure->check_response(request_idx,
PS_GRAPH_SAMPLE_NEIGHBOORS) != 0) {
......@@ -254,13 +411,14 @@ std::future<int32_t> GraphBrpcClient::random_sample_nodes(
auto &res_io_buffer = closure->cntl(0)->response_attachment();
butil::IOBufBytesIterator io_buffer_itr(res_io_buffer);
size_t bytes_size = io_buffer_itr.bytes_left();
char buffer[bytes_size];
char *buffer = new char[bytes_size];
auto size = io_buffer_itr.copy_and_forward((void *)(buffer), bytes_size);
int index = 0;
while (index < bytes_size) {
ids.push_back(*(uint64_t *)(buffer + index));
index += GraphNode::id_size;
}
delete[] buffer;
}
closure->set_promise_value(ret);
});
......@@ -292,7 +450,7 @@ std::future<int32_t> GraphBrpcClient::pull_graph_list(
auto &res_io_buffer = closure->cntl(0)->response_attachment();
butil::IOBufBytesIterator io_buffer_itr(res_io_buffer);
size_t bytes_size = io_buffer_itr.bytes_left();
char buffer[bytes_size];
char *buffer = new char[bytes_size];
io_buffer_itr.copy_and_forward((void *)(buffer), bytes_size);
int index = 0;
while (index < bytes_size) {
......@@ -301,6 +459,7 @@ std::future<int32_t> GraphBrpcClient::pull_graph_list(
index += node.get_size(false);
res.push_back(node);
}
delete buffer;
}
closure->set_promise_value(ret);
});
......
......@@ -78,6 +78,13 @@ class GraphBrpcClient : public BrpcPsClient {
const uint32_t& table_id, const std::vector<uint64_t>& node_ids,
const std::vector<std::string>& feature_names,
std::vector<std::vector<std::string>>& res);
virtual std::future<int32_t> clear_nodes(uint32_t table_id);
virtual std::future<int32_t> add_graph_node(
uint32_t table_id, std::vector<uint64_t>& node_id_list,
std::vector<bool>& is_weighted_list);
virtual std::future<int32_t> remove_graph_node(
uint32_t table_id, std::vector<uint64_t>& node_id_list);
virtual int32_t initialize();
int get_shard_num() { return shard_num; }
void set_shard_num(int shard_num) { this->shard_num = shard_num; }
......
......@@ -24,6 +24,14 @@
namespace paddle {
namespace distributed {
#define CHECK_TABLE_EXIST(table, request, response) \
if (table == NULL) { \
std::string err_msg("table not found with table_id:"); \
err_msg.append(std::to_string(request.table_id())); \
set_response_code(response, -1, err_msg.c_str()); \
return -1; \
}
int32_t GraphBrpcServer::initialize() {
auto &service_config = _config.downpour_server_param().service_param();
if (!service_config.has_service_class()) {
......@@ -71,6 +79,58 @@ uint64_t GraphBrpcServer::start(const std::string &ip, uint32_t port) {
return 0;
}
int32_t GraphBrpcService::clear_nodes(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
((GraphTable *)table)->clear_nodes();
return 0;
}
int32_t GraphBrpcService::add_graph_node(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
CHECK_TABLE_EXIST(table, request, response)
if (request.params_size() < 1) {
set_response_code(
response, -1,
"graph_get_node_feat request requires at least 2 arguments");
return 0;
}
size_t node_num = request.params(0).size() / sizeof(uint64_t);
uint64_t *node_data = (uint64_t *)(request.params(0).c_str());
std::vector<uint64_t> node_ids(node_data, node_data + node_num);
std::vector<bool> is_weighted_list;
if (request.params_size() == 2) {
size_t weight_list_size = request.params(1).size() / sizeof(bool);
bool *is_weighted_buffer = (bool *)(request.params(1).c_str());
is_weighted_list = std::vector<bool>(is_weighted_buffer,
is_weighted_buffer + weight_list_size);
}
((GraphTable *)table)->add_graph_node(node_ids, is_weighted_list);
return 0;
}
int32_t GraphBrpcService::remove_graph_node(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
CHECK_TABLE_EXIST(table, request, response)
if (request.params_size() < 1) {
set_response_code(
response, -1,
"graph_get_node_feat request requires at least 1 argument");
return 0;
}
size_t node_num = request.params(0).size() / sizeof(uint64_t);
uint64_t *node_data = (uint64_t *)(request.params(0).c_str());
std::vector<uint64_t> node_ids(node_data, node_data + node_num);
((GraphTable *)table)->remove_graph_node(node_ids);
return 0;
}
int32_t GraphBrpcServer::port() { return _server.listen_address().port; }
int32_t GraphBrpcService::initialize() {
......@@ -92,21 +152,17 @@ int32_t GraphBrpcService::initialize() {
&GraphBrpcService::graph_random_sample_nodes;
_service_handler_map[PS_GRAPH_GET_NODE_FEAT] =
&GraphBrpcService::graph_get_node_feat;
_service_handler_map[PS_GRAPH_CLEAR] = &GraphBrpcService::clear_nodes;
_service_handler_map[PS_GRAPH_ADD_GRAPH_NODE] =
&GraphBrpcService::add_graph_node;
_service_handler_map[PS_GRAPH_REMOVE_GRAPH_NODE] =
&GraphBrpcService::remove_graph_node;
// shard初始化,server启动后才可从env获取到server_list的shard信息
initialize_shard_info();
return 0;
}
#define CHECK_TABLE_EXIST(table, request, response) \
if (table == NULL) { \
std::string err_msg("table not found with table_id:"); \
err_msg.append(std::to_string(request.table_id())); \
set_response_code(response, -1, err_msg.c_str()); \
return -1; \
}
int32_t GraphBrpcService::initialize_shard_info() {
if (!_is_initialize_shard_info) {
std::lock_guard<std::mutex> guard(_initialize_shard_mutex);
......
......@@ -86,6 +86,13 @@ class GraphBrpcService : public PsBaseService {
int32_t graph_get_node_feat(Table *table, const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl);
int32_t clear_nodes(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl);
int32_t add_graph_node(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl);
int32_t remove_graph_node(Table *table, const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl);
int32_t barrier(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl);
int32_t load_one_table(Table *table, const PsRequestMessage &request,
......
......@@ -44,6 +44,9 @@ void GraphPyService::add_table_feat_conf(std::string table_name,
}
}
void add_graph_node(std::vector<uint64_t> node_ids,
std::vector<bool> weight_list) {}
void remove_graph_node(std::vector<uint64_t> node_ids) {}
void GraphPyService::set_up(std::string ips_str, int shard_num,
std::vector<std::string> node_types,
std::vector<std::string> edge_types) {
......@@ -247,6 +250,34 @@ void GraphPyClient::load_edge_file(std::string name, std::string filepath,
}
}
void GraphPyClient::clear_nodes(std::string name) {
if (this->table_id_map.count(name)) {
uint32_t table_id = this->table_id_map[name];
auto status = get_ps_client()->clear_nodes(table_id);
status.wait();
}
}
void GraphPyClient::add_graph_node(std::string name,
std::vector<uint64_t>& node_ids,
std::vector<bool>& weight_list) {
if (this->table_id_map.count(name)) {
uint32_t table_id = this->table_id_map[name];
auto status =
get_ps_client()->add_graph_node(table_id, node_ids, weight_list);
status.wait();
}
}
void GraphPyClient::remove_graph_node(std::string name,
std::vector<uint64_t>& node_ids) {
if (this->table_id_map.count(name)) {
uint32_t table_id = this->table_id_map[name];
auto status = get_ps_client()->remove_graph_node(table_id, node_ids);
status.wait();
}
}
void GraphPyClient::load_node_file(std::string name, std::string filepath) {
// 'n' means load nodes and 'node_type' follows
std::string params = "n" + name;
......
......@@ -141,6 +141,10 @@ class GraphPyClient : public GraphPyService {
void finalize_worker();
void load_edge_file(std::string name, std::string filepath, bool reverse);
void load_node_file(std::string name, std::string filepath);
void clear_nodes(std::string name);
void add_graph_node(std::string name, std::vector<uint64_t>& node_ids,
std::vector<bool>& weight_list);
void remove_graph_node(std::string name, std::vector<uint64_t>& node_ids);
int get_client_id() { return client_id; }
void set_client_id(int client_id) { this->client_id = client_id; }
void start_client();
......
......@@ -52,6 +52,9 @@ enum PsCmdID {
PS_GRAPH_SAMPLE_NEIGHBOORS = 31;
PS_GRAPH_SAMPLE_NODES = 32;
PS_GRAPH_GET_NODE_FEAT = 33;
PS_GRAPH_CLEAR = 34;
PS_GRAPH_ADD_GRAPH_NODE = 35;
PS_GRAPH_REMOVE_GRAPH_NODE = 36;
}
message PsRequestMessage {
......
......@@ -35,6 +35,77 @@ std::vector<Node *> GraphShard::get_batch(int start, int end, int step) {
size_t GraphShard::get_size() { return bucket.size(); }
int32_t GraphTable::add_graph_node(std::vector<uint64_t> &id_list,
std::vector<bool> &is_weight_list) {
size_t node_size = id_list.size();
std::vector<std::vector<std::pair<uint64_t, bool>>> batch(task_pool_size_);
for (size_t i = 0; i < node_size; i++) {
size_t shard_id = id_list[i] % shard_num;
if (shard_id >= shard_end || shard_id < shard_start) {
continue;
}
batch[get_thread_pool_index(id_list[i])].push_back(
{id_list[i], i < is_weight_list.size() ? is_weight_list[i] : false});
}
std::vector<std::future<int>> tasks;
for (size_t i = 0; i < batch.size(); ++i) {
if (!batch[i].size()) continue;
tasks.push_back(_shards_task_pool[i]->enqueue([&batch, i, this]() -> int {
for (auto &p : batch[i]) {
size_t index = p.first % this->shard_num - this->shard_start;
this->shards[index].add_graph_node(p.first)->build_edges(p.second);
}
return 0;
}));
}
for (size_t i = 0; i < tasks.size(); i++) tasks[i].get();
return 0;
}
int32_t GraphTable::remove_graph_node(std::vector<uint64_t> &id_list) {
size_t node_size = id_list.size();
std::vector<std::vector<uint64_t>> batch(task_pool_size_);
for (size_t i = 0; i < node_size; i++) {
size_t shard_id = id_list[i] % shard_num;
if (shard_id >= shard_end || shard_id < shard_start) continue;
batch[get_thread_pool_index(id_list[i])].push_back(id_list[i]);
}
std::vector<std::future<int>> tasks;
for (size_t i = 0; i < batch.size(); ++i) {
if (!batch[i].size()) continue;
tasks.push_back(_shards_task_pool[i]->enqueue([&batch, i, this]() -> int {
for (auto &p : batch[i]) {
size_t index = p % this->shard_num - this->shard_start;
this->shards[index].delete_node(p);
}
return 0;
}));
}
for (size_t i = 0; i < tasks.size(); i++) tasks[i].get();
return 0;
}
void GraphShard::clear() {
for (size_t i = 0; i < bucket.size(); i++) {
delete bucket[i];
}
bucket.clear();
node_location.clear();
}
GraphShard::~GraphShard() { clear(); }
void GraphShard::delete_node(uint64_t id) {
auto iter = node_location.find(id);
if (iter == node_location.end()) return;
int pos = iter->second;
delete bucket[pos];
if (pos != (int)bucket.size() - 1) {
bucket[pos] = bucket.back();
node_location[bucket.back()->get_id()] = pos;
}
node_location.erase(id);
bucket.pop_back();
}
GraphNode *GraphShard::add_graph_node(uint64_t id) {
if (node_location.find(id) == node_location.end()) {
node_location[id] = bucket.size();
......@@ -79,11 +150,7 @@ int32_t GraphTable::get_nodes_ids_by_ranges(
int start = 0, end, index = 0, total_size = 0;
res.clear();
std::vector<std::future<std::vector<uint64_t>>> tasks;
// std::string temp = "";
// for(int i = 0;i < shards.size();i++)
// temp+= std::to_string((int)shards[i].get_size()) + " ";
// VLOG(0)<<"range distribution "<<temp;
for (int i = 0; i < shards.size() && index < ranges.size(); i++) {
for (size_t i = 0; i < shards.size() && index < (int)ranges.size(); i++) {
end = total_size + shards[i].get_size();
start = total_size;
while (start < end && index < ranges.size()) {
......@@ -97,7 +164,6 @@ int32_t GraphTable::get_nodes_ids_by_ranges(
start = second;
first -= total_size;
second -= total_size;
// VLOG(0)<<" FIND RANGE "<<i<<" "<<first<<" "<<second;
tasks.push_back(_shards_task_pool[i % task_pool_size_]->enqueue(
[this, first, second, i]() -> std::vector<uint64_t> {
return shards[i].get_ids_by_range(first, second);
......@@ -106,7 +172,7 @@ int32_t GraphTable::get_nodes_ids_by_ranges(
}
total_size += shards[i].get_size();
}
for (int i = 0; i < tasks.size(); i++) {
for (size_t i = 0; i < tasks.size(); i++) {
auto vec = tasks[i].get();
for (auto &id : vec) {
res.push_back(id);
......@@ -219,7 +285,7 @@ int32_t GraphTable::load_edges(const std::string &path, bool reverse_edge) {
for (auto &shard : shards) {
auto bucket = shard.get_bucket();
for (int i = 0; i < bucket.size(); i++) {
for (size_t i = 0; i < bucket.size(); i++) {
bucket[i]->build_sampler(sample_type);
}
}
......@@ -238,10 +304,29 @@ Node *GraphTable::find_node(uint64_t id) {
uint32_t GraphTable::get_thread_pool_index(uint64_t node_id) {
return node_id % shard_num % shard_num_per_table % task_pool_size_;
}
uint32_t GraphTable::get_thread_pool_index_by_shard_index(
uint64_t shard_index) {
return shard_index % shard_num_per_table % task_pool_size_;
}
int32_t GraphTable::clear_nodes() {
std::vector<std::future<int>> tasks;
for (size_t i = 0; i < shards.size(); i++) {
tasks.push_back(
_shards_task_pool[get_thread_pool_index_by_shard_index(i)]->enqueue(
[this, i]() -> int {
this->shards[i].clear();
return 0;
}));
}
for (size_t i = 0; i < tasks.size(); i++) tasks[i].get();
return 0;
}
int32_t GraphTable::random_sample_nodes(int sample_size,
std::unique_ptr<char[]> &buffer,
int &actual_size) {
bool need_feature = false;
int total_size = 0;
for (int i = 0; i < shards.size(); i++) {
total_size += shards[i].get_size();
......@@ -281,7 +366,7 @@ int32_t GraphTable::random_sample_nodes(int sample_size,
}
std::vector<std::pair<int, int>> first_half, second_half;
int start_index = rand() % total_size;
for (int i = 0; i < ranges_len.size() && i < ranges_pos.size(); i++) {
for (size_t i = 0; i < ranges_len.size() && i < ranges_pos.size(); i++) {
if (ranges_pos[i] + ranges_len[i] - 1 + start_index < total_size)
first_half.push_back({ranges_pos[i] + start_index,
ranges_pos[i] + ranges_len[i] + start_index});
......@@ -386,7 +471,6 @@ std::pair<int32_t, std::string> GraphTable::parse_feature(
if (this->feat_id_map.count(fields[0])) {
int32_t id = this->feat_id_map[fields[0]];
std::string dtype = this->feat_dtype[id];
int32_t shape = this->feat_shape[id];
std::vector<std::string> values(fields.begin() + 1, fields.end());
if (dtype == "feasign") {
return std::make_pair<int32_t, std::string>(
......
......@@ -36,11 +36,12 @@ class GraphShard {
size_t get_size();
GraphShard() {}
GraphShard(int shard_num) { this->shard_num = shard_num; }
~GraphShard();
std::vector<Node *> &get_bucket() { return bucket; }
std::vector<Node *> get_batch(int start, int end, int step);
std::vector<uint64_t> get_ids_by_range(int start, int end) {
std::vector<uint64_t> res;
for (int i = start; i < end && i < bucket.size(); i++) {
for (int i = start; i < end && i < (int)bucket.size(); i++) {
res.push_back(bucket[i]->get_id());
}
return res;
......@@ -48,6 +49,8 @@ class GraphShard {
GraphNode *add_graph_node(uint64_t id);
FeatureNode *add_feature_node(uint64_t id);
Node *find_node(uint64_t id);
void delete_node(uint64_t id);
void clear();
void add_neighboor(uint64_t id, uint64_t dst_id, float weight);
std::unordered_map<uint64_t, int> get_node_location() {
return node_location;
......@@ -85,6 +88,11 @@ class GraphTable : public SparseTable {
int32_t load_nodes(const std::string &path, std::string node_type);
int32_t add_graph_node(std::vector<uint64_t> &id_list,
std::vector<bool> &is_weight_list);
int32_t remove_graph_node(std::vector<uint64_t> &id_list);
Node *find_node(uint64_t id);
virtual int32_t pull_sparse(float *values,
......@@ -97,6 +105,7 @@ class GraphTable : public SparseTable {
return 0;
}
virtual int32_t clear_nodes();
virtual void clear() {}
virtual int32_t flush() { return 0; }
virtual int32_t shrink(const std::string &param) { return 0; }
......@@ -105,6 +114,7 @@ class GraphTable : public SparseTable {
return 0;
}
virtual int32_t initialize_shard() { return 0; }
virtual uint32_t get_thread_pool_index_by_shard_index(uint64_t shard_index);
virtual uint32_t get_thread_pool_index(uint64_t node_id);
virtual std::pair<int32_t, std::string> parse_feature(std::string feat_str);
......@@ -128,4 +138,5 @@ class GraphTable : public SparseTable {
std::vector<std::shared_ptr<::ThreadPool>> _shards_task_pool;
};
} // namespace distributed
}; // namespace paddle
......@@ -124,7 +124,6 @@ void testSingleSampleNeighboor(
for (auto g : s) {
ASSERT_EQ(true, s1.find(g) != s1.end());
}
VLOG(0) << "test single done";
s.clear();
s1.clear();
vs.clear();
......@@ -141,6 +140,57 @@ void testSingleSampleNeighboor(
}
}
void testAddNode(
std::shared_ptr<paddle::distributed::GraphBrpcClient>& worker_ptr_) {
worker_ptr_->clear_nodes(0);
int total_num = 270000;
uint64_t id;
std::unordered_set<uint64_t> id_set;
for (int i = 0; i < total_num; i++) {
while (id_set.find(id = rand()) != id_set.end())
;
id_set.insert(id);
}
std::vector<uint64_t> id_list(id_set.begin(), id_set.end());
std::vector<bool> weight_list;
auto status = worker_ptr_->add_graph_node(0, id_list, weight_list);
status.wait();
std::vector<uint64_t> ids[2];
for (int i = 0; i < 2; i++) {
auto sample_status =
worker_ptr_->random_sample_nodes(0, i, total_num, ids[i]);
sample_status.wait();
}
std::unordered_set<uint64_t> id_set_check(ids[0].begin(), ids[0].end());
for (auto x : ids[1]) id_set_check.insert(x);
ASSERT_EQ(id_set.size(), id_set_check.size());
for (auto x : id_set) {
ASSERT_EQ(id_set_check.find(x) != id_set_check.end(), true);
}
std::vector<uint64_t> remove_ids;
for (auto p : id_set_check) {
if (remove_ids.size() == 0)
remove_ids.push_back(p);
else if (remove_ids.size() < total_num / 2 && rand() % 2 == 1) {
remove_ids.push_back(p);
}
}
for (auto p : remove_ids) id_set_check.erase(p);
status = worker_ptr_->remove_graph_node(0, remove_ids);
status.wait();
for (int i = 0; i < 2; i++) ids[i].clear();
for (int i = 0; i < 2; i++) {
auto sample_status =
worker_ptr_->random_sample_nodes(0, i, total_num, ids[i]);
sample_status.wait();
}
std::unordered_set<uint64_t> id_set_check1(ids[0].begin(), ids[0].end());
for (auto x : ids[1]) id_set_check1.insert(x);
ASSERT_EQ(id_set_check1.size(), id_set_check.size());
for (auto x : id_set_check1) {
ASSERT_EQ(id_set_check.find(x) != id_set_check.end(), true);
}
}
void testBatchSampleNeighboor(
std::shared_ptr<paddle::distributed::GraphBrpcClient>& worker_ptr_) {
std::vector<std::vector<std::pair<uint64_t, float>>> vs;
......@@ -527,6 +577,7 @@ void RunBrpcPushSparse() {
std::remove(edge_file_name);
std::remove(node_file_name);
testAddNode(worker_ptr_);
LOG(INFO) << "Run stop_server";
worker_ptr_->stop_server();
LOG(INFO) << "Run finalize_worker";
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册