diff --git a/paddle/fluid/distributed/service/graph_brpc_client.cc b/paddle/fluid/distributed/service/graph_brpc_client.cc index eafb4d596cc1671db26189b84ea9d0c0c31ea398..70f2da6d7252cee0268bdd35999926a232bc5b34 100644 --- a/paddle/fluid/distributed/service/graph_brpc_client.cc +++ b/paddle/fluid/distributed/service/graph_brpc_client.cc @@ -80,11 +80,11 @@ std::future GraphBrpcClient::get_node_feat( [&, node_id_buckets, query_idx_buckets, request_call_num](void *done) { int ret = 0; auto *closure = (DownpourBrpcClosure *)done; - int fail_num = 0; + size_t fail_num = 0; for (int request_idx = 0; request_idx < request_call_num; ++request_idx) { - if (closure->check_response(request_idx, - PS_GRAPH_SAMPLE_NEIGHBOORS) != 0) { + if (closure->check_response(request_idx, PS_GRAPH_GET_NODE_FEAT) != + 0) { ++fail_num; } else { auto &res_io_buffer = @@ -144,6 +144,163 @@ std::future GraphBrpcClient::get_node_feat( return fut; } + +std::future GraphBrpcClient::clear_nodes(uint32_t table_id) { + DownpourBrpcClosure *closure = new DownpourBrpcClosure( + server_size, [&, server_size = this->server_size ](void *done) { + int ret = 0; + auto *closure = (DownpourBrpcClosure *)done; + size_t fail_num = 0; + for (size_t request_idx = 0; request_idx < server_size; ++request_idx) { + if (closure->check_response(request_idx, PS_GRAPH_CLEAR) != 0) { + ++fail_num; + break; + } + } + ret = fail_num == 0 ? 0 : -1; + closure->set_promise_value(ret); + }); + auto promise = std::make_shared>(); + closure->add_promise(promise); + std::future fut = promise->get_future(); + for (size_t i = 0; i < server_size; i++) { + int server_index = i; + closure->request(server_index)->set_cmd_id(PS_GRAPH_CLEAR); + closure->request(server_index)->set_table_id(table_id); + closure->request(server_index)->set_client_id(_client_id); + + GraphPsService_Stub rpc_stub = + getServiceStub(get_cmd_channel(server_index)); + closure->cntl(server_index)->set_log_id(butil::gettimeofday_ms()); + rpc_stub.service(closure->cntl(server_index), + closure->request(server_index), + closure->response(server_index), closure); + } + return fut; +} +std::future GraphBrpcClient::add_graph_node( + uint32_t table_id, std::vector &node_id_list, + std::vector &is_weighted_list) { + std::vector> request_bucket; + std::vector> is_weighted_bucket; + bool add_weight = is_weighted_list.size() > 0; + std::vector server_index_arr; + std::vector index_mapping(server_size, -1); + for (size_t query_idx = 0; query_idx < node_id_list.size(); ++query_idx) { + int server_index = get_server_index_by_id(node_id_list[query_idx]); + if (index_mapping[server_index] == -1) { + index_mapping[server_index] = request_bucket.size(); + server_index_arr.push_back(server_index); + request_bucket.push_back(std::vector()); + if (add_weight) is_weighted_bucket.push_back(std::vector()); + } + request_bucket[index_mapping[server_index]].push_back( + node_id_list[query_idx]); + if (add_weight) + is_weighted_bucket[index_mapping[server_index]].push_back( + query_idx < is_weighted_list.size() ? is_weighted_list[query_idx] + : false); + } + size_t request_call_num = request_bucket.size(); + DownpourBrpcClosure *closure = new DownpourBrpcClosure( + request_call_num, [&, request_call_num](void *done) { + int ret = 0; + auto *closure = (DownpourBrpcClosure *)done; + size_t fail_num = 0; + for (size_t request_idx = 0; request_idx < request_call_num; + ++request_idx) { + if (closure->check_response(request_idx, PS_GRAPH_ADD_GRAPH_NODE) != + 0) { + ++fail_num; + } + } + ret = fail_num == request_call_num ? -1 : 0; + closure->set_promise_value(ret); + }); + auto promise = std::make_shared>(); + closure->add_promise(promise); + std::future fut = promise->get_future(); + + for (size_t request_idx = 0; request_idx < request_call_num; ++request_idx) { + int server_index = server_index_arr[request_idx]; + closure->request(request_idx)->set_cmd_id(PS_GRAPH_ADD_GRAPH_NODE); + closure->request(request_idx)->set_table_id(table_id); + closure->request(request_idx)->set_client_id(_client_id); + size_t node_num = request_bucket[request_idx].size(); + closure->request(request_idx) + ->add_params((char *)request_bucket[request_idx].data(), + sizeof(uint64_t) * node_num); + if (add_weight) { + bool weighted[is_weighted_bucket[request_idx].size() + 1]; + for (size_t j = 0; j < is_weighted_bucket[request_idx].size(); j++) + weighted[j] = is_weighted_bucket[request_idx][j]; + closure->request(request_idx) + ->add_params((char *)weighted, + sizeof(bool) * is_weighted_bucket[request_idx].size()); + } + // PsService_Stub rpc_stub(get_cmd_channel(server_index)); + GraphPsService_Stub rpc_stub = + getServiceStub(get_cmd_channel(server_index)); + closure->cntl(request_idx)->set_log_id(butil::gettimeofday_ms()); + rpc_stub.service(closure->cntl(request_idx), closure->request(request_idx), + closure->response(request_idx), closure); + } + return fut; +} +std::future GraphBrpcClient::remove_graph_node( + uint32_t table_id, std::vector &node_id_list) { + std::vector> request_bucket; + std::vector server_index_arr; + std::vector index_mapping(server_size, -1); + for (size_t query_idx = 0; query_idx < node_id_list.size(); ++query_idx) { + int server_index = get_server_index_by_id(node_id_list[query_idx]); + if (index_mapping[server_index] == -1) { + index_mapping[server_index] = request_bucket.size(); + server_index_arr.push_back(server_index); + request_bucket.push_back(std::vector()); + } + request_bucket[index_mapping[server_index]].push_back( + node_id_list[query_idx]); + } + size_t request_call_num = request_bucket.size(); + DownpourBrpcClosure *closure = new DownpourBrpcClosure( + request_call_num, [&, request_call_num](void *done) { + int ret = 0; + auto *closure = (DownpourBrpcClosure *)done; + int fail_num = 0; + for (size_t request_idx = 0; request_idx < request_call_num; + ++request_idx) { + if (closure->check_response(request_idx, + PS_GRAPH_REMOVE_GRAPH_NODE) != 0) { + ++fail_num; + } + } + ret = fail_num == request_call_num ? -1 : 0; + closure->set_promise_value(ret); + }); + auto promise = std::make_shared>(); + closure->add_promise(promise); + std::future fut = promise->get_future(); + + for (size_t request_idx = 0; request_idx < request_call_num; ++request_idx) { + int server_index = server_index_arr[request_idx]; + closure->request(request_idx)->set_cmd_id(PS_GRAPH_REMOVE_GRAPH_NODE); + closure->request(request_idx)->set_table_id(table_id); + closure->request(request_idx)->set_client_id(_client_id); + size_t node_num = request_bucket[request_idx].size(); + + closure->request(request_idx) + ->add_params((char *)request_bucket[request_idx].data(), + sizeof(uint64_t) * node_num); + // PsService_Stub rpc_stub(get_cmd_channel(server_index)); + GraphPsService_Stub rpc_stub = + getServiceStub(get_cmd_channel(server_index)); + closure->cntl(request_idx)->set_log_id(butil::gettimeofday_ms()); + rpc_stub.service(closure->cntl(request_idx), closure->request(request_idx), + closure->response(request_idx), closure); + } + return fut; +} // char* &buffer,int &actual_size std::future GraphBrpcClient::batch_sample_neighboors( uint32_t table_id, std::vector node_ids, int sample_size, @@ -174,8 +331,8 @@ std::future GraphBrpcClient::batch_sample_neighboors( [&, node_id_buckets, query_idx_buckets, request_call_num](void *done) { int ret = 0; auto *closure = (DownpourBrpcClosure *)done; - int fail_num = 0; - for (int request_idx = 0; request_idx < request_call_num; + size_t fail_num = 0; + for (size_t request_idx = 0; request_idx < request_call_num; ++request_idx) { if (closure->check_response(request_idx, PS_GRAPH_SAMPLE_NEIGHBOORS) != 0) { @@ -254,13 +411,14 @@ std::future GraphBrpcClient::random_sample_nodes( auto &res_io_buffer = closure->cntl(0)->response_attachment(); butil::IOBufBytesIterator io_buffer_itr(res_io_buffer); size_t bytes_size = io_buffer_itr.bytes_left(); - char buffer[bytes_size]; + char *buffer = new char[bytes_size]; auto size = io_buffer_itr.copy_and_forward((void *)(buffer), bytes_size); int index = 0; while (index < bytes_size) { ids.push_back(*(uint64_t *)(buffer + index)); index += GraphNode::id_size; } + delete[] buffer; } closure->set_promise_value(ret); }); @@ -292,7 +450,7 @@ std::future GraphBrpcClient::pull_graph_list( auto &res_io_buffer = closure->cntl(0)->response_attachment(); butil::IOBufBytesIterator io_buffer_itr(res_io_buffer); size_t bytes_size = io_buffer_itr.bytes_left(); - char buffer[bytes_size]; + char *buffer = new char[bytes_size]; io_buffer_itr.copy_and_forward((void *)(buffer), bytes_size); int index = 0; while (index < bytes_size) { @@ -301,6 +459,7 @@ std::future GraphBrpcClient::pull_graph_list( index += node.get_size(false); res.push_back(node); } + delete buffer; } closure->set_promise_value(ret); }); diff --git a/paddle/fluid/distributed/service/graph_brpc_client.h b/paddle/fluid/distributed/service/graph_brpc_client.h index 4e6775a4bedaf1a4028fe483f58be818ef1e3581..5696e8b08037b7027939f472f58ec79925143e4f 100644 --- a/paddle/fluid/distributed/service/graph_brpc_client.h +++ b/paddle/fluid/distributed/service/graph_brpc_client.h @@ -78,6 +78,13 @@ class GraphBrpcClient : public BrpcPsClient { const uint32_t& table_id, const std::vector& node_ids, const std::vector& feature_names, std::vector>& res); + + virtual std::future clear_nodes(uint32_t table_id); + virtual std::future add_graph_node( + uint32_t table_id, std::vector& node_id_list, + std::vector& is_weighted_list); + virtual std::future remove_graph_node( + uint32_t table_id, std::vector& node_id_list); virtual int32_t initialize(); int get_shard_num() { return shard_num; } void set_shard_num(int shard_num) { this->shard_num = shard_num; } diff --git a/paddle/fluid/distributed/service/graph_brpc_server.cc b/paddle/fluid/distributed/service/graph_brpc_server.cc index bdd926278b624b9e9bfdf19a4f293784bef6e28f..52ac8c5d688a4ada72212923bdd478b788e422ee 100644 --- a/paddle/fluid/distributed/service/graph_brpc_server.cc +++ b/paddle/fluid/distributed/service/graph_brpc_server.cc @@ -24,6 +24,14 @@ namespace paddle { namespace distributed { +#define CHECK_TABLE_EXIST(table, request, response) \ + if (table == NULL) { \ + std::string err_msg("table not found with table_id:"); \ + err_msg.append(std::to_string(request.table_id())); \ + set_response_code(response, -1, err_msg.c_str()); \ + return -1; \ + } + int32_t GraphBrpcServer::initialize() { auto &service_config = _config.downpour_server_param().service_param(); if (!service_config.has_service_class()) { @@ -71,6 +79,58 @@ uint64_t GraphBrpcServer::start(const std::string &ip, uint32_t port) { return 0; } +int32_t GraphBrpcService::clear_nodes(Table *table, + const PsRequestMessage &request, + PsResponseMessage &response, + brpc::Controller *cntl) { + ((GraphTable *)table)->clear_nodes(); + return 0; +} + +int32_t GraphBrpcService::add_graph_node(Table *table, + const PsRequestMessage &request, + PsResponseMessage &response, + brpc::Controller *cntl) { + CHECK_TABLE_EXIST(table, request, response) + if (request.params_size() < 1) { + set_response_code( + response, -1, + "graph_get_node_feat request requires at least 2 arguments"); + return 0; + } + + size_t node_num = request.params(0).size() / sizeof(uint64_t); + uint64_t *node_data = (uint64_t *)(request.params(0).c_str()); + std::vector node_ids(node_data, node_data + node_num); + std::vector is_weighted_list; + if (request.params_size() == 2) { + size_t weight_list_size = request.params(1).size() / sizeof(bool); + bool *is_weighted_buffer = (bool *)(request.params(1).c_str()); + is_weighted_list = std::vector(is_weighted_buffer, + is_weighted_buffer + weight_list_size); + } + + ((GraphTable *)table)->add_graph_node(node_ids, is_weighted_list); + return 0; +} +int32_t GraphBrpcService::remove_graph_node(Table *table, + const PsRequestMessage &request, + PsResponseMessage &response, + brpc::Controller *cntl) { + CHECK_TABLE_EXIST(table, request, response) + if (request.params_size() < 1) { + set_response_code( + response, -1, + "graph_get_node_feat request requires at least 1 argument"); + return 0; + } + size_t node_num = request.params(0).size() / sizeof(uint64_t); + uint64_t *node_data = (uint64_t *)(request.params(0).c_str()); + std::vector node_ids(node_data, node_data + node_num); + + ((GraphTable *)table)->remove_graph_node(node_ids); + return 0; +} int32_t GraphBrpcServer::port() { return _server.listen_address().port; } int32_t GraphBrpcService::initialize() { @@ -92,21 +152,17 @@ int32_t GraphBrpcService::initialize() { &GraphBrpcService::graph_random_sample_nodes; _service_handler_map[PS_GRAPH_GET_NODE_FEAT] = &GraphBrpcService::graph_get_node_feat; - + _service_handler_map[PS_GRAPH_CLEAR] = &GraphBrpcService::clear_nodes; + _service_handler_map[PS_GRAPH_ADD_GRAPH_NODE] = + &GraphBrpcService::add_graph_node; + _service_handler_map[PS_GRAPH_REMOVE_GRAPH_NODE] = + &GraphBrpcService::remove_graph_node; // shard初始化,server启动后才可从env获取到server_list的shard信息 initialize_shard_info(); return 0; } -#define CHECK_TABLE_EXIST(table, request, response) \ - if (table == NULL) { \ - std::string err_msg("table not found with table_id:"); \ - err_msg.append(std::to_string(request.table_id())); \ - set_response_code(response, -1, err_msg.c_str()); \ - return -1; \ - } - int32_t GraphBrpcService::initialize_shard_info() { if (!_is_initialize_shard_info) { std::lock_guard guard(_initialize_shard_mutex); diff --git a/paddle/fluid/distributed/service/graph_brpc_server.h b/paddle/fluid/distributed/service/graph_brpc_server.h index 32c572f9e6c2bf759c59190679bcf7570a807f2d..47c370572826ac2807e4ea5cb36cf3a667dfed10 100644 --- a/paddle/fluid/distributed/service/graph_brpc_server.h +++ b/paddle/fluid/distributed/service/graph_brpc_server.h @@ -86,6 +86,13 @@ class GraphBrpcService : public PsBaseService { int32_t graph_get_node_feat(Table *table, const PsRequestMessage &request, PsResponseMessage &response, brpc::Controller *cntl); + int32_t clear_nodes(Table *table, const PsRequestMessage &request, + PsResponseMessage &response, brpc::Controller *cntl); + int32_t add_graph_node(Table *table, const PsRequestMessage &request, + PsResponseMessage &response, brpc::Controller *cntl); + int32_t remove_graph_node(Table *table, const PsRequestMessage &request, + PsResponseMessage &response, + brpc::Controller *cntl); int32_t barrier(Table *table, const PsRequestMessage &request, PsResponseMessage &response, brpc::Controller *cntl); int32_t load_one_table(Table *table, const PsRequestMessage &request, diff --git a/paddle/fluid/distributed/service/graph_py_service.cc b/paddle/fluid/distributed/service/graph_py_service.cc index 61e4e0cf7bb9155d25c630296c2b55a7d3400bfc..39befb1a112c854a183903d76a71d9e6c920b215 100644 --- a/paddle/fluid/distributed/service/graph_py_service.cc +++ b/paddle/fluid/distributed/service/graph_py_service.cc @@ -44,6 +44,9 @@ void GraphPyService::add_table_feat_conf(std::string table_name, } } +void add_graph_node(std::vector node_ids, + std::vector weight_list) {} +void remove_graph_node(std::vector node_ids) {} void GraphPyService::set_up(std::string ips_str, int shard_num, std::vector node_types, std::vector edge_types) { @@ -247,6 +250,34 @@ void GraphPyClient::load_edge_file(std::string name, std::string filepath, } } +void GraphPyClient::clear_nodes(std::string name) { + if (this->table_id_map.count(name)) { + uint32_t table_id = this->table_id_map[name]; + auto status = get_ps_client()->clear_nodes(table_id); + status.wait(); + } +} + +void GraphPyClient::add_graph_node(std::string name, + std::vector& node_ids, + std::vector& weight_list) { + if (this->table_id_map.count(name)) { + uint32_t table_id = this->table_id_map[name]; + auto status = + get_ps_client()->add_graph_node(table_id, node_ids, weight_list); + status.wait(); + } +} + +void GraphPyClient::remove_graph_node(std::string name, + std::vector& node_ids) { + if (this->table_id_map.count(name)) { + uint32_t table_id = this->table_id_map[name]; + auto status = get_ps_client()->remove_graph_node(table_id, node_ids); + status.wait(); + } +} + void GraphPyClient::load_node_file(std::string name, std::string filepath) { // 'n' means load nodes and 'node_type' follows std::string params = "n" + name; diff --git a/paddle/fluid/distributed/service/graph_py_service.h b/paddle/fluid/distributed/service/graph_py_service.h index c6657be96ba446d2f7538943aab43dd47e1868fb..da027fbae3e6f0ca1e902795b0640cee1e0b76cc 100644 --- a/paddle/fluid/distributed/service/graph_py_service.h +++ b/paddle/fluid/distributed/service/graph_py_service.h @@ -141,6 +141,10 @@ class GraphPyClient : public GraphPyService { void finalize_worker(); void load_edge_file(std::string name, std::string filepath, bool reverse); void load_node_file(std::string name, std::string filepath); + void clear_nodes(std::string name); + void add_graph_node(std::string name, std::vector& node_ids, + std::vector& weight_list); + void remove_graph_node(std::string name, std::vector& node_ids); int get_client_id() { return client_id; } void set_client_id(int client_id) { this->client_id = client_id; } void start_client(); diff --git a/paddle/fluid/distributed/service/sendrecv.proto b/paddle/fluid/distributed/service/sendrecv.proto index d908c26da9870a93d81c0242ac03e26cfebdb976..a4b811e950a3b56443261ceac37fa658007d519d 100644 --- a/paddle/fluid/distributed/service/sendrecv.proto +++ b/paddle/fluid/distributed/service/sendrecv.proto @@ -52,6 +52,9 @@ enum PsCmdID { PS_GRAPH_SAMPLE_NEIGHBOORS = 31; PS_GRAPH_SAMPLE_NODES = 32; PS_GRAPH_GET_NODE_FEAT = 33; + PS_GRAPH_CLEAR = 34; + PS_GRAPH_ADD_GRAPH_NODE = 35; + PS_GRAPH_REMOVE_GRAPH_NODE = 36; } message PsRequestMessage { diff --git a/paddle/fluid/distributed/table/common_graph_table.cc b/paddle/fluid/distributed/table/common_graph_table.cc index 0dc99de1bfe82a691fdacb834acd1ad606dcb04b..92f8304a8bf62178988fef447fcf8c309d8589ea 100644 --- a/paddle/fluid/distributed/table/common_graph_table.cc +++ b/paddle/fluid/distributed/table/common_graph_table.cc @@ -35,6 +35,77 @@ std::vector GraphShard::get_batch(int start, int end, int step) { size_t GraphShard::get_size() { return bucket.size(); } +int32_t GraphTable::add_graph_node(std::vector &id_list, + std::vector &is_weight_list) { + size_t node_size = id_list.size(); + std::vector>> batch(task_pool_size_); + for (size_t i = 0; i < node_size; i++) { + size_t shard_id = id_list[i] % shard_num; + if (shard_id >= shard_end || shard_id < shard_start) { + continue; + } + batch[get_thread_pool_index(id_list[i])].push_back( + {id_list[i], i < is_weight_list.size() ? is_weight_list[i] : false}); + } + std::vector> tasks; + for (size_t i = 0; i < batch.size(); ++i) { + if (!batch[i].size()) continue; + tasks.push_back(_shards_task_pool[i]->enqueue([&batch, i, this]() -> int { + for (auto &p : batch[i]) { + size_t index = p.first % this->shard_num - this->shard_start; + this->shards[index].add_graph_node(p.first)->build_edges(p.second); + } + return 0; + })); + } + for (size_t i = 0; i < tasks.size(); i++) tasks[i].get(); + return 0; +} + +int32_t GraphTable::remove_graph_node(std::vector &id_list) { + size_t node_size = id_list.size(); + std::vector> batch(task_pool_size_); + for (size_t i = 0; i < node_size; i++) { + size_t shard_id = id_list[i] % shard_num; + if (shard_id >= shard_end || shard_id < shard_start) continue; + batch[get_thread_pool_index(id_list[i])].push_back(id_list[i]); + } + std::vector> tasks; + for (size_t i = 0; i < batch.size(); ++i) { + if (!batch[i].size()) continue; + tasks.push_back(_shards_task_pool[i]->enqueue([&batch, i, this]() -> int { + for (auto &p : batch[i]) { + size_t index = p % this->shard_num - this->shard_start; + this->shards[index].delete_node(p); + } + return 0; + })); + } + for (size_t i = 0; i < tasks.size(); i++) tasks[i].get(); + return 0; +} + +void GraphShard::clear() { + for (size_t i = 0; i < bucket.size(); i++) { + delete bucket[i]; + } + bucket.clear(); + node_location.clear(); +} + +GraphShard::~GraphShard() { clear(); } +void GraphShard::delete_node(uint64_t id) { + auto iter = node_location.find(id); + if (iter == node_location.end()) return; + int pos = iter->second; + delete bucket[pos]; + if (pos != (int)bucket.size() - 1) { + bucket[pos] = bucket.back(); + node_location[bucket.back()->get_id()] = pos; + } + node_location.erase(id); + bucket.pop_back(); +} GraphNode *GraphShard::add_graph_node(uint64_t id) { if (node_location.find(id) == node_location.end()) { node_location[id] = bucket.size(); @@ -79,11 +150,7 @@ int32_t GraphTable::get_nodes_ids_by_ranges( int start = 0, end, index = 0, total_size = 0; res.clear(); std::vector>> tasks; - // std::string temp = ""; - // for(int i = 0;i < shards.size();i++) - // temp+= std::to_string((int)shards[i].get_size()) + " "; - // VLOG(0)<<"range distribution "<enqueue( [this, first, second, i]() -> std::vector { return shards[i].get_ids_by_range(first, second); @@ -106,7 +172,7 @@ int32_t GraphTable::get_nodes_ids_by_ranges( } total_size += shards[i].get_size(); } - for (int i = 0; i < tasks.size(); i++) { + for (size_t i = 0; i < tasks.size(); i++) { auto vec = tasks[i].get(); for (auto &id : vec) { res.push_back(id); @@ -219,7 +285,7 @@ int32_t GraphTable::load_edges(const std::string &path, bool reverse_edge) { for (auto &shard : shards) { auto bucket = shard.get_bucket(); - for (int i = 0; i < bucket.size(); i++) { + for (size_t i = 0; i < bucket.size(); i++) { bucket[i]->build_sampler(sample_type); } } @@ -238,10 +304,29 @@ Node *GraphTable::find_node(uint64_t id) { uint32_t GraphTable::get_thread_pool_index(uint64_t node_id) { return node_id % shard_num % shard_num_per_table % task_pool_size_; } + +uint32_t GraphTable::get_thread_pool_index_by_shard_index( + uint64_t shard_index) { + return shard_index % shard_num_per_table % task_pool_size_; +} + +int32_t GraphTable::clear_nodes() { + std::vector> tasks; + for (size_t i = 0; i < shards.size(); i++) { + tasks.push_back( + _shards_task_pool[get_thread_pool_index_by_shard_index(i)]->enqueue( + [this, i]() -> int { + this->shards[i].clear(); + return 0; + })); + } + for (size_t i = 0; i < tasks.size(); i++) tasks[i].get(); + return 0; +} + int32_t GraphTable::random_sample_nodes(int sample_size, std::unique_ptr &buffer, int &actual_size) { - bool need_feature = false; int total_size = 0; for (int i = 0; i < shards.size(); i++) { total_size += shards[i].get_size(); @@ -281,7 +366,7 @@ int32_t GraphTable::random_sample_nodes(int sample_size, } std::vector> first_half, second_half; int start_index = rand() % total_size; - for (int i = 0; i < ranges_len.size() && i < ranges_pos.size(); i++) { + for (size_t i = 0; i < ranges_len.size() && i < ranges_pos.size(); i++) { if (ranges_pos[i] + ranges_len[i] - 1 + start_index < total_size) first_half.push_back({ranges_pos[i] + start_index, ranges_pos[i] + ranges_len[i] + start_index}); @@ -386,7 +471,6 @@ std::pair GraphTable::parse_feature( if (this->feat_id_map.count(fields[0])) { int32_t id = this->feat_id_map[fields[0]]; std::string dtype = this->feat_dtype[id]; - int32_t shape = this->feat_shape[id]; std::vector values(fields.begin() + 1, fields.end()); if (dtype == "feasign") { return std::make_pair( diff --git a/paddle/fluid/distributed/table/common_graph_table.h b/paddle/fluid/distributed/table/common_graph_table.h index b18da82abe61c9695712f542e187ac48fd5edc9d..5eeb3915f5b1f251dd5edf1a5199621a7cd0069b 100644 --- a/paddle/fluid/distributed/table/common_graph_table.h +++ b/paddle/fluid/distributed/table/common_graph_table.h @@ -36,11 +36,12 @@ class GraphShard { size_t get_size(); GraphShard() {} GraphShard(int shard_num) { this->shard_num = shard_num; } + ~GraphShard(); std::vector &get_bucket() { return bucket; } std::vector get_batch(int start, int end, int step); std::vector get_ids_by_range(int start, int end) { std::vector res; - for (int i = start; i < end && i < bucket.size(); i++) { + for (int i = start; i < end && i < (int)bucket.size(); i++) { res.push_back(bucket[i]->get_id()); } return res; @@ -48,6 +49,8 @@ class GraphShard { GraphNode *add_graph_node(uint64_t id); FeatureNode *add_feature_node(uint64_t id); Node *find_node(uint64_t id); + void delete_node(uint64_t id); + void clear(); void add_neighboor(uint64_t id, uint64_t dst_id, float weight); std::unordered_map get_node_location() { return node_location; @@ -85,6 +88,11 @@ class GraphTable : public SparseTable { int32_t load_nodes(const std::string &path, std::string node_type); + int32_t add_graph_node(std::vector &id_list, + std::vector &is_weight_list); + + int32_t remove_graph_node(std::vector &id_list); + Node *find_node(uint64_t id); virtual int32_t pull_sparse(float *values, @@ -97,6 +105,7 @@ class GraphTable : public SparseTable { return 0; } + virtual int32_t clear_nodes(); virtual void clear() {} virtual int32_t flush() { return 0; } virtual int32_t shrink(const std::string ¶m) { return 0; } @@ -105,6 +114,7 @@ class GraphTable : public SparseTable { return 0; } virtual int32_t initialize_shard() { return 0; } + virtual uint32_t get_thread_pool_index_by_shard_index(uint64_t shard_index); virtual uint32_t get_thread_pool_index(uint64_t node_id); virtual std::pair parse_feature(std::string feat_str); @@ -128,4 +138,5 @@ class GraphTable : public SparseTable { std::vector> _shards_task_pool; }; } // namespace distributed + }; // namespace paddle diff --git a/paddle/fluid/distributed/test/graph_node_test.cc b/paddle/fluid/distributed/test/graph_node_test.cc index b268bb449e14619048e89c8933dbae7daf66537b..b8630aed02ffe60181ddb6b41810f5bea602b733 100644 --- a/paddle/fluid/distributed/test/graph_node_test.cc +++ b/paddle/fluid/distributed/test/graph_node_test.cc @@ -124,7 +124,6 @@ void testSingleSampleNeighboor( for (auto g : s) { ASSERT_EQ(true, s1.find(g) != s1.end()); } - VLOG(0) << "test single done"; s.clear(); s1.clear(); vs.clear(); @@ -141,6 +140,57 @@ void testSingleSampleNeighboor( } } +void testAddNode( + std::shared_ptr& worker_ptr_) { + worker_ptr_->clear_nodes(0); + int total_num = 270000; + uint64_t id; + std::unordered_set id_set; + for (int i = 0; i < total_num; i++) { + while (id_set.find(id = rand()) != id_set.end()) + ; + id_set.insert(id); + } + std::vector id_list(id_set.begin(), id_set.end()); + std::vector weight_list; + auto status = worker_ptr_->add_graph_node(0, id_list, weight_list); + status.wait(); + std::vector ids[2]; + for (int i = 0; i < 2; i++) { + auto sample_status = + worker_ptr_->random_sample_nodes(0, i, total_num, ids[i]); + sample_status.wait(); + } + std::unordered_set id_set_check(ids[0].begin(), ids[0].end()); + for (auto x : ids[1]) id_set_check.insert(x); + ASSERT_EQ(id_set.size(), id_set_check.size()); + for (auto x : id_set) { + ASSERT_EQ(id_set_check.find(x) != id_set_check.end(), true); + } + std::vector remove_ids; + for (auto p : id_set_check) { + if (remove_ids.size() == 0) + remove_ids.push_back(p); + else if (remove_ids.size() < total_num / 2 && rand() % 2 == 1) { + remove_ids.push_back(p); + } + } + for (auto p : remove_ids) id_set_check.erase(p); + status = worker_ptr_->remove_graph_node(0, remove_ids); + status.wait(); + for (int i = 0; i < 2; i++) ids[i].clear(); + for (int i = 0; i < 2; i++) { + auto sample_status = + worker_ptr_->random_sample_nodes(0, i, total_num, ids[i]); + sample_status.wait(); + } + std::unordered_set id_set_check1(ids[0].begin(), ids[0].end()); + for (auto x : ids[1]) id_set_check1.insert(x); + ASSERT_EQ(id_set_check1.size(), id_set_check.size()); + for (auto x : id_set_check1) { + ASSERT_EQ(id_set_check.find(x) != id_set_check.end(), true); + } +} void testBatchSampleNeighboor( std::shared_ptr& worker_ptr_) { std::vector>> vs; @@ -527,6 +577,7 @@ void RunBrpcPushSparse() { std::remove(edge_file_name); std::remove(node_file_name); + testAddNode(worker_ptr_); LOG(INFO) << "Run stop_server"; worker_ptr_->stop_server(); LOG(INFO) << "Run finalize_worker";