未验证 提交 f11e843a 编写于 作者: W WangXi 提交者: GitHub

[fleet_executor] Add interceptor register (#37338)

上级 715fd051
......@@ -31,12 +31,11 @@ class MessageBus;
class FleetExecutor final {
public:
FleetExecutor() = delete;
FleetExecutor(const std::string& exe_desc_str);
explicit FleetExecutor(const std::string& exe_desc_str);
~FleetExecutor();
void Init(const paddle::framework::ProgramDesc& program_desc);
void Run();
void Release();
static std::shared_ptr<Carrier> GetCarrier();
private:
DISABLE_COPY_AND_ASSIGN(FleetExecutor);
......
......@@ -115,5 +115,27 @@ bool Interceptor::FetchRemoteMailbox() {
return true;
}
static InterceptorFactory::CreateInterceptorMap& GetInterceptorMap() {
static InterceptorFactory::CreateInterceptorMap interceptorMap;
return interceptorMap;
}
std::unique_ptr<Interceptor> InterceptorFactory::Create(const std::string& type,
int64_t id,
TaskNode* node) {
auto& interceptor_map = GetInterceptorMap();
auto iter = interceptor_map.find(type);
PADDLE_ENFORCE_NE(
iter, interceptor_map.end(),
platform::errors::NotFound("interceptor %s is not register", type));
return iter->second(id, node);
}
void InterceptorFactory::Register(
const std::string& type, InterceptorFactory::CreateInterceptorFunc func) {
auto& interceptor_map = GetInterceptorMap();
interceptor_map.emplace(type, func);
}
} // namespace distributed
} // namespace paddle
......@@ -98,5 +98,32 @@ class Interceptor {
std::queue<InterceptorMessage> local_mailbox_;
};
class InterceptorFactory {
public:
using CreateInterceptorFunc = std::unique_ptr<Interceptor> (*)(int64_t,
TaskNode*);
using CreateInterceptorMap =
std::unordered_map<std::string, CreateInterceptorFunc>;
static void Register(const std::string& type, CreateInterceptorFunc func);
static std::unique_ptr<Interceptor> Create(const std::string& type,
int64_t id, TaskNode* node);
};
#define REGISTER_INTERCEPTOR(interceptor_type, interceptor_class) \
std::unique_ptr<Interceptor> CreatorInterceptor_##interceptor_type( \
int64_t id, TaskNode* node) { \
return std::make_unique<interceptor_class>(id, node); \
} \
class __RegisterInterceptor_##interceptor_type { \
public: \
__RegisterInterceptor_##interceptor_type() { \
InterceptorFactory::Register(#interceptor_type, \
CreatorInterceptor_##interceptor_type); \
} \
}; \
__RegisterInterceptor_##interceptor_type g_register_##interceptor_type;
} // namespace distributed
} // namespace paddle
......@@ -51,6 +51,8 @@ class PingPongInterceptor : public Interceptor {
int count_{0};
};
REGISTER_INTERCEPTOR(PingPong, PingPongInterceptor);
TEST(InterceptorTest, PingPong) {
MessageBus& msg_bus = MessageBus::Instance();
msg_bus.Init({{0, 0}, {1, 0}}, {{0, "127.0.0.0:0"}}, "127.0.0.0:0");
......@@ -58,7 +60,8 @@ TEST(InterceptorTest, PingPong) {
Carrier& carrier = Carrier::Instance();
Interceptor* a = carrier.SetInterceptor(
0, std::make_unique<PingPongInterceptor>(0, nullptr));
0, InterceptorFactory::Create("PingPong", 0, nullptr));
carrier.SetInterceptor(1, std::make_unique<PingPongInterceptor>(1, nullptr));
carrier.SetCreatingFlag(false);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册