diff --git a/paddle/fluid/distributed/service/graph_brpc_client.cc b/paddle/fluid/distributed/service/graph_brpc_client.cc index 70f2da6d7252cee0268bdd35999926a232bc5b34..68d9c9669b6972f5042c86273db936995bec6a9e 100644 --- a/paddle/fluid/distributed/service/graph_brpc_client.cc +++ b/paddle/fluid/distributed/service/graph_brpc_client.cc @@ -479,6 +479,102 @@ std::future GraphBrpcClient::pull_graph_list( closure); return fut; } + +std::future GraphBrpcClient::set_node_feat( + const uint32_t &table_id, const std::vector &node_ids, + const std::vector &feature_names, + const std::vector> &features) { + std::vector request2server; + std::vector server2request(server_size, -1); + for (int query_idx = 0; query_idx < node_ids.size(); ++query_idx) { + int server_index = get_server_index_by_id(node_ids[query_idx]); + if (server2request[server_index] == -1) { + server2request[server_index] = request2server.size(); + request2server.push_back(server_index); + } + } + size_t request_call_num = request2server.size(); + std::vector> node_id_buckets(request_call_num); + std::vector> query_idx_buckets(request_call_num); + std::vector>> features_idx_buckets( + request_call_num); + for (int query_idx = 0; query_idx < node_ids.size(); ++query_idx) { + int server_index = get_server_index_by_id(node_ids[query_idx]); + int request_idx = server2request[server_index]; + node_id_buckets[request_idx].push_back(node_ids[query_idx]); + query_idx_buckets[request_idx].push_back(query_idx); + if (features_idx_buckets[request_idx].size() == 0) { + features_idx_buckets[request_idx].resize(feature_names.size()); + } + for (int feat_idx = 0; feat_idx < feature_names.size(); ++feat_idx) { + features_idx_buckets[request_idx][feat_idx].push_back( + features[feat_idx][query_idx]); + } + } + + DownpourBrpcClosure *closure = new DownpourBrpcClosure( + request_call_num, + [&, node_id_buckets, query_idx_buckets, request_call_num](void *done) { + int ret = 0; + auto *closure = (DownpourBrpcClosure *)done; + size_t fail_num = 0; + for (int request_idx = 0; request_idx < request_call_num; + ++request_idx) { + if (closure->check_response(request_idx, PS_GRAPH_SET_NODE_FEAT) != + 0) { + ++fail_num; + } + if (fail_num == request_call_num) { + ret = -1; + } + } + closure->set_promise_value(ret); + }); + + auto promise = std::make_shared>(); + closure->add_promise(promise); + std::future fut = promise->get_future(); + + for (int request_idx = 0; request_idx < request_call_num; ++request_idx) { + int server_index = request2server[request_idx]; + closure->request(request_idx)->set_cmd_id(PS_GRAPH_SET_NODE_FEAT); + closure->request(request_idx)->set_table_id(table_id); + closure->request(request_idx)->set_client_id(_client_id); + size_t node_num = node_id_buckets[request_idx].size(); + + closure->request(request_idx) + ->add_params((char *)node_id_buckets[request_idx].data(), + sizeof(uint64_t) * node_num); + std::string joint_feature_name = + paddle::string::join_strings(feature_names, '\t'); + closure->request(request_idx) + ->add_params(joint_feature_name.c_str(), joint_feature_name.size()); + + // set features + std::string set_feature = ""; + for (size_t feat_idx = 0; feat_idx < feature_names.size(); ++feat_idx) { + for (size_t node_idx = 0; node_idx < node_num; ++node_idx) { + size_t feat_len = + features_idx_buckets[request_idx][feat_idx][node_idx].size(); + set_feature.append((char *)&feat_len, sizeof(size_t)); + set_feature.append( + features_idx_buckets[request_idx][feat_idx][node_idx].data(), + feat_len); + } + } + closure->request(request_idx) + ->add_params(set_feature.c_str(), set_feature.size()); + + GraphPsService_Stub rpc_stub = + getServiceStub(get_cmd_channel(server_index)); + closure->cntl(request_idx)->set_log_id(butil::gettimeofday_ms()); + rpc_stub.service(closure->cntl(request_idx), closure->request(request_idx), + closure->response(request_idx), closure); + } + + return fut; +} + int32_t GraphBrpcClient::initialize() { // set_shard_num(_config.shard_num()); BrpcPsClient::initialize(); diff --git a/paddle/fluid/distributed/service/graph_brpc_client.h b/paddle/fluid/distributed/service/graph_brpc_client.h index 5696e8b08037b7027939f472f58ec79925143e4f..8acb2047b8e9724d4591a02e3e257cba5282f0cd 100644 --- a/paddle/fluid/distributed/service/graph_brpc_client.h +++ b/paddle/fluid/distributed/service/graph_brpc_client.h @@ -79,6 +79,11 @@ class GraphBrpcClient : public BrpcPsClient { const std::vector& feature_names, std::vector>& res); + virtual std::future set_node_feat( + const uint32_t& table_id, const std::vector& node_ids, + const std::vector& feature_names, + const std::vector>& features); + virtual std::future clear_nodes(uint32_t table_id); virtual std::future add_graph_node( uint32_t table_id, std::vector& node_id_list, diff --git a/paddle/fluid/distributed/service/graph_brpc_server.cc b/paddle/fluid/distributed/service/graph_brpc_server.cc index 52ac8c5d688a4ada72212923bdd478b788e422ee..110d4406fc5569e6abc9eabe83228c2fdca17316 100644 --- a/paddle/fluid/distributed/service/graph_brpc_server.cc +++ b/paddle/fluid/distributed/service/graph_brpc_server.cc @@ -16,6 +16,7 @@ #include "paddle/fluid/distributed/service/brpc_ps_server.h" #include // NOLINT +#include #include "butil/endpoint.h" #include "iomanip" #include "paddle/fluid/distributed/service/brpc_ps_client.h" @@ -157,6 +158,8 @@ int32_t GraphBrpcService::initialize() { &GraphBrpcService::add_graph_node; _service_handler_map[PS_GRAPH_REMOVE_GRAPH_NODE] = &GraphBrpcService::remove_graph_node; + _service_handler_map[PS_GRAPH_SET_NODE_FEAT] = + &GraphBrpcService::graph_set_node_feat; // shard初始化,server启动后才可从env获取到server_list的shard信息 initialize_shard_info(); @@ -400,5 +403,44 @@ int32_t GraphBrpcService::graph_get_node_feat(Table *table, return 0; } + +int32_t GraphBrpcService::graph_set_node_feat(Table *table, + const PsRequestMessage &request, + PsResponseMessage &response, + brpc::Controller *cntl) { + CHECK_TABLE_EXIST(table, request, response) + if (request.params_size() < 3) { + set_response_code( + response, -1, + "graph_set_node_feat request requires at least 2 arguments"); + return 0; + } + size_t node_num = request.params(0).size() / sizeof(uint64_t); + uint64_t *node_data = (uint64_t *)(request.params(0).c_str()); + std::vector node_ids(node_data, node_data + node_num); + + std::vector feature_names = + paddle::string::split_string(request.params(1), "\t"); + + std::vector> features( + feature_names.size(), std::vector(node_num)); + + const char *buffer = request.params(2).c_str(); + + for (size_t feat_idx = 0; feat_idx < feature_names.size(); ++feat_idx) { + for (size_t node_idx = 0; node_idx < node_num; ++node_idx) { + size_t feat_len = *(size_t *)(buffer); + buffer += sizeof(size_t); + auto feat = std::string(buffer, feat_len); + features[feat_idx][node_idx] = feat; + buffer += feat_len; + } + } + + ((GraphTable *)table)->set_node_feat(node_ids, feature_names, features); + + return 0; +} + } // namespace distributed } // namespace paddle diff --git a/paddle/fluid/distributed/service/graph_brpc_server.h b/paddle/fluid/distributed/service/graph_brpc_server.h index 47c370572826ac2807e4ea5cb36cf3a667dfed10..6b4853fa679923e39578d803272f1ddf978b632c 100644 --- a/paddle/fluid/distributed/service/graph_brpc_server.h +++ b/paddle/fluid/distributed/service/graph_brpc_server.h @@ -83,9 +83,13 @@ class GraphBrpcService : public PsBaseService { const PsRequestMessage &request, PsResponseMessage &response, brpc::Controller *cntl); + int32_t graph_get_node_feat(Table *table, const PsRequestMessage &request, PsResponseMessage &response, brpc::Controller *cntl); + int32_t graph_set_node_feat(Table *table, const PsRequestMessage &request, + PsResponseMessage &response, + brpc::Controller *cntl); int32_t clear_nodes(Table *table, const PsRequestMessage &request, PsResponseMessage &response, brpc::Controller *cntl); int32_t add_graph_node(Table *table, const PsRequestMessage &request, diff --git a/paddle/fluid/distributed/service/graph_py_service.cc b/paddle/fluid/distributed/service/graph_py_service.cc index 39befb1a112c854a183903d76a71d9e6c920b215..b41596270131744195ea29a18a12b1f8ba5f95f5 100644 --- a/paddle/fluid/distributed/service/graph_py_service.cc +++ b/paddle/fluid/distributed/service/graph_py_service.cc @@ -330,6 +330,19 @@ std::vector> GraphPyClient::get_node_feat( return v; } +void GraphPyClient::set_node_feat( + std::string node_type, std::vector node_ids, + std::vector feature_names, + const std::vector> features) { + if (this->table_id_map.count(node_type)) { + uint32_t table_id = this->table_id_map[node_type]; + auto status = + worker_ptr->set_node_feat(table_id, node_ids, feature_names, features); + status.wait(); + } + return; +} + std::vector GraphPyClient::pull_graph_list(std::string name, int server_index, int start, int size, diff --git a/paddle/fluid/distributed/service/graph_py_service.h b/paddle/fluid/distributed/service/graph_py_service.h index da027fbae3e6f0ca1e902795b0640cee1e0b76cc..8e03938801ce99e0fbbcafbe52174f8db86f8183 100644 --- a/paddle/fluid/distributed/service/graph_py_service.h +++ b/paddle/fluid/distributed/service/graph_py_service.h @@ -155,6 +155,9 @@ class GraphPyClient : public GraphPyService { std::vector> get_node_feat( std::string node_type, std::vector node_ids, std::vector feature_names); + void set_node_feat(std::string node_type, std::vector node_ids, + std::vector feature_names, + const std::vector> features); std::vector pull_graph_list(std::string name, int server_index, int start, int size, int step = 1); ::paddle::distributed::PSParameter GetWorkerProto(); diff --git a/paddle/fluid/distributed/service/sendrecv.proto b/paddle/fluid/distributed/service/sendrecv.proto index a4b811e950a3b56443261ceac37fa658007d519d..696c950d9b33ba5fe86d8a20ff19d55591384761 100644 --- a/paddle/fluid/distributed/service/sendrecv.proto +++ b/paddle/fluid/distributed/service/sendrecv.proto @@ -55,6 +55,7 @@ enum PsCmdID { PS_GRAPH_CLEAR = 34; PS_GRAPH_ADD_GRAPH_NODE = 35; PS_GRAPH_REMOVE_GRAPH_NODE = 36; + PS_GRAPH_SET_NODE_FEAT = 37; } message PsRequestMessage { diff --git a/paddle/fluid/distributed/table/common_graph_table.cc b/paddle/fluid/distributed/table/common_graph_table.cc index 29bcc04d9c1dfb3f3a5d32040162c4f5c6371672..41f4b0dac4d96e9466d57341e20709f04bdabdf6 100644 --- a/paddle/fluid/distributed/table/common_graph_table.cc +++ b/paddle/fluid/distributed/table/common_graph_table.cc @@ -469,6 +469,34 @@ int32_t GraphTable::get_node_feat(const std::vector &node_ids, return 0; } +int32_t GraphTable::set_node_feat( + const std::vector &node_ids, + const std::vector &feature_names, + const std::vector> &res) { + size_t node_num = node_ids.size(); + std::vector> tasks; + for (size_t idx = 0; idx < node_num; ++idx) { + uint64_t node_id = node_ids[idx]; + tasks.push_back(_shards_task_pool[get_thread_pool_index(node_id)]->enqueue( + [&, idx, node_id]() -> int { + size_t index = node_id % this->shard_num - this->shard_start; + auto node = shards[index].add_feature_node(node_id); + node->set_feature_size(this->feat_name.size()); + for (int feat_idx = 0; feat_idx < feature_names.size(); ++feat_idx) { + const std::string &feature_name = feature_names[feat_idx]; + if (feat_id_map.find(feature_name) != feat_id_map.end()) { + node->set_feature(feat_id_map[feature_name], res[feat_idx][idx]); + } + } + return 0; + })); + } + for (size_t idx = 0; idx < node_num; ++idx) { + tasks[idx].get(); + } + return 0; +} + std::pair GraphTable::parse_feature( std::string feat_str) { // Return (feat_id, btyes) if name are in this->feat_name, else return (-1, diff --git a/paddle/fluid/distributed/table/common_graph_table.h b/paddle/fluid/distributed/table/common_graph_table.h index 6ccce44c7ead6983efb57718999f1b36499b34e8..f643337a80f7c24fa320ece5269ec69d10d5fd79 100644 --- a/paddle/fluid/distributed/table/common_graph_table.h +++ b/paddle/fluid/distributed/table/common_graph_table.h @@ -46,6 +46,7 @@ class GraphShard { } return res; } + GraphNode *add_graph_node(uint64_t id); FeatureNode *add_feature_node(uint64_t id); Node *find_node(uint64_t id); @@ -122,6 +123,11 @@ class GraphTable : public SparseTable { const std::vector &feature_names, std::vector> &res); + virtual int32_t set_node_feat( + const std::vector &node_ids, + const std::vector &feature_names, + const std::vector> &res); + protected: std::vector shards; size_t shard_start, shard_end, server_num, shard_num_per_table, shard_num; diff --git a/paddle/fluid/distributed/test/graph_node_test.cc b/paddle/fluid/distributed/test/graph_node_test.cc index b8630aed02ffe60181ddb6b41810f5bea602b733..810530cdbec94d37cdd60b935b46342848feb27d 100644 --- a/paddle/fluid/distributed/test/graph_node_test.cc +++ b/paddle/fluid/distributed/test/graph_node_test.cc @@ -558,6 +558,17 @@ void RunBrpcPushSparse() { VLOG(0) << "get_node_feat: " << node_feat[1][0]; VLOG(0) << "get_node_feat: " << node_feat[1][1]; + node_feat[1][0] = "helloworld"; + + client1.set_node_feat(std::string("user"), node_ids, feature_names, + node_feat); + + // sleep(5); + node_feat = + client1.get_node_feat(std::string("user"), node_ids, feature_names); + VLOG(0) << "get_node_feat: " << node_feat[1][0]; + ASSERT_TRUE(node_feat[1][0] == "helloworld"); + // Test string node_ids.clear(); node_ids.push_back(37); diff --git a/paddle/fluid/pybind/fleet_py.cc b/paddle/fluid/pybind/fleet_py.cc index a6b542f53ae1785252b8993982345fd233902458..ea9faf57ac52b6f0369be932eea852122149b801 100644 --- a/paddle/fluid/pybind/fleet_py.cc +++ b/paddle/fluid/pybind/fleet_py.cc @@ -205,6 +205,7 @@ void BindGraphPyClient(py::module* m) { .def("pull_graph_list", &GraphPyClient::pull_graph_list) .def("start_client", &GraphPyClient::start_client) .def("batch_sample_neighboors", &GraphPyClient::batch_sample_neighboors) + .def("remove_graph_node", &GraphPyClient::remove_graph_node) .def("random_sample_nodes", &GraphPyClient::random_sample_nodes) .def("stop_server", &GraphPyClient::stop_server) .def("get_node_feat", @@ -221,6 +222,20 @@ void BindGraphPyClient(py::module* m) { } return bytes_feats; }) + .def("set_node_feat", + [](GraphPyClient& self, std::string node_type, + std::vector node_ids, + std::vector feature_names, + std::vector> bytes_feats) { + std::vector> feats(bytes_feats.size()); + for (int i = 0; i < bytes_feats.size(); ++i) { + for (int j = 0; j < bytes_feats[i].size(); ++j) { + feats[i].push_back(std::string(bytes_feats[i][j])); + } + } + self.set_node_feat(node_type, node_ids, feature_names, feats); + return; + }) .def("bind_local_server", &GraphPyClient::bind_local_server); }