未验证 提交 31776199 编写于 作者: S seemingwang 提交者: GitHub

merge cpu and gpu graph engines (#40597)

* extract sub-graph

* graph-engine merging

* fix

* fix

* fix heter-ps config
上级 313bff6b
......@@ -115,6 +115,7 @@ message TableParameter {
optional CommonAccessorParameter common = 6;
optional TableType type = 7;
optional bool compress_in_save = 8 [ default = false ];
optional GraphParameter graph_parameter = 9;
}
message TableAccessorParameter {
......@@ -211,3 +212,25 @@ message SparseAdamSGDParameter { // SparseAdamSGDRule
optional double ada_epsilon = 5 [ default = 1e-08 ];
repeated float weight_bounds = 6;
}
message GraphParameter {
optional int32 task_pool_size = 1 [ default = 24 ];
optional bool gpups_mode = 2 [ default = false ];
optional string gpups_graph_sample_class = 3
[ default = "CompleteGraphSampler" ];
optional string gpups_graph_sample_args = 4 [ default = "" ];
optional bool use_cache = 5 [ default = true ];
optional float cache_ratio = 6 [ default = 0.3 ];
optional int32 cache_ttl = 7 [ default = 5 ];
optional GraphFeature graph_feature = 8;
optional string table_name = 9 [ default = "" ];
optional string table_type = 10 [ default = "" ];
optional int32 gpups_mode_shard_num = 11 [ default = 127 ];
optional int32 gpu_num = 12 [ default = 1 ];
}
message GraphFeature {
repeated string name = 1;
repeated string dtype = 2;
repeated int32 shape = 3;
}
\ No newline at end of file
......@@ -44,7 +44,7 @@ void GraphPsService_Stub::service(
}
}
int GraphBrpcClient::get_server_index_by_id(uint64_t id) {
int GraphBrpcClient::get_server_index_by_id(int64_t id) {
int shard_num = get_shard_num();
int shard_per_server = shard_num % server_size == 0
? shard_num / server_size
......@@ -53,7 +53,7 @@ int GraphBrpcClient::get_server_index_by_id(uint64_t id) {
}
std::future<int32_t> GraphBrpcClient::get_node_feat(
const uint32_t &table_id, const std::vector<uint64_t> &node_ids,
const uint32_t &table_id, const std::vector<int64_t> &node_ids,
const std::vector<std::string> &feature_names,
std::vector<std::vector<std::string>> &res) {
std::vector<int> request2server;
......@@ -66,7 +66,7 @@ std::future<int32_t> GraphBrpcClient::get_node_feat(
}
}
size_t request_call_num = request2server.size();
std::vector<std::vector<uint64_t>> node_id_buckets(request_call_num);
std::vector<std::vector<int64_t>> node_id_buckets(request_call_num);
std::vector<std::vector<int>> query_idx_buckets(request_call_num);
for (int query_idx = 0; query_idx < node_ids.size(); ++query_idx) {
int server_index = get_server_index_by_id(node_ids[query_idx]);
......@@ -129,7 +129,7 @@ std::future<int32_t> GraphBrpcClient::get_node_feat(
closure->request(request_idx)
->add_params((char *)node_id_buckets[request_idx].data(),
sizeof(uint64_t) * node_num);
sizeof(int64_t) * node_num);
std::string joint_feature_name =
paddle::string::join_strings(feature_names, '\t');
closure->request(request_idx)
......@@ -179,9 +179,9 @@ std::future<int32_t> GraphBrpcClient::clear_nodes(uint32_t table_id) {
return fut;
}
std::future<int32_t> GraphBrpcClient::add_graph_node(
uint32_t table_id, std::vector<uint64_t> &node_id_list,
uint32_t table_id, std::vector<int64_t> &node_id_list,
std::vector<bool> &is_weighted_list) {
std::vector<std::vector<uint64_t>> request_bucket;
std::vector<std::vector<int64_t>> request_bucket;
std::vector<std::vector<bool>> is_weighted_bucket;
bool add_weight = is_weighted_list.size() > 0;
std::vector<int> server_index_arr;
......@@ -191,7 +191,7 @@ std::future<int32_t> GraphBrpcClient::add_graph_node(
if (index_mapping[server_index] == -1) {
index_mapping[server_index] = request_bucket.size();
server_index_arr.push_back(server_index);
request_bucket.push_back(std::vector<uint64_t>());
request_bucket.push_back(std::vector<int64_t>());
if (add_weight) is_weighted_bucket.push_back(std::vector<bool>());
}
request_bucket[index_mapping[server_index]].push_back(
......@@ -229,7 +229,7 @@ std::future<int32_t> GraphBrpcClient::add_graph_node(
size_t node_num = request_bucket[request_idx].size();
closure->request(request_idx)
->add_params((char *)request_bucket[request_idx].data(),
sizeof(uint64_t) * node_num);
sizeof(int64_t) * node_num);
if (add_weight) {
bool weighted[is_weighted_bucket[request_idx].size() + 1];
for (size_t j = 0; j < is_weighted_bucket[request_idx].size(); j++)
......@@ -248,8 +248,8 @@ std::future<int32_t> GraphBrpcClient::add_graph_node(
return fut;
}
std::future<int32_t> GraphBrpcClient::remove_graph_node(
uint32_t table_id, std::vector<uint64_t> &node_id_list) {
std::vector<std::vector<uint64_t>> request_bucket;
uint32_t table_id, std::vector<int64_t> &node_id_list) {
std::vector<std::vector<int64_t>> request_bucket;
std::vector<int> server_index_arr;
std::vector<int> index_mapping(server_size, -1);
for (size_t query_idx = 0; query_idx < node_id_list.size(); ++query_idx) {
......@@ -257,7 +257,7 @@ std::future<int32_t> GraphBrpcClient::remove_graph_node(
if (index_mapping[server_index] == -1) {
index_mapping[server_index] = request_bucket.size();
server_index_arr.push_back(server_index);
request_bucket.push_back(std::vector<uint64_t>());
request_bucket.push_back(std::vector<int64_t>());
}
request_bucket[index_mapping[server_index]].push_back(
node_id_list[query_idx]);
......@@ -291,7 +291,7 @@ std::future<int32_t> GraphBrpcClient::remove_graph_node(
closure->request(request_idx)
->add_params((char *)request_bucket[request_idx].data(),
sizeof(uint64_t) * node_num);
sizeof(int64_t) * node_num);
// PsService_Stub rpc_stub(get_cmd_channel(server_index));
GraphPsService_Stub rpc_stub =
getServiceStub(get_cmd_channel(server_index));
......@@ -303,9 +303,9 @@ std::future<int32_t> GraphBrpcClient::remove_graph_node(
}
// char* &buffer,int &actual_size
std::future<int32_t> GraphBrpcClient::batch_sample_neighbors(
uint32_t table_id, std::vector<uint64_t> node_ids, int sample_size,
// std::vector<std::vector<std::pair<uint64_t, float>>> &res,
std::vector<std::vector<uint64_t>> &res,
uint32_t table_id, std::vector<int64_t> node_ids, int sample_size,
// std::vector<std::vector<std::pair<int64_t, float>>> &res,
std::vector<std::vector<int64_t>> &res,
std::vector<std::vector<float>> &res_weight, bool need_weight,
int server_index) {
if (server_index != -1) {
......@@ -337,7 +337,7 @@ std::future<int32_t> GraphBrpcClient::batch_sample_neighbors(
int start = 0;
while (start < actual_size) {
res[node_idx].emplace_back(
*(uint64_t *)(node_buffer + offset + start));
*(int64_t *)(node_buffer + offset + start));
start += GraphNode::id_size;
if (need_weight) {
res_weight[node_idx].emplace_back(
......@@ -358,7 +358,7 @@ std::future<int32_t> GraphBrpcClient::batch_sample_neighbors(
closure->request(0)->set_table_id(table_id);
closure->request(0)->set_client_id(_client_id);
closure->request(0)->add_params((char *)node_ids.data(),
sizeof(uint64_t) * node_ids.size());
sizeof(int64_t) * node_ids.size());
closure->request(0)->add_params((char *)&sample_size, sizeof(int));
closure->request(0)->add_params((char *)&need_weight, sizeof(bool));
;
......@@ -380,14 +380,14 @@ std::future<int32_t> GraphBrpcClient::batch_sample_neighbors(
server2request[server_index] = request2server.size();
request2server.push_back(server_index);
}
// res.push_back(std::vector<std::pair<uint64_t, float>>());
// res.push_back(std::vector<std::pair<int64_t, float>>());
res.push_back({});
if (need_weight) {
res_weight.push_back({});
}
}
size_t request_call_num = request2server.size();
std::vector<std::vector<uint64_t>> node_id_buckets(request_call_num);
std::vector<std::vector<int64_t>> node_id_buckets(request_call_num);
std::vector<std::vector<int>> query_idx_buckets(request_call_num);
for (int query_idx = 0; query_idx < node_ids.size(); ++query_idx) {
int server_index = get_server_index_by_id(node_ids[query_idx]);
......@@ -428,7 +428,7 @@ std::future<int32_t> GraphBrpcClient::batch_sample_neighbors(
int start = 0;
while (start < actual_size) {
res[query_idx].emplace_back(
*(uint64_t *)(node_buffer + offset + start));
*(int64_t *)(node_buffer + offset + start));
start += GraphNode::id_size;
if (need_weight) {
res_weight[query_idx].emplace_back(
......@@ -459,7 +459,7 @@ std::future<int32_t> GraphBrpcClient::batch_sample_neighbors(
closure->request(request_idx)
->add_params((char *)node_id_buckets[request_idx].data(),
sizeof(uint64_t) * node_num);
sizeof(int64_t) * node_num);
closure->request(request_idx)
->add_params((char *)&sample_size, sizeof(int));
closure->request(request_idx)
......@@ -476,7 +476,7 @@ std::future<int32_t> GraphBrpcClient::batch_sample_neighbors(
}
std::future<int32_t> GraphBrpcClient::random_sample_nodes(
uint32_t table_id, int server_index, int sample_size,
std::vector<uint64_t> &ids) {
std::vector<int64_t> &ids) {
DownpourBrpcClosure *closure = new DownpourBrpcClosure(1, [&](void *done) {
int ret = 0;
auto *closure = (DownpourBrpcClosure *)done;
......@@ -490,7 +490,7 @@ std::future<int32_t> GraphBrpcClient::random_sample_nodes(
auto size = io_buffer_itr.copy_and_forward((void *)(buffer), bytes_size);
int index = 0;
while (index < bytes_size) {
ids.push_back(*(uint64_t *)(buffer + index));
ids.push_back(*(int64_t *)(buffer + index));
index += GraphNode::id_size;
}
delete[] buffer;
......@@ -633,7 +633,7 @@ std::future<int32_t> GraphBrpcClient::pull_graph_list(
}
std::future<int32_t> GraphBrpcClient::set_node_feat(
const uint32_t &table_id, const std::vector<uint64_t> &node_ids,
const uint32_t &table_id, const std::vector<int64_t> &node_ids,
const std::vector<std::string> &feature_names,
const std::vector<std::vector<std::string>> &features) {
std::vector<int> request2server;
......@@ -646,7 +646,7 @@ std::future<int32_t> GraphBrpcClient::set_node_feat(
}
}
size_t request_call_num = request2server.size();
std::vector<std::vector<uint64_t>> node_id_buckets(request_call_num);
std::vector<std::vector<int64_t>> node_id_buckets(request_call_num);
std::vector<std::vector<int>> query_idx_buckets(request_call_num);
std::vector<std::vector<std::vector<std::string>>> features_idx_buckets(
request_call_num);
......@@ -696,7 +696,7 @@ std::future<int32_t> GraphBrpcClient::set_node_feat(
closure->request(request_idx)
->add_params((char *)node_id_buckets[request_idx].data(),
sizeof(uint64_t) * node_num);
sizeof(int64_t) * node_num);
std::string joint_feature_name =
paddle::string::join_strings(feature_names, '\t');
closure->request(request_idx)
......
......@@ -63,8 +63,8 @@ class GraphBrpcClient : public BrpcPsClient {
virtual ~GraphBrpcClient() {}
// given a batch of nodes, sample graph_neighbors for each of them
virtual std::future<int32_t> batch_sample_neighbors(
uint32_t table_id, std::vector<uint64_t> node_ids, int sample_size,
std::vector<std::vector<uint64_t>>& res,
uint32_t table_id, std::vector<int64_t> node_ids, int sample_size,
std::vector<std::vector<int64_t>>& res,
std::vector<std::vector<float>>& res_weight, bool need_weight,
int server_index = -1);
......@@ -75,20 +75,20 @@ class GraphBrpcClient : public BrpcPsClient {
virtual std::future<int32_t> random_sample_nodes(uint32_t table_id,
int server_index,
int sample_size,
std::vector<uint64_t>& ids);
std::vector<int64_t>& ids);
virtual std::future<int32_t> get_node_feat(
const uint32_t& table_id, const std::vector<uint64_t>& node_ids,
const uint32_t& table_id, const std::vector<int64_t>& node_ids,
const std::vector<std::string>& feature_names,
std::vector<std::vector<std::string>>& res);
virtual std::future<int32_t> set_node_feat(
const uint32_t& table_id, const std::vector<uint64_t>& node_ids,
const uint32_t& table_id, const std::vector<int64_t>& node_ids,
const std::vector<std::string>& feature_names,
const std::vector<std::vector<std::string>>& features);
virtual std::future<int32_t> clear_nodes(uint32_t table_id);
virtual std::future<int32_t> add_graph_node(
uint32_t table_id, std::vector<uint64_t>& node_id_list,
uint32_t table_id, std::vector<int64_t>& node_id_list,
std::vector<bool>& is_weighted_list);
virtual std::future<int32_t> use_neighbors_sample_cache(uint32_t table_id,
size_t size_limit,
......@@ -96,11 +96,11 @@ class GraphBrpcClient : public BrpcPsClient {
virtual std::future<int32_t> load_graph_split_config(uint32_t table_id,
std::string path);
virtual std::future<int32_t> remove_graph_node(
uint32_t table_id, std::vector<uint64_t>& node_id_list);
uint32_t table_id, std::vector<int64_t>& node_id_list);
virtual int32_t initialize();
int get_shard_num() { return shard_num; }
void set_shard_num(int shard_num) { this->shard_num = shard_num; }
int get_server_index_by_id(uint64_t id);
int get_server_index_by_id(int64_t id);
void set_local_channel(int index) {
this->local_channel = get_cmd_channel(index);
}
......
......@@ -140,9 +140,9 @@ int32_t GraphBrpcService::add_graph_node(Table *table,
return 0;
}
size_t node_num = request.params(0).size() / sizeof(uint64_t);
uint64_t *node_data = (uint64_t *)(request.params(0).c_str());
std::vector<uint64_t> node_ids(node_data, node_data + node_num);
size_t node_num = request.params(0).size() / sizeof(int64_t);
int64_t *node_data = (int64_t *)(request.params(0).c_str());
std::vector<int64_t> node_ids(node_data, node_data + node_num);
std::vector<bool> is_weighted_list;
if (request.params_size() == 2) {
size_t weight_list_size = request.params(1).size() / sizeof(bool);
......@@ -165,9 +165,9 @@ int32_t GraphBrpcService::remove_graph_node(Table *table,
"graph_get_node_feat request requires at least 1 argument");
return 0;
}
size_t node_num = request.params(0).size() / sizeof(uint64_t);
uint64_t *node_data = (uint64_t *)(request.params(0).c_str());
std::vector<uint64_t> node_ids(node_data, node_data + node_num);
size_t node_num = request.params(0).size() / sizeof(int64_t);
int64_t *node_data = (int64_t *)(request.params(0).c_str());
std::vector<int64_t> node_ids(node_data, node_data + node_num);
((GraphTable *)table)->remove_graph_node(node_ids);
return 0;
......@@ -386,9 +386,9 @@ int32_t GraphBrpcService::graph_random_sample_neighbors(
"graph_random_sample_neighbors request requires at least 3 arguments");
return 0;
}
size_t node_num = request.params(0).size() / sizeof(uint64_t);
uint64_t *node_data = (uint64_t *)(request.params(0).c_str());
int sample_size = *(uint64_t *)(request.params(1).c_str());
size_t node_num = request.params(0).size() / sizeof(int64_t);
int64_t *node_data = (int64_t *)(request.params(0).c_str());
int sample_size = *(int64_t *)(request.params(1).c_str());
bool need_weight = *(bool *)(request.params(2).c_str());
std::vector<std::shared_ptr<char>> buffers(node_num);
std::vector<int> actual_sizes(node_num, 0);
......@@ -407,7 +407,7 @@ int32_t GraphBrpcService::graph_random_sample_neighbors(
int32_t GraphBrpcService::graph_random_sample_nodes(
Table *table, const PsRequestMessage &request, PsResponseMessage &response,
brpc::Controller *cntl) {
size_t size = *(uint64_t *)(request.params(0).c_str());
size_t size = *(int64_t *)(request.params(0).c_str());
std::unique_ptr<char[]> buffer;
int actual_size;
if (((GraphTable *)table)->random_sample_nodes(size, buffer, actual_size) ==
......@@ -430,9 +430,9 @@ int32_t GraphBrpcService::graph_get_node_feat(Table *table,
"graph_get_node_feat request requires at least 2 arguments");
return 0;
}
size_t node_num = request.params(0).size() / sizeof(uint64_t);
uint64_t *node_data = (uint64_t *)(request.params(0).c_str());
std::vector<uint64_t> node_ids(node_data, node_data + node_num);
size_t node_num = request.params(0).size() / sizeof(int64_t);
int64_t *node_data = (int64_t *)(request.params(0).c_str());
std::vector<int64_t> node_ids(node_data, node_data + node_num);
std::vector<std::string> feature_names =
paddle::string::split_string<std::string>(request.params(1), "\t");
......@@ -464,16 +464,16 @@ int32_t GraphBrpcService::sample_neighbors_across_multi_servers(
"at least 3 arguments");
return 0;
}
size_t node_num = request.params(0).size() / sizeof(uint64_t),
size_t node_num = request.params(0).size() / sizeof(int64_t),
size_of_size_t = sizeof(size_t);
uint64_t *node_data = (uint64_t *)(request.params(0).c_str());
int sample_size = *(uint64_t *)(request.params(1).c_str());
bool need_weight = *(uint64_t *)(request.params(2).c_str());
// std::vector<uint64_t> res = ((GraphTable
int64_t *node_data = (int64_t *)(request.params(0).c_str());
int sample_size = *(int64_t *)(request.params(1).c_str());
bool need_weight = *(int64_t *)(request.params(2).c_str());
// std::vector<int64_t> res = ((GraphTable
// *)table).filter_out_non_exist_nodes(node_data, sample_size);
std::vector<int> request2server;
std::vector<int> server2request(server_size, -1);
std::vector<uint64_t> local_id;
std::vector<int64_t> local_id;
std::vector<int> local_query_idx;
size_t rank = get_rank();
for (int query_idx = 0; query_idx < node_num; ++query_idx) {
......@@ -496,7 +496,7 @@ int32_t GraphBrpcService::sample_neighbors_across_multi_servers(
std::vector<std::shared_ptr<char>> local_buffers;
std::vector<int> local_actual_sizes;
std::vector<size_t> seq;
std::vector<std::vector<uint64_t>> node_id_buckets(request_call_num);
std::vector<std::vector<int64_t>> node_id_buckets(request_call_num);
std::vector<std::vector<int>> query_idx_buckets(request_call_num);
for (int query_idx = 0; query_idx < node_num; ++query_idx) {
int server_index =
......@@ -583,7 +583,7 @@ int32_t GraphBrpcService::sample_neighbors_across_multi_servers(
closure->request(request_idx)
->add_params((char *)node_id_buckets[request_idx].data(),
sizeof(uint64_t) * node_num);
sizeof(int64_t) * node_num);
closure->request(request_idx)
->add_params((char *)&sample_size, sizeof(int));
closure->request(request_idx)
......@@ -618,9 +618,9 @@ int32_t GraphBrpcService::graph_set_node_feat(Table *table,
"graph_set_node_feat request requires at least 3 arguments");
return 0;
}
size_t node_num = request.params(0).size() / sizeof(uint64_t);
uint64_t *node_data = (uint64_t *)(request.params(0).c_str());
std::vector<uint64_t> node_ids(node_data, node_data + node_num);
size_t node_num = request.params(0).size() / sizeof(int64_t);
int64_t *node_data = (int64_t *)(request.params(0).c_str());
std::vector<int64_t> node_ids(node_data, node_data + node_num);
std::vector<std::string> feature_names =
paddle::string::split_string<std::string>(request.params(1), "\t");
......
......@@ -44,9 +44,9 @@ void GraphPyService::add_table_feat_conf(std::string table_name,
}
}
void add_graph_node(std::vector<uint64_t> node_ids,
void add_graph_node(std::vector<int64_t> node_ids,
std::vector<bool> weight_list) {}
void remove_graph_node(std::vector<uint64_t> node_ids) {}
void remove_graph_node(std::vector<int64_t> node_ids) {}
void GraphPyService::set_up(std::string ips_str, int shard_num,
std::vector<std::string> node_types,
std::vector<std::string> edge_types) {
......@@ -260,7 +260,7 @@ void GraphPyClient::clear_nodes(std::string name) {
}
void GraphPyClient::add_graph_node(std::string name,
std::vector<uint64_t>& node_ids,
std::vector<int64_t>& node_ids,
std::vector<bool>& weight_list) {
if (this->table_id_map.count(name)) {
uint32_t table_id = this->table_id_map[name];
......@@ -271,7 +271,7 @@ void GraphPyClient::add_graph_node(std::string name,
}
void GraphPyClient::remove_graph_node(std::string name,
std::vector<uint64_t>& node_ids) {
std::vector<int64_t>& node_ids) {
if (this->table_id_map.count(name)) {
uint32_t table_id = this->table_id_map[name];
auto status = get_ps_client()->remove_graph_node(table_id, node_ids);
......@@ -290,13 +290,12 @@ void GraphPyClient::load_node_file(std::string name, std::string filepath) {
}
}
std::pair<std::vector<std::vector<uint64_t>>, std::vector<float>>
std::pair<std::vector<std::vector<int64_t>>, std::vector<float>>
GraphPyClient::batch_sample_neighbors(std::string name,
std::vector<uint64_t> node_ids,
std::vector<int64_t> node_ids,
int sample_size, bool return_weight,
bool return_edges) {
// std::vector<std::vector<std::pair<uint64_t, float>>> v;
std::vector<std::vector<uint64_t>> v;
std::vector<std::vector<int64_t>> v;
std::vector<std::vector<float>> v1;
if (this->table_id_map.count(name)) {
uint32_t table_id = this->table_id_map[name];
......@@ -309,7 +308,7 @@ GraphPyClient::batch_sample_neighbors(std::string name,
// res.first[1]: slice index
// res.first[2]: src nodes
// res.second: edges weight
std::pair<std::vector<std::vector<uint64_t>>, std::vector<float>> res;
std::pair<std::vector<std::vector<int64_t>>, std::vector<float>> res;
res.first.push_back({});
res.first.push_back({});
if (return_edges) res.first.push_back({});
......@@ -342,10 +341,10 @@ void GraphPyClient::use_neighbors_sample_cache(std::string name,
status.wait();
}
}
std::vector<uint64_t> GraphPyClient::random_sample_nodes(std::string name,
int server_index,
int sample_size) {
std::vector<uint64_t> v;
std::vector<int64_t> GraphPyClient::random_sample_nodes(std::string name,
int server_index,
int sample_size) {
std::vector<int64_t> v;
if (this->table_id_map.count(name)) {
uint32_t table_id = this->table_id_map[name];
auto status =
......@@ -357,7 +356,7 @@ std::vector<uint64_t> GraphPyClient::random_sample_nodes(std::string name,
// (name, dtype, ndarray)
std::vector<std::vector<std::string>> GraphPyClient::get_node_feat(
std::string node_type, std::vector<uint64_t> node_ids,
std::string node_type, std::vector<int64_t> node_ids,
std::vector<std::string> feature_names) {
std::vector<std::vector<std::string>> v(
feature_names.size(), std::vector<std::string>(node_ids.size()));
......@@ -371,7 +370,7 @@ std::vector<std::vector<std::string>> GraphPyClient::get_node_feat(
}
void GraphPyClient::set_node_feat(
std::string node_type, std::vector<uint64_t> node_ids,
std::string node_type, std::vector<int64_t> node_ids,
std::vector<std::string> feature_names,
const std::vector<std::vector<std::string>> features) {
if (this->table_id_map.count(node_type)) {
......
......@@ -70,18 +70,34 @@ class GraphPyService {
::paddle::distributed::TableAccessorParameter* accessor_proto =
sparse_table_proto->mutable_accessor();
::paddle::distributed::CommonAccessorParameter* common_proto =
sparse_table_proto->mutable_common();
// ::paddle::distributed::CommonAccessorParameter* common_proto =
// sparse_table_proto->mutable_common();
::paddle::distributed::GraphParameter* graph_proto =
sparse_table_proto->mutable_graph_parameter();
::paddle::distributed::GraphFeature* graph_feature =
graph_proto->mutable_graph_feature();
graph_proto->set_task_pool_size(24);
graph_proto->set_table_name(table_name);
graph_proto->set_table_type(table_type);
graph_proto->set_use_cache(false);
// Set GraphTable Parameter
common_proto->set_table_name(table_name);
common_proto->set_name(table_type);
// common_proto->set_table_name(table_name);
// common_proto->set_name(table_type);
// for (size_t i = 0; i < feat_name.size(); i++) {
// common_proto->add_params(feat_dtype[i]);
// common_proto->add_dims(feat_shape[i]);
// common_proto->add_attributes(feat_name[i]);
// }
for (size_t i = 0; i < feat_name.size(); i++) {
common_proto->add_params(feat_dtype[i]);
common_proto->add_dims(feat_shape[i]);
common_proto->add_attributes(feat_name[i]);
graph_feature->add_dtype(feat_dtype[i]);
graph_feature->add_shape(feat_shape[i]);
graph_feature->add_name(feat_name[i]);
}
accessor_proto->set_accessor_class("CommMergeAccessor");
}
......@@ -143,24 +159,24 @@ class GraphPyClient : public GraphPyService {
void load_edge_file(std::string name, std::string filepath, bool reverse);
void load_node_file(std::string name, std::string filepath);
void clear_nodes(std::string name);
void add_graph_node(std::string name, std::vector<uint64_t>& node_ids,
void add_graph_node(std::string name, std::vector<int64_t>& node_ids,
std::vector<bool>& weight_list);
void remove_graph_node(std::string name, std::vector<uint64_t>& node_ids);
void remove_graph_node(std::string name, std::vector<int64_t>& node_ids);
int get_client_id() { return client_id; }
void set_client_id(int client_id) { this->client_id = client_id; }
void start_client();
std::pair<std::vector<std::vector<uint64_t>>, std::vector<float>>
batch_sample_neighbors(std::string name, std::vector<uint64_t> node_ids,
std::pair<std::vector<std::vector<int64_t>>, std::vector<float>>
batch_sample_neighbors(std::string name, std::vector<int64_t> node_ids,
int sample_size, bool return_weight,
bool return_edges);
std::vector<uint64_t> random_sample_nodes(std::string name, int server_index,
int sample_size);
std::vector<int64_t> random_sample_nodes(std::string name, int server_index,
int sample_size);
std::vector<std::vector<std::string>> get_node_feat(
std::string node_type, std::vector<uint64_t> node_ids,
std::string node_type, std::vector<int64_t> node_ids,
std::vector<std::string> feature_names);
void use_neighbors_sample_cache(std::string name, size_t total_size_limit,
size_t ttl);
void set_node_feat(std::string node_type, std::vector<uint64_t> node_ids,
void set_node_feat(std::string node_type, std::vector<int64_t> node_ids,
std::vector<std::string> feature_names,
const std::vector<std::vector<std::string>> features);
std::vector<FeatureNode> pull_graph_list(std::string name, int server_index,
......
......@@ -53,7 +53,6 @@ cc_library(memory_sparse_table SRCS memory_sparse_table.cc DEPS ps_framework_pro
set_source_files_properties(memory_sparse_geo_table.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
cc_library(memory_sparse_geo_table SRCS memory_sparse_geo_table.cc DEPS ps_framework_proto ${TABLE_DEPS} common_table)
cc_library(table SRCS table.cc DEPS memory_sparse_table memory_sparse_geo_table common_table tensor_accessor tensor_table ps_framework_proto string_helper device_context gflags glog boost)
target_link_libraries(table -fopenmp)
......@@ -38,10 +38,14 @@
#include <vector>
#include "paddle/fluid/distributed/ps/table/accessor.h"
#include "paddle/fluid/distributed/ps/table/common_table.h"
#include "paddle/fluid/distributed/ps/table/graph/class_macro.h"
#include "paddle/fluid/distributed/ps/table/graph/graph_node.h"
#include "paddle/fluid/string/string_helper.h"
#include "paddle/phi/core/utils/rw_lock.h"
#ifdef PADDLE_WITH_HETERPS
#include "paddle/fluid/framework/fleet/heter_ps/gpu_graph_node.h"
#endif
namespace paddle {
namespace distributed {
class GraphShard {
......@@ -51,37 +55,37 @@ class GraphShard {
~GraphShard();
std::vector<Node *> &get_bucket() { return bucket; }
std::vector<Node *> get_batch(int start, int end, int step);
std::vector<uint64_t> get_ids_by_range(int start, int end) {
std::vector<uint64_t> res;
std::vector<int64_t> get_ids_by_range(int start, int end) {
std::vector<int64_t> res;
for (int i = start; i < end && i < (int)bucket.size(); i++) {
res.push_back(bucket[i]->get_id());
}
return res;
}
GraphNode *add_graph_node(uint64_t id);
GraphNode *add_graph_node(int64_t id);
GraphNode *add_graph_node(Node *node);
FeatureNode *add_feature_node(uint64_t id);
Node *find_node(uint64_t id);
void delete_node(uint64_t id);
FeatureNode *add_feature_node(int64_t id);
Node *find_node(int64_t id);
void delete_node(int64_t id);
void clear();
void add_neighbor(uint64_t id, uint64_t dst_id, float weight);
std::unordered_map<uint64_t, int> &get_node_location() {
void add_neighbor(int64_t id, int64_t dst_id, float weight);
std::unordered_map<int64_t, int> &get_node_location() {
return node_location;
}
private:
std::unordered_map<uint64_t, int> node_location;
std::unordered_map<int64_t, int> node_location;
std::vector<Node *> bucket;
};
enum LRUResponse { ok = 0, blocked = 1, err = 2 };
struct SampleKey {
uint64_t node_key;
int64_t node_key;
size_t sample_size;
bool is_weighted;
SampleKey(uint64_t _node_key, size_t _sample_size, bool _is_weighted)
SampleKey(int64_t _node_key, size_t _sample_size, bool _is_weighted)
: node_key(_node_key),
sample_size(_sample_size),
is_weighted(_is_weighted) {}
......@@ -300,7 +304,7 @@ class ScaledLRU {
node_size += lru_pool[i].node_size - lru_pool[i].remove_count;
}
if (node_size <= size_t(1.1 * size_limit) + 1) return 0;
if ((size_t)node_size <= size_t(1.1 * size_limit) + 1) return 0;
if (pthread_rwlock_wrlock(&rwlock) == 0) {
// VLOG(0)<"in shrink\n";
global_count = 0;
......@@ -308,9 +312,9 @@ class ScaledLRU {
global_count += lru_pool[i].node_size - lru_pool[i].remove_count;
}
// VLOG(0)<<"global_count "<<global_count<<"\n";
if (global_count > size_limit) {
if ((size_t)global_count > size_limit) {
size_t remove = global_count - size_limit;
for (int i = 0; i < lru_pool.size(); i++) {
for (size_t i = 0; i < lru_pool.size(); i++) {
lru_pool[i].total_diff = 0;
lru_pool[i].remove_count +=
1.0 * (lru_pool[i].node_size - lru_pool[i].remove_count) /
......@@ -352,9 +356,69 @@ class ScaledLRU {
friend class RandomSampleLRU<K, V>;
};
#ifdef PADDLE_WITH_HETERPS
enum GraphSamplerStatus { waiting = 0, running = 1, terminating = 2 };
class GraphTable;
class GraphSampler {
public:
GraphSampler() {
status = GraphSamplerStatus::waiting;
thread_pool.reset(new ::ThreadPool(1));
callback = [](std::vector<paddle::framework::GpuPsCommGraph> &res) {
return;
};
}
virtual int run_graph_sampling() = 0;
virtual int start_graph_sampling() {
if (status != GraphSamplerStatus::waiting) {
return -1;
}
std::promise<int> prom;
std::future<int> fut = prom.get_future();
graph_sample_task_over = thread_pool->enqueue([&prom, this]() {
prom.set_value(0);
status = GraphSamplerStatus::running;
return run_graph_sampling();
});
return fut.get();
}
virtual void init(size_t gpu_num, GraphTable *graph_table,
std::vector<std::string> args) = 0;
virtual void set_graph_sample_callback(
std::function<void(std::vector<paddle::framework::GpuPsCommGraph> &)>
callback) {
this->callback = callback;
}
virtual int end_graph_sampling() {
if (status == GraphSamplerStatus::running) {
status = GraphSamplerStatus::terminating;
return graph_sample_task_over.get();
}
return -1;
}
virtual GraphSamplerStatus get_graph_sampler_status() { return status; }
protected:
std::function<void(std::vector<paddle::framework::GpuPsCommGraph> &)>
callback;
std::shared_ptr<::ThreadPool> thread_pool;
GraphSamplerStatus status;
std::future<int> graph_sample_task_over;
std::vector<paddle::framework::GpuPsCommGraph> sample_res;
};
#endif
class GraphTable : public SparseTable {
public:
GraphTable() { use_cache = false; }
GraphTable() {
use_cache = false;
shard_num = 0;
#ifdef PADDLE_WITH_HETERPS
gpups_mode = false;
#endif
rw_lock.reset(new pthread_rwlock_t());
}
virtual ~GraphTable();
virtual int32_t pull_graph_list(int start, int size,
std::unique_ptr<char[]> &buffer,
......@@ -362,7 +426,7 @@ class GraphTable : public SparseTable {
int step);
virtual int32_t random_sample_neighbors(
uint64_t *node_ids, int sample_size,
int64_t *node_ids, int sample_size,
std::vector<std::shared_ptr<char>> &buffers,
std::vector<int> &actual_sizes, bool need_weight);
......@@ -370,9 +434,11 @@ class GraphTable : public SparseTable {
int &actual_sizes);
virtual int32_t get_nodes_ids_by_ranges(
std::vector<std::pair<int, int>> ranges, std::vector<uint64_t> &res);
virtual int32_t initialize();
std::vector<std::pair<int, int>> ranges, std::vector<int64_t> &res);
virtual int32_t initialize() { return 0; }
virtual int32_t initialize(const TableParameter &config,
const FsClientParameter &fs_config);
virtual int32_t initialize(const GraphParameter &config);
int32_t load(const std::string &path, const std::string &param);
int32_t load_graph_split_config(const std::string &path);
......@@ -380,13 +446,13 @@ class GraphTable : public SparseTable {
int32_t load_nodes(const std::string &path, std::string node_type);
int32_t add_graph_node(std::vector<uint64_t> &id_list,
int32_t add_graph_node(std::vector<int64_t> &id_list,
std::vector<bool> &is_weight_list);
int32_t remove_graph_node(std::vector<uint64_t> &id_list);
int32_t remove_graph_node(std::vector<int64_t> &id_list);
int32_t get_server_index_by_id(uint64_t id);
Node *find_node(uint64_t id);
int32_t get_server_index_by_id(int64_t id);
Node *find_node(int64_t id);
virtual int32_t pull_sparse(float *values,
const PullSparseValue &pull_value) {
......@@ -407,16 +473,27 @@ class GraphTable : public SparseTable {
return 0;
}
virtual int32_t initialize_shard() { return 0; }
virtual uint32_t get_thread_pool_index_by_shard_index(uint64_t shard_index);
virtual uint32_t get_thread_pool_index(uint64_t node_id);
virtual int32_t set_shard(size_t shard_idx, size_t server_num) {
_shard_idx = shard_idx;
/*
_shard_num is not used in graph_table, this following operation is for the
purpose of
being compatible with base class table.
*/
_shard_num = server_num;
this->server_num = server_num;
return 0;
}
virtual uint32_t get_thread_pool_index_by_shard_index(int64_t shard_index);
virtual uint32_t get_thread_pool_index(int64_t node_id);
virtual std::pair<int32_t, std::string> parse_feature(std::string feat_str);
virtual int32_t get_node_feat(const std::vector<uint64_t> &node_ids,
virtual int32_t get_node_feat(const std::vector<int64_t> &node_ids,
const std::vector<std::string> &feature_names,
std::vector<std::vector<std::string>> &res);
virtual int32_t set_node_feat(
const std::vector<uint64_t> &node_ids,
const std::vector<int64_t> &node_ids,
const std::vector<std::string> &feature_names,
const std::vector<std::vector<std::string>> &res);
......@@ -433,11 +510,25 @@ class GraphTable : public SparseTable {
}
return 0;
}
#ifdef PADDLE_WITH_HETERPS
virtual int32_t start_graph_sampling() {
return this->graph_sampler->start_graph_sampling();
}
virtual int32_t end_graph_sampling() {
return this->graph_sampler->end_graph_sampling();
}
virtual int32_t set_graph_sample_callback(
std::function<void(std::vector<paddle::framework::GpuPsCommGraph> &)>
callback) {
graph_sampler->set_graph_sample_callback(callback);
return 0;
}
// virtual GraphSampler *get_graph_sampler() { return graph_sampler.get(); }
#endif
protected:
std::vector<GraphShard *> shards, extra_shards;
size_t shard_start, shard_end, server_num, shard_num_per_server, shard_num;
const int task_pool_size_ = 24;
int task_pool_size_ = 24;
const int random_sample_nodes_ranges = 3;
std::vector<std::string> feat_name;
......@@ -450,11 +541,61 @@ class GraphTable : public SparseTable {
std::vector<std::shared_ptr<::ThreadPool>> _shards_task_pool;
std::vector<std::shared_ptr<std::mt19937_64>> _shards_task_rng_pool;
std::shared_ptr<ScaledLRU<SampleKey, SampleResult>> scaled_lru;
std::unordered_set<uint64_t> extra_nodes;
std::unordered_map<uint64_t, size_t> extra_nodes_to_thread_index;
std::unordered_set<int64_t> extra_nodes;
std::unordered_map<int64_t, size_t> extra_nodes_to_thread_index;
bool use_cache, use_duplicate_nodes;
mutable std::mutex mutex_;
std::shared_ptr<pthread_rwlock_t> rw_lock;
#ifdef PADDLE_WITH_HETERPS
// paddle::framework::GpuPsGraphTable gpu_graph_table;
bool gpups_mode;
// std::shared_ptr<::ThreadPool> graph_sample_pool;
std::shared_ptr<GraphSampler> graph_sampler;
REGISTER_GRAPH_FRIEND_CLASS(2, CompleteGraphSampler, BasicBfsGraphSampler)
#endif
};
#ifdef PADDLE_WITH_HETERPS
REGISTER_PSCORE_REGISTERER(GraphSampler);
class CompleteGraphSampler : public GraphSampler {
public:
CompleteGraphSampler() {}
~CompleteGraphSampler() {}
// virtual pthread_rwlock_t *export_rw_lock();
virtual int run_graph_sampling();
virtual void init(size_t gpu_num, GraphTable *graph_table,
std::vector<std::string> args_);
protected:
GraphTable *graph_table;
std::vector<std::vector<paddle::framework::GpuPsGraphNode>> sample_nodes;
std::vector<std::vector<int64_t>> sample_neighbors;
// std::vector<GpuPsCommGraph> sample_res;
// std::shared_ptr<std::mt19937_64> random;
int gpu_num;
};
class BasicBfsGraphSampler : public GraphSampler {
public:
BasicBfsGraphSampler() {}
~BasicBfsGraphSampler() {}
// virtual pthread_rwlock_t *export_rw_lock();
virtual int run_graph_sampling();
virtual void init(size_t gpu_num, GraphTable *graph_table,
std::vector<std::string> args_);
protected:
GraphTable *graph_table;
// std::vector<std::vector<GpuPsGraphNode>> sample_nodes;
std::vector<std::vector<paddle::framework::GpuPsGraphNode>> sample_nodes;
std::vector<std::vector<int64_t>> sample_neighbors;
size_t gpu_num;
int node_num_for_each_shard, edge_num_for_each_node;
int rounds, interval;
std::vector<std::unordered_map<int64_t, std::vector<int64_t>>>
sample_neighbors_map;
};
#endif
} // namespace distributed
}; // namespace paddle
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#define DECLARE_GRAPH_FRIEND_CLASS(a) friend class a;
#define DECLARE_1_FRIEND_CLASS(a, ...) DECLARE_GRAPH_FRIEND_CLASS(a)
#define DECLARE_2_FRIEND_CLASS(a, ...) \
DECLARE_GRAPH_FRIEND_CLASS(a) DECLARE_1_FRIEND_CLASS(__VA_ARGS__)
#define DECLARE_3_FRIEND_CLASS(a, ...) \
DECLARE_GRAPH_FRIEND_CLASS(a) DECLARE_2_FRIEND_CLASS(__VA_ARGS__)
#define DECLARE_4_FRIEND_CLASS(a, ...) \
DECLARE_GRAPH_FRIEND_CLASS(a) DECLARE_3_FRIEND_CLASS(__VA_ARGS__)
#define DECLARE_5_FRIEND_CLASS(a, ...) \
DECLARE_GRAPH_FRIEND_CLASS(a) DECLARE_4_FRIEND_CLASS(__VA_ARGS__)
#define DECLARE_6_FRIEND_CLASS(a, ...) \
DECLARE_GRAPH_FRIEND_CLASS(a) DECLARE_5_FRIEND_CLASS(__VA_ARGS__)
#define DECLARE_7_FRIEND_CLASS(a, ...) \
DECLARE_GRAPH_FRIEND_CLASS(a) DECLARE_6_FRIEND_CLASS(__VA_ARGS__)
#define DECLARE_8_FRIEND_CLASS(a, ...) \
DECLARE_GRAPH_FRIEND_CLASS(a) DECLARE_7_FRIEND_CLASS(__VA_ARGS__)
#define DECLARE_9_FRIEND_CLASS(a, ...) \
DECLARE_GRAPH_FRIEND_CLASS(a) DECLARE_8_FRIEND_CLASS(__VA_ARGS__)
#define DECLARE_10_FRIEND_CLASS(a, ...) \
DECLARE_GRAPH_FRIEND_CLASS(a) DECLARE_9_FRIEND_CLASS(__VA_ARGS__)
#define DECLARE_11_FRIEND_CLASS(a, ...) \
DECLARE_GRAPH_FRIEND_CLASS(a) DECLARE_10_FRIEND_CLASS(__VA_ARGS__)
#define REGISTER_GRAPH_FRIEND_CLASS(n, ...) \
DECLARE_##n##_FRIEND_CLASS(__VA_ARGS__)
......@@ -17,11 +17,11 @@
namespace paddle {
namespace distributed {
void GraphEdgeBlob::add_edge(uint64_t id, float weight = 1) {
void GraphEdgeBlob::add_edge(int64_t id, float weight = 1) {
id_arr.push_back(id);
}
void WeightedGraphEdgeBlob::add_edge(uint64_t id, float weight = 1) {
void WeightedGraphEdgeBlob::add_edge(int64_t id, float weight = 1) {
id_arr.push_back(id);
weight_arr.push_back(weight);
}
......
......@@ -24,19 +24,20 @@ class GraphEdgeBlob {
GraphEdgeBlob() {}
virtual ~GraphEdgeBlob() {}
size_t size() { return id_arr.size(); }
virtual void add_edge(uint64_t id, float weight);
uint64_t get_id(int idx) { return id_arr[idx]; }
virtual void add_edge(int64_t id, float weight);
int64_t get_id(int idx) { return id_arr[idx]; }
virtual float get_weight(int idx) { return 1; }
std::vector<int64_t>& export_id_array() { return id_arr; }
protected:
std::vector<uint64_t> id_arr;
std::vector<int64_t> id_arr;
};
class WeightedGraphEdgeBlob : public GraphEdgeBlob {
public:
WeightedGraphEdgeBlob() {}
virtual ~WeightedGraphEdgeBlob() {}
virtual void add_edge(uint64_t id, float weight);
virtual void add_edge(int64_t id, float weight);
virtual float get_weight(int idx) { return weight_arr[idx]; }
protected:
......
......@@ -48,6 +48,7 @@ class Node {
virtual void set_feature(int idx, std::string str) {}
virtual void set_feature_size(int size) {}
virtual int get_feature_size() { return 0; }
virtual size_t get_neighbor_size() { return 0; }
protected:
uint64_t id;
......@@ -70,6 +71,7 @@ class GraphNode : public Node {
}
virtual uint64_t get_neighbor_id(int idx) { return edges->get_id(idx); }
virtual float get_neighbor_weight(int idx) { return edges->get_weight(idx); }
virtual size_t get_neighbor_size() { return edges->size(); }
protected:
Sampler *sampler;
......
......@@ -37,6 +37,8 @@ REGISTER_PSCORE_CLASS(Table, CommonDenseTable);
REGISTER_PSCORE_CLASS(Table, CommonSparseTable);
#ifdef PADDLE_WITH_HETERPS
REGISTER_PSCORE_CLASS(Table, SSDSparseTable);
REGISTER_PSCORE_CLASS(GraphSampler, CompleteGraphSampler);
REGISTER_PSCORE_CLASS(GraphSampler, BasicBfsGraphSampler);
#endif
REGISTER_PSCORE_CLASS(Table, SparseGeoTable);
REGISTER_PSCORE_CLASS(Table, BarrierTable);
......
......@@ -24,6 +24,9 @@ cc_test(graph_node_test SRCS graph_node_test.cc DEPS graph_py_service scope serv
set_source_files_properties(graph_node_split_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
cc_test(graph_node_split_test SRCS graph_node_split_test.cc DEPS graph_py_service scope server client communicator ps_service boost table ps_framework_proto ${COMMON_DEPS})
set_source_files_properties(graph_table_sample_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
cc_test(graph_table_sample_test SRCS graph_table_sample_test.cc DEPS scope server communicator ps_service boost table ps_framework_proto ${COMMON_DEPS})
set_source_files_properties(feature_value_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
cc_test(feature_value_test SRCS feature_value_test.cc DEPS ${COMMON_DEPS} boost table)
......
......@@ -236,7 +236,7 @@ void RunGraphSplit() {
sleep(2);
std::map<uint64_t, std::vector<paddle::distributed::Region>> dense_regions;
dense_regions.insert(
std::pair<uint64_t, std::vector<paddle::distributed::Region>>(0, {}));
std::pair<int64_t, std::vector<paddle::distributed::Region>>(0, {}));
auto regions = dense_regions[0];
RunClient(dense_regions, 0, pserver_ptr_->get_service());
......@@ -250,16 +250,16 @@ void RunGraphSplit() {
worker_ptr_->load(0, std::string(edge_file_name), std::string("e>"));
srand(time(0));
pull_status.wait();
std::vector<std::vector<uint64_t>> _vs;
std::vector<std::vector<int64_t>> _vs;
std::vector<std::vector<float>> vs;
pull_status = worker_ptr_->batch_sample_neighbors(
0, std::vector<uint64_t>(1, 10240001024), 4, _vs, vs, true);
0, std::vector<int64_t>(1, 10240001024), 4, _vs, vs, true);
pull_status.wait();
ASSERT_EQ(0, _vs[0].size());
_vs.clear();
vs.clear();
pull_status = worker_ptr_->batch_sample_neighbors(
0, std::vector<uint64_t>(1, 97), 4, _vs, vs, true);
0, std::vector<int64_t>(1, 97), 4, _vs, vs, true);
pull_status.wait();
ASSERT_EQ(3, _vs[0].size());
std::remove(edge_file_name);
......
......@@ -48,10 +48,10 @@ namespace distributed = paddle::distributed;
void testSampleNodes(
std::shared_ptr<paddle::distributed::GraphBrpcClient>& worker_ptr_) {
std::vector<uint64_t> ids;
std::vector<int64_t> ids;
auto pull_status = worker_ptr_->random_sample_nodes(0, 0, 6, ids);
std::unordered_set<uint64_t> s;
std::unordered_set<uint64_t> s1 = {37, 59};
std::unordered_set<int64_t> s;
std::unordered_set<int64_t> s1 = {37, 59};
pull_status.wait();
for (auto id : ids) s.insert(id);
ASSERT_EQ(true, s.size() == s1.size());
......@@ -106,14 +106,14 @@ void testFeatureNodeSerializeFloat64() {
void testSingleSampleNeighboor(
std::shared_ptr<paddle::distributed::GraphBrpcClient>& worker_ptr_) {
std::vector<std::vector<uint64_t>> vs;
std::vector<std::vector<int64_t>> vs;
std::vector<std::vector<float>> vs1;
auto pull_status = worker_ptr_->batch_sample_neighbors(
0, std::vector<uint64_t>(1, 37), 4, vs, vs1, true);
0, std::vector<int64_t>(1, 37), 4, vs, vs1, true);
pull_status.wait();
std::unordered_set<uint64_t> s;
std::unordered_set<uint64_t> s1 = {112, 45, 145};
std::unordered_set<int64_t> s;
std::unordered_set<int64_t> s1 = {112, 45, 145};
for (auto g : vs[0]) {
s.insert(g);
}
......@@ -126,7 +126,7 @@ void testSingleSampleNeighboor(
vs.clear();
vs1.clear();
pull_status = worker_ptr_->batch_sample_neighbors(
0, std::vector<uint64_t>(1, 96), 4, vs, vs1, true);
0, std::vector<int64_t>(1, 96), 4, vs, vs1, true);
pull_status.wait();
s1 = {111, 48, 247};
for (auto g : vs[0]) {
......@@ -147,30 +147,30 @@ void testAddNode(
std::shared_ptr<paddle::distributed::GraphBrpcClient>& worker_ptr_) {
worker_ptr_->clear_nodes(0);
int total_num = 270000;
uint64_t id;
std::unordered_set<uint64_t> id_set;
int64_t id;
std::unordered_set<int64_t> id_set;
for (int i = 0; i < total_num; i++) {
while (id_set.find(id = rand()) != id_set.end())
;
id_set.insert(id);
}
std::vector<uint64_t> id_list(id_set.begin(), id_set.end());
std::vector<int64_t> id_list(id_set.begin(), id_set.end());
std::vector<bool> weight_list;
auto status = worker_ptr_->add_graph_node(0, id_list, weight_list);
status.wait();
std::vector<uint64_t> ids[2];
std::vector<int64_t> ids[2];
for (int i = 0; i < 2; i++) {
auto sample_status =
worker_ptr_->random_sample_nodes(0, i, total_num, ids[i]);
sample_status.wait();
}
std::unordered_set<uint64_t> id_set_check(ids[0].begin(), ids[0].end());
std::unordered_set<int64_t> id_set_check(ids[0].begin(), ids[0].end());
for (auto x : ids[1]) id_set_check.insert(x);
ASSERT_EQ(id_set.size(), id_set_check.size());
for (auto x : id_set) {
ASSERT_EQ(id_set_check.find(x) != id_set_check.end(), true);
}
std::vector<uint64_t> remove_ids;
std::vector<int64_t> remove_ids;
for (auto p : id_set_check) {
if (remove_ids.size() == 0)
remove_ids.push_back(p);
......@@ -187,7 +187,7 @@ void testAddNode(
worker_ptr_->random_sample_nodes(0, i, total_num, ids[i]);
sample_status.wait();
}
std::unordered_set<uint64_t> id_set_check1(ids[0].begin(), ids[0].end());
std::unordered_set<int64_t> id_set_check1(ids[0].begin(), ids[0].end());
for (auto x : ids[1]) id_set_check1.insert(x);
ASSERT_EQ(id_set_check1.size(), id_set_check.size());
for (auto x : id_set_check1) {
......@@ -196,14 +196,14 @@ void testAddNode(
}
void testBatchSampleNeighboor(
std::shared_ptr<paddle::distributed::GraphBrpcClient>& worker_ptr_) {
std::vector<std::vector<uint64_t>> vs;
std::vector<std::vector<int64_t>> vs;
std::vector<std::vector<float>> vs1;
std::vector<std::uint64_t> v = {37, 96};
std::vector<std::int64_t> v = {37, 96};
auto pull_status =
worker_ptr_->batch_sample_neighbors(0, v, 4, vs, vs1, false);
pull_status.wait();
std::unordered_set<uint64_t> s;
std::unordered_set<uint64_t> s1 = {112, 45, 145};
std::unordered_set<int64_t> s;
std::unordered_set<int64_t> s1 = {112, 45, 145};
for (auto g : vs[0]) {
s.insert(g);
}
......@@ -417,7 +417,7 @@ void RunBrpcPushSparse() {
std::map<uint64_t, std::vector<paddle::distributed::Region>> dense_regions;
dense_regions.insert(
std::pair<uint64_t, std::vector<paddle::distributed::Region>>(0, {}));
std::pair<int64_t, std::vector<paddle::distributed::Region>>(0, {}));
auto regions = dense_regions[0];
RunClient(dense_regions, 0, pserver_ptr_->get_service());
......@@ -427,14 +427,14 @@ void RunBrpcPushSparse() {
worker_ptr_->load(0, std::string(edge_file_name), std::string("e>"));
srand(time(0));
pull_status.wait();
std::vector<std::vector<uint64_t>> _vs;
std::vector<std::vector<int64_t>> _vs;
std::vector<std::vector<float>> vs;
testSampleNodes(worker_ptr_);
sleep(5);
testSingleSampleNeighboor(worker_ptr_);
testBatchSampleNeighboor(worker_ptr_);
pull_status = worker_ptr_->batch_sample_neighbors(
0, std::vector<uint64_t>(1, 10240001024), 4, _vs, vs, true);
0, std::vector<int64_t>(1, 10240001024), 4, _vs, vs, true);
pull_status.wait();
ASSERT_EQ(0, _vs[0].size());
paddle::distributed::GraphTable* g =
......@@ -445,14 +445,14 @@ void RunBrpcPushSparse() {
while (round--) {
vs.clear();
pull_status = worker_ptr_->batch_sample_neighbors(
0, std::vector<uint64_t>(1, 37), 1, _vs, vs, false);
0, std::vector<int64_t>(1, 37), 1, _vs, vs, false);
pull_status.wait();
for (int i = 0; i < ttl; i++) {
std::vector<std::vector<uint64_t>> vs1;
std::vector<std::vector<int64_t>> vs1;
std::vector<std::vector<float>> vs2;
pull_status = worker_ptr_->batch_sample_neighbors(
0, std::vector<uint64_t>(1, 37), 1, vs1, vs2, false);
0, std::vector<int64_t>(1, 37), 1, vs1, vs2, false);
pull_status.wait();
ASSERT_EQ(_vs[0].size(), vs1[0].size());
......@@ -540,7 +540,7 @@ void RunBrpcPushSparse() {
// Test Pull by step
std::unordered_set<uint64_t> count_item_nodes;
std::unordered_set<int64_t> count_item_nodes;
// pull by step 2
for (int test_step = 1; test_step < 4; test_step++) {
count_item_nodes.clear();
......@@ -558,18 +558,18 @@ void RunBrpcPushSparse() {
ASSERT_EQ(count_item_nodes.size(), 12);
}
std::pair<std::vector<std::vector<uint64_t>>, std::vector<float>> res;
std::pair<std::vector<std::vector<int64_t>>, std::vector<float>> res;
res = client1.batch_sample_neighbors(
std::string("user2item"), std::vector<uint64_t>(1, 96), 4, true, false);
std::string("user2item"), std::vector<int64_t>(1, 96), 4, true, false);
ASSERT_EQ(res.first[0].size(), 3);
std::vector<uint64_t> node_ids;
std::vector<int64_t> node_ids;
node_ids.push_back(96);
node_ids.push_back(37);
res = client1.batch_sample_neighbors(std::string("user2item"), node_ids, 4,
true, false);
ASSERT_EQ(res.first[1].size(), 1);
std::vector<uint64_t> nodes_ids = client2.random_sample_nodes("user", 0, 6);
std::vector<int64_t> nodes_ids = client2.random_sample_nodes("user", 0, 6);
ASSERT_EQ(nodes_ids.size(), 2);
ASSERT_EQ(true, (nodes_ids[0] == 59 && nodes_ids[1] == 37) ||
(nodes_ids[0] == 37 && nodes_ids[1] == 59));
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <unistd.h>
#include <condition_variable> // NOLINT
#include <fstream>
#include <iomanip>
#include <string>
#include <thread> // NOLINT
#include <unordered_set>
#include <vector>
#include "google/protobuf/text_format.h"
#include <chrono>
#include "gtest/gtest.h"
#include "paddle/fluid/distributed/ps.pb.h"
#include "paddle/fluid/distributed/ps/service/env.h"
#include "paddle/fluid/distributed/ps/service/sendrecv.pb.h"
#include "paddle/fluid/distributed/ps/table/common_graph_table.h"
#include "paddle/fluid/distributed/ps/table/graph/graph_node.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/framework/variable.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/string/printf.h"
#include "paddle/phi/kernels/funcs/math_function.h"
namespace framework = paddle::framework;
namespace platform = paddle::platform;
namespace operators = paddle::operators;
namespace memory = paddle::memory;
namespace distributed = paddle::distributed;
std::vector<std::string> edges = {
std::string("37\t45\t0.34"), std::string("37\t145\t0.31"),
std::string("37\t112\t0.21"), std::string("96\t48\t1.4"),
std::string("96\t247\t0.31"), std::string("96\t111\t1.21"),
std::string("59\t45\t0.34"), std::string("59\t145\t0.31"),
std::string("59\t122\t0.21"), std::string("97\t48\t0.34"),
std::string("97\t247\t0.31"), std::string("97\t111\t0.21")};
// odd id:96 48 122 112
char edge_file_name[] = "edges.txt";
std::vector<std::string> nodes = {
std::string("user\t37\ta 0.34\tb 13 14\tc hello\td abc"),
std::string("user\t96\ta 0.31\tb 15 10\tc 96hello\td abcd"),
std::string("user\t59\ta 0.11\tb 11 14"),
std::string("user\t97\ta 0.11\tb 12 11"),
std::string("item\t45\ta 0.21"),
std::string("item\t145\ta 0.21"),
std::string("item\t112\ta 0.21"),
std::string("item\t48\ta 0.21"),
std::string("item\t247\ta 0.21"),
std::string("item\t111\ta 0.21"),
std::string("item\t46\ta 0.21"),
std::string("item\t146\ta 0.21"),
std::string("item\t122\ta 0.21"),
std::string("item\t49\ta 0.21"),
std::string("item\t248\ta 0.21"),
std::string("item\t113\ta 0.21")};
char node_file_name[] = "nodes.txt";
void prepare_file(char file_name[], std::vector<std::string> data) {
std::ofstream ofile;
ofile.open(file_name);
for (auto x : data) {
ofile << x << std::endl;
}
ofile.close();
}
void testGraphSample() {
#ifdef PADDLE_WITH_HETERPS
::paddle::distributed::GraphParameter table_proto;
table_proto.set_gpups_mode(true);
table_proto.set_gpups_mode_shard_num(127);
table_proto.set_gpu_num(2);
distributed::GraphTable graph_table, graph_table1;
graph_table.initialize(table_proto);
prepare_file(edge_file_name, edges);
graph_table.load(std::string(edge_file_name), std::string("e>"));
std::vector<paddle::framework::GpuPsCommGraph> res;
std::promise<int> prom;
std::future<int> fut = prom.get_future();
graph_table.set_graph_sample_callback(
[&res, &prom](std::vector<paddle::framework::GpuPsCommGraph> &res0) {
res = res0;
prom.set_value(0);
});
graph_table.start_graph_sampling();
fut.get();
graph_table.end_graph_sampling();
ASSERT_EQ(2, res.size());
// 37 59 97
for (int i = 0; i < (int)res[1].node_size; i++) {
std::cout << res[1].node_list[i].node_id << std::endl;
}
ASSERT_EQ(3, res[1].node_size);
::paddle::distributed::GraphParameter table_proto1;
table_proto1.set_gpups_mode(true);
table_proto1.set_gpups_mode_shard_num(127);
table_proto1.set_gpu_num(2);
table_proto1.set_gpups_graph_sample_class("BasicBfsGraphSampler");
table_proto1.set_gpups_graph_sample_args("5,5,1,1");
graph_table1.initialize(table_proto1);
graph_table1.load(std::string(edge_file_name), std::string("e>"));
std::vector<paddle::framework::GpuPsCommGraph> res1;
std::promise<int> prom1;
std::future<int> fut1 = prom1.get_future();
graph_table1.set_graph_sample_callback(
[&res1, &prom1](std::vector<paddle::framework::GpuPsCommGraph> &res0) {
res1 = res0;
prom1.set_value(0);
});
graph_table1.start_graph_sampling();
fut1.get();
graph_table1.end_graph_sampling();
// distributed::BasicBfsGraphSampler *sampler1 =
// (distributed::BasicBfsGraphSampler *)graph_table1.get_graph_sampler();
// sampler1->start_graph_sampling();
// std::this_thread::sleep_for (std::chrono::seconds(1));
// std::vector<paddle::framework::GpuPsCommGraph> res1;// =
// sampler1->fetch_sample_res();
ASSERT_EQ(2, res1.size());
// odd id:96 48 122 112
for (int i = 0; i < (int)res1[0].node_size; i++) {
std::cout << res1[0].node_list[i].node_id << std::endl;
}
ASSERT_EQ(4, res1[0].node_size);
#endif
}
TEST(testGraphSample, Run) { testGraphSample(); }
......@@ -10,8 +10,9 @@ IF(WITH_GPU)
nv_library(heter_comm SRCS heter_comm.h feature_value.h heter_resource.cc heter_resource.h hashtable.h mem_pool.h DEPS ${HETERPS_DEPS})
nv_test(test_heter_comm SRCS feature_value.h DEPS heter_comm)
nv_library(heter_ps SRCS heter_ps.cu DEPS heter_comm)
nv_library(graph_gpu_ps SRCS graph_gpu_ps_table.h DEPS heter_comm)
nv_library(graph_gpu_ps SRCS graph_gpu_ps_table.h DEPS heter_comm table)
nv_test(test_graph_comm SRCS test_graph.cu DEPS graph_gpu_ps)
nv_test(test_cpu_graph_sample SRCS test_cpu_graph_sample.cu DEPS graph_gpu_ps)
ENDIF()
IF(WITH_ROCM)
hip_library(heter_comm SRCS heter_comm.h feature_value.h heter_resource.cc heter_resource.h hashtable.h DEPS cub device_context)
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#ifdef PADDLE_WITH_HETERPS
namespace paddle {
namespace framework {
struct GpuPsGraphNode {
int64_t node_id;
int neighbor_size, neighbor_offset;
// this node's neighbor is stored on [neighbor_offset,neighbor_offset +
// neighbor_size) of int64_t *neighbor_list;
};
struct GpuPsCommGraph {
int64_t *neighbor_list;
GpuPsGraphNode *node_list;
int neighbor_size, node_size;
// the size of neighbor array and graph_node_list array
GpuPsCommGraph()
: neighbor_list(NULL), node_list(NULL), neighbor_size(0), node_size(0) {}
GpuPsCommGraph(int64_t *neighbor_list_, GpuPsGraphNode *node_list_,
int neighbor_size_, int node_size_)
: neighbor_list(neighbor_list_),
node_list(node_list_),
neighbor_size(neighbor_size_),
node_size(node_size_) {}
};
/*
suppose we have a graph like this
0----3-----5----7
\ |\ |\
17 8 9 1 2
we save the nodes in arbitrary order,
in this example,the order is
[0,5,1,2,7,3,8,9,17]
let us name this array u_id;
we record each node's neighbors:
0:3,17
5:3,7
1:7
2:7
7:1,2,5
3:0,5,8,9
8:3
9:3
17:0
by concatenating each node's neighbor_list in the order we save the node id.
we get [3,17,3,7,7,7,1,2,5,0,5,8,9,3,3,0]
this is the neighbor_list of GpuPsCommGraph
given this neighbor_list and the order to save node id,
we know,
node 0's neighbors are in the range [0,1] of neighbor_list
node 5's neighbors are in the range [2,3] of neighbor_list
node 1's neighbors are in the range [4,4] of neighbor_list
node 2:[5,5]
node 7:[6,6]
node 3:[9,12]
node 8:[13,13]
node 9:[14,14]
node 17:[15,15]
...
by the above information,
we generate a node_list:GpuPsGraphNode *graph_node_list in GpuPsCommGraph
of size 9,
where node_list[i].id = u_id[i]
then we have:
node_list[0]-> node_id:0, neighbor_size:2, neighbor_offset:0
node_list[1]-> node_id:5, neighbor_size:2, neighbor_offset:2
node_list[2]-> node_id:1, neighbor_size:1, neighbor_offset:4
node_list[3]-> node_id:2, neighbor_size:1, neighbor_offset:5
node_list[4]-> node_id:7, neighbor_size:3, neighbor_offset:6
node_list[5]-> node_id:3, neighbor_size:4, neighbor_offset:9
node_list[6]-> node_id:8, neighbor_size:1, neighbor_offset:13
node_list[7]-> node_id:9, neighbor_size:1, neighbor_offset:14
node_list[8]-> node_id:17, neighbor_size:1, neighbor_offset:15
*/
struct NeighborSampleResult {
int64_t *val;
int *actual_sample_size, sample_size, key_size;
NeighborSampleResult(int _sample_size, int _key_size)
: sample_size(_sample_size), key_size(_key_size) {
actual_sample_size = NULL;
val = NULL;
};
~NeighborSampleResult() {
if (val != NULL) cudaFree(val);
if (actual_sample_size != NULL) cudaFree(actual_sample_size);
}
};
struct NodeQueryResult {
int64_t *val;
int actual_sample_size;
NodeQueryResult() {
val = NULL;
actual_sample_size = 0;
};
~NodeQueryResult() {
if (val != NULL) cudaFree(val);
}
};
}
};
#endif
......@@ -14,114 +14,25 @@
#pragma once
#include "heter_comm.h"
#include "paddle/fluid/distributed/ps/table/common_graph_table.h"
#include "paddle/fluid/framework/fleet/heter_ps/gpu_graph_node.h"
#include "paddle/fluid/platform/enforce.h"
#ifdef PADDLE_WITH_HETERPS
namespace paddle {
namespace framework {
struct GpuPsGraphNode {
int64_t node_id;
int neighbor_size, neighbor_offset;
// this node's neighbor is stored on [neighbor_offset,neighbor_offset +
// neighbor_size) of int64_t *neighbor_list;
};
struct GpuPsCommGraph {
int64_t *neighbor_list;
GpuPsGraphNode *node_list;
int neighbor_size, node_size;
// the size of neighbor array and graph_node_list array
GpuPsCommGraph()
: neighbor_list(NULL), node_list(NULL), neighbor_size(0), node_size(0) {}
GpuPsCommGraph(int64_t *neighbor_list_, GpuPsGraphNode *node_list_,
int neighbor_size_, int node_size_)
: neighbor_list(neighbor_list_),
node_list(node_list_),
neighbor_size(neighbor_size_),
node_size(node_size_) {}
};
/*
suppose we have a graph like this
0----3-----5----7
\ |\ |\
17 8 9 1 2
we save the nodes in arbitrary order,
in this example,the order is
[0,5,1,2,7,3,8,9,17]
let us name this array u_id;
we record each node's neighbors:
0:3,17
5:3,7
1:7
2:7
7:1,2,5
3:0,5,8,9
8:3
9:3
17:0
by concatenating each node's neighbor_list in the order we save the node id.
we get [3,17,3,7,7,7,1,2,5,0,5,8,9,3,3,0]
this is the neighbor_list of GpuPsCommGraph
given this neighbor_list and the order to save node id,
we know,
node 0's neighbors are in the range [0,1] of neighbor_list
node 5's neighbors are in the range [2,3] of neighbor_list
node 1's neighbors are in the range [4,4] of neighbor_list
node 2:[5,5]
node 7:[6,6]
node 3:[9,12]
node 8:[13,13]
node 9:[14,14]
node 17:[15,15]
...
by the above information,
we generate a node_list:GpuPsGraphNode *graph_node_list in GpuPsCommGraph
of size 9,
where node_list[i].id = u_id[i]
then we have:
node_list[0]-> node_id:0, neighbor_size:2, neighbor_offset:0
node_list[1]-> node_id:5, neighbor_size:2, neighbor_offset:2
node_list[2]-> node_id:1, neighbor_size:1, neighbor_offset:4
node_list[3]-> node_id:2, neighbor_size:1, neighbor_offset:5
node_list[4]-> node_id:7, neighbor_size:3, neighbor_offset:6
node_list[5]-> node_id:3, neighbor_size:4, neighbor_offset:9
node_list[6]-> node_id:8, neighbor_size:1, neighbor_offset:13
node_list[7]-> node_id:9, neighbor_size:1, neighbor_offset:14
node_list[8]-> node_id:17, neighbor_size:1, neighbor_offset:15
*/
struct NeighborSampleResult {
int64_t *val;
int *actual_sample_size, sample_size, key_size;
NeighborSampleResult(int _sample_size, int _key_size)
: sample_size(_sample_size), key_size(_key_size) {
actual_sample_size = NULL;
val = NULL;
};
~NeighborSampleResult() {
if (val != NULL) cudaFree(val);
if (actual_sample_size != NULL) cudaFree(actual_sample_size);
}
};
struct NodeQueryResult {
int64_t *val;
int actual_sample_size;
NodeQueryResult() {
val = NULL;
actual_sample_size = 0;
};
~NodeQueryResult() {
if (val != NULL) cudaFree(val);
}
};
class GpuPsGraphTable : public HeterComm<int64_t, int, int> {
public:
GpuPsGraphTable(std::shared_ptr<HeterPsResource> resource)
: HeterComm<int64_t, int, int>(1, resource) {
load_factor_ = 0.25;
rw_lock.reset(new pthread_rwlock_t());
cpu_table_status = -1;
}
~GpuPsGraphTable() {
if (cpu_table_status != -1) {
end_graph_sampling();
}
}
void build_graph_from_cpu(std::vector<GpuPsCommGraph> &cpu_node_list);
NodeQueryResult *graph_node_sample(int gpu_id, int sample_size);
......@@ -134,9 +45,19 @@ class GpuPsGraphTable : public HeterComm<int64_t, int, int> {
int *h_right,
int64_t *src_sample_res,
int *actual_sample_size);
int init_cpu_table(const paddle::distributed::GraphParameter &graph);
int load(const std::string &path, const std::string &param);
virtual int32_t end_graph_sampling() {
return cpu_graph_table->end_graph_sampling();
}
private:
std::vector<GpuPsCommGraph> gpu_graph_list;
std::shared_ptr<paddle::distributed::GraphTable> cpu_graph_table;
std::shared_ptr<pthread_rwlock_t> rw_lock;
mutable std::mutex mutex_;
std::condition_variable cv_;
int cpu_table_status;
};
}
};
......
......@@ -14,6 +14,7 @@
#pragma once
#ifdef PADDLE_WITH_HETERPS
//#include "paddle/fluid/framework/fleet/heter_ps/graph_gpu_ps_table.h"
namespace paddle {
namespace framework {
/*
......@@ -45,6 +46,33 @@ __global__ void neighbor_sample_example(GpuPsCommGraph graph, int* index,
}
}
int GpuPsGraphTable::init_cpu_table(
const paddle::distributed::GraphParameter& graph) {
cpu_graph_table.reset(new paddle::distributed::GraphTable);
cpu_table_status = cpu_graph_table->initialize(graph);
if (cpu_table_status != 0) return cpu_table_status;
std::function<void(std::vector<GpuPsCommGraph>&)> callback =
[this](std::vector<GpuPsCommGraph>& res) {
pthread_rwlock_wrlock(this->rw_lock.get());
this->clear_graph_info();
this->build_graph_from_cpu(res);
pthread_rwlock_unlock(this->rw_lock.get());
cv_.notify_one();
};
cpu_graph_table->set_graph_sample_callback(callback);
return cpu_table_status;
}
int GpuPsGraphTable::load(const std::string& path, const std::string& param) {
int status = cpu_graph_table->load(path, param);
if (status != 0) {
return status;
}
std::unique_lock<std::mutex> lock(mutex_);
cpu_graph_table->start_graph_sampling();
cv_.wait(lock);
return 0;
}
/*
comment 1
......@@ -68,6 +96,7 @@ __global__ void neighbor_sample_example(GpuPsCommGraph graph, int* index,
that's what fill_dvals does.
*/
void GpuPsGraphTable::move_neighbor_sample_result_to_source_gpu(
int gpu_id, int gpu_num, int sample_size, int* h_left, int* h_right,
int64_t* src_sample_res, int* actual_sample_size) {
......@@ -258,7 +287,7 @@ NeighborSampleResult* GpuPsGraphTable::graph_neighbor_sample(int gpu_id,
auto d_shard_keys = memory::Alloc(place, len * sizeof(int64_t));
int64_t* d_shard_keys_ptr = reinterpret_cast<int64_t*>(d_shard_keys->ptr());
auto d_shard_vals = memory::Alloc(place, len * sizeof(int64_t));
auto d_shard_vals = memory::Alloc(place, sample_size * len * sizeof(int64_t));
int64_t* d_shard_vals_ptr = reinterpret_cast<int64_t*>(d_shard_vals->ptr());
auto d_shard_actual_sample_size = memory::Alloc(place, len * sizeof(int));
int* d_shard_actual_sample_size_ptr =
......
......@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#ifdef PADDLE_WITH_HETERPS
//#include "paddle/fluid/framework/fleet/heter_ps/heter_comm.h"
#include <queue>
namespace paddle {
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gtest/gtest.h>
#include <algorithm>
#include <vector>
#include "paddle/fluid/framework/fleet/heter_ps/feature_value.h"
#include "paddle/fluid/framework/fleet/heter_ps/graph_gpu_ps_table.h"
#include "paddle/fluid/framework/fleet/heter_ps/heter_comm.h"
#include "paddle/fluid/framework/fleet/heter_ps/heter_resource.h"
#include "paddle/fluid/framework/fleet/heter_ps/optimizer.cuh.h"
#include "paddle/fluid/platform/cuda_device_guard.h"
using namespace paddle::framework;
void prepare_file(char file_name[], std::vector<std::string> data) {
std::ofstream ofile;
ofile.open(file_name);
for (auto x : data) {
ofile << x << std::endl;
}
ofile.close();
}
char edge_file_name[] = "edges.txt";
TEST(TEST_FLEET, graph_sample) {
std::vector<std::string> edges;
int gpu_count = 3;
std::vector<int> dev_ids;
dev_ids.push_back(0);
dev_ids.push_back(1);
dev_ids.push_back(2);
std::shared_ptr<HeterPsResource> resource =
std::make_shared<HeterPsResource>(dev_ids);
resource->enable_p2p();
GpuPsGraphTable g(resource);
int node_count = 10;
std::vector<std::vector<int64_t>> neighbors(node_count);
int ind = 0;
int64_t node_id = 0;
// std::vector<GpuPsCommGraph> graph_list(gpu_count);
while (ind < node_count) {
int neighbor_size = ind + 1;
while (neighbor_size--) {
edges.push_back(std::to_string(ind) + "\t" + std::to_string(node_id) +
"\t1.0");
node_id++;
}
ind++;
}
/*
gpu 0:
0,3,6,9
gpu 1:
1,4,7
gpu 2:
2,5,8
query(2,6) returns nodes [6,9,1,4,7,2]
*/
::paddle::distributed::GraphParameter table_proto;
table_proto.set_gpups_mode(true);
table_proto.set_gpups_mode_shard_num(127);
table_proto.set_gpu_num(3);
table_proto.set_gpups_graph_sample_class("BasicBfsGraphSampler");
table_proto.set_gpups_graph_sample_args("5,5,1,1");
prepare_file(edge_file_name, edges);
g.init_cpu_table(table_proto);
g.load(std::string(edge_file_name), std::string("e>"));
/*
node x's neighbor list = [(1+x)*x/2,(1+x)*x/2 + 1,.....,(1+x)*x/2 + x]
so node 6's neighbors are [21,22...,27]
node 7's neighbors are [28,29,..35]
node 0's neighbors are [0]
query([7,0,6],sample_size=3) should return [28,29,30,0,x,x,21,22,23]
6 --index-->2
0 --index--->0
7 --index-->2
*/
int64_t cpu_key[3] = {7, 0, 6};
void *key;
cudaMalloc((void **)&key, 3 * sizeof(int64_t));
cudaMemcpy(key, cpu_key, 3 * sizeof(int64_t), cudaMemcpyHostToDevice);
auto neighbor_sample_res = g.graph_neighbor_sample(0, (int64_t *)key, 3, 3);
int64_t *res = new int64_t[9];
cudaMemcpy(res, neighbor_sample_res->val, 72, cudaMemcpyDeviceToHost);
std::sort(res, res + 3);
std::sort(res + 6, res + 9);
int64_t expected_sample_val[] = {28, 29, 30, 0, -1, -1, 21, 22, 23};
for (int i = 0; i < 9; i++) {
if (expected_sample_val[i] != -1) {
ASSERT_EQ(res[i], expected_sample_val[i]);
}
}
delete[] res;
delete neighbor_sample_res;
}
......@@ -225,7 +225,7 @@ void BindGraphPyClient(py::module* m) {
.def("stop_server", &GraphPyClient::stop_server)
.def("get_node_feat",
[](GraphPyClient& self, std::string node_type,
std::vector<uint64_t> node_ids,
std::vector<int64_t> node_ids,
std::vector<std::string> feature_names) {
auto feats =
self.get_node_feat(node_type, node_ids, feature_names);
......@@ -239,7 +239,7 @@ void BindGraphPyClient(py::module* m) {
})
.def("set_node_feat",
[](GraphPyClient& self, std::string node_type,
std::vector<uint64_t> node_ids,
std::vector<int64_t> node_ids,
std::vector<std::string> feature_names,
std::vector<std::vector<py::bytes>> bytes_feats) {
std::vector<std::vector<std::string>> feats(bytes_feats.size());
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册