未验证 提交 85c8c170 编写于 作者: S seemingwang 提交者: GitHub

simplify graph-engine's templates (#36990)

* graph engine demo

* upload unsaved changes

* fix dependency error

* fix shard_num problem

* py client

* remove lock and graph-type

* add load direct graph

* add load direct graph

* add load direct graph

* batch random_sample

* batch_sample_k

* fix num_nodes size

* batch brpc

* batch brpc

* add test

* add test

* add load_nodes; change add_node function

* change sample return type to pair

* resolve conflict

* resolved conflict

* resolved conflict

* separate server and client

* merge pair type

* fix

* resolved conflict

* fixed segment fault; high-level VLOG for load edges and load nodes

* random_sample return 0

* rm useless loop

* test:load edge

* fix ret -1

* test: rm sample

* rm sample

* random_sample return future

* random_sample return int

* test fake node

* fixed here

* memory leak

* remove test code

* fix return problem

* add common_graph_table

* random sample node &test & change data-structure from linkedList to vector

* add common_graph_table

* sample with srand

* add node_types

* optimize nodes sample

* recover test

* random sample

* destruct weighted sampler

* GraphEdgeBlob

* WeightedGraphEdgeBlob to GraphEdgeBlob

* WeightedGraphEdgeBlob to GraphEdgeBlob

* pybind sample nodes api

* pull nodes with step

* fixed pull_graph_list bug; add test for pull_graph_list by step

* add graph table;name

* add graph table;name

* add pybind

* add pybind

* add FeatureNode

* add FeatureNode

* add FeatureNode Serialize

* add FeatureNode Serialize

* get_feat_node

* avoid local rpc

* fix get_node_feat

* fix get_node_feat

* remove log

* get_node_feat return  py:bytes

* merge develop with graph_engine

* fix threadpool.h head

* fix

* fix typo

* resolve conflict

* fix conflict

* recover lost content

* fix pybind of FeatureNode

* recover cmake

* recover tools

* resolve conflict

* resolve linking problem

* code style

* change test_server port

* fix code problems

* remove shard_num config

* remove redundent threads

* optimize start server

* remove logs

* fix code problems by reviewers' suggestions

* move graph files into a folder

* code style change

* remove graph operations from base table

* optimize get_feat function of graph engine

* fix long long count problem

* remove redandunt graph files

* remove unused shell

* recover dropout_op_pass.h

* fix potential stack overflow when request number is too large & node add & node clear & node remove

* when sample k is larger than neigbor num, return directly

* using random seed generator of paddle to speed up

* fix bug of random sample k

* fix code style

* fix code style

* add remove graph to fleet_py.cc

* fix blocking_queue problem

* fix style

* fix

* recover capacity check

* add remove graph node; add set_feature

* add remove graph node; add set_feature

* add remove graph node; add set_feature

* add remove graph node; add set_feature

* fix distributed op combining problems

* optimize

* remove logs

* fix MultiSlotDataGenerator error

* cache for graph engine

* fix type compare error

* more test&fix thread terminating problem

* remove header

* change time interval of shrink

* use cache when sample nodes

* remove unused function

* change unique_ptr to shared_ptr

* simplify cache template

* cache api on client

