diff --git a/paddle/fluid/distributed/fleet_executor/fleet_executor.h b/paddle/fluid/distributed/fleet_executor/fleet_executor.h index c939f70955c61337c861425f28e12941d58ec626..6343e21f7dcb836758b479d43227b2fd85eb5704 100644 --- a/paddle/fluid/distributed/fleet_executor/fleet_executor.h +++ b/paddle/fluid/distributed/fleet_executor/fleet_executor.h @@ -31,12 +31,11 @@ class MessageBus; class FleetExecutor final { public: FleetExecutor() = delete; - FleetExecutor(const std::string& exe_desc_str); + explicit FleetExecutor(const std::string& exe_desc_str); ~FleetExecutor(); void Init(const paddle::framework::ProgramDesc& program_desc); void Run(); void Release(); - static std::shared_ptr GetCarrier(); private: DISABLE_COPY_AND_ASSIGN(FleetExecutor); diff --git a/paddle/fluid/distributed/fleet_executor/interceptor.cc b/paddle/fluid/distributed/fleet_executor/interceptor.cc index 696f7dd752eec3f2f5e800f8f7c76bf7e8befb1d..52806da6ad0ffc1009108ffa13ad0dbeb8bfa551 100644 --- a/paddle/fluid/distributed/fleet_executor/interceptor.cc +++ b/paddle/fluid/distributed/fleet_executor/interceptor.cc @@ -115,5 +115,27 @@ bool Interceptor::FetchRemoteMailbox() { return true; } +static InterceptorFactory::CreateInterceptorMap& GetInterceptorMap() { + static InterceptorFactory::CreateInterceptorMap interceptorMap; + return interceptorMap; +} + +std::unique_ptr InterceptorFactory::Create(const std::string& type, + int64_t id, + TaskNode* node) { + auto& interceptor_map = GetInterceptorMap(); + auto iter = interceptor_map.find(type); + PADDLE_ENFORCE_NE( + iter, interceptor_map.end(), + platform::errors::NotFound("interceptor %s is not register", type)); + return iter->second(id, node); +} + +void InterceptorFactory::Register( + const std::string& type, InterceptorFactory::CreateInterceptorFunc func) { + auto& interceptor_map = GetInterceptorMap(); + interceptor_map.emplace(type, func); +} + } // namespace distributed } // namespace paddle diff --git a/paddle/fluid/distributed/fleet_executor/interceptor.h b/paddle/fluid/distributed/fleet_executor/interceptor.h index 2e86dc2fe525d44d491aad7b8ef730317b50777f..9497d7f3de0fa0ac99ea45ff04a90afe1653c11a 100644 --- a/paddle/fluid/distributed/fleet_executor/interceptor.h +++ b/paddle/fluid/distributed/fleet_executor/interceptor.h @@ -98,5 +98,32 @@ class Interceptor { std::queue local_mailbox_; }; +class InterceptorFactory { + public: + using CreateInterceptorFunc = std::unique_ptr (*)(int64_t, + TaskNode*); + using CreateInterceptorMap = + std::unordered_map; + + static void Register(const std::string& type, CreateInterceptorFunc func); + + static std::unique_ptr Create(const std::string& type, + int64_t id, TaskNode* node); +}; + +#define REGISTER_INTERCEPTOR(interceptor_type, interceptor_class) \ + std::unique_ptr CreatorInterceptor_##interceptor_type( \ + int64_t id, TaskNode* node) { \ + return std::make_unique(id, node); \ + } \ + class __RegisterInterceptor_##interceptor_type { \ + public: \ + __RegisterInterceptor_##interceptor_type() { \ + InterceptorFactory::Register(#interceptor_type, \ + CreatorInterceptor_##interceptor_type); \ + } \ + }; \ + __RegisterInterceptor_##interceptor_type g_register_##interceptor_type; + } // namespace distributed } // namespace paddle diff --git a/paddle/fluid/distributed/fleet_executor/test/interceptor_ping_pong_test.cc b/paddle/fluid/distributed/fleet_executor/test/interceptor_ping_pong_test.cc index 783c924398a70307b47c28b90082241fb711b344..52df12395d55a6173098cdb44b519c69bc682d78 100644 --- a/paddle/fluid/distributed/fleet_executor/test/interceptor_ping_pong_test.cc +++ b/paddle/fluid/distributed/fleet_executor/test/interceptor_ping_pong_test.cc @@ -51,6 +51,8 @@ class PingPongInterceptor : public Interceptor { int count_{0}; }; +REGISTER_INTERCEPTOR(PingPong, PingPongInterceptor); + TEST(InterceptorTest, PingPong) { MessageBus& msg_bus = MessageBus::Instance(); msg_bus.Init({{0, 0}, {1, 0}}, {{0, "127.0.0.0:0"}}, "127.0.0.0:0"); @@ -58,7 +60,8 @@ TEST(InterceptorTest, PingPong) { Carrier& carrier = Carrier::Instance(); Interceptor* a = carrier.SetInterceptor( - 0, std::make_unique(0, nullptr)); + 0, InterceptorFactory::Create("PingPong", 0, nullptr)); + carrier.SetInterceptor(1, std::make_unique(1, nullptr)); carrier.SetCreatingFlag(false);