diff --git a/paddle/fluid/framework/executor.cc b/paddle/fluid/framework/executor.cc index b30a9806eb19ee12d2a70afe3ca806224b0f75d6..179ad1abcaaa6f65987f27020d7dbafda96a76e4 100644 --- a/paddle/fluid/framework/executor.cc +++ b/paddle/fluid/framework/executor.cc @@ -21,7 +21,7 @@ limitations under the License. */ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/reader.h" #ifdef PADDLE_WITH_DISTRIBUTE -#include "paddle/fluid/operators/detail/grpc_client.h" +#include "paddle/fluid/operators/distributed/grpc_client.h" #endif #include "paddle/fluid/platform/place.h" #include "paddle/fluid/platform/profiler.h" @@ -49,8 +49,8 @@ Executor::Executor(const platform::Place& place) : place_(place) {} #ifdef PADDLE_WITH_DISTRIBUTE void Executor::Complete() { - ::paddle::operators::detail::RPCClient::GetInstance< - ::paddle::operators::detail::GRPCClient>() + ::paddle::operators::distributed::RPCClient::GetInstance< + ::paddle::operators::distributed::GRPCClient>() ->SendComplete(); } #endif diff --git a/paddle/fluid/operators/CMakeLists.txt b/paddle/fluid/operators/CMakeLists.txt index d6a36eff09c7f70803d3be619b26d16660da1ec2..fe58ca17b25fdad8b75f5250d480d5801ac55566 100644 --- a/paddle/fluid/operators/CMakeLists.txt +++ b/paddle/fluid/operators/CMakeLists.txt @@ -184,8 +184,8 @@ else() set(DEPS_OPS ${DEPS_OPS} nccl_op) endif() -add_subdirectory(detail) if(WITH_DISTRIBUTE) + add_subdirectory(distributed) set(DISTRIBUTE_DEPS "") if(WITH_GRPC) diff --git a/paddle/fluid/operators/detail/macros.h b/paddle/fluid/operators/detail/macros.h index da1de72dad00db3ffe609e17bd198ef0a56bbfcd..b9e385994efcea0388756e8bd780ebfc719ed08d 100644 --- a/paddle/fluid/operators/detail/macros.h +++ b/paddle/fluid/operators/detail/macros.h @@ -15,13 +15,13 @@ #pragma once #ifdef PADDLE_WITH_GRPC -#include "paddle/fluid/operators/detail/grpc_client.h" -#include "paddle/fluid/operators/detail/grpc_server.h" -#define RPCSERVER_T detail::AsyncGRPCServer -#define RPCCLIENT_T detail::GRPCClient +#include "paddle/fluid/operators/distributed/grpc_client.h" +#include "paddle/fluid/operators/distributed/grpc_server.h" +#define RPCSERVER_T distributed::AsyncGRPCServer +#define RPCCLIENT_T distributed::GRPCClient #else -#include "paddle/fluid/operators/detail/brpc_client.h" -#include "paddle/fluid/operators/detail/brpc_server.h" -#define RPCSERVER_T detail::AsyncBRPCServer -#define RPCCLIENT_T detail::BRPCClient +#include "paddle/fluid/operators/distributed/brpc_client.h" +#include "paddle/fluid/operators/distributed/brpc_server.h" +#define RPCSERVER_T distributed::AsyncBRPCServer +#define RPCCLIENT_T distributed::BRPCClient #endif diff --git a/paddle/fluid/operators/detail/CMakeLists.txt b/paddle/fluid/operators/distributed/CMakeLists.txt similarity index 97% rename from paddle/fluid/operators/detail/CMakeLists.txt rename to paddle/fluid/operators/distributed/CMakeLists.txt index abc5aad0430e71928a441c9488dda16dfdd63b9c..312f80e09077f21a47985c1c936c2ac41c292ead 100644 --- a/paddle/fluid/operators/detail/CMakeLists.txt +++ b/paddle/fluid/operators/distributed/CMakeLists.txt @@ -1,8 +1,3 @@ -if(NOT WITH_DISTRIBUTE) - return() -endif() - - if(WITH_GRPC) grpc_library(sendrecvop_grpc SRCS bytebuffer_stream.cc sendrecvop_utils.cc grpc_client.cc request_handler_impl.cc rpc_client.cc rpc_server.cc grpc_server.cc variable_response.cc PROTO send_recv.proto DEPS lod_tensor diff --git a/paddle/fluid/operators/detail/brpc_client.cc b/paddle/fluid/operators/distributed/brpc_client.cc similarity index 98% rename from paddle/fluid/operators/detail/brpc_client.cc rename to paddle/fluid/operators/distributed/brpc_client.cc index 9a4e410f1d83e93883438fae116c38eb60787673..b394c678fb6503eb73a1e11e6feb814251e9e940 100644 --- a/paddle/fluid/operators/detail/brpc_client.cc +++ b/paddle/fluid/operators/distributed/brpc_client.cc @@ -12,12 +12,12 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/operators/detail/brpc_client.h" +#include "paddle/fluid/operators/distributed/brpc_client.h" #include "paddle/fluid/framework/threadpool.h" namespace paddle { namespace operators { -namespace detail { +namespace distributed { DEFINE_int32(brpc_channel_num, 24, "Number of channels to send requests connected to one server"); @@ -175,6 +175,6 @@ ChannelQueuePtr BRPCClient::GetChannel(const std::string& ep) { return q; } -} // namespace detail +} // namespace distributed } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/detail/brpc_client.h b/paddle/fluid/operators/distributed/brpc_client.h similarity index 94% rename from paddle/fluid/operators/detail/brpc_client.h rename to paddle/fluid/operators/distributed/brpc_client.h index 1e953ea431d51a9586bfd0b352c7f27d079ff1a8..34f140687f91d866536f5e2b647c7445a6624736 100644 --- a/paddle/fluid/operators/detail/brpc_client.h +++ b/paddle/fluid/operators/distributed/brpc_client.h @@ -31,13 +31,13 @@ limitations under the License. */ #include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/selected_rows.h" -#include "paddle/fluid/operators/detail/rpc_client.h" -#include "paddle/fluid/operators/detail/send_recv.pb.h" +#include "paddle/fluid/operators/distributed/rpc_client.h" +#include "paddle/fluid/operators/distributed/send_recv.pb.h" #include "paddle/fluid/platform/macros.h" // for DISABLE_COPY_AND_ASSIGN namespace paddle { namespace operators { -namespace detail { +namespace distributed { struct ChannelContext { brpc::Channel channel; @@ -95,6 +95,6 @@ class BRPCClient : public RPCClient { DISABLE_COPY_AND_ASSIGN(BRPCClient); }; -} // namespace detail +} // namespace distributed } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/detail/brpc_server.cc b/paddle/fluid/operators/distributed/brpc_server.cc similarity index 86% rename from paddle/fluid/operators/detail/brpc_server.cc rename to paddle/fluid/operators/distributed/brpc_server.cc index 2170abe679f9ededff3b53e3139e56f8aad227cb..862167f02084cfe81db1c0936bbfb0415fa85721 100644 --- a/paddle/fluid/operators/detail/brpc_server.cc +++ b/paddle/fluid/operators/distributed/brpc_server.cc @@ -12,13 +12,13 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/operators/detail/brpc_server.h" -#include "paddle/fluid/operators/detail/request_handler.h" +#include "paddle/fluid/operators/distributed/brpc_server.h" +#include "paddle/fluid/operators/distributed/request_handler.h" namespace sendrecv { typedef std::unordered_map + paddle::operators::distributed::RequestHandler*> HandlerMap; class BRPCServiceImpl : public SendRecvService { @@ -27,17 +27,17 @@ class BRPCServiceImpl : public SendRecvService { : request_send_h_(nullptr), request_get_h_(nullptr), request_prefetch_h_(nullptr) { - auto it = rpc_call_map.find(paddle::operators::detail::kRequestSend); + auto it = rpc_call_map.find(paddle::operators::distributed::kRequestSend); if (it != rpc_call_map.end()) { request_send_h_ = it->second; } - it = rpc_call_map.find(paddle::operators::detail::kRequestSend); + it = rpc_call_map.find(paddle::operators::distributed::kRequestSend); if (it != rpc_call_map.end()) { request_get_h_ = it->second; } - it = rpc_call_map.find(paddle::operators::detail::kRequestPrefetch); + it = rpc_call_map.find(paddle::operators::distributed::kRequestPrefetch); if (it != rpc_call_map.end()) { request_prefetch_h_ = it->second; } @@ -88,15 +88,15 @@ class BRPCServiceImpl : public SendRecvService { } private: - paddle::operators::detail::RequestHandler* request_send_h_; - paddle::operators::detail::RequestHandler* request_get_h_; - paddle::operators::detail::RequestHandler* request_prefetch_h_; + paddle::operators::distributed::RequestHandler* request_send_h_; + paddle::operators::distributed::RequestHandler* request_get_h_; + paddle::operators::distributed::RequestHandler* request_prefetch_h_; }; } // namespace sendrecv namespace paddle { namespace operators { -namespace detail { +namespace distributed { void AsyncBRPCServer::StartServer() { // Instance of your service. @@ -139,6 +139,6 @@ void AsyncBRPCServer::WaitServerReady() { VLOG(3) << "AsyncGRPCServer WaitSeverReady"; } -}; // namespace detail +}; // namespace distributed }; // namespace operators }; // namespace paddle diff --git a/paddle/fluid/operators/detail/brpc_server.h b/paddle/fluid/operators/distributed/brpc_server.h similarity index 88% rename from paddle/fluid/operators/detail/brpc_server.h rename to paddle/fluid/operators/distributed/brpc_server.h index 0105c8074a46849031d8fa9c21a5507a982ec3c3..85a7ad0dfe843dad483d43631b69a79d75211ce9 100644 --- a/paddle/fluid/operators/detail/brpc_server.h +++ b/paddle/fluid/operators/distributed/brpc_server.h @@ -19,12 +19,12 @@ limitations under the License. */ #include #include "brpc/server.h" -#include "paddle/fluid/operators/detail/rpc_server.h" -#include "paddle/fluid/operators/detail/send_recv.pb.h" +#include "paddle/fluid/operators/distributed/rpc_server.h" +#include "paddle/fluid/operators/distributed/send_recv.pb.h" namespace paddle { namespace operators { -namespace detail { +namespace distributed { class AsyncBRPCServer final : public RPCServer { public: @@ -48,6 +48,6 @@ class AsyncBRPCServer final : public RPCServer { int ready_; }; -}; // namespace detail +}; // namespace distributed }; // namespace operators }; // namespace paddle diff --git a/paddle/fluid/operators/detail/bytebuffer_stream.cc b/paddle/fluid/operators/distributed/bytebuffer_stream.cc similarity index 94% rename from paddle/fluid/operators/detail/bytebuffer_stream.cc rename to paddle/fluid/operators/distributed/bytebuffer_stream.cc index a14171563edb0ac9a22b7ae493c965de3efb7823..6e91b447db838c9095432eda22e9e1171e938d31 100644 --- a/paddle/fluid/operators/detail/bytebuffer_stream.cc +++ b/paddle/fluid/operators/distributed/bytebuffer_stream.cc @@ -17,11 +17,11 @@ limitations under the License. */ // file and did some modifications so that we can send gRPC // requests without too much copying of the tensor data. -#include "paddle/fluid/operators/detail/bytebuffer_stream.h" +#include "paddle/fluid/operators/distributed/bytebuffer_stream.h" namespace paddle { namespace operators { -namespace detail { +namespace distributed { GrpcByteBufferSource::GrpcByteBufferSource() {} @@ -83,6 +83,6 @@ google::protobuf::int64 GrpcByteBufferSource::ByteCount() const { return byte_count_; } -} // namespace detail +} // namespace distributed } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/detail/bytebuffer_stream.h b/paddle/fluid/operators/distributed/bytebuffer_stream.h similarity index 99% rename from paddle/fluid/operators/detail/bytebuffer_stream.h rename to paddle/fluid/operators/distributed/bytebuffer_stream.h index 054dd4ff294414cca55d7e033f2c5403bbb85526..e7de172c79c30761483b5d96f5bad19860208832 100644 --- a/paddle/fluid/operators/detail/bytebuffer_stream.h +++ b/paddle/fluid/operators/distributed/bytebuffer_stream.h @@ -106,7 +106,7 @@ class GrpcBufferReader final namespace paddle { namespace operators { -namespace detail { +namespace distributed { // Source provides a way for a particular RPC implementation to provide // received data to ParseFrom. class Source { @@ -183,6 +183,6 @@ class GrpcByteSource : public Source { char space_[sizeof(Reader)]; }; -} // namespace detail +} // namespace distributed } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/detail/grpc_client.cc b/paddle/fluid/operators/distributed/grpc_client.cc similarity index 97% rename from paddle/fluid/operators/detail/grpc_client.cc rename to paddle/fluid/operators/distributed/grpc_client.cc index ea004f7cd340030e61571825941a50e89735ef05..65d63784c1486f3304c053a2bd3c58b8b30eda2f 100644 --- a/paddle/fluid/operators/detail/grpc_client.cc +++ b/paddle/fluid/operators/distributed/grpc_client.cc @@ -12,19 +12,19 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/detail/grpc_client.h" +#include "paddle/fluid/operators/distributed/grpc_client.h" #include #include #include "paddle/fluid/framework/threadpool.h" -#include "paddle/fluid/operators/detail/request_handler.h" +#include "paddle/fluid/operators/distributed/request_handler.h" #include "paddle/fluid/platform/profiler.h" namespace paddle { namespace operators { -namespace detail { +namespace distributed { void GRPCClient::InitImpl() { InitEventLoop(); } @@ -276,6 +276,6 @@ std::shared_ptr GRPCClient::GetChannel(const std::string& ep) { return ch; } -} // namespace detail +} // namespace distributed } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/detail/grpc_client.h b/paddle/fluid/operators/distributed/grpc_client.h similarity index 97% rename from paddle/fluid/operators/detail/grpc_client.h rename to paddle/fluid/operators/distributed/grpc_client.h index 44000c028b499d9ad1a0e0dd40a5e287cd61d143..a6efa7dfd11e87caafa6109391f133b0233d58dd 100644 --- a/paddle/fluid/operators/detail/grpc_client.h +++ b/paddle/fluid/operators/distributed/grpc_client.h @@ -38,13 +38,13 @@ limitations under the License. */ #include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/selected_rows.h" -#include "paddle/fluid/operators/detail/rpc_client.h" -#include "paddle/fluid/operators/detail/sendrecvop_utils.h" +#include "paddle/fluid/operators/distributed/rpc_client.h" +#include "paddle/fluid/operators/distributed/sendrecvop_utils.h" #include "paddle/fluid/platform/macros.h" // for DISABLE_COPY_AND_ASSIGN namespace paddle { namespace operators { -namespace detail { +namespace distributed { struct VarHandle { std::string ep; @@ -226,6 +226,6 @@ class GRPCClient : public RPCClient { DISABLE_COPY_AND_ASSIGN(GRPCClient); }; -} // namespace detail +} // namespace distributed } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/detail/grpc_serde_test.cc b/paddle/fluid/operators/distributed/grpc_serde_test.cc similarity index 93% rename from paddle/fluid/operators/detail/grpc_serde_test.cc rename to paddle/fluid/operators/distributed/grpc_serde_test.cc index 15892295e6901fe649788c9e34604008fc8cbdfa..3d107b533bcb7bfef3f9b13ec99afbd579a62e52 100644 --- a/paddle/fluid/operators/detail/grpc_serde_test.cc +++ b/paddle/fluid/operators/distributed/grpc_serde_test.cc @@ -21,8 +21,8 @@ limitations under the License. */ #include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/tensor_util.h" #include "paddle/fluid/framework/variable.h" -#include "paddle/fluid/operators/detail/sendrecvop_utils.h" -#include "paddle/fluid/operators/detail/variable_response.h" +#include "paddle/fluid/operators/distributed/sendrecvop_utils.h" +#include "paddle/fluid/operators/distributed/variable_response.h" #include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/platform/place.h" #include "paddle/fluid/string/printf.h" @@ -50,7 +50,7 @@ void RunSerdeTestSelectedRows(platform::Place place) { for (int i = 0; i < 564; ++i) rows->push_back(i); ::grpc::ByteBuffer msg; - operators::detail::SerializeToByteBuffer("myvar", &var, ctx, &msg); + operators::distributed::SerializeToByteBuffer("myvar", &var, ctx, &msg); EXPECT_GT(msg.Length(), static_cast(0)); // deserialize @@ -81,10 +81,10 @@ void RunSerdeTestSelectedRows(platform::Place place) { // deserialize zero-copy // framework::Variable var2; - // operators::detail::DeserializeFromByteBuffer(msg, ctx, &var2); + // operators::distributed::DeserializeFromByteBuffer(msg, ctx, &var2); framework::Scope scope; scope.Var("myvar"); - operators::detail::VariableResponse resp(&scope, &ctx); + operators::distributed::VariableResponse resp(&scope, &ctx); EXPECT_EQ(resp.Parse(msg), 0); framework::Variable* var2 = resp.GetVar(); @@ -128,7 +128,7 @@ void RunTestLodTensor(platform::Place place, int from_type = 0) { math::set_constant(ctx, tensor, 31.9); ::grpc::ByteBuffer msg; - operators::detail::SerializeToByteBuffer("myvar", &var, ctx, &msg); + operators::distributed::SerializeToByteBuffer("myvar", &var, ctx, &msg); EXPECT_GT(msg.Length(), static_cast(0)); // deserialize @@ -171,7 +171,7 @@ void RunTestLodTensor(platform::Place place, int from_type = 0) { // deserialize zero-copy framework::Scope scope; scope.Var("myvar"); - operators::detail::VariableResponse resp(&scope, &ctx); + operators::distributed::VariableResponse resp(&scope, &ctx); if (from_type == 0) { EXPECT_EQ(resp.Parse(msg), 0); } else { diff --git a/paddle/fluid/operators/detail/grpc_server.cc b/paddle/fluid/operators/distributed/grpc_server.cc similarity index 96% rename from paddle/fluid/operators/detail/grpc_server.cc rename to paddle/fluid/operators/distributed/grpc_server.cc index 5a87258901c6563fe793d4041f344011a56d9a01..707f665a29d83d7cdf6e4e80624f2402a7b0a2e7 100644 --- a/paddle/fluid/operators/detail/grpc_server.cc +++ b/paddle/fluid/operators/distributed/grpc_server.cc @@ -15,13 +15,13 @@ limitations under the License. */ #include #include -#include "paddle/fluid/operators/detail/grpc_server.h" +#include "paddle/fluid/operators/distributed/grpc_server.h" using ::grpc::ServerAsyncResponseWriter; namespace paddle { namespace operators { -namespace detail { +namespace distributed { enum CallStatus { PROCESS = 0, FINISH }; // reference: @@ -74,7 +74,7 @@ class RequestSend final : public RequestBase { request_.reset(new VariableResponse(request_handler->scope(), request_handler->dev_ctx(), !request_handler->sync_mode())); - int method_id = static_cast(detail::GrpcMethod::kSendVariable); + int method_id = static_cast(distributed::GrpcMethod::kSendVariable); service_->RequestAsyncUnary( method_id, &ctx_, request_.get(), &responder_, cq_, cq_, reinterpret_cast(static_cast(req_id))); @@ -106,7 +106,7 @@ class RequestGet final : public RequestBase { ::grpc::ServerCompletionQueue* cq, RequestHandler* request_handler, int req_id) : RequestBase(service, cq, request_handler, req_id), responder_(&ctx_) { - auto method_id = static_cast(detail::GrpcMethod::kGetVariable); + auto method_id = static_cast(distributed::GrpcMethod::kGetVariable); service_->RequestAsyncUnary( method_id, &ctx_, &request_, &responder_, cq_, cq_, reinterpret_cast(static_cast(req_id))); @@ -150,7 +150,8 @@ class RequestPrefetch final : public RequestBase { local_scope_(nullptr) { request_.reset(new VariableResponse(request_handler->scope(), request_handler->dev_ctx(), true)); - int method_id = static_cast(detail::GrpcMethod::kPrefetchVariable); + int method_id = + static_cast(distributed::GrpcMethod::kPrefetchVariable); service_->RequestAsyncUnary( method_id, &ctx_, request_.get(), &responder_, cq_, cq_, reinterpret_cast(static_cast(req_id))); @@ -354,6 +355,6 @@ void AsyncGRPCServer::HandleRequest( } } -} // namespace detail +} // namespace distributed } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/detail/grpc_server.h b/paddle/fluid/operators/distributed/grpc_server.h similarity index 85% rename from paddle/fluid/operators/detail/grpc_server.h rename to paddle/fluid/operators/distributed/grpc_server.h index f1db7590f6f14d5d44acc12453861a446e278cd2..d2524f5e65db6dedab78f45e17380359b58a3d11 100644 --- a/paddle/fluid/operators/detail/grpc_server.h +++ b/paddle/fluid/operators/distributed/grpc_server.h @@ -29,17 +29,17 @@ limitations under the License. */ #include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/selected_rows.h" #include "paddle/fluid/framework/var_type.h" -#include "paddle/fluid/operators/detail/grpc_service.h" -#include "paddle/fluid/operators/detail/request_handler.h" -#include "paddle/fluid/operators/detail/rpc_server.h" -#include "paddle/fluid/operators/detail/send_recv.grpc.pb.h" -#include "paddle/fluid/operators/detail/send_recv.pb.h" -#include "paddle/fluid/operators/detail/sendrecvop_utils.h" +#include "paddle/fluid/operators/distributed/grpc_service.h" +#include "paddle/fluid/operators/distributed/request_handler.h" +#include "paddle/fluid/operators/distributed/rpc_server.h" +#include "paddle/fluid/operators/distributed/send_recv.grpc.pb.h" +#include "paddle/fluid/operators/distributed/send_recv.pb.h" +#include "paddle/fluid/operators/distributed/sendrecvop_utils.h" #include "paddle/fluid/platform/profiler.h" namespace paddle { namespace operators { -namespace detail { +namespace distributed { class RequestBase; @@ -84,6 +84,6 @@ class AsyncGRPCServer final : public RPCServer { std::map> rpc_reqs_; }; -}; // namespace detail +}; // namespace distributed }; // namespace operators }; // namespace paddle diff --git a/paddle/fluid/operators/detail/grpc_service.h b/paddle/fluid/operators/distributed/grpc_service.h similarity index 87% rename from paddle/fluid/operators/detail/grpc_service.h rename to paddle/fluid/operators/distributed/grpc_service.h index e0505c2b9d0903837713d7e0032b01ab091c2e04..141be3e68012743a32e4df5de148a55717f8e9a2 100644 --- a/paddle/fluid/operators/detail/grpc_service.h +++ b/paddle/fluid/operators/distributed/grpc_service.h @@ -23,7 +23,7 @@ #include #include #include -#include "paddle/fluid/operators/detail/variable_response.h" +#include "paddle/fluid/operators/distributed/variable_response.h" #include "paddle/fluid/platform/profiler.h" @@ -42,24 +42,25 @@ class ServerContext; // Support parsing/unparsing of tensorflow::VariableResponse. // Wire-format is identical to RecvVariableResponse. template <> -class SerializationTraits { +class SerializationTraits { public: static Status Serialize( - const paddle::operators::detail::VariableResponse& msg, + const paddle::operators::distributed::VariableResponse& msg, grpc_byte_buffer** bp, bool* own_buffer) { PADDLE_ENFORCE(false, "SerializationTraits::Serialize not implemented!"); return Status(); } - static Status Deserialize(grpc_byte_buffer* buffer, - paddle::operators::detail::VariableResponse* msg, - int max_message_size = INT_MAX) { + static Status Deserialize( + grpc_byte_buffer* buffer, + paddle::operators::distributed::VariableResponse* msg, + int max_message_size = INT_MAX) { if (buffer == nullptr) { return Status(StatusCode::INTERNAL, "No payload"); } Status result = g_core_codegen_interface->ok(); if (result.ok()) { - paddle::operators::detail::GrpcByteSource source(buffer); + paddle::operators::distributed::GrpcByteSource source(buffer); int ret = msg->Parse(&source); if (ret != 0) { result = Status(StatusCode::INTERNAL, "VariableResponse parse error"); @@ -73,7 +74,7 @@ class SerializationTraits { namespace paddle { namespace operators { -namespace detail { +namespace distributed { enum class GrpcMethod { kSendVariable, @@ -118,6 +119,6 @@ class GrpcService final { }; }; -} // namespace detail +} // namespace distributed } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/detail/proto_encoder_helper.h b/paddle/fluid/operators/distributed/proto_encoder_helper.h similarity index 98% rename from paddle/fluid/operators/detail/proto_encoder_helper.h rename to paddle/fluid/operators/distributed/proto_encoder_helper.h index d91d054b2507f32d1e948dde33da06a70cabe775..2fab02e32fe18ee04f86a69bb5bae1cbe7c6762c 100644 --- a/paddle/fluid/operators/detail/proto_encoder_helper.h +++ b/paddle/fluid/operators/distributed/proto_encoder_helper.h @@ -26,7 +26,7 @@ limitations under the License. */ namespace paddle { namespace operators { -namespace detail { +namespace distributed { char* EncodeVarint32(char* dst, uint32_t v) { // Operate on characters as unsigneds @@ -144,6 +144,6 @@ class ProtoEncodeHelper { char* limit_; // Just for CHECKs }; -} // namespace detail +} // namespace distributed } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/detail/request_handler.h b/paddle/fluid/operators/distributed/request_handler.h similarity index 98% rename from paddle/fluid/operators/detail/request_handler.h rename to paddle/fluid/operators/distributed/request_handler.h index a2d08747d59220d30a5b8fd56074fd2739ae3bab..cf106656aa56c2130d8be8dbe7478c3397f9b9ad 100644 --- a/paddle/fluid/operators/detail/request_handler.h +++ b/paddle/fluid/operators/distributed/request_handler.h @@ -31,7 +31,7 @@ namespace paddle { namespace operators { -namespace detail { +namespace distributed { constexpr char kRequestSend[] = "RequestSend"; constexpr char kRequestGet[] = "RequestGet"; @@ -124,6 +124,6 @@ class RequestHandler { RPCServer* rpc_server_; }; -} // namespace detail +} // namespace distributed } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/detail/request_handler_impl.cc b/paddle/fluid/operators/distributed/request_handler_impl.cc similarity index 95% rename from paddle/fluid/operators/detail/request_handler_impl.cc rename to paddle/fluid/operators/distributed/request_handler_impl.cc index 7425bee798cd9ba0af8cd777a6db63862c8a4031..cb78c15c01e8e7f47ec759a75090f9a6b880b493 100644 --- a/paddle/fluid/operators/detail/request_handler_impl.cc +++ b/paddle/fluid/operators/distributed/request_handler_impl.cc @@ -20,12 +20,12 @@ #include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/selected_rows.h" -#include "paddle/fluid/operators/detail/request_handler_impl.h" -#include "paddle/fluid/operators/detail/rpc_server.h" +#include "paddle/fluid/operators/distributed/request_handler_impl.h" +#include "paddle/fluid/operators/distributed/rpc_server.h" namespace paddle { namespace operators { -namespace detail { +namespace distributed { bool RequestSendHandler::Handle(const std::string& varname, framework::Scope* scope, @@ -119,6 +119,6 @@ bool RequestPrefetchHandler::Handle(const std::string& varname, return true; } -} // namespace detail +} // namespace distributed } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/detail/request_handler_impl.h b/paddle/fluid/operators/distributed/request_handler_impl.h similarity index 95% rename from paddle/fluid/operators/detail/request_handler_impl.h rename to paddle/fluid/operators/distributed/request_handler_impl.h index 3f77c09a9598b431d747f1b824615e49d939098e..abbe8778911a21ece3090bc9790d51a3cb31b6d7 100644 --- a/paddle/fluid/operators/detail/request_handler_impl.h +++ b/paddle/fluid/operators/distributed/request_handler_impl.h @@ -28,11 +28,11 @@ #include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/selected_rows.h" #include "paddle/fluid/framework/var_type.h" -#include "paddle/fluid/operators/detail/request_handler.h" +#include "paddle/fluid/operators/distributed/request_handler.h" namespace paddle { namespace operators { -namespace detail { +namespace distributed { class RequestSendHandler final : public RequestHandler { public: @@ -66,6 +66,6 @@ class RequestPrefetchHandler final : public RequestHandler { const std::string& out_var_name = "") override; }; -} // namespace detail +} // namespace distributed } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/detail/rpc_client.cc b/paddle/fluid/operators/distributed/rpc_client.cc similarity index 88% rename from paddle/fluid/operators/detail/rpc_client.cc rename to paddle/fluid/operators/distributed/rpc_client.cc index 9a791403e3d6b99c5d4de5183e83e1af655d7d4c..c71edf977c18e554c502732e9bf4bb4ea99f8f99 100644 --- a/paddle/fluid/operators/detail/rpc_client.cc +++ b/paddle/fluid/operators/distributed/rpc_client.cc @@ -12,15 +12,15 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/operators/detail/rpc_client.h" +#include "paddle/fluid/operators/distributed/rpc_client.h" namespace paddle { namespace operators { -namespace detail { +namespace distributed { std::once_flag RPCClient::init_flag_; std::unique_ptr RPCClient::rpc_client_(nullptr); -} // namespace detail +} // namespace distributed } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/detail/rpc_client.h b/paddle/fluid/operators/distributed/rpc_client.h similarity index 98% rename from paddle/fluid/operators/detail/rpc_client.h rename to paddle/fluid/operators/distributed/rpc_client.h index 47c6ffb4fd7a002fc0bd8053fb3314a2fbf18fd3..72fa6d940886bc676e9d03d13f12d07772f5f5a7 100644 --- a/paddle/fluid/operators/detail/rpc_client.h +++ b/paddle/fluid/operators/distributed/rpc_client.h @@ -22,7 +22,7 @@ namespace paddle { namespace operators { -namespace detail { +namespace distributed { class RPCClient { public: @@ -84,6 +84,6 @@ class RPCClient { static std::once_flag init_flag_; static std::unique_ptr rpc_client_; }; -} // namespace detail +} // namespace distributed } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/detail/rpc_server.cc b/paddle/fluid/operators/distributed/rpc_server.cc similarity index 96% rename from paddle/fluid/operators/detail/rpc_server.cc rename to paddle/fluid/operators/distributed/rpc_server.cc index cd0fe96e2301ee3304fe9a2967df58b9f7072d8d..fa0cb71b3056de92f65139c5402132fc8cbb7a87 100644 --- a/paddle/fluid/operators/detail/rpc_server.cc +++ b/paddle/fluid/operators/distributed/rpc_server.cc @@ -17,11 +17,11 @@ #include #include -#include "paddle/fluid/operators/detail/rpc_server.h" +#include "paddle/fluid/operators/distributed/rpc_server.h" namespace paddle { namespace operators { -namespace detail { +namespace distributed { void RPCServer::ShutDown() { LOG(INFO) << "RPCServer ShutDown "; @@ -112,6 +112,6 @@ void RPCServer::WaitCond(const std::string& rpc_name) { lock, [=] { return (cur_cond_.load() == cond || exit_flag_.load()); }); } -} // namespace detail +} // namespace distributed } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/detail/rpc_server.h b/paddle/fluid/operators/distributed/rpc_server.h similarity index 95% rename from paddle/fluid/operators/detail/rpc_server.h rename to paddle/fluid/operators/distributed/rpc_server.h index 2e3342428cb56c34abaca655d5906668cda8f140..cf25e78435bb470b25a46db647ca818571cc83a5 100644 --- a/paddle/fluid/operators/detail/rpc_server.h +++ b/paddle/fluid/operators/distributed/rpc_server.h @@ -19,11 +19,11 @@ #include // NOLINT #include #include -#include "paddle/fluid/operators/detail/request_handler.h" +#include "paddle/fluid/operators/distributed/request_handler.h" namespace paddle { namespace operators { -namespace detail { +namespace distributed { class RPCServer { public: @@ -86,6 +86,6 @@ class RPCServer { friend class RequestHandler; }; -}; // namespace detail +}; // namespace distributed }; // namespace operators }; // namespace paddle diff --git a/paddle/fluid/operators/detail/rpc_server_test.cc b/paddle/fluid/operators/distributed/rpc_server_test.cc similarity index 87% rename from paddle/fluid/operators/detail/rpc_server_test.cc rename to paddle/fluid/operators/distributed/rpc_server_test.cc index 463a7b80cfac280de5afe91ee85caaaf074cef32..a0693cffabcc561b0adfafc2c49027a890dd5efc 100644 --- a/paddle/fluid/operators/detail/rpc_server_test.cc +++ b/paddle/fluid/operators/distributed/rpc_server_test.cc @@ -22,18 +22,18 @@ limitations under the License. */ #include "paddle/fluid/framework/operator.h" #include "paddle/fluid/operators/detail/macros.h" -#include "paddle/fluid/operators/detail/request_handler_impl.h" -#include "paddle/fluid/operators/detail/rpc_client.h" -#include "paddle/fluid/operators/detail/rpc_server.h" +#include "paddle/fluid/operators/distributed/request_handler_impl.h" +#include "paddle/fluid/operators/distributed/rpc_client.h" +#include "paddle/fluid/operators/distributed/rpc_server.h" namespace framework = paddle::framework; namespace platform = paddle::platform; -namespace detail = paddle::operators::detail; +namespace distributed = paddle::operators::distributed; USE_OP(lookup_table); -std::unique_ptr g_rpc_service; -std::unique_ptr g_req_handler; +std::unique_ptr g_rpc_service; +std::unique_ptr g_req_handler; framework::BlockDesc* AppendPrefetchBlcok(framework::ProgramDesc* program) { auto root_block = program->MutableBlock(0); @@ -113,19 +113,21 @@ void StartServer() { g_req_handler->SetScope(&scope); g_req_handler->SetExecutor(&exe); - g_rpc_service->RegisterRPC(detail::kRequestPrefetch, g_req_handler.get()); + g_rpc_service->RegisterRPC(distributed::kRequestPrefetch, + g_req_handler.get()); g_req_handler->SetRPCServer(g_rpc_service.get()); std::thread server_thread( - std::bind(&detail::RPCServer::StartServer, g_rpc_service.get())); + std::bind(&distributed::RPCServer::StartServer, g_rpc_service.get())); server_thread.join(); } TEST(PREFETCH, CPU) { - g_req_handler.reset(new detail::RequestPrefetchHandler(true)); + g_req_handler.reset(new distributed::RequestPrefetchHandler(true)); g_rpc_service.reset(new RPCSERVER_T("127.0.0.1:0", 1)); - detail::RPCClient* client = detail::RPCClient::GetInstance(); + distributed::RPCClient* client = + distributed::RPCClient::GetInstance(); std::thread server_thread(StartServer); g_rpc_service->WaitServerReady(); diff --git a/paddle/fluid/operators/detail/send_recv.proto b/paddle/fluid/operators/distributed/send_recv.proto similarity index 100% rename from paddle/fluid/operators/detail/send_recv.proto rename to paddle/fluid/operators/distributed/send_recv.proto diff --git a/paddle/fluid/operators/detail/sendrecvop_utils.cc b/paddle/fluid/operators/distributed/sendrecvop_utils.cc similarity index 95% rename from paddle/fluid/operators/detail/sendrecvop_utils.cc rename to paddle/fluid/operators/distributed/sendrecvop_utils.cc index 507b465435609a91ebca97dd70b176c3b79bee02..98129d9f1014c39347e3409533f2bc10092611d2 100644 --- a/paddle/fluid/operators/detail/sendrecvop_utils.cc +++ b/paddle/fluid/operators/distributed/sendrecvop_utils.cc @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/detail/sendrecvop_utils.h" +#include "paddle/fluid/operators/distributed/sendrecvop_utils.h" #ifdef PADDLE_WITH_CUDA #include @@ -23,14 +23,14 @@ limitations under the License. */ #include "google/protobuf/io/coded_stream.h" #include "google/protobuf/io/zero_copy_stream.h" #include "paddle/fluid/framework/data_type.h" -#include "paddle/fluid/operators/detail/bytebuffer_stream.h" -#include "paddle/fluid/operators/detail/proto_encoder_helper.h" -#include "paddle/fluid/operators/detail/variable_response.h" +#include "paddle/fluid/operators/distributed/bytebuffer_stream.h" +#include "paddle/fluid/operators/distributed/proto_encoder_helper.h" +#include "paddle/fluid/operators/distributed/variable_response.h" #include "paddle/fluid/platform/profiler.h" namespace paddle { namespace operators { -namespace detail { +namespace distributed { using VarMsg = sendrecv::VariableMessage; @@ -222,11 +222,11 @@ void DeserializeFromByteBuffer(const ::grpc::ByteBuffer& msg, const platform::DeviceContext& ctx, const framework::Scope* scope, framework::Variable** var) { - operators::detail::VariableResponse resp(scope, &ctx); + operators::distributed::VariableResponse resp(scope, &ctx); PADDLE_ENFORCE(resp.Parse(msg) == 0, "parse bytebuffer to tensor error!"); *var = resp.GetVar(); } -} // namespace detail +} // namespace distributed } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/detail/sendrecvop_utils.h b/paddle/fluid/operators/distributed/sendrecvop_utils.h similarity index 92% rename from paddle/fluid/operators/detail/sendrecvop_utils.h rename to paddle/fluid/operators/distributed/sendrecvop_utils.h index bd16bf1dab8d933ffd18b6d6d9e3ce1c7d73029b..fe25e73fa608727ba0bb912a82776b330ec8d83a 100644 --- a/paddle/fluid/operators/detail/sendrecvop_utils.h +++ b/paddle/fluid/operators/distributed/sendrecvop_utils.h @@ -25,12 +25,12 @@ limitations under the License. */ #include "paddle/fluid/framework/tensor_util.h" #include "paddle/fluid/framework/var_type.h" -#include "paddle/fluid/operators/detail/send_recv.grpc.pb.h" -#include "paddle/fluid/operators/detail/send_recv.pb.h" +#include "paddle/fluid/operators/distributed/send_recv.grpc.pb.h" +#include "paddle/fluid/operators/distributed/send_recv.pb.h" namespace paddle { namespace operators { -namespace detail { +namespace distributed { typedef void (*DestroyCallback)(void*); @@ -61,6 +61,6 @@ inline std::type_index ToTypeIndex(sendrecv::VariableMessage::Type type) { } } -} // namespace detail +} // namespace distributed } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/detail/variable_response.cc b/paddle/fluid/operators/distributed/variable_response.cc similarity index 96% rename from paddle/fluid/operators/detail/variable_response.cc rename to paddle/fluid/operators/distributed/variable_response.cc index 24cb91a3bb820a0e5d51aaa49154434919080f69..619890b1939be8777b89b94e415a3c2d63376658 100644 --- a/paddle/fluid/operators/detail/variable_response.cc +++ b/paddle/fluid/operators/distributed/variable_response.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/operators/detail/variable_response.h" +#include "paddle/fluid/operators/distributed/variable_response.h" #include #include @@ -22,12 +22,12 @@ #endif #include "paddle/fluid/platform/profiler.h" -#include "paddle/fluid/operators/detail/send_recv.pb.h" -#include "paddle/fluid/operators/detail/sendrecvop_utils.h" +#include "paddle/fluid/operators/distributed/send_recv.pb.h" +#include "paddle/fluid/operators/distributed/sendrecvop_utils.h" namespace paddle { namespace operators { -namespace detail { +namespace distributed { enum WireType { WIRETYPE_VARINT = 0, @@ -158,13 +158,13 @@ bool VariableResponse::CopySelectRowsTensorData( slr->set_height(meta_.slr_height()); auto* tensor = slr->mutable_value(); tensor->Resize(dims); - PADDLE_ENFORCE_EQ( - static_cast(tensor->numel()), - length / framework::SizeOfType( - paddle::operators::detail::ToTypeIndex(meta_.data_type()))); + PADDLE_ENFORCE_EQ(static_cast(tensor->numel()), + length / framework::SizeOfType( + paddle::operators::distributed::ToTypeIndex( + meta_.data_type()))); void* tensor_data = tensor->mutable_data( ctx.GetPlace(), - paddle::operators::detail::ToTypeIndex(meta_.data_type())); + paddle::operators::distributed::ToTypeIndex(meta_.data_type())); if (!ReadRaw(input, ctx, tensor->place(), tensor_data, length)) { return false; @@ -480,6 +480,6 @@ int VariableResponse::Parse(Source* source) { return 0; } -}; // namespace detail +}; // namespace distributed }; // namespace operators }; // namespace paddle diff --git a/paddle/fluid/operators/detail/variable_response.h b/paddle/fluid/operators/distributed/variable_response.h similarity index 92% rename from paddle/fluid/operators/detail/variable_response.h rename to paddle/fluid/operators/distributed/variable_response.h index 69cfd784f8dd4f129f50c6882061e53e8535b949..1db4a0a522654ff2497b8bd9ee1381b5ab64067a 100644 --- a/paddle/fluid/operators/detail/variable_response.h +++ b/paddle/fluid/operators/distributed/variable_response.h @@ -22,17 +22,17 @@ #include "paddle/fluid/framework/selected_rows.h" #include "paddle/fluid/framework/var_type.h" -#include "paddle/fluid/operators/detail/send_recv.grpc.pb.h" -#include "paddle/fluid/operators/detail/send_recv.pb.h" +#include "paddle/fluid/operators/distributed/send_recv.grpc.pb.h" +#include "paddle/fluid/operators/distributed/send_recv.pb.h" #include "google/protobuf/io/coded_stream.h" #include "google/protobuf/io/zero_copy_stream.h" #include "paddle/fluid/framework/tensor.h" -#include "paddle/fluid/operators/detail/bytebuffer_stream.h" +#include "paddle/fluid/operators/distributed/bytebuffer_stream.h" namespace paddle { namespace operators { -namespace detail { +namespace distributed { class VariableResponse { public: @@ -99,6 +99,6 @@ class VariableResponse { sendrecv::VariableMessage meta_; }; -}; // namespace detail +}; // namespace distributed }; // namespace operators }; // namespace paddle diff --git a/paddle/fluid/operators/fetch_barrier_op.cc b/paddle/fluid/operators/fetch_barrier_op.cc index 98b051afb551f373009d2bd3df1a8daa64b7e6c7..02beb80fc8a9f451393dcdd54492c4f88f908497 100644 --- a/paddle/fluid/operators/fetch_barrier_op.cc +++ b/paddle/fluid/operators/fetch_barrier_op.cc @@ -42,8 +42,8 @@ class FetchBarrierOp : public framework::OperatorBase { // For profiling platform::RecordEvent record_event(Type(), &ctx); - detail::RPCClient* rpc_client = - detail::RPCClient::GetInstance(); + distributed::RPCClient* rpc_client = + distributed::RPCClient::GetInstance(); rpc_client->Wait(); diff --git a/paddle/fluid/operators/gen_nccl_id_op.cc b/paddle/fluid/operators/gen_nccl_id_op.cc index f824eee4e7d1ef19c9a38fd5d3369265f9c549a0..697c239e59d158428ae9ba9f7feded19637dff28 100644 --- a/paddle/fluid/operators/gen_nccl_id_op.cc +++ b/paddle/fluid/operators/gen_nccl_id_op.cc @@ -22,7 +22,7 @@ limitations under the License. */ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/threadpool.h" #include "paddle/fluid/operators/detail/macros.h" -#include "paddle/fluid/operators/detail/request_handler_impl.h" +#include "paddle/fluid/operators/distributed/request_handler_impl.h" #include "paddle/fluid/platform/nccl_helper.h" namespace paddle { @@ -60,7 +60,8 @@ class GenNCCLIdOp : public framework::OperatorBase { std::vector endpoint_list = Attr>("endpoint_list"); - detail::RPCClient* client = detail::RPCClient::GetInstance(); + distributed::RPCClient* client = + distributed::RPCClient::GetInstance(); for (auto& ep : endpoint_list) { VLOG(3) << "sending nccl id to " << ep; @@ -80,11 +81,11 @@ class GenNCCLIdOp : public framework::OperatorBase { // NOTE: Can not use unique_ptr here because the default // deleter will call GRPC Server's base class's dtor and // that will cause a wired crash. - detail::RequestSendHandler rpc_h(true); - std::unique_ptr rpc_service( + distributed::RequestSendHandler rpc_h(true); + std::unique_ptr rpc_service( new RPCSERVER_T(endpoint, 1)); - rpc_service->RegisterRPC(detail::kRequestSend, &rpc_h); + rpc_service->RegisterRPC(distributed::kRequestSend, &rpc_h); rpc_h.SetRPCServer(rpc_service.get()); framework::ProgramDesc empty_program; @@ -95,11 +96,11 @@ class GenNCCLIdOp : public framework::OperatorBase { rpc_h.SetExecutor(&executor); std::thread server_thread( - std::bind(&detail::RPCServer::StartServer, rpc_service.get())); + std::bind(&distributed::RPCServer::StartServer, rpc_service.get())); - rpc_service->SetCond(detail::kRequestSend); + rpc_service->SetCond(distributed::kRequestSend); VLOG(3) << "start getting nccl id from trainer 0..."; - rpc_service->WaitBarrier(detail::kRequestSend); + rpc_service->WaitBarrier(distributed::kRequestSend); VLOG(3) << "got nccl id and stop server..."; rpc_service->ShutDown(); VLOG(3) << "rpc server stopped"; diff --git a/paddle/fluid/operators/listen_and_serv_op.cc b/paddle/fluid/operators/listen_and_serv_op.cc index 57c2ce457791d830e4230aa25e1c5b358f476782..f840064ecaca4bc38191727da39d07676dc18ee1 100644 --- a/paddle/fluid/operators/listen_and_serv_op.cc +++ b/paddle/fluid/operators/listen_and_serv_op.cc @@ -21,14 +21,14 @@ limitations under the License. */ #include "paddle/fluid/operators/detail/macros.h" -#include "paddle/fluid/operators/detail/request_handler_impl.h" +#include "paddle/fluid/operators/distributed/request_handler_impl.h" #include "paddle/fluid/operators/listen_and_serv_op.h" #include "paddle/fluid/platform/profiler.h" namespace paddle { namespace operators { -void RunServer(std::shared_ptr service) { +void RunServer(std::shared_ptr service) { service->StartServer(); VLOG(4) << "RunServer thread end"; } @@ -121,12 +121,12 @@ void ListenAndServOp::RunSyncLoop( while (true) { // Get from multiple trainers, we don't care about the order in which // the gradients arrives, just add suffix 0~n and merge the gradient. - rpc_service_->SetCond(detail::kRequestSend); - rpc_service_->WaitBarrier(detail::kRequestSend); + rpc_service_->SetCond(distributed::kRequestSend); + rpc_service_->WaitBarrier(distributed::kRequestSend); if (rpc_service_->IsExit()) { LOG(WARNING) << "get exit!rpc_processor break!"; - rpc_service_->SetCond(detail::kRequestGet); + rpc_service_->SetCond(distributed::kRequestGet); break; } @@ -154,11 +154,11 @@ void ListenAndServOp::RunSyncLoop( recv_scope); VLOG(2) << "run all blocks spent " << GetTimestamp() - ts << "(ms)"; - rpc_service_->SetCond(detail::kRequestGet); - rpc_service_->WaitBarrier(detail::kRequestGet); + rpc_service_->SetCond(distributed::kRequestGet); + rpc_service_->WaitBarrier(distributed::kRequestGet); rpc_service_->ResetBarrierCounter(); // reset received sparse vars to avoid reuse it in the next mini-batch - dynamic_cast(request_send_handler_.get()) + dynamic_cast(request_send_handler_.get()) ->ResetSparseVarRecorder(); } // while(true) } @@ -215,13 +215,13 @@ void ListenAndServOp::RunAsyncLoop(framework::Executor *executor, } static void FillRequestCtx( - detail::RequestHandler *h, framework::Scope *scope, + distributed::RequestHandler *h, framework::Scope *scope, platform::DeviceContext *dev_ctx, framework::Executor *executor, framework::ProgramDesc *program, std::unordered_map> *prefetch_ctx, - detail::RPCServer *rpc_server) { + distributed::RPCServer *rpc_server) { h->SetScope(scope); h->SetDevCtx(dev_ctx); h->SetExecutor(executor); @@ -249,14 +249,16 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope, rpc_service_.reset(new RPCSERVER_T(endpoint, fan_in)); - request_send_handler_.reset(new detail::RequestSendHandler(sync_mode)); - request_get_handler_.reset(new detail::RequestGetHandler(sync_mode)); + request_send_handler_.reset(new distributed::RequestSendHandler(sync_mode)); + request_get_handler_.reset(new distributed::RequestGetHandler(sync_mode)); request_prefetch_handler_.reset( - new detail::RequestPrefetchHandler(sync_mode)); + new distributed::RequestPrefetchHandler(sync_mode)); - rpc_service_->RegisterRPC(detail::kRequestSend, request_send_handler_.get()); - rpc_service_->RegisterRPC(detail::kRequestGet, request_get_handler_.get()); - rpc_service_->RegisterRPC(detail::kRequestPrefetch, + rpc_service_->RegisterRPC(distributed::kRequestSend, + request_send_handler_.get()); + rpc_service_->RegisterRPC(distributed::kRequestGet, + request_get_handler_.get()); + rpc_service_->RegisterRPC(distributed::kRequestPrefetch, request_prefetch_handler_.get()); auto *optimize_block = Attr(kOptimizeBlock); diff --git a/paddle/fluid/operators/listen_and_serv_op.h b/paddle/fluid/operators/listen_and_serv_op.h index 46c3a19e20b3f2dd970a672bb99f98e83d3e25bf..9aa322ad602d7a72bb90aaa4a67e7f2f7a3a54cd 100644 --- a/paddle/fluid/operators/listen_and_serv_op.h +++ b/paddle/fluid/operators/listen_and_serv_op.h @@ -24,8 +24,8 @@ limitations under the License. */ #include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/threadpool.h" -#include "paddle/fluid/operators/detail/request_handler.h" -#include "paddle/fluid/operators/detail/rpc_server.h" +#include "paddle/fluid/operators/distributed/request_handler.h" +#include "paddle/fluid/operators/distributed/rpc_server.h" namespace paddle { namespace operators { @@ -33,7 +33,7 @@ namespace operators { constexpr char kOptimizeBlock[] = "OptimizeBlock"; constexpr char kPrefetchVarNameToBlockId[] = "prefetch_var_name_to_block_id"; -void RunServer(std::shared_ptr service); +void RunServer(std::shared_ptr service); class ListenAndServOp : public framework::OperatorBase { public: @@ -62,10 +62,11 @@ class ListenAndServOp : public framework::OperatorBase { const platform::Place& dev_place) const override; protected: - mutable std::shared_ptr rpc_service_; - mutable std::shared_ptr request_send_handler_; - mutable std::shared_ptr request_get_handler_; - mutable std::shared_ptr request_prefetch_handler_; + mutable std::shared_ptr rpc_service_; + mutable std::shared_ptr request_send_handler_; + mutable std::shared_ptr request_get_handler_; + mutable std::shared_ptr + request_prefetch_handler_; mutable std::shared_ptr server_thread_; }; diff --git a/paddle/fluid/operators/prefetch_op.cc b/paddle/fluid/operators/prefetch_op.cc index f71ba84b318c1f8b0604310f3db8a0826124e207..8734282fe496b8e90af19abd5549566d62316fc3 100644 --- a/paddle/fluid/operators/prefetch_op.cc +++ b/paddle/fluid/operators/prefetch_op.cc @@ -41,8 +41,8 @@ class PrefetchOp : public framework::OperatorBase { platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); auto& ctx = *pool.Get(place); - detail::RPCClient* rpc_client = - detail::RPCClient::GetInstance(); + distributed::RPCClient* rpc_client = + distributed::RPCClient::GetInstance(); for (size_t i = 0; i < ins.size(); i++) { if (NeedSend(scope, ins[i])) { diff --git a/paddle/fluid/operators/recv_op.cc b/paddle/fluid/operators/recv_op.cc index 15dfb5469bf51330b98d6699fb3ce708222212ed..9854a31f5b10f5ecd940c0d41c2c3e468fc17bad 100644 --- a/paddle/fluid/operators/recv_op.cc +++ b/paddle/fluid/operators/recv_op.cc @@ -43,8 +43,8 @@ class RecvOp : public framework::OperatorBase { // For profiling platform::RecordEvent record_event(Type(), &ctx); - detail::RPCClient* rpc_client = - detail::RPCClient::GetInstance(); + distributed::RPCClient* rpc_client = + distributed::RPCClient::GetInstance(); for (size_t i = 0; i < outs.size(); i++) { VLOG(3) << "getting " << outs[i] << " from " << epmap[i]; diff --git a/paddle/fluid/operators/send_barrier_op.cc b/paddle/fluid/operators/send_barrier_op.cc index c6c975a23ce846464388c72af5d8902144ceb16a..6b4572dcccc21e783f1df0b9bcde11d532ff4ba8 100644 --- a/paddle/fluid/operators/send_barrier_op.cc +++ b/paddle/fluid/operators/send_barrier_op.cc @@ -44,8 +44,8 @@ class SendBarrierOp : public framework::OperatorBase { // For profiling platform::RecordEvent record_event(Type(), &ctx); - detail::RPCClient* rpc_client = - detail::RPCClient::GetInstance(); + distributed::RPCClient* rpc_client = + distributed::RPCClient::GetInstance(); VLOG(3) << "SendBarrierOp sync_mode:" << sync_mode; diff --git a/paddle/fluid/operators/send_op.cc b/paddle/fluid/operators/send_op.cc index 84ec36625314572d16e5c537884b6efec420cc60..0cac329aafa8c4c67cae48ba62a48575f5edba92 100644 --- a/paddle/fluid/operators/send_op.cc +++ b/paddle/fluid/operators/send_op.cc @@ -45,8 +45,8 @@ class SendOp : public framework::OperatorBase { // For profiling platform::RecordEvent record_event(Type(), &ctx); - detail::RPCClient* rpc_client = - detail::RPCClient::GetInstance(); + distributed::RPCClient* rpc_client = + distributed::RPCClient::GetInstance(); for (size_t i = 0; i < ins.size(); i++) { if (NeedSend(scope, ins[i])) { diff --git a/paddle/fluid/operators/test_send_nccl_id.cc b/paddle/fluid/operators/test_send_nccl_id.cc index 5015b1005569ba70b147ebb795243e24ab81ea5c..e2b7b6b8e447381229e4ad594b7974bc0aa159d5 100644 --- a/paddle/fluid/operators/test_send_nccl_id.cc +++ b/paddle/fluid/operators/test_send_nccl_id.cc @@ -21,7 +21,7 @@ limitations under the License. */ #include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/operators/detail/macros.h" -#include "paddle/fluid/operators/detail/request_handler_impl.h" +#include "paddle/fluid/operators/distributed/request_handler_impl.h" #include "paddle/fluid/operators/listen_and_serv_op.h" #include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/math/selected_rows_functor.h" @@ -37,11 +37,11 @@ USE_NO_KERNEL_OP(listen_and_serv); namespace f = paddle::framework; namespace p = paddle::platform; namespace m = paddle::operators::math; -namespace detail = paddle::operators::detail; +namespace distributed = paddle::operators::distributed; namespace string = paddle::string; -std::unique_ptr g_rpc_service; -std::unique_ptr g_req_handler; +std::unique_ptr g_rpc_service; +std::unique_ptr g_req_handler; void StartServer() { f::Scope scope; @@ -57,14 +57,14 @@ void StartServer() { g_req_handler->SetProgram(&empty_program); g_req_handler->SetExecutor(&executor); - g_rpc_service->RegisterRPC(detail::kRequestSend, g_req_handler.get()); + g_rpc_service->RegisterRPC(distributed::kRequestSend, g_req_handler.get()); g_req_handler->SetRPCServer(g_rpc_service.get()); std::thread server_thread( - std::bind(&detail::RPCServer::StartServer, g_rpc_service.get())); + std::bind(&distributed::RPCServer::StartServer, g_rpc_service.get())); - g_rpc_service->SetCond(detail::kRequestSend); - g_rpc_service->WaitBarrier(detail::kRequestSend); + g_rpc_service->SetCond(distributed::kRequestSend); + g_rpc_service->WaitBarrier(distributed::kRequestSend); LOG(INFO) << "got nccl id and stop server..."; g_rpc_service->ShutDown(); @@ -72,7 +72,7 @@ void StartServer() { } TEST(SendNcclId, RPCServer) { - g_req_handler.reset(new detail::RequestSendHandler(true)); + g_req_handler.reset(new distributed::RequestSendHandler(true)); g_rpc_service.reset(new RPCSERVER_T("127.0.0.1:0", 1)); std::thread server_thread(StartServer); @@ -91,7 +91,8 @@ TEST(SendNcclId, RPCServer) { std::string ep = string::Sprintf("127.0.0.1:%d", port); - detail::RPCClient* client = detail::RPCClient::GetInstance(); + distributed::RPCClient* client = + distributed::RPCClient::GetInstance(); LOG(INFO) << "connect to server" << ep; client->AsyncSendVar(ep, dev_ctx, scope, NCCL_ID_VARNAME);