Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
f11e843a
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
f11e843a
编写于
11月 19, 2021
作者:
W
WangXi
提交者:
GitHub
11月 19, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[fleet_executor] Add interceptor register (#37338)
上级
715fd051
变更
4
显示空白变更内容
内联
并排
Showing
4 changed file
with
54 addition
and
3 deletion
+54
-3
paddle/fluid/distributed/fleet_executor/fleet_executor.h
paddle/fluid/distributed/fleet_executor/fleet_executor.h
+1
-2
paddle/fluid/distributed/fleet_executor/interceptor.cc
paddle/fluid/distributed/fleet_executor/interceptor.cc
+22
-0
paddle/fluid/distributed/fleet_executor/interceptor.h
paddle/fluid/distributed/fleet_executor/interceptor.h
+27
-0
paddle/fluid/distributed/fleet_executor/test/interceptor_ping_pong_test.cc
...ributed/fleet_executor/test/interceptor_ping_pong_test.cc
+4
-1
未找到文件。
paddle/fluid/distributed/fleet_executor/fleet_executor.h
浏览文件 @
f11e843a
...
@@ -31,12 +31,11 @@ class MessageBus;
...
@@ -31,12 +31,11 @@ class MessageBus;
class
FleetExecutor
final
{
class
FleetExecutor
final
{
public:
public:
FleetExecutor
()
=
delete
;
FleetExecutor
()
=
delete
;
FleetExecutor
(
const
std
::
string
&
exe_desc_str
);
explicit
FleetExecutor
(
const
std
::
string
&
exe_desc_str
);
~
FleetExecutor
();
~
FleetExecutor
();
void
Init
(
const
paddle
::
framework
::
ProgramDesc
&
program_desc
);
void
Init
(
const
paddle
::
framework
::
ProgramDesc
&
program_desc
);
void
Run
();
void
Run
();
void
Release
();
void
Release
();
static
std
::
shared_ptr
<
Carrier
>
GetCarrier
();
private:
private:
DISABLE_COPY_AND_ASSIGN
(
FleetExecutor
);
DISABLE_COPY_AND_ASSIGN
(
FleetExecutor
);
...
...
paddle/fluid/distributed/fleet_executor/interceptor.cc
浏览文件 @
f11e843a
...
@@ -115,5 +115,27 @@ bool Interceptor::FetchRemoteMailbox() {
...
@@ -115,5 +115,27 @@ bool Interceptor::FetchRemoteMailbox() {
return
true
;
return
true
;
}
}
static
InterceptorFactory
::
CreateInterceptorMap
&
GetInterceptorMap
()
{
static
InterceptorFactory
::
CreateInterceptorMap
interceptorMap
;
return
interceptorMap
;
}
std
::
unique_ptr
<
Interceptor
>
InterceptorFactory
::
Create
(
const
std
::
string
&
type
,
int64_t
id
,
TaskNode
*
node
)
{
auto
&
interceptor_map
=
GetInterceptorMap
();
auto
iter
=
interceptor_map
.
find
(
type
);
PADDLE_ENFORCE_NE
(
iter
,
interceptor_map
.
end
(),
platform
::
errors
::
NotFound
(
"interceptor %s is not register"
,
type
));
return
iter
->
second
(
id
,
node
);
}
void
InterceptorFactory
::
Register
(
const
std
::
string
&
type
,
InterceptorFactory
::
CreateInterceptorFunc
func
)
{
auto
&
interceptor_map
=
GetInterceptorMap
();
interceptor_map
.
emplace
(
type
,
func
);
}
}
// namespace distributed
}
// namespace distributed
}
// namespace paddle
}
// namespace paddle
paddle/fluid/distributed/fleet_executor/interceptor.h
浏览文件 @
f11e843a
...
@@ -98,5 +98,32 @@ class Interceptor {
...
@@ -98,5 +98,32 @@ class Interceptor {
std
::
queue
<
InterceptorMessage
>
local_mailbox_
;
std
::
queue
<
InterceptorMessage
>
local_mailbox_
;
};
};
class
InterceptorFactory
{
public:
using
CreateInterceptorFunc
=
std
::
unique_ptr
<
Interceptor
>
(
*
)(
int64_t
,
TaskNode
*
);
using
CreateInterceptorMap
=
std
::
unordered_map
<
std
::
string
,
CreateInterceptorFunc
>
;
static
void
Register
(
const
std
::
string
&
type
,
CreateInterceptorFunc
func
);
static
std
::
unique_ptr
<
Interceptor
>
Create
(
const
std
::
string
&
type
,
int64_t
id
,
TaskNode
*
node
);
};
#define REGISTER_INTERCEPTOR(interceptor_type, interceptor_class) \
std::unique_ptr<Interceptor> CreatorInterceptor_##interceptor_type( \
int64_t id, TaskNode* node) { \
return std::make_unique<interceptor_class>(id, node); \
} \
class __RegisterInterceptor_##interceptor_type { \
public: \
__RegisterInterceptor_##interceptor_type() { \
InterceptorFactory::Register(#interceptor_type, \
CreatorInterceptor_##interceptor_type); \
} \
}; \
__RegisterInterceptor_##interceptor_type g_register_##interceptor_type;
}
// namespace distributed
}
// namespace distributed
}
// namespace paddle
}
// namespace paddle
paddle/fluid/distributed/fleet_executor/test/interceptor_ping_pong_test.cc
浏览文件 @
f11e843a
...
@@ -51,6 +51,8 @@ class PingPongInterceptor : public Interceptor {
...
@@ -51,6 +51,8 @@ class PingPongInterceptor : public Interceptor {
int
count_
{
0
};
int
count_
{
0
};
};
};
REGISTER_INTERCEPTOR
(
PingPong
,
PingPongInterceptor
);
TEST
(
InterceptorTest
,
PingPong
)
{
TEST
(
InterceptorTest
,
PingPong
)
{
MessageBus
&
msg_bus
=
MessageBus
::
Instance
();
MessageBus
&
msg_bus
=
MessageBus
::
Instance
();
msg_bus
.
Init
({{
0
,
0
},
{
1
,
0
}},
{{
0
,
"127.0.0.0:0"
}},
"127.0.0.0:0"
);
msg_bus
.
Init
({{
0
,
0
},
{
1
,
0
}},
{{
0
,
"127.0.0.0:0"
}},
"127.0.0.0:0"
);
...
@@ -58,7 +60,8 @@ TEST(InterceptorTest, PingPong) {
...
@@ -58,7 +60,8 @@ TEST(InterceptorTest, PingPong) {
Carrier
&
carrier
=
Carrier
::
Instance
();
Carrier
&
carrier
=
Carrier
::
Instance
();
Interceptor
*
a
=
carrier
.
SetInterceptor
(
Interceptor
*
a
=
carrier
.
SetInterceptor
(
0
,
std
::
make_unique
<
PingPongInterceptor
>
(
0
,
nullptr
));
0
,
InterceptorFactory
::
Create
(
"PingPong"
,
0
,
nullptr
));
carrier
.
SetInterceptor
(
1
,
std
::
make_unique
<
PingPongInterceptor
>
(
1
,
nullptr
));
carrier
.
SetInterceptor
(
1
,
std
::
make_unique
<
PingPongInterceptor
>
(
1
,
nullptr
));
carrier
.
SetCreatingFlag
(
false
);
carrier
.
SetCreatingFlag
(
false
);
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录