diff --git a/paddle/fluid/distributed/ps/service/graph_brpc_server.h b/paddle/fluid/distributed/ps/service/graph_brpc_server.h index 9a8e24dcbcd8aaf43f0f89aa38f7728bc670b1f7..f933ddaacd6e5927579adc3779cf786e9ff88eb8 100644 --- a/paddle/fluid/distributed/ps/service/graph_brpc_server.h +++ b/paddle/fluid/distributed/ps/service/graph_brpc_server.h @@ -63,97 +63,98 @@ class GraphBrpcService; typedef int32_t (GraphBrpcService::*serviceFunc)( Table *table, const PsRequestMessage &request, - PsResponseMessage &response, + PsResponseMessage &response, // NOLINT brpc::Controller *cntl); class GraphBrpcService : public PsBaseService { public: - virtual int32_t Initialize() override; + int32_t Initialize() override; - virtual void service(::google::protobuf::RpcController *controller, - const PsRequestMessage *request, - PsResponseMessage *response, - ::google::protobuf::Closure *done) override; + void service(::google::protobuf::RpcController *controller, + const PsRequestMessage *request, + PsResponseMessage *response, + ::google::protobuf::Closure *done) override; protected: std::unordered_map _service_handler_map; int32_t InitializeShardInfo(); int32_t pull_graph_list(Table *table, const PsRequestMessage &request, - PsResponseMessage &response, + PsResponseMessage &response, // NOLINT brpc::Controller *cntl); int32_t graph_random_sample_neighbors(Table *table, const PsRequestMessage &request, - PsResponseMessage &response, + PsResponseMessage &response, // NOLINT brpc::Controller *cntl); int32_t graph_random_sample_nodes(Table *table, const PsRequestMessage &request, - PsResponseMessage &response, + PsResponseMessage &response, // NOLINT brpc::Controller *cntl); int32_t graph_get_node_feat(Table *table, const PsRequestMessage &request, - PsResponseMessage &response, + PsResponseMessage &response, // NOLINT brpc::Controller *cntl); int32_t graph_set_node_feat(Table *table, const PsRequestMessage &request, - PsResponseMessage &response, + PsResponseMessage &response, // NOLINT brpc::Controller *cntl); int32_t clear_nodes(Table *table, const PsRequestMessage &request, - PsResponseMessage &response, + PsResponseMessage &response, // NOLINT brpc::Controller *cntl); int32_t add_graph_node(Table *table, const PsRequestMessage &request, - PsResponseMessage &response, + PsResponseMessage &response, // NOLINT brpc::Controller *cntl); int32_t remove_graph_node(Table *table, const PsRequestMessage &request, - PsResponseMessage &response, + PsResponseMessage &response, // NOLINT brpc::Controller *cntl); int32_t Barrier(Table *table, const PsRequestMessage &request, - PsResponseMessage &response, + PsResponseMessage &response, // NOLINT brpc::Controller *cntl); int32_t LoadOneTable(Table *table, const PsRequestMessage &request, - PsResponseMessage &response, + PsResponseMessage &response, // NOLINT brpc::Controller *cntl); int32_t LoadAllTable(Table *table, const PsRequestMessage &request, - PsResponseMessage &response, + PsResponseMessage &response, // NOLINT brpc::Controller *cntl); int32_t StopServer(Table *table, const PsRequestMessage &request, - PsResponseMessage &response, + PsResponseMessage &response, // NOLINT brpc::Controller *cntl); int32_t StartProfiler(Table *table, const PsRequestMessage &request, - PsResponseMessage &response, + PsResponseMessage &response, // NOLINT brpc::Controller *cntl); int32_t StopProfiler(Table *table, const PsRequestMessage &request, - PsResponseMessage &response, + PsResponseMessage &response, // NOLINT brpc::Controller *cntl); int32_t PrintTableStat(Table *table, const PsRequestMessage &request, - PsResponseMessage &response, + PsResponseMessage &response, // NOLINT brpc::Controller *cntl); - int32_t sample_neighbors_across_multi_servers(Table *table, - const PsRequestMessage &request, - PsResponseMessage &response, - brpc::Controller *cntl); + int32_t sample_neighbors_across_multi_servers( + Table *table, + const PsRequestMessage &request, + PsResponseMessage &response, // NOLINT + brpc::Controller *cntl); int32_t use_neighbors_sample_cache(Table *table, const PsRequestMessage &request, - PsResponseMessage &response, + PsResponseMessage &response, // NOLINT brpc::Controller *cntl); int32_t load_graph_split_config(Table *table, const PsRequestMessage &request, - PsResponseMessage &response, + PsResponseMessage &response, // NOLINT brpc::Controller *cntl); private: diff --git a/paddle/fluid/distributed/ps/service/ps_local_client.cc b/paddle/fluid/distributed/ps/service/ps_local_client.cc index 8d7071ad9ea69981dd2c25d68debfb1c853583eb..ecc8819102ed324abe24ead304539505cad422f3 100644 --- a/paddle/fluid/distributed/ps/service/ps_local_client.cc +++ b/paddle/fluid/distributed/ps/service/ps_local_client.cc @@ -16,7 +16,7 @@ #include "paddle/fluid/distributed/ps/table/table.h" -//#define pslib_debug_dense_compress +// #define pslib_debug_dense_compress namespace paddle { namespace distributed { @@ -36,13 +36,13 @@ int32_t PsLocalClient::Initialize() { ::std::future PsLocalClient::Shrink(uint32_t table_id, const std::string threshold) { - // TODO + // TODO // NOLINT return done(); } ::std::future PsLocalClient::Load(const std::string& epoch, const std::string& mode) { - // TODO + // TODO // NOLINT for (auto& it : _table_map) { Load(it.first, epoch, mode); } @@ -51,7 +51,7 @@ int32_t PsLocalClient::Initialize() { ::std::future PsLocalClient::Load(uint32_t table_id, const std::string& epoch, const std::string& mode) { - // TODO + // TODO // NOLINT auto* table_ptr = GetTable(table_id); table_ptr->Load(epoch, mode); return done(); @@ -59,7 +59,7 @@ int32_t PsLocalClient::Initialize() { ::std::future PsLocalClient::Save(const std::string& epoch, const std::string& mode) { - // TODO + // TODO // NOLINT for (auto& it : _table_map) { Save(it.first, epoch, mode); } @@ -68,7 +68,7 @@ int32_t PsLocalClient::Initialize() { ::std::future PsLocalClient::Save(uint32_t table_id, const std::string& epoch, const std::string& mode) { - // TODO + // TODO // NOLINT auto* table_ptr = GetTable(table_id); table_ptr->Flush(); table_ptr->Save(epoch, mode); @@ -76,11 +76,11 @@ int32_t PsLocalClient::Initialize() { } ::std::future PsLocalClient::Clear() { - // TODO + // TODO // NOLINT return done(); } ::std::future PsLocalClient::Clear(uint32_t table_id) { - // TODO + // TODO // NOLINT return done(); } @@ -125,8 +125,10 @@ int32_t PsLocalClient::Initialize() { while (shard_buffer_remain > 0 && region_idx < region_num) { auto& region = regions[region_idx]; if (region.size - region_data_idx >= shard_buffer_remain) { - memcpy((void*)(region.data + region_data_idx), - (uint8_t*)(void*)(region_buffer.data()) + index, + memcpy(reinterpret_cast(region.data + region_data_idx), + reinterpret_cast( + reinterpret_cast(region_buffer.data())) + + index, shard_buffer_remain); region_data_idx += shard_buffer_remain; shard_buffer_remain = 0; @@ -134,8 +136,10 @@ int32_t PsLocalClient::Initialize() { ++region_idx; region_data_idx = 0; } else { - memcpy((void*)(region.data + region_data_idx), - (uint8_t*)(void*)(region_buffer.data()) + index, + memcpy(reinterpret_cast(region.data + region_data_idx), + reinterpret_cast( + reinterpret_cast(region_buffer.data())) + + index, region.size - region_data_idx); shard_buffer_remain -= (region.size - region_data_idx); index += (region.size - region_data_idx); @@ -230,7 +234,7 @@ int32_t PsLocalClient::Initialize() { return done(); } -//::std::future PsLocalClient::PullSparse(float** select_values, +// ::std::future PsLocalClient::PullSparse(float** select_values, // size_t table_id, // const uint64_t* keys, // size_t num) { @@ -271,7 +275,7 @@ int32_t PsLocalClient::Initialize() { // std::make_shared("pslib_downpour_client_pull_sparse"); // auto local_timer = // std::make_shared("pslib_downpour_client_pull_sparse_local"); - //将key拆分到各shard请求,并记录原始对应value指针 + // 将key拆分到各shard请求,并记录原始对应value指针 auto* table_ptr = GetTable(table_id); TableContext table_context; diff --git a/paddle/fluid/distributed/ps/service/ps_local_client.h b/paddle/fluid/distributed/ps/service/ps_local_client.h index 583ea8052eb01d6658568cec011757d4fc2d9eb8..725290b28d3db042dfcafe968d82bf358db42395 100644 --- a/paddle/fluid/distributed/ps/service/ps_local_client.h +++ b/paddle/fluid/distributed/ps/service/ps_local_client.h @@ -32,26 +32,26 @@ class PsLocalClient : public PSClient { return 0; } - virtual ::std::future Shrink(uint32_t table_id, - const std::string threshold) override; - virtual ::std::future Load(const std::string& epoch, - const std::string& mode) override; - virtual ::std::future Load(uint32_t table_id, - const std::string& epoch, - const std::string& mode) override; - - virtual ::std::future Save(const std::string& epoch, - const std::string& mode) override; - virtual ::std::future Save(uint32_t table_id, - const std::string& epoch, - const std::string& mode) override; - - virtual ::std::future Clear() override; - virtual ::std::future Clear(uint32_t table_id) override; - - virtual ::std::future StopServer() override; - - virtual void FinalizeWorker() override {} + ::std::future Shrink(uint32_t table_id, + const std::string threshold) override; + ::std::future Load(const std::string& epoch, + const std::string& mode) override; + ::std::future Load(uint32_t table_id, + const std::string& epoch, + const std::string& mode) override; + + ::std::future Save(const std::string& epoch, + const std::string& mode) override; + ::std::future Save(uint32_t table_id, + const std::string& epoch, + const std::string& mode) override; + + ::std::future Clear() override; + ::std::future Clear(uint32_t table_id) override; + + ::std::future StopServer() override; + + void FinalizeWorker() override {} virtual ::std::future PullDense(Region* regions, size_t region_num, size_t table_id); @@ -102,7 +102,7 @@ class PsLocalClient : public PSClient { prom.set_value(0); return fut; - }; + } virtual std::future StopProfiler() { std::promise prom; @@ -147,8 +147,9 @@ class PsLocalClient : public PSClient { return 0; } - virtual ::std::future SendClient2ClientMsg( - int msg_type, int to_client_id, const std::string& msg) override { + ::std::future SendClient2ClientMsg(int msg_type, + int to_client_id, + const std::string& msg) override { std::promise prom; std::future fut = prom.get_future(); prom.set_value(0); @@ -157,25 +158,23 @@ class PsLocalClient : public PSClient { } virtual size_t GetServerNums() { return 1; } - virtual std::future PushDenseRawGradient(int table_id, - float* total_send_data, - size_t total_send_data_size, - void* callback) override; - - virtual std::future PushSparseRawGradient( - size_t table_id, - const uint64_t* keys, - const float** update_values, - size_t num, - void* callback) override; - - virtual std::future PushSparseRawGradientPartial( - size_t table_id, - const uint64_t* keys, - const float** update_values, - uint32_t num, - void* done, - int pserver_idx) override { + std::future PushDenseRawGradient(int table_id, + float* total_send_data, + size_t total_send_data_size, + void* callback) override; + + std::future PushSparseRawGradient(size_t table_id, + const uint64_t* keys, + const float** update_values, + size_t num, + void* callback) override; + + std::future PushSparseRawGradientPartial(size_t table_id, + const uint64_t* keys, + const float** update_values, + uint32_t num, + void* done, + int pserver_idx) override { std::promise prom; std::future fut = prom.get_future(); prom.set_value(0); @@ -183,11 +182,11 @@ class PsLocalClient : public PSClient { return fut; } - virtual std::future PushSparseParam(size_t table_id, - const uint64_t* keys, - const float** update_values, - size_t num, - void* done) override { + std::future PushSparseParam(size_t table_id, + const uint64_t* keys, + const float** update_values, + size_t num, + void* done) override { std::promise prom; std::future fut = prom.get_future(); prom.set_value(0); @@ -196,7 +195,7 @@ class PsLocalClient : public PSClient { } private: - virtual int32_t Initialize() override; + int32_t Initialize() override; std::future done() { std::shared_ptr> prom =