* fix
Co-authored-by: NHuang Zhengjie <270018958@qq.com>
Co-authored-by: NWeiyue Su <weiyue.su@gmail.com>
Co-authored-by: Nsuweiyue <suweiyue@baidu.com>
Co-authored-by: Nluobin06 <luobin06@baidu.com>
Co-authored-by: Nliweibin02 <liweibin02@baidu.com>
Co-authored-by: Ntangwei12 <tangwei12@baidu.com>
上级 9d2dd727
...@@ -302,7 +302,7 @@ std::future<int32_t> GraphBrpcClient::remove_graph_node( ...@@ -302,7 +302,7 @@ std::future<int32_t> GraphBrpcClient::remove_graph_node(
return fut; return fut;
} }
// char* &buffer,int &actual_size // char* &buffer,int &actual_size
std::future<int32_t> GraphBrpcClient::batch_sample_neighboors( std::future<int32_t> GraphBrpcClient::batch_sample_neighbors(
uint32_t table_id, std::vector<uint64_t> node_ids, int sample_size, uint32_t table_id, std::vector<uint64_t> node_ids, int sample_size,
std::vector<std::vector<std::pair<uint64_t, float>>> &res, std::vector<std::vector<std::pair<uint64_t, float>>> &res,
int server_index) { int server_index) {
...@@ -390,8 +390,8 @@ std::future<int32_t> GraphBrpcClient::batch_sample_neighboors( ...@@ -390,8 +390,8 @@ std::future<int32_t> GraphBrpcClient::batch_sample_neighboors(
size_t fail_num = 0; size_t fail_num = 0;
for (size_t request_idx = 0; request_idx < request_call_num; for (size_t request_idx = 0; request_idx < request_call_num;
++request_idx) { ++request_idx) {
if (closure->check_response(request_idx, if (closure->check_response(request_idx, PS_GRAPH_SAMPLE_NEIGHBORS) !=
PS_GRAPH_SAMPLE_NEIGHBOORS) != 0) { 0) {
++fail_num; ++fail_num;
} else { } else {
auto &res_io_buffer = auto &res_io_buffer =
...@@ -435,7 +435,7 @@ std::future<int32_t> GraphBrpcClient::batch_sample_neighboors( ...@@ -435,7 +435,7 @@ std::future<int32_t> GraphBrpcClient::batch_sample_neighboors(
for (int request_idx = 0; request_idx < request_call_num; ++request_idx) { for (int request_idx = 0; request_idx < request_call_num; ++request_idx) {
int server_index = request2server[request_idx]; int server_index = request2server[request_idx];
closure->request(request_idx)->set_cmd_id(PS_GRAPH_SAMPLE_NEIGHBOORS); closure->request(request_idx)->set_cmd_id(PS_GRAPH_SAMPLE_NEIGHBORS);
closure->request(request_idx)->set_table_id(table_id); closure->request(request_idx)->set_table_id(table_id);
closure->request(request_idx)->set_client_id(_client_id); closure->request(request_idx)->set_client_id(_client_id);
size_t node_num = node_id_buckets[request_idx].size(); size_t node_num = node_id_buckets[request_idx].size();
...@@ -494,6 +494,47 @@ std::future<int32_t> GraphBrpcClient::random_sample_nodes( ...@@ -494,6 +494,47 @@ std::future<int32_t> GraphBrpcClient::random_sample_nodes(
closure); closure);
return fut; return fut;
} }
std::future<int32_t> GraphBrpcClient::use_neighbors_sample_cache(
uint32_t table_id, size_t total_size_limit, size_t ttl) {
DownpourBrpcClosure *closure = new DownpourBrpcClosure(
server_size, [&, server_size = this->server_size ](void *done) {
int ret = 0;
auto *closure = (DownpourBrpcClosure *)done;
size_t fail_num = 0;
for (size_t request_idx = 0; request_idx < server_size; ++request_idx) {
if (closure->check_response(
request_idx, PS_GRAPH_USE_NEIGHBORS_SAMPLE_CACHE) != 0) {
++fail_num;
break;
}
}
ret = fail_num == 0 ? 0 : -1;
closure->set_promise_value(ret);
});
auto promise = std::make_shared<std::promise<int32_t>>();
closure->add_promise(promise);
size_t size_limit = total_size_limit / server_size +
(total_size_limit % server_size != 0 ? 1 : 0);
std::future<int> fut = promise->get_future();
for (size_t i = 0; i < server_size; i++) {
int server_index = i;
closure->request(server_index)
->set_cmd_id(PS_GRAPH_USE_NEIGHBORS_SAMPLE_CACHE);
closure->request(server_index)->set_table_id(table_id);
closure->request(server_index)->set_client_id(_client_id);
closure->request(server_index)
->add_params((char *)&size_limit, sizeof(size_t));
closure->request(server_index)->add_params((char *)&ttl, sizeof(size_t));
GraphPsService_Stub rpc_stub =
getServiceStub(get_cmd_channel(server_index));
closure->cntl(server_index)->set_log_id(butil::gettimeofday_ms());
rpc_stub.service(closure->cntl(server_index),
closure->request(server_index),
closure->response(server_index), closure);
}
return fut;
}
std::future<int32_t> GraphBrpcClient::pull_graph_list( std::future<int32_t> GraphBrpcClient::pull_graph_list(
uint32_t table_id, int server_index, int start, int size, int step, uint32_t table_id, int server_index, int start, int size, int step,
std::vector<FeatureNode> &res) { std::vector<FeatureNode> &res) {
...@@ -515,7 +556,7 @@ std::future<int32_t> GraphBrpcClient::pull_graph_list( ...@@ -515,7 +556,7 @@ std::future<int32_t> GraphBrpcClient::pull_graph_list(
index += node.get_size(false); index += node.get_size(false);
res.push_back(node); res.push_back(node);
} }
delete buffer; delete[] buffer;
} }
closure->set_promise_value(ret); closure->set_promise_value(ret);
}); });
......
...@@ -61,8 +61,8 @@ class GraphBrpcClient : public BrpcPsClient { ...@@ -61,8 +61,8 @@ class GraphBrpcClient : public BrpcPsClient {
public: public:
GraphBrpcClient() {} GraphBrpcClient() {}
virtual ~GraphBrpcClient() {} virtual ~GraphBrpcClient() {}
// given a batch of nodes, sample graph_neighboors for each of them // given a batch of nodes, sample graph_neighbors for each of them
virtual std::future<int32_t> batch_sample_neighboors( virtual std::future<int32_t> batch_sample_neighbors(
uint32_t table_id, std::vector<uint64_t> node_ids, int sample_size, uint32_t table_id, std::vector<uint64_t> node_ids, int sample_size,
std::vector<std::vector<std::pair<uint64_t, float>>>& res, std::vector<std::vector<std::pair<uint64_t, float>>>& res,
int server_index = -1); int server_index = -1);
...@@ -89,6 +89,9 @@ class GraphBrpcClient : public BrpcPsClient { ...@@ -89,6 +89,9 @@ class GraphBrpcClient : public BrpcPsClient {
virtual std::future<int32_t> add_graph_node( virtual std::future<int32_t> add_graph_node(
uint32_t table_id, std::vector<uint64_t>& node_id_list, uint32_t table_id, std::vector<uint64_t>& node_id_list,
std::vector<bool>& is_weighted_list); std::vector<bool>& is_weighted_list);
virtual std::future<int32_t> use_neighbors_sample_cache(uint32_t table_id,
size_t size_limit,
size_t ttl);
virtual std::future<int32_t> remove_graph_node( virtual std::future<int32_t> remove_graph_node(
uint32_t table_id, std::vector<uint64_t>& node_id_list); uint32_t table_id, std::vector<uint64_t>& node_id_list);
virtual int32_t initialize(); virtual int32_t initialize();
......
...@@ -187,8 +187,8 @@ int32_t GraphBrpcService::initialize() { ...@@ -187,8 +187,8 @@ int32_t GraphBrpcService::initialize() {
_service_handler_map[PS_STOP_PROFILER] = &GraphBrpcService::stop_profiler; _service_handler_map[PS_STOP_PROFILER] = &GraphBrpcService::stop_profiler;
_service_handler_map[PS_PULL_GRAPH_LIST] = &GraphBrpcService::pull_graph_list; _service_handler_map[PS_PULL_GRAPH_LIST] = &GraphBrpcService::pull_graph_list;
_service_handler_map[PS_GRAPH_SAMPLE_NEIGHBOORS] = _service_handler_map[PS_GRAPH_SAMPLE_NEIGHBORS] =
&GraphBrpcService::graph_random_sample_neighboors; &GraphBrpcService::graph_random_sample_neighbors;
_service_handler_map[PS_GRAPH_SAMPLE_NODES] = _service_handler_map[PS_GRAPH_SAMPLE_NODES] =
&GraphBrpcService::graph_random_sample_nodes; &GraphBrpcService::graph_random_sample_nodes;
_service_handler_map[PS_GRAPH_GET_NODE_FEAT] = _service_handler_map[PS_GRAPH_GET_NODE_FEAT] =
...@@ -201,8 +201,9 @@ int32_t GraphBrpcService::initialize() { ...@@ -201,8 +201,9 @@ int32_t GraphBrpcService::initialize() {
_service_handler_map[PS_GRAPH_SET_NODE_FEAT] = _service_handler_map[PS_GRAPH_SET_NODE_FEAT] =
&GraphBrpcService::graph_set_node_feat; &GraphBrpcService::graph_set_node_feat;
_service_handler_map[PS_GRAPH_SAMPLE_NODES_FROM_ONE_SERVER] = _service_handler_map[PS_GRAPH_SAMPLE_NODES_FROM_ONE_SERVER] =
&GraphBrpcService::sample_neighboors_across_multi_servers; &GraphBrpcService::sample_neighbors_across_multi_servers;
_service_handler_map[PS_GRAPH_USE_NEIGHBORS_SAMPLE_CACHE] =
&GraphBrpcService::use_neighbors_sample_cache;
// shard初始化,server启动后才可从env获取到server_list的shard信息 // shard初始化,server启动后才可从env获取到server_list的shard信息
initialize_shard_info(); initialize_shard_info();
...@@ -373,7 +374,7 @@ int32_t GraphBrpcService::pull_graph_list(Table *table, ...@@ -373,7 +374,7 @@ int32_t GraphBrpcService::pull_graph_list(Table *table,
cntl->response_attachment().append(buffer.get(), actual_size); cntl->response_attachment().append(buffer.get(), actual_size);
return 0; return 0;
} }
int32_t GraphBrpcService::graph_random_sample_neighboors( int32_t GraphBrpcService::graph_random_sample_neighbors(
Table *table, const PsRequestMessage &request, PsResponseMessage &response, Table *table, const PsRequestMessage &request, PsResponseMessage &response,
brpc::Controller *cntl) { brpc::Controller *cntl) {
CHECK_TABLE_EXIST(table, request, response) CHECK_TABLE_EXIST(table, request, response)
...@@ -389,7 +390,7 @@ int32_t GraphBrpcService::graph_random_sample_neighboors( ...@@ -389,7 +390,7 @@ int32_t GraphBrpcService::graph_random_sample_neighboors(
std::vector<std::shared_ptr<char>> buffers(node_num); std::vector<std::shared_ptr<char>> buffers(node_num);
std::vector<int> actual_sizes(node_num, 0); std::vector<int> actual_sizes(node_num, 0);
((GraphTable *)table) ((GraphTable *)table)
->random_sample_neighboors(node_data, sample_size, buffers, actual_sizes); ->random_sample_neighbors(node_data, sample_size, buffers, actual_sizes);
cntl->response_attachment().append(&node_num, sizeof(size_t)); cntl->response_attachment().append(&node_num, sizeof(size_t));
cntl->response_attachment().append(actual_sizes.data(), cntl->response_attachment().append(actual_sizes.data(),
...@@ -448,7 +449,7 @@ int32_t GraphBrpcService::graph_get_node_feat(Table *table, ...@@ -448,7 +449,7 @@ int32_t GraphBrpcService::graph_get_node_feat(Table *table,
return 0; return 0;
} }
int32_t GraphBrpcService::sample_neighboors_across_multi_servers( int32_t GraphBrpcService::sample_neighbors_across_multi_servers(
Table *table, const PsRequestMessage &request, PsResponseMessage &response, Table *table, const PsRequestMessage &request, PsResponseMessage &response,
brpc::Controller *cntl) { brpc::Controller *cntl) {
// sleep(5); // sleep(5);
...@@ -456,7 +457,7 @@ int32_t GraphBrpcService::sample_neighboors_across_multi_servers( ...@@ -456,7 +457,7 @@ int32_t GraphBrpcService::sample_neighboors_across_multi_servers(
if (request.params_size() < 2) { if (request.params_size() < 2) {
set_response_code( set_response_code(
response, -1, response, -1,
"graph_random_sample request requires at least 2 arguments"); "graph_random_neighbors_sample request requires at least 2 arguments");
return 0; return 0;
} }
size_t node_num = request.params(0).size() / sizeof(uint64_t), size_t node_num = request.params(0).size() / sizeof(uint64_t),
...@@ -519,7 +520,7 @@ int32_t GraphBrpcService::sample_neighboors_across_multi_servers( ...@@ -519,7 +520,7 @@ int32_t GraphBrpcService::sample_neighboors_across_multi_servers(
remote_call_num); remote_call_num);
size_t fail_num = 0; size_t fail_num = 0;
for (size_t request_idx = 0; request_idx < remote_call_num; ++request_idx) { for (size_t request_idx = 0; request_idx < remote_call_num; ++request_idx) {
if (closure->check_response(request_idx, PS_GRAPH_SAMPLE_NEIGHBOORS) != if (closure->check_response(request_idx, PS_GRAPH_SAMPLE_NEIGHBORS) !=
0) { 0) {
++fail_num; ++fail_num;
failed[request2server[request_idx]] = true; failed[request2server[request_idx]] = true;
...@@ -570,7 +571,7 @@ int32_t GraphBrpcService::sample_neighboors_across_multi_servers( ...@@ -570,7 +571,7 @@ int32_t GraphBrpcService::sample_neighboors_across_multi_servers(
for (int request_idx = 0; request_idx < remote_call_num; ++request_idx) { for (int request_idx = 0; request_idx < remote_call_num; ++request_idx) {
int server_index = request2server[request_idx]; int server_index = request2server[request_idx];
closure->request(request_idx)->set_cmd_id(PS_GRAPH_SAMPLE_NEIGHBOORS); closure->request(request_idx)->set_cmd_id(PS_GRAPH_SAMPLE_NEIGHBORS);
closure->request(request_idx)->set_table_id(request.table_id()); closure->request(request_idx)->set_table_id(request.table_id());
closure->request(request_idx)->set_client_id(rank); closure->request(request_idx)->set_client_id(rank);
size_t node_num = node_id_buckets[request_idx].size(); size_t node_num = node_id_buckets[request_idx].size();
...@@ -590,8 +591,8 @@ int32_t GraphBrpcService::sample_neighboors_across_multi_servers( ...@@ -590,8 +591,8 @@ int32_t GraphBrpcService::sample_neighboors_across_multi_servers(
} }
if (server2request[rank] != -1) { if (server2request[rank] != -1) {
((GraphTable *)table) ((GraphTable *)table)
->random_sample_neighboors(node_id_buckets.back().data(), sample_size, ->random_sample_neighbors(node_id_buckets.back().data(), sample_size,
local_buffers, local_actual_sizes); local_buffers, local_actual_sizes);
} }
local_promise.get()->set_value(0); local_promise.get()->set_value(0);
if (remote_call_num == 0) func(closure); if (remote_call_num == 0) func(closure);
...@@ -636,5 +637,20 @@ int32_t GraphBrpcService::graph_set_node_feat(Table *table, ...@@ -636,5 +637,20 @@ int32_t GraphBrpcService::graph_set_node_feat(Table *table,
return 0; return 0;
} }
int32_t GraphBrpcService::use_neighbors_sample_cache(
Table *table, const PsRequestMessage &request, PsResponseMessage &response,
brpc::Controller *cntl) {
CHECK_TABLE_EXIST(table, request, response)
if (request.params_size() < 2) {
set_response_code(response, -1,
"use_neighbors_sample_cache request requires at least 2 "
"arguments[cache_size, ttl]");
return 0;
}
size_t size_limit = *(size_t *)(request.params(0).c_str());
size_t ttl = *(size_t *)(request.params(1).c_str());
((GraphTable *)table)->make_neighbor_sample_cache(size_limit, ttl);
return 0;
}
} // namespace distributed } // namespace distributed
} // namespace paddle } // namespace paddle
...@@ -78,10 +78,10 @@ class GraphBrpcService : public PsBaseService { ...@@ -78,10 +78,10 @@ class GraphBrpcService : public PsBaseService {
int32_t initialize_shard_info(); int32_t initialize_shard_info();
int32_t pull_graph_list(Table *table, const PsRequestMessage &request, int32_t pull_graph_list(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl); PsResponseMessage &response, brpc::Controller *cntl);
int32_t graph_random_sample_neighboors(Table *table, int32_t graph_random_sample_neighbors(Table *table,
const PsRequestMessage &request, const PsRequestMessage &request,
PsResponseMessage &response, PsResponseMessage &response,
brpc::Controller *cntl); brpc::Controller *cntl);
int32_t graph_random_sample_nodes(Table *table, int32_t graph_random_sample_nodes(Table *table,
const PsRequestMessage &request, const PsRequestMessage &request,
PsResponseMessage &response, PsResponseMessage &response,
...@@ -116,9 +116,15 @@ class GraphBrpcService : public PsBaseService { ...@@ -116,9 +116,15 @@ class GraphBrpcService : public PsBaseService {
int32_t print_table_stat(Table *table, const PsRequestMessage &request, int32_t print_table_stat(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl); PsResponseMessage &response, brpc::Controller *cntl);
int32_t sample_neighboors_across_multi_servers( int32_t sample_neighbors_across_multi_servers(Table *table,
Table *table, const PsRequestMessage &request, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl); PsResponseMessage &response,
brpc::Controller *cntl);
int32_t use_neighbors_sample_cache(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl);
private: private:
bool _is_initialize_shard_info; bool _is_initialize_shard_info;
......
...@@ -290,19 +290,29 @@ void GraphPyClient::load_node_file(std::string name, std::string filepath) { ...@@ -290,19 +290,29 @@ void GraphPyClient::load_node_file(std::string name, std::string filepath) {
} }
} }
std::vector<std::vector<std::pair<uint64_t, float>>> std::vector<std::vector<std::pair<uint64_t, float>>>
GraphPyClient::batch_sample_neighboors(std::string name, GraphPyClient::batch_sample_neighbors(std::string name,
std::vector<uint64_t> node_ids, std::vector<uint64_t> node_ids,
int sample_size) { int sample_size) {
std::vector<std::vector<std::pair<uint64_t, float>>> v; std::vector<std::vector<std::pair<uint64_t, float>>> v;
if (this->table_id_map.count(name)) { if (this->table_id_map.count(name)) {
uint32_t table_id = this->table_id_map[name]; uint32_t table_id = this->table_id_map[name];
auto status = auto status =
worker_ptr->batch_sample_neighboors(table_id, node_ids, sample_size, v); worker_ptr->batch_sample_neighbors(table_id, node_ids, sample_size, v);
status.wait(); status.wait();
} }
return v; return v;
} }
void GraphPyClient::use_neighbors_sample_cache(std::string name,
size_t total_size_limit,
size_t ttl) {
if (this->table_id_map.count(name)) {
uint32_t table_id = this->table_id_map[name];
auto status =
worker_ptr->use_neighbors_sample_cache(table_id, total_size_limit, ttl);
status.wait();
}
}
std::vector<uint64_t> GraphPyClient::random_sample_nodes(std::string name, std::vector<uint64_t> GraphPyClient::random_sample_nodes(std::string name,
int server_index, int server_index,
int sample_size) { int sample_size) {
......
...@@ -148,13 +148,15 @@ class GraphPyClient : public GraphPyService { ...@@ -148,13 +148,15 @@ class GraphPyClient : public GraphPyService {
int get_client_id() { return client_id; } int get_client_id() { return client_id; }
void set_client_id(int client_id) { this->client_id = client_id; } void set_client_id(int client_id) { this->client_id = client_id; }
void start_client(); void start_client();
std::vector<std::vector<std::pair<uint64_t, float>>> batch_sample_neighboors( std::vector<std::vector<std::pair<uint64_t, float>>> batch_sample_neighbors(
std::string name, std::vector<uint64_t> node_ids, int sample_size); std::string name, std::vector<uint64_t> node_ids, int sample_size);
std::vector<uint64_t> random_sample_nodes(std::string name, int server_index, std::vector<uint64_t> random_sample_nodes(std::string name, int server_index,
int sample_size); int sample_size);
std::vector<std::vector<std::string>> get_node_feat( std::vector<std::vector<std::string>> get_node_feat(
std::string node_type, std::vector<uint64_t> node_ids, std::string node_type, std::vector<uint64_t> node_ids,
std::vector<std::string> feature_names); std::vector<std::string> feature_names);
void use_neighbors_sample_cache(std::string name, size_t total_size_limit,
size_t ttl);
void set_node_feat(std::string node_type, std::vector<uint64_t> node_ids, void set_node_feat(std::string node_type, std::vector<uint64_t> node_ids,
std::vector<std::string> feature_names, std::vector<std::string> feature_names,
const std::vector<std::vector<std::string>> features); const std::vector<std::vector<std::string>> features);
......
...@@ -49,7 +49,7 @@ enum PsCmdID { ...@@ -49,7 +49,7 @@ enum PsCmdID {
PS_STOP_PROFILER = 28; PS_STOP_PROFILER = 28;
PS_PUSH_GLOBAL_STEP = 29; PS_PUSH_GLOBAL_STEP = 29;
PS_PULL_GRAPH_LIST = 30; PS_PULL_GRAPH_LIST = 30;
PS_GRAPH_SAMPLE_NEIGHBOORS = 31; PS_GRAPH_SAMPLE_NEIGHBORS = 31;
PS_GRAPH_SAMPLE_NODES = 32; PS_GRAPH_SAMPLE_NODES = 32;
PS_GRAPH_GET_NODE_FEAT = 33; PS_GRAPH_GET_NODE_FEAT = 33;
PS_GRAPH_CLEAR = 34; PS_GRAPH_CLEAR = 34;
...@@ -57,6 +57,7 @@ enum PsCmdID { ...@@ -57,6 +57,7 @@ enum PsCmdID {
PS_GRAPH_REMOVE_GRAPH_NODE = 36; PS_GRAPH_REMOVE_GRAPH_NODE = 36;
PS_GRAPH_SET_NODE_FEAT = 37; PS_GRAPH_SET_NODE_FEAT = 37;
PS_GRAPH_SAMPLE_NODES_FROM_ONE_SERVER = 38; PS_GRAPH_SAMPLE_NODES_FROM_ONE_SERVER = 38;
PS_GRAPH_USE_NEIGHBORS_SAMPLE_CACHE = 39;
} }
message PsRequestMessage { message PsRequestMessage {
......
...@@ -392,7 +392,7 @@ int32_t GraphTable::random_sample_nodes(int sample_size, ...@@ -392,7 +392,7 @@ int32_t GraphTable::random_sample_nodes(int sample_size,
memcpy(pointer, res.data(), actual_size); memcpy(pointer, res.data(), actual_size);
return 0; return 0;
} }
int32_t GraphTable::random_sample_neighboors( int32_t GraphTable::random_sample_neighbors(
uint64_t *node_ids, int sample_size, uint64_t *node_ids, int sample_size,
std::vector<std::shared_ptr<char>> &buffers, std::vector<std::shared_ptr<char>> &buffers,
std::vector<int> &actual_sizes) { std::vector<int> &actual_sizes) {
......
...@@ -89,12 +89,6 @@ struct SampleKey { ...@@ -89,12 +89,6 @@ struct SampleKey {
} }
}; };
struct SampleKeyHash {
size_t operator()(const SampleKey &s) const {
return s.node_key ^ s.sample_size;
}
};
class SampleResult { class SampleResult {
public: public:
size_t actual_size; size_t actual_size;
...@@ -121,13 +115,13 @@ class LRUNode { ...@@ -121,13 +115,13 @@ class LRUNode {
// time to live // time to live
LRUNode<K, V> *pre, *next; LRUNode<K, V> *pre, *next;
}; };
template <typename K, typename V, typename Hash = std::hash<K>> template <typename K, typename V>
class ScaledLRU; class ScaledLRU;
template <typename K, typename V, typename Hash = std::hash<K>> template <typename K, typename V>
class RandomSampleLRU { class RandomSampleLRU {
public: public:
RandomSampleLRU(ScaledLRU<K, V, Hash> *_father) : father(_father) { RandomSampleLRU(ScaledLRU<K, V> *_father) : father(_father) {
node_size = 0; node_size = 0;
node_head = node_end = NULL; node_head = node_end = NULL;
global_ttl = father->ttl; global_ttl = father->ttl;
...@@ -229,15 +223,15 @@ class RandomSampleLRU { ...@@ -229,15 +223,15 @@ class RandomSampleLRU {
} }
private: private:
std::unordered_map<K, LRUNode<K, V> *, Hash> key_map; std::unordered_map<K, LRUNode<K, V> *> key_map;
ScaledLRU<K, V, Hash> *father; ScaledLRU<K, V> *father;
size_t global_ttl; size_t global_ttl;
int node_size; int node_size;
LRUNode<K, V> *node_head, *node_end; LRUNode<K, V> *node_head, *node_end;
friend class ScaledLRU<K, V, Hash>; friend class ScaledLRU<K, V>;
}; };
template <typename K, typename V, typename Hash> template <typename K, typename V>
class ScaledLRU { class ScaledLRU {
public: public:
ScaledLRU(size_t shard_num, size_t size_limit, size_t _ttl) ScaledLRU(size_t shard_num, size_t size_limit, size_t _ttl)
...@@ -246,8 +240,8 @@ class ScaledLRU { ...@@ -246,8 +240,8 @@ class ScaledLRU {
stop = false; stop = false;
thread_pool.reset(new ::ThreadPool(1)); thread_pool.reset(new ::ThreadPool(1));
global_count = 0; global_count = 0;
lru_pool = std::vector<RandomSampleLRU<K, V, Hash>>( lru_pool = std::vector<RandomSampleLRU<K, V>>(shard_num,
shard_num, RandomSampleLRU<K, V, Hash>(this)); RandomSampleLRU<K, V>(this));
shrink_job = std::thread([this]() -> void { shrink_job = std::thread([this]() -> void {
while (true) { while (true) {
{ {
...@@ -352,16 +346,16 @@ class ScaledLRU { ...@@ -352,16 +346,16 @@ class ScaledLRU {
size_t ttl; size_t ttl;
bool stop; bool stop;
std::thread shrink_job; std::thread shrink_job;
std::vector<RandomSampleLRU<K, V, Hash>> lru_pool; std::vector<RandomSampleLRU<K, V>> lru_pool;
mutable std::mutex mutex_; mutable std::mutex mutex_;
std::condition_variable cv_; std::condition_variable cv_;
struct RemovedNode { struct RemovedNode {
LRUNode<K, V> *node; LRUNode<K, V> *node;
RandomSampleLRU<K, V, Hash> *lru_pointer; RandomSampleLRU<K, V> *lru_pointer;
bool operator>(const RemovedNode &a) const { return node->ms > a.node->ms; } bool operator>(const RemovedNode &a) const { return node->ms > a.node->ms; }
}; };
std::shared_ptr<::ThreadPool> thread_pool; std::shared_ptr<::ThreadPool> thread_pool;
friend class RandomSampleLRU<K, V, Hash>; friend class RandomSampleLRU<K, V>;
}; };
class GraphTable : public SparseTable { class GraphTable : public SparseTable {
...@@ -373,7 +367,7 @@ class GraphTable : public SparseTable { ...@@ -373,7 +367,7 @@ class GraphTable : public SparseTable {
int &actual_size, bool need_feature, int &actual_size, bool need_feature,
int step); int step);
virtual int32_t random_sample_neighboors( virtual int32_t random_sample_neighbors(
uint64_t *node_ids, int sample_size, uint64_t *node_ids, int sample_size,
std::vector<std::shared_ptr<char>> &buffers, std::vector<std::shared_ptr<char>> &buffers,
std::vector<int> &actual_sizes); std::vector<int> &actual_sizes);
...@@ -433,11 +427,11 @@ class GraphTable : public SparseTable { ...@@ -433,11 +427,11 @@ class GraphTable : public SparseTable {
size_t get_server_num() { return server_num; } size_t get_server_num() { return server_num; }
virtual int32_t make_neigh_sample_cache(size_t size_limit, size_t ttl) { virtual int32_t make_neighbor_sample_cache(size_t size_limit, size_t ttl) {
{ {
std::unique_lock<std::mutex> lock(mutex_); std::unique_lock<std::mutex> lock(mutex_);
if (use_cache == false) { if (use_cache == false) {
scaled_lru.reset(new ScaledLRU<SampleKey, SampleResult, SampleKeyHash>( scaled_lru.reset(new ScaledLRU<SampleKey, SampleResult>(
shard_end - shard_start, size_limit, ttl)); shard_end - shard_start, size_limit, ttl));
use_cache = true; use_cache = true;
} }
...@@ -460,10 +454,20 @@ class GraphTable : public SparseTable { ...@@ -460,10 +454,20 @@ class GraphTable : public SparseTable {
std::vector<std::shared_ptr<::ThreadPool>> _shards_task_pool; std::vector<std::shared_ptr<::ThreadPool>> _shards_task_pool;
std::vector<std::shared_ptr<std::mt19937_64>> _shards_task_rng_pool; std::vector<std::shared_ptr<std::mt19937_64>> _shards_task_rng_pool;
std::shared_ptr<ScaledLRU<SampleKey, SampleResult, SampleKeyHash>> scaled_lru; std::shared_ptr<ScaledLRU<SampleKey, SampleResult>> scaled_lru;
bool use_cache; bool use_cache;
mutable std::mutex mutex_; mutable std::mutex mutex_;
}; };
} // namespace distributed } // namespace distributed
}; // namespace paddle }; // namespace paddle
namespace std {
template <>
struct hash<paddle::distributed::SampleKey> {
size_t operator()(const paddle::distributed::SampleKey &s) const {
return s.node_key ^ s.sample_size;
}
};
}
...@@ -111,7 +111,7 @@ void testFeatureNodeSerializeFloat64() { ...@@ -111,7 +111,7 @@ void testFeatureNodeSerializeFloat64() {
void testSingleSampleNeighboor( void testSingleSampleNeighboor(
std::shared_ptr<paddle::distributed::GraphBrpcClient>& worker_ptr_) { std::shared_ptr<paddle::distributed::GraphBrpcClient>& worker_ptr_) {
std::vector<std::vector<std::pair<uint64_t, float>>> vs; std::vector<std::vector<std::pair<uint64_t, float>>> vs;
auto pull_status = worker_ptr_->batch_sample_neighboors( auto pull_status = worker_ptr_->batch_sample_neighbors(
0, std::vector<uint64_t>(1, 37), 4, vs); 0, std::vector<uint64_t>(1, 37), 4, vs);
pull_status.wait(); pull_status.wait();
...@@ -127,7 +127,7 @@ void testSingleSampleNeighboor( ...@@ -127,7 +127,7 @@ void testSingleSampleNeighboor(
s.clear(); s.clear();
s1.clear(); s1.clear();
vs.clear(); vs.clear();
pull_status = worker_ptr_->batch_sample_neighboors( pull_status = worker_ptr_->batch_sample_neighbors(
0, std::vector<uint64_t>(1, 96), 4, vs); 0, std::vector<uint64_t>(1, 96), 4, vs);
pull_status.wait(); pull_status.wait();
s1 = {111, 48, 247}; s1 = {111, 48, 247};
...@@ -139,7 +139,7 @@ void testSingleSampleNeighboor( ...@@ -139,7 +139,7 @@ void testSingleSampleNeighboor(
ASSERT_EQ(true, s1.find(g) != s1.end()); ASSERT_EQ(true, s1.find(g) != s1.end());
} }
vs.clear(); vs.clear();
pull_status = worker_ptr_->batch_sample_neighboors(0, {96, 37}, 4, vs, 0); pull_status = worker_ptr_->batch_sample_neighbors(0, {96, 37}, 4, vs, 0);
pull_status.wait(); pull_status.wait();
ASSERT_EQ(vs.size(), 2); ASSERT_EQ(vs.size(), 2);
} }
...@@ -199,7 +199,7 @@ void testBatchSampleNeighboor( ...@@ -199,7 +199,7 @@ void testBatchSampleNeighboor(
std::shared_ptr<paddle::distributed::GraphBrpcClient>& worker_ptr_) { std::shared_ptr<paddle::distributed::GraphBrpcClient>& worker_ptr_) {
std::vector<std::vector<std::pair<uint64_t, float>>> vs; std::vector<std::vector<std::pair<uint64_t, float>>> vs;
std::vector<std::uint64_t> v = {37, 96}; std::vector<std::uint64_t> v = {37, 96};
auto pull_status = worker_ptr_->batch_sample_neighboors(0, v, 4, vs); auto pull_status = worker_ptr_->batch_sample_neighbors(0, v, 4, vs);
pull_status.wait(); pull_status.wait();
std::unordered_set<uint64_t> s; std::unordered_set<uint64_t> s;
std::unordered_set<uint64_t> s1 = {112, 45, 145}; std::unordered_set<uint64_t> s1 = {112, 45, 145};
...@@ -401,7 +401,6 @@ void RunClient( ...@@ -401,7 +401,6 @@ void RunClient(
} }
void RunBrpcPushSparse() { void RunBrpcPushSparse() {
std::cout << "in test cache";
testCache(); testCache();
setenv("http_proxy", "", 1); setenv("http_proxy", "", 1);
setenv("https_proxy", "", 1); setenv("https_proxy", "", 1);
...@@ -436,24 +435,24 @@ void RunBrpcPushSparse() { ...@@ -436,24 +435,24 @@ void RunBrpcPushSparse() {
sleep(5); sleep(5);
testSingleSampleNeighboor(worker_ptr_); testSingleSampleNeighboor(worker_ptr_);
testBatchSampleNeighboor(worker_ptr_); testBatchSampleNeighboor(worker_ptr_);
pull_status = worker_ptr_->batch_sample_neighboors( pull_status = worker_ptr_->batch_sample_neighbors(
0, std::vector<uint64_t>(1, 10240001024), 4, vs); 0, std::vector<uint64_t>(1, 10240001024), 4, vs);
pull_status.wait(); pull_status.wait();
ASSERT_EQ(0, vs[0].size()); ASSERT_EQ(0, vs[0].size());
paddle::distributed::GraphTable* g = paddle::distributed::GraphTable* g =
(paddle::distributed::GraphTable*)pserver_ptr_->table(0); (paddle::distributed::GraphTable*)pserver_ptr_->table(0);
size_t ttl = 6; size_t ttl = 6;
g->make_neigh_sample_cache(4, ttl); g->make_neighbor_sample_cache(4, ttl);
int round = 5; int round = 5;
while (round--) { while (round--) {
vs.clear(); vs.clear();
pull_status = worker_ptr_->batch_sample_neighboors( pull_status = worker_ptr_->batch_sample_neighbors(
0, std::vector<uint64_t>(1, 37), 1, vs); 0, std::vector<uint64_t>(1, 37), 1, vs);
pull_status.wait(); pull_status.wait();
for (int i = 0; i < ttl; i++) { for (int i = 0; i < ttl; i++) {
std::vector<std::vector<std::pair<uint64_t, float>>> vs1; std::vector<std::vector<std::pair<uint64_t, float>>> vs1;
pull_status = worker_ptr_->batch_sample_neighboors( pull_status = worker_ptr_->batch_sample_neighbors(
0, std::vector<uint64_t>(1, 37), 1, vs1); 0, std::vector<uint64_t>(1, 37), 1, vs1);
pull_status.wait(); pull_status.wait();
ASSERT_EQ(vs[0].size(), vs1[0].size()); ASSERT_EQ(vs[0].size(), vs1[0].size());
...@@ -560,13 +559,13 @@ void RunBrpcPushSparse() { ...@@ -560,13 +559,13 @@ void RunBrpcPushSparse() {
ASSERT_EQ(count_item_nodes.size(), 12); ASSERT_EQ(count_item_nodes.size(), 12);
} }
vs = client1.batch_sample_neighboors(std::string("user2item"), vs = client1.batch_sample_neighbors(std::string("user2item"),
std::vector<uint64_t>(1, 96), 4); std::vector<uint64_t>(1, 96), 4);
ASSERT_EQ(vs[0].size(), 3); ASSERT_EQ(vs[0].size(), 3);
std::vector<uint64_t> node_ids; std::vector<uint64_t> node_ids;
node_ids.push_back(96); node_ids.push_back(96);
node_ids.push_back(37); node_ids.push_back(37);
vs = client1.batch_sample_neighboors(std::string("user2item"), node_ids, 4); vs = client1.batch_sample_neighbors(std::string("user2item"), node_ids, 4);
ASSERT_EQ(vs.size(), 2); ASSERT_EQ(vs.size(), 2);
std::vector<uint64_t> nodes_ids = client2.random_sample_nodes("user", 0, 6); std::vector<uint64_t> nodes_ids = client2.random_sample_nodes("user", 0, 6);
...@@ -635,8 +634,7 @@ void RunBrpcPushSparse() { ...@@ -635,8 +634,7 @@ void RunBrpcPushSparse() {
void testCache() { void testCache() {
::paddle::distributed::ScaledLRU<::paddle::distributed::SampleKey, ::paddle::distributed::ScaledLRU<::paddle::distributed::SampleKey,
::paddle::distributed::SampleResult, ::paddle::distributed::SampleResult>
::paddle::distributed::SampleKeyHash>
st(1, 2, 4); st(1, 2, 4);
char* str = new char[7]; char* str = new char[7];
strcpy(str, "54321"); strcpy(str, "54321");
......
...@@ -205,7 +205,8 @@ void BindGraphPyClient(py::module* m) { ...@@ -205,7 +205,8 @@ void BindGraphPyClient(py::module* m) {
.def("add_table_feat_conf", &GraphPyClient::add_table_feat_conf) .def("add_table_feat_conf", &GraphPyClient::add_table_feat_conf)
.def("pull_graph_list", &GraphPyClient::pull_graph_list) .def("pull_graph_list", &GraphPyClient::pull_graph_list)
.def("start_client", &GraphPyClient::start_client) .def("start_client", &GraphPyClient::start_client)
.def("batch_sample_neighboors", &GraphPyClient::batch_sample_neighboors) .def("batch_sample_neighboors", &GraphPyClient::batch_sample_neighbors)
.def("batch_sample_neighbors", &GraphPyClient::batch_sample_neighbors)
.def("remove_graph_node", &GraphPyClient::remove_graph_node) .def("remove_graph_node", &GraphPyClient::remove_graph_node)
.def("random_sample_nodes", &GraphPyClient::random_sample_nodes) .def("random_sample_nodes", &GraphPyClient::random_sample_nodes)
.def("stop_server", &GraphPyClient::stop_server) .def("stop_server", &GraphPyClient::stop_server)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册