未验证 提交 c3efabeb 编写于 作者: S seemingwang 提交者: GitHub

set node feature (#34994)

上级 77a8a394
......@@ -479,6 +479,102 @@ std::future<int32_t> GraphBrpcClient::pull_graph_list(
closure);
return fut;
}
std::future<int32_t> GraphBrpcClient::set_node_feat(
const uint32_t &table_id, const std::vector<uint64_t> &node_ids,
const std::vector<std::string> &feature_names,
const std::vector<std::vector<std::string>> &features) {
std::vector<int> request2server;
std::vector<int> server2request(server_size, -1);
for (int query_idx = 0; query_idx < node_ids.size(); ++query_idx) {
int server_index = get_server_index_by_id(node_ids[query_idx]);
if (server2request[server_index] == -1) {
server2request[server_index] = request2server.size();
request2server.push_back(server_index);
}
}
size_t request_call_num = request2server.size();
std::vector<std::vector<uint64_t>> node_id_buckets(request_call_num);
std::vector<std::vector<int>> query_idx_buckets(request_call_num);
std::vector<std::vector<std::vector<std::string>>> features_idx_buckets(
request_call_num);
for (int query_idx = 0; query_idx < node_ids.size(); ++query_idx) {
int server_index = get_server_index_by_id(node_ids[query_idx]);
int request_idx = server2request[server_index];
node_id_buckets[request_idx].push_back(node_ids[query_idx]);
query_idx_buckets[request_idx].push_back(query_idx);
if (features_idx_buckets[request_idx].size() == 0) {
features_idx_buckets[request_idx].resize(feature_names.size());
}
for (int feat_idx = 0; feat_idx < feature_names.size(); ++feat_idx) {
features_idx_buckets[request_idx][feat_idx].push_back(
features[feat_idx][query_idx]);
}
}
DownpourBrpcClosure *closure = new DownpourBrpcClosure(
request_call_num,
[&, node_id_buckets, query_idx_buckets, request_call_num](void *done) {
int ret = 0;
auto *closure = (DownpourBrpcClosure *)done;
size_t fail_num = 0;
for (int request_idx = 0; request_idx < request_call_num;
++request_idx) {
if (closure->check_response(request_idx, PS_GRAPH_SET_NODE_FEAT) !=
0) {
++fail_num;
}
if (fail_num == request_call_num) {
ret = -1;
}
}
closure->set_promise_value(ret);
});
auto promise = std::make_shared<std::promise<int32_t>>();
closure->add_promise(promise);
std::future<int> fut = promise->get_future();
for (int request_idx = 0; request_idx < request_call_num; ++request_idx) {
int server_index = request2server[request_idx];
closure->request(request_idx)->set_cmd_id(PS_GRAPH_SET_NODE_FEAT);
closure->request(request_idx)->set_table_id(table_id);
closure->request(request_idx)->set_client_id(_client_id);
size_t node_num = node_id_buckets[request_idx].size();
closure->request(request_idx)
->add_params((char *)node_id_buckets[request_idx].data(),
sizeof(uint64_t) * node_num);
std::string joint_feature_name =
paddle::string::join_strings(feature_names, '\t');
closure->request(request_idx)
->add_params(joint_feature_name.c_str(), joint_feature_name.size());
// set features
std::string set_feature = "";
for (size_t feat_idx = 0; feat_idx < feature_names.size(); ++feat_idx) {
for (size_t node_idx = 0; node_idx < node_num; ++node_idx) {
size_t feat_len =
features_idx_buckets[request_idx][feat_idx][node_idx].size();
set_feature.append((char *)&feat_len, sizeof(size_t));
set_feature.append(
features_idx_buckets[request_idx][feat_idx][node_idx].data(),
feat_len);
}
}
closure->request(request_idx)
->add_params(set_feature.c_str(), set_feature.size());
GraphPsService_Stub rpc_stub =
getServiceStub(get_cmd_channel(server_index));
closure->cntl(request_idx)->set_log_id(butil::gettimeofday_ms());
rpc_stub.service(closure->cntl(request_idx), closure->request(request_idx),
closure->response(request_idx), closure);
}
return fut;
}
int32_t GraphBrpcClient::initialize() {
// set_shard_num(_config.shard_num());
BrpcPsClient::initialize();
......
......@@ -79,6 +79,11 @@ class GraphBrpcClient : public BrpcPsClient {
const std::vector<std::string>& feature_names,
std::vector<std::vector<std::string>>& res);
virtual std::future<int32_t> set_node_feat(
const uint32_t& table_id, const std::vector<uint64_t>& node_ids,
const std::vector<std::string>& feature_names,
const std::vector<std::vector<std::string>>& features);
virtual std::future<int32_t> clear_nodes(uint32_t table_id);
virtual std::future<int32_t> add_graph_node(
uint32_t table_id, std::vector<uint64_t>& node_id_list,
......
......@@ -16,6 +16,7 @@
#include "paddle/fluid/distributed/service/brpc_ps_server.h"
#include <thread> // NOLINT
#include <utility>
#include "butil/endpoint.h"
#include "iomanip"
#include "paddle/fluid/distributed/service/brpc_ps_client.h"
......@@ -157,6 +158,8 @@ int32_t GraphBrpcService::initialize() {
&GraphBrpcService::add_graph_node;
_service_handler_map[PS_GRAPH_REMOVE_GRAPH_NODE] =
&GraphBrpcService::remove_graph_node;
_service_handler_map[PS_GRAPH_SET_NODE_FEAT] =
&GraphBrpcService::graph_set_node_feat;
// shard初始化,server启动后才可从env获取到server_list的shard信息
initialize_shard_info();
......@@ -400,5 +403,44 @@ int32_t GraphBrpcService::graph_get_node_feat(Table *table,
return 0;
}
int32_t GraphBrpcService::graph_set_node_feat(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
CHECK_TABLE_EXIST(table, request, response)
if (request.params_size() < 3) {
set_response_code(
response, -1,
"graph_set_node_feat request requires at least 2 arguments");
return 0;
}
size_t node_num = request.params(0).size() / sizeof(uint64_t);
uint64_t *node_data = (uint64_t *)(request.params(0).c_str());
std::vector<uint64_t> node_ids(node_data, node_data + node_num);
std::vector<std::string> feature_names =
paddle::string::split_string<std::string>(request.params(1), "\t");
std::vector<std::vector<std::string>> features(
feature_names.size(), std::vector<std::string>(node_num));
const char *buffer = request.params(2).c_str();
for (size_t feat_idx = 0; feat_idx < feature_names.size(); ++feat_idx) {
for (size_t node_idx = 0; node_idx < node_num; ++node_idx) {
size_t feat_len = *(size_t *)(buffer);
buffer += sizeof(size_t);
auto feat = std::string(buffer, feat_len);
features[feat_idx][node_idx] = feat;
buffer += feat_len;
}
}
((GraphTable *)table)->set_node_feat(node_ids, feature_names, features);
return 0;
}
} // namespace distributed
} // namespace paddle
......@@ -83,9 +83,13 @@ class GraphBrpcService : public PsBaseService {
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl);
int32_t graph_get_node_feat(Table *table, const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl);
int32_t graph_set_node_feat(Table *table, const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl);
int32_t clear_nodes(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl);
int32_t add_graph_node(Table *table, const PsRequestMessage &request,
......
......@@ -330,6 +330,19 @@ std::vector<std::vector<std::string>> GraphPyClient::get_node_feat(
return v;
}
void GraphPyClient::set_node_feat(
std::string node_type, std::vector<uint64_t> node_ids,
std::vector<std::string> feature_names,
const std::vector<std::vector<std::string>> features) {
if (this->table_id_map.count(node_type)) {
uint32_t table_id = this->table_id_map[node_type];
auto status =
worker_ptr->set_node_feat(table_id, node_ids, feature_names, features);
status.wait();
}
return;
}
std::vector<FeatureNode> GraphPyClient::pull_graph_list(std::string name,
int server_index,
int start, int size,
......
......@@ -155,6 +155,9 @@ class GraphPyClient : public GraphPyService {
std::vector<std::vector<std::string>> get_node_feat(
std::string node_type, std::vector<uint64_t> node_ids,
std::vector<std::string> feature_names);
void set_node_feat(std::string node_type, std::vector<uint64_t> node_ids,
std::vector<std::string> feature_names,
const std::vector<std::vector<std::string>> features);
std::vector<FeatureNode> pull_graph_list(std::string name, int server_index,
int start, int size, int step = 1);
::paddle::distributed::PSParameter GetWorkerProto();
......
......@@ -55,6 +55,7 @@ enum PsCmdID {
PS_GRAPH_CLEAR = 34;
PS_GRAPH_ADD_GRAPH_NODE = 35;
PS_GRAPH_REMOVE_GRAPH_NODE = 36;
PS_GRAPH_SET_NODE_FEAT = 37;
}
message PsRequestMessage {
......
......@@ -469,6 +469,34 @@ int32_t GraphTable::get_node_feat(const std::vector<uint64_t> &node_ids,
return 0;
}
int32_t GraphTable::set_node_feat(
const std::vector<uint64_t> &node_ids,
const std::vector<std::string> &feature_names,
const std::vector<std::vector<std::string>> &res) {
size_t node_num = node_ids.size();
std::vector<std::future<int>> tasks;
for (size_t idx = 0; idx < node_num; ++idx) {
uint64_t node_id = node_ids[idx];
tasks.push_back(_shards_task_pool[get_thread_pool_index(node_id)]->enqueue(
[&, idx, node_id]() -> int {
size_t index = node_id % this->shard_num - this->shard_start;
auto node = shards[index].add_feature_node(node_id);
node->set_feature_size(this->feat_name.size());
for (int feat_idx = 0; feat_idx < feature_names.size(); ++feat_idx) {
const std::string &feature_name = feature_names[feat_idx];
if (feat_id_map.find(feature_name) != feat_id_map.end()) {
node->set_feature(feat_id_map[feature_name], res[feat_idx][idx]);
}
}
return 0;
}));
}
for (size_t idx = 0; idx < node_num; ++idx) {
tasks[idx].get();
}
return 0;
}
std::pair<int32_t, std::string> GraphTable::parse_feature(
std::string feat_str) {
// Return (feat_id, btyes) if name are in this->feat_name, else return (-1,
......
......@@ -46,6 +46,7 @@ class GraphShard {
}
return res;
}
GraphNode *add_graph_node(uint64_t id);
FeatureNode *add_feature_node(uint64_t id);
Node *find_node(uint64_t id);
......@@ -122,6 +123,11 @@ class GraphTable : public SparseTable {
const std::vector<std::string> &feature_names,
std::vector<std::vector<std::string>> &res);
virtual int32_t set_node_feat(
const std::vector<uint64_t> &node_ids,
const std::vector<std::string> &feature_names,
const std::vector<std::vector<std::string>> &res);
protected:
std::vector<GraphShard> shards;
size_t shard_start, shard_end, server_num, shard_num_per_table, shard_num;
......
......@@ -558,6 +558,17 @@ void RunBrpcPushSparse() {
VLOG(0) << "get_node_feat: " << node_feat[1][0];
VLOG(0) << "get_node_feat: " << node_feat[1][1];
node_feat[1][0] = "helloworld";
client1.set_node_feat(std::string("user"), node_ids, feature_names,
node_feat);
// sleep(5);
node_feat =
client1.get_node_feat(std::string("user"), node_ids, feature_names);
VLOG(0) << "get_node_feat: " << node_feat[1][0];
ASSERT_TRUE(node_feat[1][0] == "helloworld");
// Test string
node_ids.clear();
node_ids.push_back(37);
......
......@@ -205,6 +205,7 @@ void BindGraphPyClient(py::module* m) {
.def("pull_graph_list", &GraphPyClient::pull_graph_list)
.def("start_client", &GraphPyClient::start_client)
.def("batch_sample_neighboors", &GraphPyClient::batch_sample_neighboors)
.def("remove_graph_node", &GraphPyClient::remove_graph_node)
.def("random_sample_nodes", &GraphPyClient::random_sample_nodes)
.def("stop_server", &GraphPyClient::stop_server)
.def("get_node_feat",
......@@ -221,6 +222,20 @@ void BindGraphPyClient(py::module* m) {
}
return bytes_feats;
})
.def("set_node_feat",
[](GraphPyClient& self, std::string node_type,
std::vector<uint64_t> node_ids,
std::vector<std::string> feature_names,
std::vector<std::vector<py::bytes>> bytes_feats) {
std::vector<std::vector<std::string>> feats(bytes_feats.size());
for (int i = 0; i < bytes_feats.size(); ++i) {
for (int j = 0; j < bytes_feats[i].size(); ++j) {
feats[i].push_back(std::string(bytes_feats[i][j]));
}
}
self.set_node_feat(node_type, node_ids, feature_names, feats);
return;
})
.def("bind_local_server", &GraphPyClient::bind_local_server);
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册