未验证 提交 5eb640c6 编写于 作者: S seemingwang 提交者: GitHub

Graph engine4 (#36587)

上级 d64f7b3b
......@@ -304,7 +304,63 @@ std::future<int32_t> GraphBrpcClient::remove_graph_node(
// char* &buffer,int &actual_size
std::future<int32_t> GraphBrpcClient::batch_sample_neighboors(
uint32_t table_id, std::vector<uint64_t> node_ids, int sample_size,
std::vector<std::vector<std::pair<uint64_t, float>>> &res) {
std::vector<std::vector<std::pair<uint64_t, float>>> &res,
int server_index) {
if (server_index != -1) {
res.resize(node_ids.size());
DownpourBrpcClosure *closure = new DownpourBrpcClosure(1, [&](void *done) {
int ret = 0;
auto *closure = (DownpourBrpcClosure *)done;
if (closure->check_response(0, PS_GRAPH_SAMPLE_NODES_FROM_ONE_SERVER) !=
0) {
ret = -1;
} else {
auto &res_io_buffer = closure->cntl(0)->response_attachment();
butil::IOBufBytesIterator io_buffer_itr(res_io_buffer);
size_t bytes_size = io_buffer_itr.bytes_left();
std::unique_ptr<char[]> buffer_wrapper(new char[bytes_size]);
char *buffer = buffer_wrapper.get();
io_buffer_itr.copy_and_forward((void *)(buffer), bytes_size);
size_t node_num = *(size_t *)buffer;
int *actual_sizes = (int *)(buffer + sizeof(size_t));
char *node_buffer = buffer + sizeof(size_t) + sizeof(int) * node_num;
int offset = 0;
for (size_t node_idx = 0; node_idx < node_num; ++node_idx) {
int actual_size = actual_sizes[node_idx];
int start = 0;
while (start < actual_size) {
res[node_idx].push_back(
{*(uint64_t *)(node_buffer + offset + start),
*(float *)(node_buffer + offset + start +
GraphNode::id_size)});
start += GraphNode::id_size + GraphNode::weight_size;
}
offset += actual_size;
}
}
closure->set_promise_value(ret);
});
auto promise = std::make_shared<std::promise<int32_t>>();
closure->add_promise(promise);
std::future<int> fut = promise->get_future();
;
closure->request(0)->set_cmd_id(PS_GRAPH_SAMPLE_NODES_FROM_ONE_SERVER);
closure->request(0)->set_table_id(table_id);
closure->request(0)->set_client_id(_client_id);
closure->request(0)->add_params((char *)node_ids.data(),
sizeof(uint64_t) * node_ids.size());
closure->request(0)->add_params((char *)&sample_size, sizeof(int));
;
// PsService_Stub rpc_stub(get_cmd_channel(server_index));
GraphPsService_Stub rpc_stub =
getServiceStub(get_cmd_channel(server_index));
closure->cntl(0)->set_log_id(butil::gettimeofday_ms());
rpc_stub.service(closure->cntl(0), closure->request(0),
closure->response(0), closure);
return fut;
}
std::vector<int> request2server;
std::vector<int> server2request(server_size, -1);
res.clear();
......
......@@ -64,7 +64,8 @@ class GraphBrpcClient : public BrpcPsClient {
// given a batch of nodes, sample graph_neighboors for each of them
virtual std::future<int32_t> batch_sample_neighboors(
uint32_t table_id, std::vector<uint64_t> node_ids, int sample_size,
std::vector<std::vector<std::pair<uint64_t, float>>>& res);
std::vector<std::vector<std::pair<uint64_t, float>>>& res,
int server_index = -1);
virtual std::future<int32_t> pull_graph_list(uint32_t table_id,
int server_index, int start,
......
......@@ -61,6 +61,10 @@ int32_t GraphBrpcServer::initialize() {
return 0;
}
brpc::Channel *GraphBrpcServer::get_cmd_channel(size_t server_index) {
return _pserver_channels[server_index].get();
}
uint64_t GraphBrpcServer::start(const std::string &ip, uint32_t port) {
std::unique_lock<std::mutex> lock(mutex_);
......@@ -80,6 +84,42 @@ uint64_t GraphBrpcServer::start(const std::string &ip, uint32_t port) {
return 0;
}
int32_t GraphBrpcServer::build_peer2peer_connection(int rank) {
this->rank = rank;
auto _env = environment();
brpc::ChannelOptions options;
options.protocol = "baidu_std";
options.timeout_ms = 500000;
options.connection_type = "pooled";
options.connect_timeout_ms = 10000;
options.max_retry = 3;
std::vector<PSHost> server_list = _env->get_ps_servers();
_pserver_channels.resize(server_list.size());
std::ostringstream os;
std::string server_ip_port;
for (size_t i = 0; i < server_list.size(); ++i) {
server_ip_port.assign(server_list[i].ip.c_str());
server_ip_port.append(":");
server_ip_port.append(std::to_string(server_list[i].port));
_pserver_channels[i].reset(new brpc::Channel());
if (_pserver_channels[i]->Init(server_ip_port.c_str(), "", &options) != 0) {
VLOG(0) << "GraphServer connect to Server:" << server_ip_port
<< " Failed! Try again.";
std::string int_ip_port =
GetIntTypeEndpoint(server_list[i].ip, server_list[i].port);
if (_pserver_channels[i]->Init(int_ip_port.c_str(), "", &options) != 0) {
LOG(ERROR) << "GraphServer connect to Server:" << int_ip_port
<< " Failed!";
return -1;
}
}
os << server_ip_port << ",";
}
LOG(INFO) << "servers peer2peer connection success:" << os.str();
return 0;
}
int32_t GraphBrpcService::clear_nodes(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
......@@ -160,6 +200,9 @@ int32_t GraphBrpcService::initialize() {
&GraphBrpcService::remove_graph_node;
_service_handler_map[PS_GRAPH_SET_NODE_FEAT] =
&GraphBrpcService::graph_set_node_feat;
_service_handler_map[PS_GRAPH_SAMPLE_NODES_FROM_ONE_SERVER] =
&GraphBrpcService::sample_neighboors_across_multi_servers;
// shard初始化,server启动后才可从env获取到server_list的shard信息
initialize_shard_info();
......@@ -172,10 +215,10 @@ int32_t GraphBrpcService::initialize_shard_info() {
if (_is_initialize_shard_info) {
return 0;
}
size_t shard_num = _server->environment()->get_ps_servers().size();
server_size = _server->environment()->get_ps_servers().size();
auto &table_map = *(_server->table());
for (auto itr : table_map) {
itr.second->set_shard(_rank, shard_num);
itr.second->set_shard(_rank, server_size);
}
_is_initialize_shard_info = true;
}
......@@ -209,8 +252,10 @@ void GraphBrpcService::service(google::protobuf::RpcController *cntl_base,
int service_ret = (this->*handler_func)(table, *request, *response, cntl);
if (service_ret != 0) {
response->set_err_code(service_ret);
if (!response->has_err_msg()) {
response->set_err_msg("server internal error");
}
}
}
int32_t GraphBrpcService::barrier(Table *table, const PsRequestMessage &request,
......@@ -403,7 +448,156 @@ int32_t GraphBrpcService::graph_get_node_feat(Table *table,
return 0;
}
int32_t GraphBrpcService::sample_neighboors_across_multi_servers(
Table *table, const PsRequestMessage &request, PsResponseMessage &response,
brpc::Controller *cntl) {
// sleep(5);
CHECK_TABLE_EXIST(table, request, response)
if (request.params_size() < 2) {
set_response_code(
response, -1,
"graph_random_sample request requires at least 2 arguments");
return 0;
}
size_t node_num = request.params(0).size() / sizeof(uint64_t),
size_of_size_t = sizeof(size_t);
uint64_t *node_data = (uint64_t *)(request.params(0).c_str());
int sample_size = *(uint64_t *)(request.params(1).c_str());
// std::vector<uint64_t> res = ((GraphTable
// *)table).filter_out_non_exist_nodes(node_data, sample_size);
std::vector<int> request2server;
std::vector<int> server2request(server_size, -1);
std::vector<uint64_t> local_id;
std::vector<int> local_query_idx;
size_t rank = get_rank();
for (int query_idx = 0; query_idx < node_num; ++query_idx) {
int server_index =
((GraphTable *)table)->get_server_index_by_id(node_data[query_idx]);
if (server2request[server_index] == -1) {
server2request[server_index] = request2server.size();
request2server.push_back(server_index);
}
}
if (server2request[rank] != -1) {
auto pos = server2request[rank];
std::swap(request2server[pos],
request2server[(int)request2server.size() - 1]);
server2request[request2server[pos]] = pos;
server2request[request2server[(int)request2server.size() - 1]] =
request2server.size() - 1;
}
size_t request_call_num = request2server.size();
std::vector<std::unique_ptr<char[]>> local_buffers;
std::vector<int> local_actual_sizes;
std::vector<size_t> seq;
std::vector<std::vector<uint64_t>> node_id_buckets(request_call_num);
std::vector<std::vector<int>> query_idx_buckets(request_call_num);
for (int query_idx = 0; query_idx < node_num; ++query_idx) {
int server_index =
((GraphTable *)table)->get_server_index_by_id(node_data[query_idx]);
int request_idx = server2request[server_index];
node_id_buckets[request_idx].push_back(node_data[query_idx]);
query_idx_buckets[request_idx].push_back(query_idx);
seq.push_back(request_idx);
}
size_t remote_call_num = request_call_num;
if (request2server.size() != 0 && request2server.back() == rank) {
remote_call_num--;
local_buffers.resize(node_id_buckets.back().size());
local_actual_sizes.resize(node_id_buckets.back().size());
}
cntl->response_attachment().append(&node_num, sizeof(size_t));
auto local_promise = std::make_shared<std::promise<int32_t>>();
std::future<int> local_fut = local_promise->get_future();
std::vector<bool> failed(server_size, false);
std::function<void(void *)> func = [&, node_id_buckets, query_idx_buckets,
request_call_num](void *done) {
local_fut.get();
std::vector<int> actual_size;
auto *closure = (DownpourBrpcClosure *)done;
std::vector<std::unique_ptr<butil::IOBufBytesIterator>> res(
remote_call_num);
size_t fail_num = 0;
for (size_t request_idx = 0; request_idx < remote_call_num; ++request_idx) {
if (closure->check_response(request_idx, PS_GRAPH_SAMPLE_NEIGHBOORS) !=
0) {
++fail_num;
failed[request2server[request_idx]] = true;
} else {
auto &res_io_buffer = closure->cntl(request_idx)->response_attachment();
size_t node_size;
res[request_idx].reset(new butil::IOBufBytesIterator(res_io_buffer));
size_t num;
res[request_idx]->copy_and_forward(&num, sizeof(size_t));
}
}
int size;
int local_index = 0;
for (size_t i = 0; i < node_num; i++) {
if (fail_num > 0 && failed[seq[i]]) {
size = 0;
} else if (request2server[seq[i]] != rank) {
res[seq[i]]->copy_and_forward(&size, sizeof(int));
} else {
size = local_actual_sizes[local_index++];
}
actual_size.push_back(size);
}
cntl->response_attachment().append(actual_size.data(),
actual_size.size() * sizeof(int));
local_index = 0;
for (size_t i = 0; i < node_num; i++) {
if (fail_num > 0 && failed[seq[i]]) {
continue;
} else if (request2server[seq[i]] != rank) {
char temp[actual_size[i] + 1];
res[seq[i]]->copy_and_forward(temp, actual_size[i]);
cntl->response_attachment().append(temp, actual_size[i]);
} else {
char *temp = local_buffers[local_index++].get();
cntl->response_attachment().append(temp, actual_size[i]);
}
}
closure->set_promise_value(0);
};
DownpourBrpcClosure *closure = new DownpourBrpcClosure(remote_call_num, func);
auto promise = std::make_shared<std::promise<int32_t>>();
closure->add_promise(promise);
std::future<int> fut = promise->get_future();
for (int request_idx = 0; request_idx < remote_call_num; ++request_idx) {
int server_index = request2server[request_idx];
closure->request(request_idx)->set_cmd_id(PS_GRAPH_SAMPLE_NEIGHBOORS);
closure->request(request_idx)->set_table_id(request.table_id());
closure->request(request_idx)->set_client_id(rank);
size_t node_num = node_id_buckets[request_idx].size();
closure->request(request_idx)
->add_params((char *)node_id_buckets[request_idx].data(),
sizeof(uint64_t) * node_num);
closure->request(request_idx)
->add_params((char *)&sample_size, sizeof(int));
PsService_Stub rpc_stub(
((GraphBrpcServer *)get_server())->get_cmd_channel(server_index));
// GraphPsService_Stub rpc_stub =
// getServiceStub(get_cmd_channel(server_index));
closure->cntl(request_idx)->set_log_id(butil::gettimeofday_ms());
rpc_stub.service(closure->cntl(request_idx), closure->request(request_idx),
closure->response(request_idx), closure);
}
if (server2request[rank] != -1) {
((GraphTable *)table)
->random_sample_neighboors(node_id_buckets.back().data(), sample_size,
local_buffers, local_actual_sizes);
}
local_promise.get()->set_value(0);
if (remote_call_num == 0) func(closure);
fut.get();
return 0;
}
int32_t GraphBrpcService::graph_set_node_feat(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
......@@ -412,7 +606,7 @@ int32_t GraphBrpcService::graph_set_node_feat(Table *table,
if (request.params_size() < 3) {
set_response_code(
response, -1,
"graph_set_node_feat request requires at least 2 arguments");
"graph_set_node_feat request requires at least 3 arguments");
return 0;
}
size_t node_num = request.params(0).size() / sizeof(uint64_t);
......
......@@ -32,6 +32,8 @@ class GraphBrpcServer : public PSServer {
virtual ~GraphBrpcServer() {}
PsBaseService *get_service() { return _service.get(); }
virtual uint64_t start(const std::string &ip, uint32_t port);
virtual int32_t build_peer2peer_connection(int rank);
virtual brpc::Channel *get_cmd_channel(size_t server_index);
virtual int32_t stop() {
std::unique_lock<std::mutex> lock(mutex_);
if (stoped_) return 0;
......@@ -50,6 +52,7 @@ class GraphBrpcServer : public PSServer {
mutable std::mutex mutex_;
std::condition_variable cv_;
bool stoped_ = false;
int rank;
brpc::Server _server;
std::shared_ptr<PsBaseService> _service;
std::vector<std::shared_ptr<brpc::Channel>> _pserver_channels;
......@@ -113,12 +116,18 @@ class GraphBrpcService : public PsBaseService {
int32_t print_table_stat(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl);
int32_t sample_neighboors_across_multi_servers(
Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl);
private:
bool _is_initialize_shard_info;
std::mutex _initialize_shard_mutex;
std::unordered_map<int32_t, serviceHandlerFunc> _msg_handler_map;
std::vector<float> _ori_values;
const int sample_nodes_ranges = 23;
size_t server_size;
std::shared_ptr<::ThreadPool> task_pool;
};
} // namespace distributed
......
......@@ -107,6 +107,7 @@ void GraphPyServer::start_server(bool block) {
empty_vec.push_back(empty_prog);
pserver_ptr->configure(server_proto, _ps_env, rank, empty_vec);
pserver_ptr->start(ip, port);
pserver_ptr->build_peer2peer_connection(rank);
std::condition_variable* cv_ = pserver_ptr->export_cv();
if (block) {
std::mutex mutex_;
......
......@@ -56,6 +56,7 @@ enum PsCmdID {
PS_GRAPH_ADD_GRAPH_NODE = 35;
PS_GRAPH_REMOVE_GRAPH_NODE = 36;
PS_GRAPH_SET_NODE_FEAT = 37;
PS_GRAPH_SAMPLE_NODES_FROM_ONE_SERVER = 38;
}
message PsRequestMessage {
......
......@@ -147,7 +147,7 @@ class PsBaseService : public PsService {
public:
PsBaseService() : _rank(0), _server(NULL), _config(NULL) {}
virtual ~PsBaseService() {}
virtual size_t get_rank() { return _rank; }
virtual int32_t configure(PSServer *server) {
_server = server;
_rank = _server->rank();
......@@ -167,6 +167,7 @@ class PsBaseService : public PsService {
}
virtual int32_t initialize() = 0;
PSServer *get_server() { return _server; }
protected:
size_t _rank;
......
......@@ -305,12 +305,12 @@ Node *GraphTable::find_node(uint64_t id) {
return node;
}
uint32_t GraphTable::get_thread_pool_index(uint64_t node_id) {
return node_id % shard_num % shard_num_per_table % task_pool_size_;
return node_id % shard_num % shard_num_per_server % task_pool_size_;
}
uint32_t GraphTable::get_thread_pool_index_by_shard_index(
uint64_t shard_index) {
return shard_index % shard_num_per_table % task_pool_size_;
return shard_index % shard_num_per_server % task_pool_size_;
}
int32_t GraphTable::clear_nodes() {
......@@ -575,6 +575,11 @@ int32_t GraphTable::pull_graph_list(int start, int total_size,
actual_size = size;
return 0;
}
int32_t GraphTable::get_server_index_by_id(uint64_t id) {
return id % shard_num / shard_num_per_server;
}
int32_t GraphTable::initialize() {
_shards_task_pool.resize(task_pool_size_);
for (size_t i = 0; i < _shards_task_pool.size(); ++i) {
......@@ -611,13 +616,12 @@ int32_t GraphTable::initialize() {
shard_num = _config.shard_num();
VLOG(0) << "in init graph table shard num = " << shard_num << " shard_idx"
<< _shard_idx;
shard_num_per_table = sparse_local_shard_num(shard_num, server_num);
shard_start = _shard_idx * shard_num_per_table;
shard_end = shard_start + shard_num_per_table;
shard_num_per_server = sparse_local_shard_num(shard_num, server_num);
shard_start = _shard_idx * shard_num_per_server;
shard_end = shard_start + shard_num_per_server;
VLOG(0) << "in init graph table shard idx = " << _shard_idx << " shard_start "
<< shard_start << " shard_end " << shard_end;
// shards.resize(shard_num_per_table);
shards = std::vector<GraphShard>(shard_num_per_table, GraphShard(shard_num));
shards = std::vector<GraphShard>(shard_num_per_server, GraphShard(shard_num));
return 0;
}
} // namespace distributed
......
......@@ -94,6 +94,7 @@ class GraphTable : public SparseTable {
int32_t remove_graph_node(std::vector<uint64_t> &id_list);
int32_t get_server_index_by_id(uint64_t id);
Node *find_node(uint64_t id);
virtual int32_t pull_sparse(float *values,
......@@ -128,9 +129,11 @@ class GraphTable : public SparseTable {
const std::vector<std::string> &feature_names,
const std::vector<std::vector<std::string>> &res);
size_t get_server_num() { return server_num; }
protected:
std::vector<GraphShard> shards;
size_t shard_start, shard_end, server_num, shard_num_per_table, shard_num;
size_t shard_start, shard_end, server_num, shard_num_per_server, shard_num;
const int task_pool_size_ = 24;
const int random_sample_nodes_ranges = 3;
......
......@@ -138,6 +138,10 @@ void testSingleSampleNeighboor(
for (auto g : s) {
ASSERT_EQ(true, s1.find(g) != s1.end());
}
vs.clear();
pull_status = worker_ptr_->batch_sample_neighboors(0, {96, 37}, 4, vs, 0);
pull_status.wait();
ASSERT_EQ(vs.size(), 2);
}
void testAddNode(
......@@ -356,6 +360,7 @@ void RunServer() {
pserver_ptr_->configure(server_proto, _ps_env, 0, empty_vec);
LOG(INFO) << "first server, run start(ip,port)";
pserver_ptr_->start(ip_, port_);
pserver_ptr_->build_peer2peer_connection(0);
LOG(INFO) << "init first server Done";
}
......@@ -373,6 +378,7 @@ void RunServer2() {
empty_vec2.push_back(empty_prog2);
pserver_ptr2->configure(server_proto2, _ps_env2, 1, empty_vec2);
pserver_ptr2->start(ip2, port2);
pserver_ptr2->build_peer2peer_connection(1);
}
void RunClient(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册