From 876aa71776c701e458703832a758ee65e0b15124 Mon Sep 17 00:00:00 2001 From: seemingwang Date: Thu, 2 Dec 2021 16:14:55 +0800 Subject: [PATCH] support distributed graph_split load and query. (#37740) --- .../distributed/service/graph_brpc_client.cc | 36 +++ .../distributed/service/graph_brpc_client.h | 2 + .../distributed/service/graph_brpc_server.cc | 17 ++ .../distributed/service/graph_brpc_server.h | 4 + .../fluid/distributed/service/sendrecv.proto | 1 + .../distributed/table/common_graph_table.cc | 212 ++++++++++++-- .../distributed/table/common_graph_table.h | 14 +- .../distributed/table/graph/graph_node.cc | 3 + paddle/fluid/distributed/test/CMakeLists.txt | 3 + .../distributed/test/graph_node_split_test.cc | 275 ++++++++++++++++++ 10 files changed, 534 insertions(+), 33 deletions(-) create mode 100644 paddle/fluid/distributed/test/graph_node_split_test.cc diff --git a/paddle/fluid/distributed/service/graph_brpc_client.cc b/paddle/fluid/distributed/service/graph_brpc_client.cc index c5ad4b00994..a9682d6a6ef 100644 --- a/paddle/fluid/distributed/service/graph_brpc_client.cc +++ b/paddle/fluid/distributed/service/graph_brpc_client.cc @@ -514,6 +514,42 @@ std::future GraphBrpcClient::random_sample_nodes( return fut; } +std::future GraphBrpcClient::load_graph_split_config( + uint32_t table_id, std::string path) { + DownpourBrpcClosure *closure = new DownpourBrpcClosure( + server_size, [&, server_size = this->server_size ](void *done) { + int ret = 0; + auto *closure = (DownpourBrpcClosure *)done; + size_t fail_num = 0; + for (size_t request_idx = 0; request_idx < server_size; ++request_idx) { + if (closure->check_response(request_idx, + PS_GRAPH_LOAD_GRAPH_SPLIT_CONFIG) != 0) { + ++fail_num; + break; + } + } + ret = fail_num == 0 ? 0 : -1; + closure->set_promise_value(ret); + }); + auto promise = std::make_shared>(); + closure->add_promise(promise); + std::future fut = promise->get_future(); + for (size_t i = 0; i < server_size; i++) { + int server_index = i; + closure->request(server_index) + ->set_cmd_id(PS_GRAPH_LOAD_GRAPH_SPLIT_CONFIG); + closure->request(server_index)->set_table_id(table_id); + closure->request(server_index)->set_client_id(_client_id); + closure->request(server_index)->add_params(path); + GraphPsService_Stub rpc_stub = + getServiceStub(get_cmd_channel(server_index)); + closure->cntl(server_index)->set_log_id(butil::gettimeofday_ms()); + rpc_stub.service(closure->cntl(server_index), + closure->request(server_index), + closure->response(server_index), closure); + } + return fut; +} std::future GraphBrpcClient::use_neighbors_sample_cache( uint32_t table_id, size_t total_size_limit, size_t ttl) { DownpourBrpcClosure *closure = new DownpourBrpcClosure( diff --git a/paddle/fluid/distributed/service/graph_brpc_client.h b/paddle/fluid/distributed/service/graph_brpc_client.h index e3d2ff1d32d..2e5d5b6ee93 100644 --- a/paddle/fluid/distributed/service/graph_brpc_client.h +++ b/paddle/fluid/distributed/service/graph_brpc_client.h @@ -93,6 +93,8 @@ class GraphBrpcClient : public BrpcPsClient { virtual std::future use_neighbors_sample_cache(uint32_t table_id, size_t size_limit, size_t ttl); + virtual std::future load_graph_split_config(uint32_t table_id, + std::string path); virtual std::future remove_graph_node( uint32_t table_id, std::vector& node_id_list); virtual int32_t initialize(); diff --git a/paddle/fluid/distributed/service/graph_brpc_server.cc b/paddle/fluid/distributed/service/graph_brpc_server.cc index 094ecbbd402..c1348e4804e 100644 --- a/paddle/fluid/distributed/service/graph_brpc_server.cc +++ b/paddle/fluid/distributed/service/graph_brpc_server.cc @@ -204,6 +204,8 @@ int32_t GraphBrpcService::initialize() { &GraphBrpcService::sample_neighbors_across_multi_servers; _service_handler_map[PS_GRAPH_USE_NEIGHBORS_SAMPLE_CACHE] = &GraphBrpcService::use_neighbors_sample_cache; + _service_handler_map[PS_GRAPH_LOAD_GRAPH_SPLIT_CONFIG] = + &GraphBrpcService::load_graph_split_config; // shard初始化,server启动后才可从env获取到server_list的shard信息 initialize_shard_info(); @@ -658,5 +660,20 @@ int32_t GraphBrpcService::use_neighbors_sample_cache( ((GraphTable *)table)->make_neighbor_sample_cache(size_limit, ttl); return 0; } + +int32_t GraphBrpcService::load_graph_split_config( + Table *table, const PsRequestMessage &request, PsResponseMessage &response, + brpc::Controller *cntl) { + CHECK_TABLE_EXIST(table, request, response) + if (request.params_size() < 1) { + set_response_code(response, -1, + "load_graph_split_configrequest requires at least 1 " + "argument1[file_path]"); + return 0; + } + ((GraphTable *)table)->load_graph_split_config(request.params(0)); + return 0; +} + } // namespace distributed } // namespace paddle diff --git a/paddle/fluid/distributed/service/graph_brpc_server.h b/paddle/fluid/distributed/service/graph_brpc_server.h index d1a6aa63604..ecd78d28ca8 100644 --- a/paddle/fluid/distributed/service/graph_brpc_server.h +++ b/paddle/fluid/distributed/service/graph_brpc_server.h @@ -126,6 +126,10 @@ class GraphBrpcService : public PsBaseService { PsResponseMessage &response, brpc::Controller *cntl); + int32_t load_graph_split_config(Table *table, const PsRequestMessage &request, + PsResponseMessage &response, + brpc::Controller *cntl); + private: bool _is_initialize_shard_info; std::mutex _initialize_shard_mutex; diff --git a/paddle/fluid/distributed/service/sendrecv.proto b/paddle/fluid/distributed/service/sendrecv.proto index 8ee9b359072..6dfaff1ffa1 100644 --- a/paddle/fluid/distributed/service/sendrecv.proto +++ b/paddle/fluid/distributed/service/sendrecv.proto @@ -58,6 +58,7 @@ enum PsCmdID { PS_GRAPH_SET_NODE_FEAT = 37; PS_GRAPH_SAMPLE_NODES_FROM_ONE_SERVER = 38; PS_GRAPH_USE_NEIGHBORS_SAMPLE_CACHE = 39; + PS_GRAPH_LOAD_GRAPH_SPLIT_CONFIG = 40; } message PsRequestMessage { diff --git a/paddle/fluid/distributed/table/common_graph_table.cc b/paddle/fluid/distributed/table/common_graph_table.cc index b690d71eab8..042a4dee62b 100644 --- a/paddle/fluid/distributed/table/common_graph_table.cc +++ b/paddle/fluid/distributed/table/common_graph_table.cc @@ -56,7 +56,7 @@ int32_t GraphTable::add_graph_node(std::vector &id_list, tasks.push_back(_shards_task_pool[i]->enqueue([&batch, i, this]() -> int { for (auto &p : batch[i]) { size_t index = p.first % this->shard_num - this->shard_start; - this->shards[index].add_graph_node(p.first)->build_edges(p.second); + this->shards[index]->add_graph_node(p.first)->build_edges(p.second); } return 0; })); @@ -79,7 +79,7 @@ int32_t GraphTable::remove_graph_node(std::vector &id_list) { tasks.push_back(_shards_task_pool[i]->enqueue([&batch, i, this]() -> int { for (auto &p : batch[i]) { size_t index = p % this->shard_num - this->shard_start; - this->shards[index].delete_node(p); + this->shards[index]->delete_node(p); } return 0; })); @@ -97,6 +97,7 @@ void GraphShard::clear() { } GraphShard::~GraphShard() { clear(); } + void GraphShard::delete_node(uint64_t id) { auto iter = node_location.find(id); if (iter == node_location.end()) return; @@ -117,6 +118,14 @@ GraphNode *GraphShard::add_graph_node(uint64_t id) { return (GraphNode *)bucket[node_location[id]]; } +GraphNode *GraphShard::add_graph_node(Node *node) { + auto id = node->get_id(); + if (node_location.find(id) == node_location.end()) { + node_location[id] = bucket.size(); + bucket.push_back(node); + } + return (GraphNode *)bucket[node_location[id]]; +} FeatureNode *GraphShard::add_feature_node(uint64_t id) { if (node_location.find(id) == node_location.end()) { node_location[id] = bucket.size(); @@ -134,6 +143,33 @@ Node *GraphShard::find_node(uint64_t id) { return iter == node_location.end() ? nullptr : bucket[iter->second]; } +GraphTable::~GraphTable() { + for (auto p : shards) { + delete p; + } + for (auto p : extra_shards) { + delete p; + } + shards.clear(); + extra_shards.clear(); +} + +int32_t GraphTable::load_graph_split_config(const std::string &path) { + VLOG(4) << "in server side load graph split config\n"; + std::ifstream file(path); + std::string line; + while (std::getline(file, line)) { + auto values = paddle::string::split_string(line, "\t"); + if (values.size() < 2) continue; + size_t index = (size_t)std::stoi(values[0]); + if (index != _shard_idx) continue; + auto dst_id = std::stoull(values[1]); + extra_nodes.insert(dst_id); + } + if (extra_nodes.size() != 0) use_duplicate_nodes = true; + return 0; +} + int32_t GraphTable::load(const std::string &path, const std::string ¶m) { bool load_edge = (param[0] == 'e'); bool load_node = (param[0] == 'n'); @@ -154,7 +190,7 @@ int32_t GraphTable::get_nodes_ids_by_ranges( res.clear(); std::vector>> tasks; for (size_t i = 0; i < shards.size() && index < (int)ranges.size(); i++) { - end = total_size + shards[i].get_size(); + end = total_size + shards[i]->get_size(); start = total_size; while (start < end && index < ranges.size()) { if (ranges[index].second <= start) @@ -169,11 +205,11 @@ int32_t GraphTable::get_nodes_ids_by_ranges( second -= total_size; tasks.push_back(_shards_task_pool[i % task_pool_size_]->enqueue( [this, first, second, i]() -> std::vector { - return shards[i].get_ids_by_range(first, second); + return shards[i]->get_ids_by_range(first, second); })); } } - total_size += shards[i].get_size(); + total_size += shards[i]->get_size(); } for (size_t i = 0; i < tasks.size(); i++) { auto vec = tasks[i].get(); @@ -217,7 +253,7 @@ int32_t GraphTable::load_nodes(const std::string &path, std::string node_type) { size_t index = shard_id - shard_start; - auto node = shards[index].add_feature_node(id); + auto node = shards[index]->add_feature_node(id); node->set_feature_size(feat_name.size()); @@ -245,7 +281,7 @@ int32_t GraphTable::load_edges(const std::string &path, bool reverse_edge) { std::string sample_type = "random"; bool is_weighted = false; int valid_count = 0; - + int extra_alloc_index = 0; for (auto path : paths) { std::ifstream file(path); std::string line; @@ -268,8 +304,24 @@ int32_t GraphTable::load_edges(const std::string &path, bool reverse_edge) { size_t src_shard_id = src_id % shard_num; if (src_shard_id >= shard_end || src_shard_id < shard_start) { - VLOG(4) << "will not load " << src_id << " from " << path - << ", please check id distribution"; + if (use_duplicate_nodes == false || + extra_nodes.find(src_id) == extra_nodes.end()) { + VLOG(4) << "will not load " << src_id << " from " << path + << ", please check id distribution"; + continue; + } + int index; + if (extra_nodes_to_thread_index.find(src_id) != + extra_nodes_to_thread_index.end()) { + index = extra_nodes_to_thread_index[src_id]; + } else { + index = extra_alloc_index++; + extra_alloc_index %= task_pool_size_; + extra_nodes_to_thread_index[src_id] = index; + } + extra_shards[index]->add_graph_node(src_id)->build_edges(is_weighted); + extra_shards[index]->add_neighbor(src_id, dst_id, weight); + valid_count++; continue; } if (count % 1000000 == 0) { @@ -278,36 +330,130 @@ int32_t GraphTable::load_edges(const std::string &path, bool reverse_edge) { } size_t index = src_shard_id - shard_start; - shards[index].add_graph_node(src_id)->build_edges(is_weighted); - shards[index].add_neighbor(src_id, dst_id, weight); + shards[index]->add_graph_node(src_id)->build_edges(is_weighted); + shards[index]->add_neighbor(src_id, dst_id, weight); valid_count++; } } VLOG(0) << valid_count << "/" << count << " edges are loaded successfully in " << path; + std::vector used(task_pool_size_, 0); // Build Sampler j for (auto &shard : shards) { - auto bucket = shard.get_bucket(); + auto bucket = shard->get_bucket(); for (size_t i = 0; i < bucket.size(); i++) { bucket[i]->build_sampler(sample_type); + used[get_thread_pool_index(bucket[i]->get_id())]++; } } + /*----------------------- + relocate the duplicate nodes to make them distributed evenly among threads. +*/ + for (auto &shard : extra_shards) { + auto bucket = shard->get_bucket(); + for (size_t i = 0; i < bucket.size(); i++) { + bucket[i]->build_sampler(sample_type); + } + } + int size = extra_nodes_to_thread_index.size(); + if (size == 0) return 0; + std::vector index; + for (int i = 0; i < used.size(); i++) index.push_back(i); + sort(index.begin(), index.end(), + [&](int &a, int &b) { return used[a] < used[b]; }); + + std::vector alloc(index.size(), 0), has_alloc(index.size(), 0); + int t = 1, aim = 0, mod = 0; + for (; t < used.size(); t++) { + if ((used[index[t]] - used[index[t - 1]]) * t >= size) { + break; + } else { + size -= (used[index[t]] - used[index[t - 1]]) * t; + } + } + aim = used[index[t - 1]] + size / t; + mod = size % t; + for (int x = t - 1; x >= 0; x--) { + alloc[index[x]] = aim; + if (t - x <= mod) alloc[index[x]]++; + alloc[index[x]] -= used[index[x]]; + } + std::vector vec[index.size()]; + for (auto p : extra_nodes_to_thread_index) { + has_alloc[p.second]++; + vec[p.second].push_back(p.first); + } + sort(index.begin(), index.end(), [&](int &a, int &b) { + return has_alloc[a] - alloc[a] < has_alloc[b] - alloc[b]; + }); + int left = 0, right = index.size() - 1; + while (left < right) { + if (has_alloc[index[right]] - alloc[index[right]] == 0) break; + int x = std::min(alloc[index[left]] - has_alloc[index[left]], + has_alloc[index[right]] - alloc[index[right]]); + has_alloc[index[left]] += x; + has_alloc[index[right]] -= x; + uint64_t id; + while (x--) { + id = vec[index[right]].back(); + vec[index[right]].pop_back(); + extra_nodes_to_thread_index[id] = index[left]; + vec[index[left]].push_back(id); + } + if (has_alloc[index[right]] - alloc[index[right]] == 0) right--; + if (alloc[index[left]] - has_alloc[index[left]] == 0) left++; + } + std::vector extra_shards_copy; + for (int i = 0; i < task_pool_size_; ++i) { + extra_shards_copy.push_back(new GraphShard()); + } + for (auto &shard : extra_shards) { + auto &bucket = shard->get_bucket(); + auto &node_location = shard->get_node_location(); + while (bucket.size()) { + Node *temp = bucket.back(); + bucket.pop_back(); + node_location.erase(temp->get_id()); + extra_shards_copy[extra_nodes_to_thread_index[temp->get_id()]] + ->add_graph_node(temp); + } + } + for (int i = 0; i < task_pool_size_; ++i) { + delete extra_shards[i]; + extra_shards[i] = extra_shards_copy[i]; + } return 0; } Node *GraphTable::find_node(uint64_t id) { size_t shard_id = id % shard_num; if (shard_id >= shard_end || shard_id < shard_start) { - return nullptr; + if (use_duplicate_nodes == false || extra_nodes_to_thread_index.size() == 0) + return nullptr; + auto iter = extra_nodes_to_thread_index.find(id); + if (iter == extra_nodes_to_thread_index.end()) + return nullptr; + else { + return extra_shards[iter->second]->find_node(id); + } } size_t index = shard_id - shard_start; - Node *node = shards[index].find_node(id); + Node *node = shards[index]->find_node(id); return node; } uint32_t GraphTable::get_thread_pool_index(uint64_t node_id) { - return node_id % shard_num % shard_num_per_server % task_pool_size_; + if (use_duplicate_nodes == false || extra_nodes_to_thread_index.size() == 0) + return node_id % shard_num % shard_num_per_server % task_pool_size_; + size_t src_shard_id = node_id % shard_num; + if (src_shard_id >= shard_end || src_shard_id < shard_start) { + auto iter = extra_nodes_to_thread_index.find(node_id); + if (iter != extra_nodes_to_thread_index.end()) { + return iter->second; + } + } + return src_shard_id % shard_num_per_server % task_pool_size_; } uint32_t GraphTable::get_thread_pool_index_by_shard_index( @@ -319,11 +465,16 @@ int32_t GraphTable::clear_nodes() { std::vector> tasks; for (size_t i = 0; i < shards.size(); i++) { tasks.push_back( - _shards_task_pool[get_thread_pool_index_by_shard_index(i)]->enqueue( - [this, i]() -> int { - this->shards[i].clear(); - return 0; - })); + _shards_task_pool[i % task_pool_size_]->enqueue([this, i]() -> int { + this->shards[i]->clear(); + return 0; + })); + } + for (size_t i = 0; i < extra_shards.size(); i++) { + tasks.push_back(_shards_task_pool[i]->enqueue([this, i]() -> int { + this->extra_shards[i]->clear(); + return 0; + })); } for (size_t i = 0; i < tasks.size(); i++) tasks[i].get(); return 0; @@ -334,7 +485,7 @@ int32_t GraphTable::random_sample_nodes(int sample_size, int &actual_size) { int total_size = 0; for (int i = 0; i < shards.size(); i++) { - total_size += shards[i].get_size(); + total_size += shards[i]->get_size(); } if (sample_size > total_size) sample_size = total_size; int range_num = random_sample_nodes_ranges; @@ -401,8 +552,8 @@ int32_t GraphTable::random_sample_neighbors( size_t node_num = buffers.size(); std::function char_del = [](char *c) { delete[] c; }; std::vector> tasks; - std::vector> seq_id(shard_end - shard_start); - std::vector> id_list(shard_end - shard_start); + std::vector> seq_id(task_pool_size_); + std::vector> id_list(task_pool_size_); size_t index; for (size_t idx = 0; idx < node_num; ++idx) { index = get_thread_pool_index(node_ids[idx]); @@ -524,7 +675,7 @@ int32_t GraphTable::set_node_feat( tasks.push_back(_shards_task_pool[get_thread_pool_index(node_id)]->enqueue( [&, idx, node_id]() -> int { size_t index = node_id % this->shard_num - this->shard_start; - auto node = shards[index].add_feature_node(node_id); + auto node = shards[index]->add_feature_node(node_id); node->set_feature_size(this->feat_name.size()); for (int feat_idx = 0; feat_idx < feature_names.size(); ++feat_idx) { const std::string &feature_name = feature_names[feat_idx]; @@ -581,7 +732,7 @@ int32_t GraphTable::pull_graph_list(int start, int total_size, int size = 0, cur_size; std::vector>> tasks; for (size_t i = 0; i < shards.size() && total_size > 0; i++) { - cur_size = shards[i].get_size(); + cur_size = shards[i]->get_size(); if (size + cur_size <= start) { size += cur_size; continue; @@ -590,7 +741,7 @@ int32_t GraphTable::pull_graph_list(int start, int total_size, int end = start + (count - 1) * step + 1; tasks.push_back(_shards_task_pool[i % task_pool_size_]->enqueue( [this, i, start, end, step, size]() -> std::vector { - return this->shards[i].get_batch(start - size, end - size, step); + return this->shards[i]->get_batch(start - size, end - size, step); })); start += count * step; total_size -= count; @@ -665,7 +816,14 @@ int32_t GraphTable::initialize() { shard_end = shard_start + shard_num_per_server; VLOG(0) << "in init graph table shard idx = " << _shard_idx << " shard_start " << shard_start << " shard_end " << shard_end; - shards = std::vector(shard_num_per_server, GraphShard(shard_num)); + for (int i = 0; i < shard_num_per_server; i++) { + shards.push_back(new GraphShard()); + } + use_duplicate_nodes = false; + for (int i = 0; i < task_pool_size_; i++) { + extra_shards.push_back(new GraphShard()); + } + return 0; } } // namespace distributed diff --git a/paddle/fluid/distributed/table/common_graph_table.h b/paddle/fluid/distributed/table/common_graph_table.h index 9ca59db3bb2..b76ab0ae950 100644 --- a/paddle/fluid/distributed/table/common_graph_table.h +++ b/paddle/fluid/distributed/table/common_graph_table.h @@ -47,7 +47,6 @@ class GraphShard { public: size_t get_size(); GraphShard() {} - GraphShard(int shard_num) { this->shard_num = shard_num; } ~GraphShard(); std::vector &get_bucket() { return bucket; } std::vector get_batch(int start, int end, int step); @@ -60,18 +59,18 @@ class GraphShard { } GraphNode *add_graph_node(uint64_t id); + GraphNode *add_graph_node(Node *node); FeatureNode *add_feature_node(uint64_t id); Node *find_node(uint64_t id); void delete_node(uint64_t id); void clear(); void add_neighbor(uint64_t id, uint64_t dst_id, float weight); - std::unordered_map get_node_location() { + std::unordered_map &get_node_location() { return node_location; } private: std::unordered_map node_location; - int shard_num; std::vector bucket; }; @@ -355,7 +354,7 @@ class ScaledLRU { class GraphTable : public SparseTable { public: GraphTable() { use_cache = false; } - virtual ~GraphTable() {} + virtual ~GraphTable(); virtual int32_t pull_graph_list(int start, int size, std::unique_ptr &buffer, int &actual_size, bool need_feature, @@ -374,6 +373,7 @@ class GraphTable : public SparseTable { virtual int32_t initialize(); int32_t load(const std::string &path, const std::string ¶m); + int32_t load_graph_split_config(const std::string &path); int32_t load_edges(const std::string &path, bool reverse); @@ -434,7 +434,7 @@ class GraphTable : public SparseTable { } protected: - std::vector shards; + std::vector shards, extra_shards; size_t shard_start, shard_end, server_num, shard_num_per_server, shard_num; const int task_pool_size_ = 24; const int random_sample_nodes_ranges = 3; @@ -449,7 +449,9 @@ class GraphTable : public SparseTable { std::vector> _shards_task_pool; std::vector> _shards_task_rng_pool; std::shared_ptr> scaled_lru; - bool use_cache; + std::unordered_set extra_nodes; + std::unordered_map extra_nodes_to_thread_index; + bool use_cache, use_duplicate_nodes; mutable std::mutex mutex_; }; } // namespace distributed diff --git a/paddle/fluid/distributed/table/graph/graph_node.cc b/paddle/fluid/distributed/table/graph/graph_node.cc index e2311cc307b..52c708be884 100644 --- a/paddle/fluid/distributed/table/graph/graph_node.cc +++ b/paddle/fluid/distributed/table/graph/graph_node.cc @@ -65,6 +65,9 @@ void GraphNode::build_edges(bool is_weighted) { } } void GraphNode::build_sampler(std::string sample_type) { + if (sampler != nullptr) { + return; + } if (sample_type == "random") { sampler = new RandomSampler(); } else if (sample_type == "weighted") { diff --git a/paddle/fluid/distributed/test/CMakeLists.txt b/paddle/fluid/distributed/test/CMakeLists.txt index 597a08973b9..62de82832e1 100644 --- a/paddle/fluid/distributed/test/CMakeLists.txt +++ b/paddle/fluid/distributed/test/CMakeLists.txt @@ -21,6 +21,9 @@ cc_test(brpc_utils_test SRCS brpc_utils_test.cc DEPS brpc_utils scope math_funct set_source_files_properties(graph_node_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) cc_test(graph_node_test SRCS graph_node_test.cc DEPS graph_py_service scope server client communicator ps_service boost table ps_framework_proto ${COMMON_DEPS}) +set_source_files_properties(graph_node_split_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) +cc_test(graph_node_split_test SRCS graph_node_split_test.cc DEPS graph_py_service scope server client communicator ps_service boost table ps_framework_proto ${COMMON_DEPS}) + set_source_files_properties(feature_value_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) cc_test(feature_value_test SRCS feature_value_test.cc DEPS ${COMMON_DEPS} boost table) diff --git a/paddle/fluid/distributed/test/graph_node_split_test.cc b/paddle/fluid/distributed/test/graph_node_split_test.cc new file mode 100644 index 00000000000..3fcddde787f --- /dev/null +++ b/paddle/fluid/distributed/test/graph_node_split_test.cc @@ -0,0 +1,275 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include // NOLINT +#include +#include +#include +#include // NOLINT +#include +#include +#include "google/protobuf/text_format.h" + +#include "gtest/gtest.h" +#include "paddle/fluid/distributed/ps.pb.h" +#include "paddle/fluid/distributed/service/brpc_ps_client.h" +#include "paddle/fluid/distributed/service/brpc_ps_server.h" +#include "paddle/fluid/distributed/service/env.h" +#include "paddle/fluid/distributed/service/graph_brpc_client.h" +#include "paddle/fluid/distributed/service/graph_brpc_server.h" +#include "paddle/fluid/distributed/service/graph_py_service.h" +#include "paddle/fluid/distributed/service/ps_client.h" +#include "paddle/fluid/distributed/service/sendrecv.pb.h" +#include "paddle/fluid/distributed/service/service.h" +#include "paddle/fluid/distributed/table/graph/graph_node.h" +#include "paddle/fluid/framework/lod_tensor.h" +#include "paddle/fluid/framework/program_desc.h" +#include "paddle/fluid/framework/scope.h" +#include "paddle/fluid/framework/tensor_util.h" +#include "paddle/fluid/framework/variable.h" +#include "paddle/fluid/operators/math/math_function.h" +#include "paddle/fluid/platform/place.h" +#include "paddle/fluid/string/printf.h" + +namespace framework = paddle::framework; +namespace platform = paddle::platform; +namespace operators = paddle::operators; +namespace math = paddle::operators::math; +namespace memory = paddle::memory; +namespace distributed = paddle::distributed; + +std::vector edges = { + std::string("37\t45\t0.34"), std::string("37\t145\t0.31"), + std::string("37\t112\t0.21"), std::string("96\t48\t1.4"), + std::string("96\t247\t0.31"), std::string("96\t111\t1.21"), + std::string("59\t45\t0.34"), std::string("59\t145\t0.31"), + std::string("59\t122\t0.21"), std::string("97\t48\t0.34"), + std::string("97\t247\t0.31"), std::string("97\t111\t0.21")}; +char edge_file_name[] = "edges.txt"; + +std::vector nodes = { + std::string("user\t37\ta 0.34\tb 13 14\tc hello\td abc"), + std::string("user\t96\ta 0.31\tb 15 10\tc 96hello\td abcd"), + std::string("user\t59\ta 0.11\tb 11 14"), + std::string("user\t97\ta 0.11\tb 12 11"), + std::string("item\t45\ta 0.21"), + std::string("item\t145\ta 0.21"), + std::string("item\t112\ta 0.21"), + std::string("item\t48\ta 0.21"), + std::string("item\t247\ta 0.21"), + std::string("item\t111\ta 0.21"), + std::string("item\t46\ta 0.21"), + std::string("item\t146\ta 0.21"), + std::string("item\t122\ta 0.21"), + std::string("item\t49\ta 0.21"), + std::string("item\t248\ta 0.21"), + std::string("item\t113\ta 0.21")}; +char node_file_name[] = "nodes.txt"; + +std::vector graph_split = {std::string("0\t97")}; +char graph_split_file_name[] = "graph_split.txt"; + +void prepare_file(char file_name[], std::vector data) { + std::ofstream ofile; + ofile.open(file_name); + for (auto x : data) { + ofile << x << std::endl; + } + + ofile.close(); +} +void GetDownpourSparseTableProto( + ::paddle::distributed::TableParameter* sparse_table_proto) { + sparse_table_proto->set_table_id(0); + sparse_table_proto->set_table_class("GraphTable"); + sparse_table_proto->set_shard_num(127); + sparse_table_proto->set_type(::paddle::distributed::PS_SPARSE_TABLE); + ::paddle::distributed::TableAccessorParameter* accessor_proto = + sparse_table_proto->mutable_accessor(); + accessor_proto->set_accessor_class("CommMergeAccessor"); +} + +::paddle::distributed::PSParameter GetServerProto() { + // Generate server proto desc + ::paddle::distributed::PSParameter server_fleet_desc; + ::paddle::distributed::ServerParameter* server_proto = + server_fleet_desc.mutable_server_param(); + ::paddle::distributed::DownpourServerParameter* downpour_server_proto = + server_proto->mutable_downpour_server_param(); + ::paddle::distributed::ServerServiceParameter* server_service_proto = + downpour_server_proto->mutable_service_param(); + server_service_proto->set_service_class("GraphBrpcService"); + server_service_proto->set_server_class("GraphBrpcServer"); + server_service_proto->set_client_class("GraphBrpcClient"); + server_service_proto->set_start_server_port(0); + server_service_proto->set_server_thread_num(12); + + ::paddle::distributed::TableParameter* sparse_table_proto = + downpour_server_proto->add_downpour_table_param(); + GetDownpourSparseTableProto(sparse_table_proto); + return server_fleet_desc; +} + +::paddle::distributed::PSParameter GetWorkerProto() { + ::paddle::distributed::PSParameter worker_fleet_desc; + ::paddle::distributed::WorkerParameter* worker_proto = + worker_fleet_desc.mutable_worker_param(); + + ::paddle::distributed::DownpourWorkerParameter* downpour_worker_proto = + worker_proto->mutable_downpour_worker_param(); + + ::paddle::distributed::TableParameter* worker_sparse_table_proto = + downpour_worker_proto->add_downpour_table_param(); + GetDownpourSparseTableProto(worker_sparse_table_proto); + + ::paddle::distributed::ServerParameter* server_proto = + worker_fleet_desc.mutable_server_param(); + ::paddle::distributed::DownpourServerParameter* downpour_server_proto = + server_proto->mutable_downpour_server_param(); + ::paddle::distributed::ServerServiceParameter* server_service_proto = + downpour_server_proto->mutable_service_param(); + server_service_proto->set_service_class("GraphBrpcService"); + server_service_proto->set_server_class("GraphBrpcServer"); + server_service_proto->set_client_class("GraphBrpcClient"); + server_service_proto->set_start_server_port(0); + server_service_proto->set_server_thread_num(12); + + ::paddle::distributed::TableParameter* server_sparse_table_proto = + downpour_server_proto->add_downpour_table_param(); + GetDownpourSparseTableProto(server_sparse_table_proto); + + return worker_fleet_desc; +} + +/*-------------------------------------------------------------------------*/ + +std::string ip_ = "127.0.0.1", ip2 = "127.0.0.1"; +uint32_t port_ = 5209, port2 = 5210; + +std::vector host_sign_list_; + +std::shared_ptr pserver_ptr_, + pserver_ptr2; + +std::shared_ptr worker_ptr_; + +void RunServer() { + LOG(INFO) << "init first server"; + ::paddle::distributed::PSParameter server_proto = GetServerProto(); + + auto _ps_env = paddle::distributed::PaddlePSEnvironment(); + _ps_env.set_ps_servers(&host_sign_list_, 2); // test + pserver_ptr_ = std::shared_ptr( + (paddle::distributed::GraphBrpcServer*) + paddle::distributed::PSServerFactory::create(server_proto)); + std::vector empty_vec; + framework::ProgramDesc empty_prog; + empty_vec.push_back(empty_prog); + pserver_ptr_->configure(server_proto, _ps_env, 0, empty_vec); + LOG(INFO) << "first server, run start(ip,port)"; + pserver_ptr_->start(ip_, port_); + pserver_ptr_->build_peer2peer_connection(0); + LOG(INFO) << "init first server Done"; +} + +void RunServer2() { + LOG(INFO) << "init second server"; + ::paddle::distributed::PSParameter server_proto2 = GetServerProto(); + + auto _ps_env2 = paddle::distributed::PaddlePSEnvironment(); + _ps_env2.set_ps_servers(&host_sign_list_, 2); // test + pserver_ptr2 = std::shared_ptr( + (paddle::distributed::GraphBrpcServer*) + paddle::distributed::PSServerFactory::create(server_proto2)); + std::vector empty_vec2; + framework::ProgramDesc empty_prog2; + empty_vec2.push_back(empty_prog2); + pserver_ptr2->configure(server_proto2, _ps_env2, 1, empty_vec2); + pserver_ptr2->start(ip2, port2); + pserver_ptr2->build_peer2peer_connection(1); +} + +void RunClient( + std::map>& dense_regions, + int index, paddle::distributed::PsBaseService* service) { + ::paddle::distributed::PSParameter worker_proto = GetWorkerProto(); + paddle::distributed::PaddlePSEnvironment _ps_env; + auto servers_ = host_sign_list_.size(); + _ps_env = paddle::distributed::PaddlePSEnvironment(); + _ps_env.set_ps_servers(&host_sign_list_, servers_); + worker_ptr_ = std::shared_ptr( + (paddle::distributed::GraphBrpcClient*) + paddle::distributed::PSClientFactory::create(worker_proto)); + worker_ptr_->configure(worker_proto, dense_regions, _ps_env, 0); + worker_ptr_->set_shard_num(127); + worker_ptr_->set_local_channel(index); + worker_ptr_->set_local_graph_service( + (paddle::distributed::GraphBrpcService*)service); +} + +void RunGraphSplit() { + setenv("http_proxy", "", 1); + setenv("https_proxy", "", 1); + prepare_file(edge_file_name, edges); + prepare_file(node_file_name, nodes); + prepare_file(graph_split_file_name, graph_split); + auto ph_host = paddle::distributed::PSHost(ip_, port_, 0); + host_sign_list_.push_back(ph_host.serialize_to_string()); + + // test-start + auto ph_host2 = paddle::distributed::PSHost(ip2, port2, 1); + host_sign_list_.push_back(ph_host2.serialize_to_string()); + // test-end + // Srart Server + std::thread* server_thread = new std::thread(RunServer); + + std::thread* server_thread2 = new std::thread(RunServer2); + + sleep(2); + std::map> dense_regions; + dense_regions.insert( + std::pair>(0, {})); + auto regions = dense_regions[0]; + + RunClient(dense_regions, 0, pserver_ptr_->get_service()); + + /*-----------------------Test Server Init----------------------------------*/ + + auto pull_status = worker_ptr_->load_graph_split_config( + 0, std::string(graph_split_file_name)); + pull_status.wait(); + pull_status = + worker_ptr_->load(0, std::string(edge_file_name), std::string("e>")); + srand(time(0)); + pull_status.wait(); + std::vector> _vs; + std::vector> vs; + pull_status = worker_ptr_->batch_sample_neighbors( + 0, std::vector(1, 10240001024), 4, _vs, vs, true); + pull_status.wait(); + ASSERT_EQ(0, _vs[0].size()); + _vs.clear(); + vs.clear(); + pull_status = worker_ptr_->batch_sample_neighbors( + 0, std::vector(1, 97), 4, _vs, vs, true); + pull_status.wait(); + ASSERT_EQ(3, _vs[0].size()); + std::remove(edge_file_name); + std::remove(node_file_name); + std::remove(graph_split_file_name); + LOG(INFO) << "Run stop_server"; + worker_ptr_->stop_server(); + LOG(INFO) << "Run finalize_worker"; + worker_ptr_->finalize_worker(); +} + +TEST(RunGraphSplit, Run) { RunGraphSplit(); } \ No newline at end of file -- GitLab