diff --git a/paddle/fluid/distributed/service/graph_brpc_client.cc b/paddle/fluid/distributed/service/graph_brpc_client.cc index 68d9c9669b6972f5042c86273db936995bec6a9e..9f65a66708def030f7dfd9a9ac668c79ad60744f 100644 --- a/paddle/fluid/distributed/service/graph_brpc_client.cc +++ b/paddle/fluid/distributed/service/graph_brpc_client.cc @@ -304,7 +304,63 @@ std::future GraphBrpcClient::remove_graph_node( // char* &buffer,int &actual_size std::future GraphBrpcClient::batch_sample_neighboors( uint32_t table_id, std::vector node_ids, int sample_size, - std::vector>> &res) { + std::vector>> &res, + int server_index) { + if (server_index != -1) { + res.resize(node_ids.size()); + DownpourBrpcClosure *closure = new DownpourBrpcClosure(1, [&](void *done) { + int ret = 0; + auto *closure = (DownpourBrpcClosure *)done; + if (closure->check_response(0, PS_GRAPH_SAMPLE_NODES_FROM_ONE_SERVER) != + 0) { + ret = -1; + } else { + auto &res_io_buffer = closure->cntl(0)->response_attachment(); + butil::IOBufBytesIterator io_buffer_itr(res_io_buffer); + size_t bytes_size = io_buffer_itr.bytes_left(); + std::unique_ptr buffer_wrapper(new char[bytes_size]); + char *buffer = buffer_wrapper.get(); + io_buffer_itr.copy_and_forward((void *)(buffer), bytes_size); + + size_t node_num = *(size_t *)buffer; + int *actual_sizes = (int *)(buffer + sizeof(size_t)); + char *node_buffer = buffer + sizeof(size_t) + sizeof(int) * node_num; + + int offset = 0; + for (size_t node_idx = 0; node_idx < node_num; ++node_idx) { + int actual_size = actual_sizes[node_idx]; + int start = 0; + while (start < actual_size) { + res[node_idx].push_back( + {*(uint64_t *)(node_buffer + offset + start), + *(float *)(node_buffer + offset + start + + GraphNode::id_size)}); + start += GraphNode::id_size + GraphNode::weight_size; + } + offset += actual_size; + } + } + closure->set_promise_value(ret); + }); + auto promise = std::make_shared>(); + closure->add_promise(promise); + std::future fut = promise->get_future(); + ; + closure->request(0)->set_cmd_id(PS_GRAPH_SAMPLE_NODES_FROM_ONE_SERVER); + closure->request(0)->set_table_id(table_id); + closure->request(0)->set_client_id(_client_id); + closure->request(0)->add_params((char *)node_ids.data(), + sizeof(uint64_t) * node_ids.size()); + closure->request(0)->add_params((char *)&sample_size, sizeof(int)); + ; + // PsService_Stub rpc_stub(get_cmd_channel(server_index)); + GraphPsService_Stub rpc_stub = + getServiceStub(get_cmd_channel(server_index)); + closure->cntl(0)->set_log_id(butil::gettimeofday_ms()); + rpc_stub.service(closure->cntl(0), closure->request(0), + closure->response(0), closure); + return fut; + } std::vector request2server; std::vector server2request(server_size, -1); res.clear(); diff --git a/paddle/fluid/distributed/service/graph_brpc_client.h b/paddle/fluid/distributed/service/graph_brpc_client.h index 8acb2047b8e9724d4591a02e3e257cba5282f0cd..1fbb3fa9b0550e2e658c98b4ca41acd3e55a440a 100644 --- a/paddle/fluid/distributed/service/graph_brpc_client.h +++ b/paddle/fluid/distributed/service/graph_brpc_client.h @@ -64,7 +64,8 @@ class GraphBrpcClient : public BrpcPsClient { // given a batch of nodes, sample graph_neighboors for each of them virtual std::future batch_sample_neighboors( uint32_t table_id, std::vector node_ids, int sample_size, - std::vector>>& res); + std::vector>>& res, + int server_index = -1); virtual std::future pull_graph_list(uint32_t table_id, int server_index, int start, diff --git a/paddle/fluid/distributed/service/graph_brpc_server.cc b/paddle/fluid/distributed/service/graph_brpc_server.cc index 110d4406fc5569e6abc9eabe83228c2fdca17316..b404082f7c4102a70edb97965b381880cc36bc3c 100644 --- a/paddle/fluid/distributed/service/graph_brpc_server.cc +++ b/paddle/fluid/distributed/service/graph_brpc_server.cc @@ -61,6 +61,10 @@ int32_t GraphBrpcServer::initialize() { return 0; } +brpc::Channel *GraphBrpcServer::get_cmd_channel(size_t server_index) { + return _pserver_channels[server_index].get(); +} + uint64_t GraphBrpcServer::start(const std::string &ip, uint32_t port) { std::unique_lock lock(mutex_); @@ -80,6 +84,42 @@ uint64_t GraphBrpcServer::start(const std::string &ip, uint32_t port) { return 0; } +int32_t GraphBrpcServer::build_peer2peer_connection(int rank) { + this->rank = rank; + auto _env = environment(); + brpc::ChannelOptions options; + options.protocol = "baidu_std"; + options.timeout_ms = 500000; + options.connection_type = "pooled"; + options.connect_timeout_ms = 10000; + options.max_retry = 3; + + std::vector server_list = _env->get_ps_servers(); + _pserver_channels.resize(server_list.size()); + std::ostringstream os; + std::string server_ip_port; + for (size_t i = 0; i < server_list.size(); ++i) { + server_ip_port.assign(server_list[i].ip.c_str()); + server_ip_port.append(":"); + server_ip_port.append(std::to_string(server_list[i].port)); + _pserver_channels[i].reset(new brpc::Channel()); + if (_pserver_channels[i]->Init(server_ip_port.c_str(), "", &options) != 0) { + VLOG(0) << "GraphServer connect to Server:" << server_ip_port + << " Failed! Try again."; + std::string int_ip_port = + GetIntTypeEndpoint(server_list[i].ip, server_list[i].port); + if (_pserver_channels[i]->Init(int_ip_port.c_str(), "", &options) != 0) { + LOG(ERROR) << "GraphServer connect to Server:" << int_ip_port + << " Failed!"; + return -1; + } + } + os << server_ip_port << ","; + } + LOG(INFO) << "servers peer2peer connection success:" << os.str(); + return 0; +} + int32_t GraphBrpcService::clear_nodes(Table *table, const PsRequestMessage &request, PsResponseMessage &response, @@ -160,6 +200,9 @@ int32_t GraphBrpcService::initialize() { &GraphBrpcService::remove_graph_node; _service_handler_map[PS_GRAPH_SET_NODE_FEAT] = &GraphBrpcService::graph_set_node_feat; + _service_handler_map[PS_GRAPH_SAMPLE_NODES_FROM_ONE_SERVER] = + &GraphBrpcService::sample_neighboors_across_multi_servers; + // shard初始化,server启动后才可从env获取到server_list的shard信息 initialize_shard_info(); @@ -172,10 +215,10 @@ int32_t GraphBrpcService::initialize_shard_info() { if (_is_initialize_shard_info) { return 0; } - size_t shard_num = _server->environment()->get_ps_servers().size(); + server_size = _server->environment()->get_ps_servers().size(); auto &table_map = *(_server->table()); for (auto itr : table_map) { - itr.second->set_shard(_rank, shard_num); + itr.second->set_shard(_rank, server_size); } _is_initialize_shard_info = true; } @@ -209,7 +252,9 @@ void GraphBrpcService::service(google::protobuf::RpcController *cntl_base, int service_ret = (this->*handler_func)(table, *request, *response, cntl); if (service_ret != 0) { response->set_err_code(service_ret); - response->set_err_msg("server internal error"); + if (!response->has_err_msg()) { + response->set_err_msg("server internal error"); + } } } @@ -403,7 +448,156 @@ int32_t GraphBrpcService::graph_get_node_feat(Table *table, return 0; } - +int32_t GraphBrpcService::sample_neighboors_across_multi_servers( + Table *table, const PsRequestMessage &request, PsResponseMessage &response, + brpc::Controller *cntl) { + // sleep(5); + CHECK_TABLE_EXIST(table, request, response) + if (request.params_size() < 2) { + set_response_code( + response, -1, + "graph_random_sample request requires at least 2 arguments"); + return 0; + } + size_t node_num = request.params(0).size() / sizeof(uint64_t), + size_of_size_t = sizeof(size_t); + uint64_t *node_data = (uint64_t *)(request.params(0).c_str()); + int sample_size = *(uint64_t *)(request.params(1).c_str()); + // std::vector res = ((GraphTable + // *)table).filter_out_non_exist_nodes(node_data, sample_size); + std::vector request2server; + std::vector server2request(server_size, -1); + std::vector local_id; + std::vector local_query_idx; + size_t rank = get_rank(); + for (int query_idx = 0; query_idx < node_num; ++query_idx) { + int server_index = + ((GraphTable *)table)->get_server_index_by_id(node_data[query_idx]); + if (server2request[server_index] == -1) { + server2request[server_index] = request2server.size(); + request2server.push_back(server_index); + } + } + if (server2request[rank] != -1) { + auto pos = server2request[rank]; + std::swap(request2server[pos], + request2server[(int)request2server.size() - 1]); + server2request[request2server[pos]] = pos; + server2request[request2server[(int)request2server.size() - 1]] = + request2server.size() - 1; + } + size_t request_call_num = request2server.size(); + std::vector> local_buffers; + std::vector local_actual_sizes; + std::vector seq; + std::vector> node_id_buckets(request_call_num); + std::vector> query_idx_buckets(request_call_num); + for (int query_idx = 0; query_idx < node_num; ++query_idx) { + int server_index = + ((GraphTable *)table)->get_server_index_by_id(node_data[query_idx]); + int request_idx = server2request[server_index]; + node_id_buckets[request_idx].push_back(node_data[query_idx]); + query_idx_buckets[request_idx].push_back(query_idx); + seq.push_back(request_idx); + } + size_t remote_call_num = request_call_num; + if (request2server.size() != 0 && request2server.back() == rank) { + remote_call_num--; + local_buffers.resize(node_id_buckets.back().size()); + local_actual_sizes.resize(node_id_buckets.back().size()); + } + cntl->response_attachment().append(&node_num, sizeof(size_t)); + auto local_promise = std::make_shared>(); + std::future local_fut = local_promise->get_future(); + std::vector failed(server_size, false); + std::function func = [&, node_id_buckets, query_idx_buckets, + request_call_num](void *done) { + local_fut.get(); + std::vector actual_size; + auto *closure = (DownpourBrpcClosure *)done; + std::vector> res( + remote_call_num); + size_t fail_num = 0; + for (size_t request_idx = 0; request_idx < remote_call_num; ++request_idx) { + if (closure->check_response(request_idx, PS_GRAPH_SAMPLE_NEIGHBOORS) != + 0) { + ++fail_num; + failed[request2server[request_idx]] = true; + } else { + auto &res_io_buffer = closure->cntl(request_idx)->response_attachment(); + size_t node_size; + res[request_idx].reset(new butil::IOBufBytesIterator(res_io_buffer)); + size_t num; + res[request_idx]->copy_and_forward(&num, sizeof(size_t)); + } + } + int size; + int local_index = 0; + for (size_t i = 0; i < node_num; i++) { + if (fail_num > 0 && failed[seq[i]]) { + size = 0; + } else if (request2server[seq[i]] != rank) { + res[seq[i]]->copy_and_forward(&size, sizeof(int)); + } else { + size = local_actual_sizes[local_index++]; + } + actual_size.push_back(size); + } + cntl->response_attachment().append(actual_size.data(), + actual_size.size() * sizeof(int)); + + local_index = 0; + for (size_t i = 0; i < node_num; i++) { + if (fail_num > 0 && failed[seq[i]]) { + continue; + } else if (request2server[seq[i]] != rank) { + char temp[actual_size[i] + 1]; + res[seq[i]]->copy_and_forward(temp, actual_size[i]); + cntl->response_attachment().append(temp, actual_size[i]); + } else { + char *temp = local_buffers[local_index++].get(); + cntl->response_attachment().append(temp, actual_size[i]); + } + } + closure->set_promise_value(0); + }; + + DownpourBrpcClosure *closure = new DownpourBrpcClosure(remote_call_num, func); + + auto promise = std::make_shared>(); + closure->add_promise(promise); + std::future fut = promise->get_future(); + + for (int request_idx = 0; request_idx < remote_call_num; ++request_idx) { + int server_index = request2server[request_idx]; + closure->request(request_idx)->set_cmd_id(PS_GRAPH_SAMPLE_NEIGHBOORS); + closure->request(request_idx)->set_table_id(request.table_id()); + closure->request(request_idx)->set_client_id(rank); + size_t node_num = node_id_buckets[request_idx].size(); + + closure->request(request_idx) + ->add_params((char *)node_id_buckets[request_idx].data(), + sizeof(uint64_t) * node_num); + closure->request(request_idx) + ->add_params((char *)&sample_size, sizeof(int)); + PsService_Stub rpc_stub( + ((GraphBrpcServer *)get_server())->get_cmd_channel(server_index)); + // GraphPsService_Stub rpc_stub = + // getServiceStub(get_cmd_channel(server_index)); + closure->cntl(request_idx)->set_log_id(butil::gettimeofday_ms()); + rpc_stub.service(closure->cntl(request_idx), closure->request(request_idx), + closure->response(request_idx), closure); + } + if (server2request[rank] != -1) { + ((GraphTable *)table) + ->random_sample_neighboors(node_id_buckets.back().data(), sample_size, + local_buffers, local_actual_sizes); + } + local_promise.get()->set_value(0); + if (remote_call_num == 0) func(closure); + fut.get(); + return 0; +} int32_t GraphBrpcService::graph_set_node_feat(Table *table, const PsRequestMessage &request, PsResponseMessage &response, @@ -412,7 +606,7 @@ int32_t GraphBrpcService::graph_set_node_feat(Table *table, if (request.params_size() < 3) { set_response_code( response, -1, - "graph_set_node_feat request requires at least 2 arguments"); + "graph_set_node_feat request requires at least 3 arguments"); return 0; } size_t node_num = request.params(0).size() / sizeof(uint64_t); diff --git a/paddle/fluid/distributed/service/graph_brpc_server.h b/paddle/fluid/distributed/service/graph_brpc_server.h index 6b4853fa679923e39578d803272f1ddf978b632c..817fe08331165daf8953154f673eca14ea9d2e72 100644 --- a/paddle/fluid/distributed/service/graph_brpc_server.h +++ b/paddle/fluid/distributed/service/graph_brpc_server.h @@ -32,6 +32,8 @@ class GraphBrpcServer : public PSServer { virtual ~GraphBrpcServer() {} PsBaseService *get_service() { return _service.get(); } virtual uint64_t start(const std::string &ip, uint32_t port); + virtual int32_t build_peer2peer_connection(int rank); + virtual brpc::Channel *get_cmd_channel(size_t server_index); virtual int32_t stop() { std::unique_lock lock(mutex_); if (stoped_) return 0; @@ -50,6 +52,7 @@ class GraphBrpcServer : public PSServer { mutable std::mutex mutex_; std::condition_variable cv_; bool stoped_ = false; + int rank; brpc::Server _server; std::shared_ptr _service; std::vector> _pserver_channels; @@ -113,12 +116,18 @@ class GraphBrpcService : public PsBaseService { int32_t print_table_stat(Table *table, const PsRequestMessage &request, PsResponseMessage &response, brpc::Controller *cntl); + int32_t sample_neighboors_across_multi_servers( + Table *table, const PsRequestMessage &request, + PsResponseMessage &response, brpc::Controller *cntl); + private: bool _is_initialize_shard_info; std::mutex _initialize_shard_mutex; std::unordered_map _msg_handler_map; std::vector _ori_values; const int sample_nodes_ranges = 23; + size_t server_size; + std::shared_ptr<::ThreadPool> task_pool; }; } // namespace distributed diff --git a/paddle/fluid/distributed/service/graph_py_service.cc b/paddle/fluid/distributed/service/graph_py_service.cc index b41596270131744195ea29a18a12b1f8ba5f95f5..498805136417f2f6520b6952465c0d90204efb0a 100644 --- a/paddle/fluid/distributed/service/graph_py_service.cc +++ b/paddle/fluid/distributed/service/graph_py_service.cc @@ -107,6 +107,7 @@ void GraphPyServer::start_server(bool block) { empty_vec.push_back(empty_prog); pserver_ptr->configure(server_proto, _ps_env, rank, empty_vec); pserver_ptr->start(ip, port); + pserver_ptr->build_peer2peer_connection(rank); std::condition_variable* cv_ = pserver_ptr->export_cv(); if (block) { std::mutex mutex_; diff --git a/paddle/fluid/distributed/service/sendrecv.proto b/paddle/fluid/distributed/service/sendrecv.proto index 696c950d9b33ba5fe86d8a20ff19d55591384761..42e25258ec3fe1eb8061129576e751284a9fa30d 100644 --- a/paddle/fluid/distributed/service/sendrecv.proto +++ b/paddle/fluid/distributed/service/sendrecv.proto @@ -56,6 +56,7 @@ enum PsCmdID { PS_GRAPH_ADD_GRAPH_NODE = 35; PS_GRAPH_REMOVE_GRAPH_NODE = 36; PS_GRAPH_SET_NODE_FEAT = 37; + PS_GRAPH_SAMPLE_NODES_FROM_ONE_SERVER = 38; } message PsRequestMessage { diff --git a/paddle/fluid/distributed/service/server.h b/paddle/fluid/distributed/service/server.h index 89b089386f501835b7c384477b84f98f94c2a4a9..dffe19545ce52bb29ab339f0e76fde939c762b84 100644 --- a/paddle/fluid/distributed/service/server.h +++ b/paddle/fluid/distributed/service/server.h @@ -147,7 +147,7 @@ class PsBaseService : public PsService { public: PsBaseService() : _rank(0), _server(NULL), _config(NULL) {} virtual ~PsBaseService() {} - + virtual size_t get_rank() { return _rank; } virtual int32_t configure(PSServer *server) { _server = server; _rank = _server->rank(); @@ -167,6 +167,7 @@ class PsBaseService : public PsService { } virtual int32_t initialize() = 0; + PSServer *get_server() { return _server; } protected: size_t _rank; diff --git a/paddle/fluid/distributed/table/common_graph_table.cc b/paddle/fluid/distributed/table/common_graph_table.cc index 41f4b0dac4d96e9466d57341e20709f04bdabdf6..2c20e79b3b2d3497adbb596628e8105925c6a96c 100644 --- a/paddle/fluid/distributed/table/common_graph_table.cc +++ b/paddle/fluid/distributed/table/common_graph_table.cc @@ -305,12 +305,12 @@ Node *GraphTable::find_node(uint64_t id) { return node; } uint32_t GraphTable::get_thread_pool_index(uint64_t node_id) { - return node_id % shard_num % shard_num_per_table % task_pool_size_; + return node_id % shard_num % shard_num_per_server % task_pool_size_; } uint32_t GraphTable::get_thread_pool_index_by_shard_index( uint64_t shard_index) { - return shard_index % shard_num_per_table % task_pool_size_; + return shard_index % shard_num_per_server % task_pool_size_; } int32_t GraphTable::clear_nodes() { @@ -575,6 +575,11 @@ int32_t GraphTable::pull_graph_list(int start, int total_size, actual_size = size; return 0; } + +int32_t GraphTable::get_server_index_by_id(uint64_t id) { + return id % shard_num / shard_num_per_server; +} + int32_t GraphTable::initialize() { _shards_task_pool.resize(task_pool_size_); for (size_t i = 0; i < _shards_task_pool.size(); ++i) { @@ -611,13 +616,12 @@ int32_t GraphTable::initialize() { shard_num = _config.shard_num(); VLOG(0) << "in init graph table shard num = " << shard_num << " shard_idx" << _shard_idx; - shard_num_per_table = sparse_local_shard_num(shard_num, server_num); - shard_start = _shard_idx * shard_num_per_table; - shard_end = shard_start + shard_num_per_table; + shard_num_per_server = sparse_local_shard_num(shard_num, server_num); + shard_start = _shard_idx * shard_num_per_server; + shard_end = shard_start + shard_num_per_server; VLOG(0) << "in init graph table shard idx = " << _shard_idx << " shard_start " << shard_start << " shard_end " << shard_end; - // shards.resize(shard_num_per_table); - shards = std::vector(shard_num_per_table, GraphShard(shard_num)); + shards = std::vector(shard_num_per_server, GraphShard(shard_num)); return 0; } } // namespace distributed diff --git a/paddle/fluid/distributed/table/common_graph_table.h b/paddle/fluid/distributed/table/common_graph_table.h index f643337a80f7c24fa320ece5269ec69d10d5fd79..d681262c664807943bd3dda9bce4512495a441ed 100644 --- a/paddle/fluid/distributed/table/common_graph_table.h +++ b/paddle/fluid/distributed/table/common_graph_table.h @@ -94,6 +94,7 @@ class GraphTable : public SparseTable { int32_t remove_graph_node(std::vector &id_list); + int32_t get_server_index_by_id(uint64_t id); Node *find_node(uint64_t id); virtual int32_t pull_sparse(float *values, @@ -128,9 +129,11 @@ class GraphTable : public SparseTable { const std::vector &feature_names, const std::vector> &res); + size_t get_server_num() { return server_num; } + protected: std::vector shards; - size_t shard_start, shard_end, server_num, shard_num_per_table, shard_num; + size_t shard_start, shard_end, server_num, shard_num_per_server, shard_num; const int task_pool_size_ = 24; const int random_sample_nodes_ranges = 3; diff --git a/paddle/fluid/distributed/test/graph_node_test.cc b/paddle/fluid/distributed/test/graph_node_test.cc index 810530cdbec94d37cdd60b935b46342848feb27d..613770220f9d7995242da79f3b5fd70142c119f0 100644 --- a/paddle/fluid/distributed/test/graph_node_test.cc +++ b/paddle/fluid/distributed/test/graph_node_test.cc @@ -138,6 +138,10 @@ void testSingleSampleNeighboor( for (auto g : s) { ASSERT_EQ(true, s1.find(g) != s1.end()); } + vs.clear(); + pull_status = worker_ptr_->batch_sample_neighboors(0, {96, 37}, 4, vs, 0); + pull_status.wait(); + ASSERT_EQ(vs.size(), 2); } void testAddNode( @@ -356,6 +360,7 @@ void RunServer() { pserver_ptr_->configure(server_proto, _ps_env, 0, empty_vec); LOG(INFO) << "first server, run start(ip,port)"; pserver_ptr_->start(ip_, port_); + pserver_ptr_->build_peer2peer_connection(0); LOG(INFO) << "init first server Done"; } @@ -373,6 +378,7 @@ void RunServer2() { empty_vec2.push_back(empty_prog2); pserver_ptr2->configure(server_proto2, _ps_env2, 1, empty_vec2); pserver_ptr2->start(ip2, port2); + pserver_ptr2->build_peer2peer_connection(1); } void RunClient(