未验证 提交 c3efabeb 编写于 作者: S seemingwang 提交者: GitHub

set node feature (#34994)

上级 77a8a394
...@@ -479,6 +479,102 @@ std::future<int32_t> GraphBrpcClient::pull_graph_list( ...@@ -479,6 +479,102 @@ std::future<int32_t> GraphBrpcClient::pull_graph_list(
closure); closure);
return fut; return fut;
} }
std::future<int32_t> GraphBrpcClient::set_node_feat(
const uint32_t &table_id, const std::vector<uint64_t> &node_ids,
const std::vector<std::string> &feature_names,
const std::vector<std::vector<std::string>> &features) {
std::vector<int> request2server;
std::vector<int> server2request(server_size, -1);
for (int query_idx = 0; query_idx < node_ids.size(); ++query_idx) {
int server_index = get_server_index_by_id(node_ids[query_idx]);
if (server2request[server_index] == -1) {
server2request[server_index] = request2server.size();
request2server.push_back(server_index);
}
}
size_t request_call_num = request2server.size();
std::vector<std::vector<uint64_t>> node_id_buckets(request_call_num);
std::vector<std::vector<int>> query_idx_buckets(request_call_num);
std::vector<std::vector<std::vector<std::string>>> features_idx_buckets(
request_call_num);
for (int query_idx = 0; query_idx < node_ids.size(); ++query_idx) {
int server_index = get_server_index_by_id(node_ids[query_idx]);
int request_idx = server2request[server_index];
node_id_buckets[request_idx].push_back(node_ids[query_idx]);
query_idx_buckets[request_idx].push_back(query_idx);
if (features_idx_buckets[request_idx].size() == 0) {
features_idx_buckets[request_idx].resize(feature_names.size());
}
for (int feat_idx = 0; feat_idx < feature_names.size(); ++feat_idx) {
features_idx_buckets[request_idx][feat_idx].push_back(
features[feat_idx][query_idx]);
}
}
DownpourBrpcClosure *closure = new DownpourBrpcClosure(
request_call_num,
[&, node_id_buckets, query_idx_buckets, request_call_num](void *done) {
int ret = 0;
auto *closure = (DownpourBrpcClosure *)done;
size_t fail_num = 0;
for (int request_idx = 0; request_idx < request_call_num;
++request_idx) {
if (closure->check_response(request_idx, PS_GRAPH_SET_NODE_FEAT) !=
0) {
++fail_num;
}
if (fail_num == request_call_num) {
ret = -1;
}
}
closure->set_promise_value(ret);
});
auto promise = std::make_shared<std::promise<int32_t>>();
closure->add_promise(promise);
std::future<int> fut = promise->get_future();
for (int request_idx = 0; request_idx < request_call_num; ++request_idx) {
int server_index = request2server[request_idx];
closure->request(request_idx)->set_cmd_id(PS_GRAPH_SET_NODE_FEAT);
closure->request(request_idx)->set_table_id(table_id);
closure->request(request_idx)->set_client_id(_client_id);
size_t node_num = node_id_buckets[request_idx].size();
closure->request(request_idx)
->add_params((char *)node_id_buckets[request_idx].data(),
sizeof(uint64_t) * node_num);
std::string joint_feature_name =
paddle::string::join_strings(feature_names, '\t');
closure->request(request_idx)
->add_params(joint_feature_name.c_str(), joint_feature_name.size());
// set features
std::string set_feature = "";
for (size_t feat_idx = 0; feat_idx < feature_names.size(); ++feat_idx) {
for (size_t node_idx = 0; node_idx < node_num; ++node_idx) {
size_t feat_len =
features_idx_buckets[request_idx][feat_idx][node_idx].size();
set_feature.append((char *)&feat_len, sizeof(size_t));
set_feature.append(
features_idx_buckets[request_idx][feat_idx][node_idx].data(),
feat_len);
}
}
closure->request(request_idx)
->add_params(set_feature.c_str(), set_feature.size());
GraphPsService_Stub rpc_stub =
getServiceStub(get_cmd_channel(server_index));
closure->cntl(request_idx)->set_log_id(butil::gettimeofday_ms());
rpc_stub.service(closure->cntl(request_idx), closure->request(request_idx),
closure->response(request_idx), closure);
}
return fut;
}
int32_t GraphBrpcClient::initialize() { int32_t GraphBrpcClient::initialize() {
// set_shard_num(_config.shard_num()); // set_shard_num(_config.shard_num());
BrpcPsClient::initialize(); BrpcPsClient::initialize();
......
...@@ -79,6 +79,11 @@ class GraphBrpcClient : public BrpcPsClient { ...@@ -79,6 +79,11 @@ class GraphBrpcClient : public BrpcPsClient {
const std::vector<std::string>& feature_names, const std::vector<std::string>& feature_names,
std::vector<std::vector<std::string>>& res); std::vector<std::vector<std::string>>& res);
virtual std::future<int32_t> set_node_feat(
const uint32_t& table_id, const std::vector<uint64_t>& node_ids,
const std::vector<std::string>& feature_names,
const std::vector<std::vector<std::string>>& features);
virtual std::future<int32_t> clear_nodes(uint32_t table_id); virtual std::future<int32_t> clear_nodes(uint32_t table_id);
virtual std::future<int32_t> add_graph_node( virtual std::future<int32_t> add_graph_node(
uint32_t table_id, std::vector<uint64_t>& node_id_list, uint32_t table_id, std::vector<uint64_t>& node_id_list,
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
#include "paddle/fluid/distributed/service/brpc_ps_server.h" #include "paddle/fluid/distributed/service/brpc_ps_server.h"
#include <thread> // NOLINT #include <thread> // NOLINT
#include <utility>
#include "butil/endpoint.h" #include "butil/endpoint.h"
#include "iomanip" #include "iomanip"
#include "paddle/fluid/distributed/service/brpc_ps_client.h" #include "paddle/fluid/distributed/service/brpc_ps_client.h"
...@@ -157,6 +158,8 @@ int32_t GraphBrpcService::initialize() { ...@@ -157,6 +158,8 @@ int32_t GraphBrpcService::initialize() {
&GraphBrpcService::add_graph_node; &GraphBrpcService::add_graph_node;
_service_handler_map[PS_GRAPH_REMOVE_GRAPH_NODE] = _service_handler_map[PS_GRAPH_REMOVE_GRAPH_NODE] =
&GraphBrpcService::remove_graph_node; &GraphBrpcService::remove_graph_node;
_service_handler_map[PS_GRAPH_SET_NODE_FEAT] =
&GraphBrpcService::graph_set_node_feat;
// shard初始化,server启动后才可从env获取到server_list的shard信息 // shard初始化,server启动后才可从env获取到server_list的shard信息
initialize_shard_info(); initialize_shard_info();
...@@ -400,5 +403,44 @@ int32_t GraphBrpcService::graph_get_node_feat(Table *table, ...@@ -400,5 +403,44 @@ int32_t GraphBrpcService::graph_get_node_feat(Table *table,
return 0; return 0;
} }
int32_t GraphBrpcService::graph_set_node_feat(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
CHECK_TABLE_EXIST(table, request, response)
if (request.params_size() < 3) {
set_response_code(
response, -1,
"graph_set_node_feat request requires at least 2 arguments");
return 0;
}
size_t node_num = request.params(0).size() / sizeof(uint64_t);
uint64_t *node_data = (uint64_t *)(request.params(0).c_str());
std::vector<uint64_t> node_ids(node_data, node_data + node_num);
std::vector<std::string> feature_names =
paddle::string::split_string<std::string>(request.params(1), "\t");
std::vector<std::vector<std::string>> features(
feature_names.size(), std::vector<std::string>(node_num));
const char *buffer = request.params(2).c_str();
for (size_t feat_idx = 0; feat_idx < feature_names.size(); ++feat_idx) {
for (size_t node_idx = 0; node_idx < node_num; ++node_idx) {
size_t feat_len = *(size_t *)(buffer);
buffer += sizeof(size_t);
auto feat = std::string(buffer, feat_len);
features[feat_idx][node_idx] = feat;
buffer += feat_len;
}
}
((GraphTable *)table)->set_node_feat(node_ids, feature_names, features);
return 0;
}
} // namespace distributed } // namespace distributed
} // namespace paddle } // namespace paddle
...@@ -83,9 +83,13 @@ class GraphBrpcService : public PsBaseService { ...@@ -83,9 +83,13 @@ class GraphBrpcService : public PsBaseService {
const PsRequestMessage &request, const PsRequestMessage &request,
PsResponseMessage &response, PsResponseMessage &response,
brpc::Controller *cntl); brpc::Controller *cntl);
int32_t graph_get_node_feat(Table *table, const PsRequestMessage &request, int32_t graph_get_node_feat(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, PsResponseMessage &response,
brpc::Controller *cntl); brpc::Controller *cntl);
int32_t graph_set_node_feat(Table *table, const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl);
int32_t clear_nodes(Table *table, const PsRequestMessage &request, int32_t clear_nodes(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl); PsResponseMessage &response, brpc::Controller *cntl);
int32_t add_graph_node(Table *table, const PsRequestMessage &request, int32_t add_graph_node(Table *table, const PsRequestMessage &request,
......
...@@ -330,6 +330,19 @@ std::vector<std::vector<std::string>> GraphPyClient::get_node_feat( ...@@ -330,6 +330,19 @@ std::vector<std::vector<std::string>> GraphPyClient::get_node_feat(
return v; return v;
} }
void GraphPyClient::set_node_feat(
std::string node_type, std::vector<uint64_t> node_ids,
std::vector<std::string> feature_names,
const std::vector<std::vector<std::string>> features) {
if (this->table_id_map.count(node_type)) {
uint32_t table_id = this->table_id_map[node_type];
auto status =
worker_ptr->set_node_feat(table_id, node_ids, feature_names, features);
status.wait();
}
return;
}
std::vector<FeatureNode> GraphPyClient::pull_graph_list(std::string name, std::vector<FeatureNode> GraphPyClient::pull_graph_list(std::string name,
int server_index, int server_index,
int start, int size, int start, int size,
......
...@@ -155,6 +155,9 @@ class GraphPyClient : public GraphPyService { ...@@ -155,6 +155,9 @@ class GraphPyClient : public GraphPyService {
std::vector<std::vector<std::string>> get_node_feat( std::vector<std::vector<std::string>> get_node_feat(
std::string node_type, std::vector<uint64_t> node_ids, std::string node_type, std::vector<uint64_t> node_ids,
std::vector<std::string> feature_names); std::vector<std::string> feature_names);
void set_node_feat(std::string node_type, std::vector<uint64_t> node_ids,
std::vector<std::string> feature_names,
const std::vector<std::vector<std::string>> features);
std::vector<FeatureNode> pull_graph_list(std::string name, int server_index, std::vector<FeatureNode> pull_graph_list(std::string name, int server_index,
int start, int size, int step = 1); int start, int size, int step = 1);
::paddle::distributed::PSParameter GetWorkerProto(); ::paddle::distributed::PSParameter GetWorkerProto();
......
...@@ -55,6 +55,7 @@ enum PsCmdID { ...@@ -55,6 +55,7 @@ enum PsCmdID {
PS_GRAPH_CLEAR = 34; PS_GRAPH_CLEAR = 34;
PS_GRAPH_ADD_GRAPH_NODE = 35; PS_GRAPH_ADD_GRAPH_NODE = 35;
PS_GRAPH_REMOVE_GRAPH_NODE = 36; PS_GRAPH_REMOVE_GRAPH_NODE = 36;
PS_GRAPH_SET_NODE_FEAT = 37;
} }
message PsRequestMessage { message PsRequestMessage {
......
...@@ -469,6 +469,34 @@ int32_t GraphTable::get_node_feat(const std::vector<uint64_t> &node_ids, ...@@ -469,6 +469,34 @@ int32_t GraphTable::get_node_feat(const std::vector<uint64_t> &node_ids,
return 0; return 0;
} }
int32_t GraphTable::set_node_feat(
const std::vector<uint64_t> &node_ids,
const std::vector<std::string> &feature_names,
const std::vector<std::vector<std::string>> &res) {
size_t node_num = node_ids.size();
std::vector<std::future<int>> tasks;
for (size_t idx = 0; idx < node_num; ++idx) {
uint64_t node_id = node_ids[idx];
tasks.push_back(_shards_task_pool[get_thread_pool_index(node_id)]->enqueue(
[&, idx, node_id]() -> int {
size_t index = node_id % this->shard_num - this->shard_start;
auto node = shards[index].add_feature_node(node_id);
node->set_feature_size(this->feat_name.size());
for (int feat_idx = 0; feat_idx < feature_names.size(); ++feat_idx) {
const std::string &feature_name = feature_names[feat_idx];
if (feat_id_map.find(feature_name) != feat_id_map.end()) {
node->set_feature(feat_id_map[feature_name], res[feat_idx][idx]);
}
}
return 0;
}));
}
for (size_t idx = 0; idx < node_num; ++idx) {
tasks[idx].get();
}
return 0;
}
std::pair<int32_t, std::string> GraphTable::parse_feature( std::pair<int32_t, std::string> GraphTable::parse_feature(
std::string feat_str) { std::string feat_str) {
// Return (feat_id, btyes) if name are in this->feat_name, else return (-1, // Return (feat_id, btyes) if name are in this->feat_name, else return (-1,
......
...@@ -46,6 +46,7 @@ class GraphShard { ...@@ -46,6 +46,7 @@ class GraphShard {
} }
return res; return res;
} }
GraphNode *add_graph_node(uint64_t id); GraphNode *add_graph_node(uint64_t id);
FeatureNode *add_feature_node(uint64_t id); FeatureNode *add_feature_node(uint64_t id);
Node *find_node(uint64_t id); Node *find_node(uint64_t id);
...@@ -122,6 +123,11 @@ class GraphTable : public SparseTable { ...@@ -122,6 +123,11 @@ class GraphTable : public SparseTable {
const std::vector<std::string> &feature_names, const std::vector<std::string> &feature_names,
std::vector<std::vector<std::string>> &res); std::vector<std::vector<std::string>> &res);
virtual int32_t set_node_feat(
const std::vector<uint64_t> &node_ids,
const std::vector<std::string> &feature_names,
const std::vector<std::vector<std::string>> &res);
protected: protected:
std::vector<GraphShard> shards; std::vector<GraphShard> shards;
size_t shard_start, shard_end, server_num, shard_num_per_table, shard_num; size_t shard_start, shard_end, server_num, shard_num_per_table, shard_num;
......
...@@ -558,6 +558,17 @@ void RunBrpcPushSparse() { ...@@ -558,6 +558,17 @@ void RunBrpcPushSparse() {
VLOG(0) << "get_node_feat: " << node_feat[1][0]; VLOG(0) << "get_node_feat: " << node_feat[1][0];
VLOG(0) << "get_node_feat: " << node_feat[1][1]; VLOG(0) << "get_node_feat: " << node_feat[1][1];
node_feat[1][0] = "helloworld";
client1.set_node_feat(std::string("user"), node_ids, feature_names,
node_feat);
// sleep(5);
node_feat =
client1.get_node_feat(std::string("user"), node_ids, feature_names);
VLOG(0) << "get_node_feat: " << node_feat[1][0];
ASSERT_TRUE(node_feat[1][0] == "helloworld");
// Test string // Test string
node_ids.clear(); node_ids.clear();
node_ids.push_back(37); node_ids.push_back(37);
......
...@@ -205,6 +205,7 @@ void BindGraphPyClient(py::module* m) { ...@@ -205,6 +205,7 @@ void BindGraphPyClient(py::module* m) {
.def("pull_graph_list", &GraphPyClient::pull_graph_list) .def("pull_graph_list", &GraphPyClient::pull_graph_list)
.def("start_client", &GraphPyClient::start_client) .def("start_client", &GraphPyClient::start_client)
.def("batch_sample_neighboors", &GraphPyClient::batch_sample_neighboors) .def("batch_sample_neighboors", &GraphPyClient::batch_sample_neighboors)
.def("remove_graph_node", &GraphPyClient::remove_graph_node)
.def("random_sample_nodes", &GraphPyClient::random_sample_nodes) .def("random_sample_nodes", &GraphPyClient::random_sample_nodes)
.def("stop_server", &GraphPyClient::stop_server) .def("stop_server", &GraphPyClient::stop_server)
.def("get_node_feat", .def("get_node_feat",
...@@ -221,6 +222,20 @@ void BindGraphPyClient(py::module* m) { ...@@ -221,6 +222,20 @@ void BindGraphPyClient(py::module* m) {
} }
return bytes_feats; return bytes_feats;
}) })
.def("set_node_feat",
[](GraphPyClient& self, std::string node_type,
std::vector<uint64_t> node_ids,
std::vector<std::string> feature_names,
std::vector<std::vector<py::bytes>> bytes_feats) {
std::vector<std::vector<std::string>> feats(bytes_feats.size());
for (int i = 0; i < bytes_feats.size(); ++i) {
for (int j = 0; j < bytes_feats[i].size(); ++j) {
feats[i].push_back(std::string(bytes_feats[i][j]));
}
}
self.set_node_feat(node_type, node_ids, feature_names, feats);
return;
})
.def("bind_local_server", &GraphPyClient::bind_local_server); .def("bind_local_server", &GraphPyClient::bind_local_server);
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册