Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
2273471d
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
2273471d
编写于
1月 04, 2022
作者:
L
LiYuRio
提交者:
GitHub
1月 04, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[fleet_executor] Support multi carriers (#38650)
上级
2d2609ea
变更
16
显示空白变更内容
内联
并排
Showing
16 changed file
with
126 addition
and
199 deletion
+126
-199
paddle/fluid/distributed/fleet_executor/carrier.cc
paddle/fluid/distributed/fleet_executor/carrier.cc
+6
-19
paddle/fluid/distributed/fleet_executor/carrier.h
paddle/fluid/distributed/fleet_executor/carrier.h
+5
-5
paddle/fluid/distributed/fleet_executor/fleet_executor.cc
paddle/fluid/distributed/fleet_executor/fleet_executor.cc
+19
-29
paddle/fluid/distributed/fleet_executor/fleet_executor.h
paddle/fluid/distributed/fleet_executor/fleet_executor.h
+5
-3
paddle/fluid/distributed/fleet_executor/global_map.h
paddle/fluid/distributed/fleet_executor/global_map.h
+47
-0
paddle/fluid/distributed/fleet_executor/interceptor_message_service.cc
...distributed/fleet_executor/interceptor_message_service.cc
+2
-8
paddle/fluid/distributed/fleet_executor/runtime_graph.h
paddle/fluid/distributed/fleet_executor/runtime_graph.h
+0
-11
paddle/fluid/distributed/fleet_executor/test/CMakeLists.txt
paddle/fluid/distributed/fleet_executor/test/CMakeLists.txt
+0
-3
paddle/fluid/distributed/fleet_executor/test/compute_interceptor_run_op_test.cc
...ed/fleet_executor/test/compute_interceptor_run_op_test.cc
+4
-2
paddle/fluid/distributed/fleet_executor/test/compute_interceptor_test.cc
...stributed/fleet_executor/test/compute_interceptor_test.cc
+4
-2
paddle/fluid/distributed/fleet_executor/test/interceptor_pass_the_parcel_test.cc
...d/fleet_executor/test/interceptor_pass_the_parcel_test.cc
+0
-101
paddle/fluid/distributed/fleet_executor/test/interceptor_ping_pong_test.cc
...ributed/fleet_executor/test/interceptor_ping_pong_test.cc
+4
-2
paddle/fluid/distributed/fleet_executor/test/interceptor_ping_pong_with_brpc_test.cc
...eet_executor/test/interceptor_ping_pong_with_brpc_test.cc
+9
-5
paddle/fluid/distributed/fleet_executor/test/interceptor_pipeline_long_path_test.cc
...leet_executor/test/interceptor_pipeline_long_path_test.cc
+4
-3
paddle/fluid/distributed/fleet_executor/test/interceptor_pipeline_short_path_test.cc
...eet_executor/test/interceptor_pipeline_short_path_test.cc
+4
-2
python/paddle/fluid/executor.py
python/paddle/fluid/executor.py
+13
-4
未找到文件。
paddle/fluid/distributed/fleet_executor/carrier.cc
浏览文件 @
2273471d
...
@@ -30,11 +30,9 @@ USE_INTERCEPTOR(Amplifier);
...
@@ -30,11 +30,9 @@ USE_INTERCEPTOR(Amplifier);
void
Carrier
::
Init
(
void
Carrier
::
Init
(
int64_t
rank
,
int64_t
rank
,
const
std
::
unordered_map
<
int64_t
,
int64_t
>&
interceptor_id_to_rank
,
const
std
::
unordered_map
<
int64_t
,
int64_t
>&
interceptor_id_to_rank
)
{
const
std
::
unordered_set
<
int64_t
>&
interceptor_ids
)
{
rank_
=
rank
;
rank_
=
rank
;
interceptor_id_to_rank_
=
interceptor_id_to_rank
;
interceptor_id_to_rank_
=
interceptor_id_to_rank
;
interceptor_ids_
=
interceptor_ids
;
// TODO(fleet_exe dev): thread pool
// TODO(fleet_exe dev): thread pool
thread_num_
=
1
;
thread_num_
=
1
;
...
@@ -45,14 +43,12 @@ void Carrier::Init(
...
@@ -45,14 +43,12 @@ void Carrier::Init(
void
Carrier
::
Init
(
void
Carrier
::
Init
(
int64_t
rank
,
int64_t
rank
,
const
std
::
unordered_map
<
int64_t
,
int64_t
>&
interceptor_id_to_rank
,
const
std
::
unordered_map
<
int64_t
,
int64_t
>&
interceptor_id_to_rank
,
const
std
::
unordered_set
<
int64_t
>&
interceptor_ids
,
const
std
::
unordered_map
<
int64_t
,
TaskNode
*>&
interceptor_id_to_node
,
const
std
::
unordered_map
<
int64_t
,
TaskNode
*>&
interceptor_id_to_node
,
framework
::
Scope
*
root_scope
,
framework
::
Scope
*
minibatch_scope
,
framework
::
Scope
*
root_scope
,
framework
::
Scope
*
minibatch_scope
,
const
std
::
vector
<
framework
::
Scope
*>&
microbatch_scopes
,
const
std
::
vector
<
framework
::
Scope
*>&
microbatch_scopes
,
const
platform
::
Place
&
place
)
{
const
platform
::
Place
&
place
)
{
rank_
=
rank
;
rank_
=
rank
;
interceptor_id_to_rank_
=
interceptor_id_to_rank
;
interceptor_id_to_rank_
=
interceptor_id_to_rank
;
interceptor_ids_
=
interceptor_ids
;
interceptor_id_to_node_
=
interceptor_id_to_node
;
interceptor_id_to_node_
=
interceptor_id_to_node
;
minibatch_scope_
=
minibatch_scope
;
minibatch_scope_
=
minibatch_scope
;
microbatch_scopes_
=
microbatch_scopes
;
microbatch_scopes_
=
microbatch_scopes
;
...
@@ -156,9 +152,7 @@ bool Carrier::Send(const InterceptorMessage& msg) {
...
@@ -156,9 +152,7 @@ bool Carrier::Send(const InterceptorMessage& msg) {
if
(
src_rank
==
dst_rank
)
{
if
(
src_rank
==
dst_rank
)
{
VLOG
(
3
)
<<
"Send a message from interceptor "
<<
src_id
VLOG
(
3
)
<<
"Send a message from interceptor "
<<
src_id
<<
" to interceptor "
<<
dst_id
<<
", which are in the same ranks."
;
<<
" to interceptor "
<<
dst_id
<<
", which are in the same ranks."
;
int64_t
carrier_id
=
*
GlobalMap
<
int64_t
,
int64_t
>::
Get
(
dst_id
);
return
EnqueueInterceptorMessage
(
msg
);
return
GlobalMap
<
int64_t
,
Carrier
>::
Get
(
carrier_id
)
->
EnqueueInterceptorMessage
(
msg
);
}
else
{
}
else
{
PADDLE_ENFORCE_NOT_NULL
(
PADDLE_ENFORCE_NOT_NULL
(
msg_bus_
.
get
(),
msg_bus_
.
get
(),
...
@@ -192,9 +186,6 @@ Interceptor* Carrier::SetInterceptor(int64_t interceptor_id,
...
@@ -192,9 +186,6 @@ Interceptor* Carrier::SetInterceptor(int64_t interceptor_id,
loop
,
platform
::
errors
::
Fatal
(
"thread task loop must not null"
));
loop
,
platform
::
errors
::
Fatal
(
"thread task loop must not null"
));
interceptor
->
RegisterTaskLoop
(
loop
);
interceptor
->
RegisterTaskLoop
(
loop
);
// TODO(liyurui): Using struct InterceptorID replace int64_t
GlobalMap
<
int64_t
,
int64_t
>::
Create
(
interceptor_id
,
carrier_id_
);
auto
*
ptr
=
interceptor
.
get
();
auto
*
ptr
=
interceptor
.
get
();
interceptor_idx_to_interceptor_
.
insert
(
interceptor_idx_to_interceptor_
.
insert
(
std
::
make_pair
(
interceptor_id
,
std
::
move
(
interceptor
)));
std
::
make_pair
(
interceptor_id
,
std
::
move
(
interceptor
)));
...
@@ -220,19 +211,15 @@ static std::shared_ptr<framework::GarbageCollector> GetGC(
...
@@ -220,19 +211,15 @@ static std::shared_ptr<framework::GarbageCollector> GetGC(
}
}
void
Carrier
::
CreateInterceptors
()
{
void
Carrier
::
CreateInterceptors
()
{
if
(
interceptor_id
s
_
.
empty
())
return
;
if
(
interceptor_id
_to_node
_
.
empty
())
return
;
auto
gc
=
GetGC
(
place_
);
auto
gc
=
GetGC
(
place_
);
// create each Interceptor
// create each Interceptor
// no auto init since there is no config
// no auto init since there is no config
for
(
int64_t
interceptor_id
:
interceptor_ids_
)
{
for
(
const
auto
&
item
:
interceptor_id_to_node_
)
{
const
auto
&
task_node_iter
=
interceptor_id_to_node_
.
find
(
interceptor_id
);
int64_t
interceptor_id
=
item
.
first
;
PADDLE_ENFORCE_NE
(
TaskNode
*
task_node
=
item
.
second
;
task_node_iter
,
interceptor_id_to_node_
.
end
(),
platform
::
errors
::
NotFound
(
"Can not find task node for interceptor %ld"
,
interceptor_id
));
TaskNode
*
task_node
=
task_node_iter
->
second
;
PADDLE_ENFORCE_LT
(
PADDLE_ENFORCE_LT
(
task_node
->
run_at_offset
(),
task_node
->
run_per_steps
(),
task_node
->
run_at_offset
(),
task_node
->
run_per_steps
(),
...
...
paddle/fluid/distributed/fleet_executor/carrier.h
浏览文件 @
2273471d
...
@@ -43,17 +43,17 @@ class InterceptorMessageServiceImpl;
...
@@ -43,17 +43,17 @@ class InterceptorMessageServiceImpl;
class
RuntimeGraph
;
class
RuntimeGraph
;
class
MessageBus
;
class
MessageBus
;
// TODO(liyurui): Add CarrierId instead of std::string
class
Carrier
final
{
class
Carrier
final
{
public:
public:
explicit
Carrier
(
int64_t
carrier_id
)
:
carrier_id_
(
carrier_id
)
{}
explicit
Carrier
(
const
std
::
string
&
carrier_id
)
:
carrier_id_
(
carrier_id
)
{}
~
Carrier
();
~
Carrier
();
void
Init
(
int64_t
rank
,
void
Init
(
int64_t
rank
,
const
std
::
unordered_map
<
int64_t
,
int64_t
>&
interceptor_id_to_rank
,
const
std
::
unordered_map
<
int64_t
,
int64_t
>&
interceptor_id_to_rank
);
const
std
::
unordered_set
<
int64_t
>&
interceptor_ids
);
void
Init
(
void
Init
(
int64_t
rank
,
int64_t
rank
,
const
std
::
unordered_map
<
int64_t
,
int64_t
>&
interceptor_id_to_rank
,
const
std
::
unordered_map
<
int64_t
,
int64_t
>&
interceptor_id_to_rank
,
const
std
::
unordered_set
<
int64_t
>&
interceptor_ids
,
const
std
::
unordered_map
<
int64_t
,
TaskNode
*>&
interceptor_id_to_node
,
const
std
::
unordered_map
<
int64_t
,
TaskNode
*>&
interceptor_id_to_node
,
framework
::
Scope
*
root_scope
,
framework
::
Scope
*
minibatch_scope
,
framework
::
Scope
*
root_scope
,
framework
::
Scope
*
minibatch_scope
,
const
std
::
vector
<
framework
::
Scope
*>&
microbatch_scopes
,
const
std
::
vector
<
framework
::
Scope
*>&
microbatch_scopes
,
...
@@ -109,7 +109,7 @@ class Carrier final {
...
@@ -109,7 +109,7 @@ class Carrier final {
paddle
::
platform
::
DeviceContext
*
dev_ctx_
{
nullptr
};
paddle
::
platform
::
DeviceContext
*
dev_ctx_
{
nullptr
};
std
::
shared_ptr
<
MessageBus
>
msg_bus_
;
std
::
shared_ptr
<
MessageBus
>
msg_bus_
;
int64_t
rank_
;
int64_t
rank_
;
int64_t
carrier_id_
;
std
::
string
carrier_id_
;
std
::
unordered_map
<
int64_t
,
TaskNode
*>
interceptor_id_to_node_
;
std
::
unordered_map
<
int64_t
,
TaskNode
*>
interceptor_id_to_node_
;
std
::
unordered_map
<
int64_t
,
int64_t
>
interceptor_id_to_rank_
;
std
::
unordered_map
<
int64_t
,
int64_t
>
interceptor_id_to_rank_
;
int
thread_num_
;
int
thread_num_
;
...
...
paddle/fluid/distributed/fleet_executor/fleet_executor.cc
浏览文件 @
2273471d
...
@@ -36,14 +36,15 @@ FleetExecutor::FleetExecutor(const std::string& exe_desc_str) {
...
@@ -36,14 +36,15 @@ FleetExecutor::FleetExecutor(const std::string& exe_desc_str) {
FleetExecutor
::~
FleetExecutor
()
{
FleetExecutor
::~
FleetExecutor
()
{
root_scope_
->
DropKids
();
root_scope_
->
DropKids
();
for
(
const
auto
&
item
:
runtime_graph_
->
carrier_id_to_interceptor_ids
()
)
{
for
(
const
auto
&
carrier_id
:
carrier_ids_
)
{
GlobalMap
<
int64_t
,
Carrier
>::
Get
(
item
.
first
)
->
Release
();
GlobalMap
<
std
::
string
,
Carrier
>::
Get
(
carrier_id
)
->
Release
();
}
}
}
}
void
FleetExecutor
::
Init
(
void
FleetExecutor
::
Init
(
const
framework
::
ProgramDesc
&
program_desc
,
framework
::
Scope
*
scope
,
const
std
::
string
&
carrier_id
,
const
framework
::
ProgramDesc
&
program_desc
,
const
platform
::
Place
&
place
,
const
std
::
vector
<
TaskNode
*>&
task_nodes
,
framework
::
Scope
*
scope
,
const
platform
::
Place
&
place
,
const
std
::
vector
<
TaskNode
*>&
task_nodes
,
const
std
::
unordered_map
<
int64_t
,
int64_t
>&
task_id_to_rank
)
{
const
std
::
unordered_map
<
int64_t
,
int64_t
>&
task_id_to_rank
)
{
PADDLE_ENFORCE_GT
(
task_nodes
.
size
(),
0
,
PADDLE_ENFORCE_GT
(
task_nodes
.
size
(),
0
,
platform
::
errors
::
InvalidArgument
(
platform
::
errors
::
InvalidArgument
(
...
@@ -58,19 +59,13 @@ void FleetExecutor::Init(
...
@@ -58,19 +59,13 @@ void FleetExecutor::Init(
auto
unused_vars
=
framework
::
GetUnusedVars
(
program_desc
.
Block
(
0
),
ops
,
{});
auto
unused_vars
=
framework
::
GetUnusedVars
(
program_desc
.
Block
(
0
),
ops
,
{});
runtime_graph_
=
std
::
make_shared
<
RuntimeGraph
>
();
runtime_graph_
=
std
::
make_shared
<
RuntimeGraph
>
();
std
::
unordered_map
<
int64_t
,
TaskNode
*>
interceptor_id_to_task
;
std
::
unordered_map
<
int64_t
,
TaskNode
*>
interceptor_id_to_task
;
std
::
unordered_map
<
int64_t
,
std
::
unordered_set
<
int64_t
>>
carrier_id_to_interceptor_ids
;
std
::
unordered_set
<
int64_t
>
interceptor_ids
;
for
(
auto
task_node
:
task_nodes
)
{
for
(
auto
task_node
:
task_nodes
)
{
task_node
->
SetUnusedVars
(
unused_vars
);
task_node
->
SetUnusedVars
(
unused_vars
);
int64_t
interceptor_id
=
task_node
->
task_id
();
int64_t
interceptor_id
=
task_node
->
task_id
();
interceptor_id_to_task
.
emplace
(
interceptor_id
,
task_node
);
interceptor_id_to_task
.
emplace
(
interceptor_id
,
task_node
);
interceptor_ids
.
insert
(
interceptor_id
);
}
}
carrier_id_to_interceptor_ids
.
emplace
(
0
,
interceptor_ids
);
runtime_graph_
->
SetInterceptorIdToRank
(
task_id_to_rank
);
runtime_graph_
->
SetInterceptorIdToRank
(
task_id_to_rank
);
runtime_graph_
->
SetInterceptorIdToNode
(
interceptor_id_to_task
);
runtime_graph_
->
SetInterceptorIdToNode
(
interceptor_id_to_task
);
runtime_graph_
->
SetCarrierIdToInterceptorIds
(
carrier_id_to_interceptor_ids
);
for
(
auto
&
unique_op
:
ops
)
{
for
(
auto
&
unique_op
:
ops
)
{
unique_op
.
release
();
unique_op
.
release
();
}
}
...
@@ -87,27 +82,23 @@ void FleetExecutor::Init(
...
@@ -87,27 +82,23 @@ void FleetExecutor::Init(
}
}
VLOG
(
5
)
<<
runtime_graph_
->
DebugString
();
VLOG
(
5
)
<<
runtime_graph_
->
DebugString
();
msg_bus_
=
std
::
make_shared
<
MessageBus
>
();
msg_bus_
=
std
::
make_shared
<
MessageBus
>
();
for
(
const
auto
&
item
:
runtime_graph_
->
carrier_id_to_interceptor_ids
())
{
Carrier
*
carrier
=
GlobalMap
<
int64_t
,
Carrier
>::
Create
(
item
.
first
,
item
.
first
);
GlobalMap
<
std
::
string
,
Carrier
>::
Create
(
carrier_id
,
carrier_id
);
}
carrier_ids_
.
insert
(
carrier_id
);
InitCarrier
();
GlobalVal
<
std
::
string
>::
Set
(
carrier_id
);
// TODO(liyurui): Maybe message bus should be created only once
InitCarrier
(
carrier
);
InitMessageBus
();
InitMessageBus
();
// Wait for all message bus connected.
// Wait for all message bus connected.
msg_bus_
->
Barrier
();
msg_bus_
->
Barrier
();
}
}
void
FleetExecutor
::
InitCarrier
()
{
void
FleetExecutor
::
InitCarrier
(
Carrier
*
carrier
)
{
for
(
const
auto
&
item
:
runtime_graph_
->
carrier_id_to_interceptor_ids
())
{
Carrier
*
carrier
=
GlobalMap
<
int64_t
,
Carrier
>::
Get
(
item
.
first
);
PADDLE_ENFORCE_NOT_NULL
(
carrier
,
platform
::
errors
::
InvalidArgument
(
"Carrier has not been created."
));
carrier
->
SetMsgBus
(
msg_bus_
);
carrier
->
SetMsgBus
(
msg_bus_
);
carrier
->
Init
(
exe_desc_
.
cur_rank
(),
carrier
->
Init
(
exe_desc_
.
cur_rank
(),
runtime_graph_
->
interceptor_id_to_rank
(),
runtime_graph_
->
interceptor_id_to_rank
(),
item
.
second
,
runtime_graph_
->
interceptor_id_to_node
(),
root_scope_
,
runtime_graph_
->
interceptor_id_to_node
(),
root_scope_
,
minibatch_scope_
,
microbatch_scopes_
,
place_
);
minibatch_scope_
,
microbatch_scopes_
,
place_
);
}
}
}
void
FleetExecutor
::
InitMessageBus
()
{
void
FleetExecutor
::
InitMessageBus
()
{
...
@@ -145,10 +136,9 @@ void FleetExecutor::InitMessageBus() {
...
@@ -145,10 +136,9 @@ void FleetExecutor::InitMessageBus() {
}
}
}
}
void
FleetExecutor
::
Run
()
{
void
FleetExecutor
::
Run
(
const
std
::
string
&
carrier_id
)
{
for
(
const
auto
&
item
:
runtime_graph_
->
carrier_id_to_interceptor_ids
())
{
GlobalMap
<
std
::
string
,
Carrier
>::
Get
(
carrier_id
)
->
Start
();
GlobalMap
<
int64_t
,
Carrier
>::
Get
(
item
.
first
)
->
Start
();
GlobalVal
<
std
::
string
>::
Set
(
carrier_id
);
}
for
(
auto
*
micro_scop
:
microbatch_scopes_
)
{
for
(
auto
*
micro_scop
:
microbatch_scopes_
)
{
// By default, we should delete all kid scopes after run executor because
// By default, we should delete all kid scopes after run executor because
// some operators may create local scope when running, such as while_op.
// some operators may create local scope when running, such as while_op.
...
...
paddle/fluid/distributed/fleet_executor/fleet_executor.h
浏览文件 @
2273471d
...
@@ -37,16 +37,17 @@ class FleetExecutor final {
...
@@ -37,16 +37,17 @@ class FleetExecutor final {
FleetExecutor
()
=
delete
;
FleetExecutor
()
=
delete
;
explicit
FleetExecutor
(
const
std
::
string
&
exe_desc_str
);
explicit
FleetExecutor
(
const
std
::
string
&
exe_desc_str
);
~
FleetExecutor
();
~
FleetExecutor
();
void
Init
(
const
framework
::
ProgramDesc
&
program_desc
,
framework
::
Scope
*
scope
,
void
Init
(
const
std
::
string
&
carrier_id
,
const
framework
::
ProgramDesc
&
program_desc
,
framework
::
Scope
*
scope
,
const
platform
::
Place
&
place
,
const
platform
::
Place
&
place
,
const
std
::
vector
<
TaskNode
*>&
task_nodes
,
const
std
::
vector
<
TaskNode
*>&
task_nodes
,
const
std
::
unordered_map
<
int64_t
,
int64_t
>&
task_id_to_rank
);
const
std
::
unordered_map
<
int64_t
,
int64_t
>&
task_id_to_rank
);
void
Run
();
void
Run
(
const
std
::
string
&
carrier_id
);
private:
private:
DISABLE_COPY_AND_ASSIGN
(
FleetExecutor
);
DISABLE_COPY_AND_ASSIGN
(
FleetExecutor
);
void
InitMessageBus
();
void
InitMessageBus
();
void
InitCarrier
();
void
InitCarrier
(
Carrier
*
carrier
);
void
CopyParameters
(
int
microbatch_id
,
const
framework
::
ProgramDesc
&
program
);
void
CopyParameters
(
int
microbatch_id
,
const
framework
::
ProgramDesc
&
program
);
FleetExecutorDesc
exe_desc_
;
FleetExecutorDesc
exe_desc_
;
std
::
shared_ptr
<
RuntimeGraph
>
runtime_graph_
;
std
::
shared_ptr
<
RuntimeGraph
>
runtime_graph_
;
...
@@ -57,6 +58,7 @@ class FleetExecutor final {
...
@@ -57,6 +58,7 @@ class FleetExecutor final {
// The carriers under FleetExecutor will share message bus,
// The carriers under FleetExecutor will share message bus,
// using shared_ptr to manage lifetime and condition race.
// using shared_ptr to manage lifetime and condition race.
std
::
shared_ptr
<
MessageBus
>
msg_bus_
;
std
::
shared_ptr
<
MessageBus
>
msg_bus_
;
std
::
unordered_set
<
std
::
string
>
carrier_ids_
;
};
};
}
// namespace distributed
}
// namespace distributed
...
...
paddle/fluid/distributed/fleet_executor/global_map.h
浏览文件 @
2273471d
...
@@ -17,6 +17,24 @@
...
@@ -17,6 +17,24 @@
namespace
paddle
{
namespace
paddle
{
namespace
distributed
{
namespace
distributed
{
// TODO(liyurui): Change this file to global.h
template
<
typename
T
>
class
GlobalVal
final
{
public:
static
T
Get
()
{
return
*
GetPtr
();
}
static
T
Set
(
T
val
)
{
auto
*
ptr
=
GetPtr
();
*
ptr
=
val
;
return
val
;
}
private:
static
T
*
GetPtr
()
{
static
T
value
;
return
&
value
;
}
};
template
<
typename
KeyT
,
typename
ValueT
>
template
<
typename
KeyT
,
typename
ValueT
>
class
GlobalMap
final
{
class
GlobalMap
final
{
public:
public:
...
@@ -26,6 +44,7 @@ class GlobalMap final {
...
@@ -26,6 +44,7 @@ class GlobalMap final {
item
,
platform
::
errors
::
NotFound
(
"This value is not in global map."
));
item
,
platform
::
errors
::
NotFound
(
"This value is not in global map."
));
return
item
;
return
item
;
}
}
template
<
typename
...
Args
>
template
<
typename
...
Args
>
static
ValueT
*
Create
(
KeyT
id
,
Args
&&
...
args
)
{
static
ValueT
*
Create
(
KeyT
id
,
Args
&&
...
args
)
{
auto
*
ptr
=
GetPPtr
(
id
);
auto
*
ptr
=
GetPPtr
(
id
);
...
@@ -37,6 +56,34 @@ class GlobalMap final {
...
@@ -37,6 +56,34 @@ class GlobalMap final {
return
item
;
return
item
;
}
}
private:
static
std
::
unique_ptr
<
ValueT
>*
GetPPtr
(
KeyT
id
)
{
static
std
::
unordered_map
<
KeyT
,
std
::
unique_ptr
<
ValueT
>>
id_to_ptr
;
return
&
id_to_ptr
[
id
];
}
};
template
<
typename
KeyT
,
typename
ValueT
>
class
ThreadSafeGlobalMap
final
{
public:
static
ValueT
*
Get
(
KeyT
id
)
{
ValueT
*
item
=
GetPPtr
(
id
)
->
get
();
PADDLE_ENFORCE_NOT_NULL
(
item
,
platform
::
errors
::
NotFound
(
"This value is not in thread safe global map."
));
return
item
;
}
template
<
typename
...
Args
>
static
ValueT
*
Create
(
KeyT
id
,
Args
&&
...
args
)
{
auto
*
ptr
=
GetPPtr
(
id
);
PADDLE_ENFORCE_EQ
(
ptr
->
get
(),
nullptr
,
platform
::
errors
::
AlreadyExists
(
"This value has already in thread safe global map."
));
ValueT
*
item
=
new
ValueT
(
std
::
forward
<
Args
>
(
args
)...);
ptr
->
reset
(
item
);
return
item
;
}
private:
private:
static
std
::
unique_ptr
<
ValueT
>*
GetPPtr
(
KeyT
id
)
{
static
std
::
unique_ptr
<
ValueT
>*
GetPPtr
(
KeyT
id
)
{
static
std
::
mutex
mutex
;
static
std
::
mutex
mutex
;
...
...
paddle/fluid/distributed/fleet_executor/interceptor_message_service.cc
浏览文件 @
2273471d
...
@@ -29,14 +29,8 @@ void InterceptorMessageServiceImpl::InterceptorMessageService(
...
@@ -29,14 +29,8 @@ void InterceptorMessageServiceImpl::InterceptorMessageService(
VLOG
(
3
)
<<
"Interceptor Message Service receives a message from interceptor "
VLOG
(
3
)
<<
"Interceptor Message Service receives a message from interceptor "
<<
request
->
src_id
()
<<
" to interceptor "
<<
request
->
dst_id
()
<<
request
->
src_id
()
<<
" to interceptor "
<<
request
->
dst_id
()
<<
", with the message: "
<<
request
->
message_type
();
<<
", with the message: "
<<
request
->
message_type
();
// TODO(liyurui): Remove this hard code.
const
auto
&
carrier_id
=
GlobalVal
<
std
::
string
>::
Get
();
int64_t
carrier_id
;
bool
flag
=
GlobalMap
<
std
::
string
,
Carrier
>::
Get
(
carrier_id
)
if
(
request
->
ctrl_message
())
{
carrier_id
=
0
;
}
else
{
carrier_id
=
*
GlobalMap
<
int64_t
,
int64_t
>::
Get
(
request
->
dst_id
());
}
bool
flag
=
GlobalMap
<
int64_t
,
Carrier
>::
Get
(
carrier_id
)
->
EnqueueInterceptorMessage
(
*
request
);
->
EnqueueInterceptorMessage
(
*
request
);
response
->
set_rst
(
flag
);
response
->
set_rst
(
flag
);
}
}
...
...
paddle/fluid/distributed/fleet_executor/runtime_graph.h
浏览文件 @
2273471d
...
@@ -35,10 +35,6 @@ class RuntimeGraph final {
...
@@ -35,10 +35,6 @@ class RuntimeGraph final {
const
std
::
unordered_map
<
int64_t
,
int64_t
>&
interceptor_id_to_rank
()
const
{
const
std
::
unordered_map
<
int64_t
,
int64_t
>&
interceptor_id_to_rank
()
const
{
return
interceptor_id_to_rank_
;
return
interceptor_id_to_rank_
;
}
}
const
std
::
unordered_map
<
int64_t
,
std
::
unordered_set
<
int64_t
>>&
carrier_id_to_interceptor_ids
()
const
{
return
carrier_id_to_interceptor_ids_
;
}
void
SetInterceptorIdToRank
(
void
SetInterceptorIdToRank
(
const
std
::
unordered_map
<
int64_t
,
int64_t
>&
interceptor_id_to_rank
)
{
const
std
::
unordered_map
<
int64_t
,
int64_t
>&
interceptor_id_to_rank
)
{
interceptor_id_to_rank_
=
interceptor_id_to_rank
;
interceptor_id_to_rank_
=
interceptor_id_to_rank
;
...
@@ -47,19 +43,12 @@ class RuntimeGraph final {
...
@@ -47,19 +43,12 @@ class RuntimeGraph final {
const
std
::
unordered_map
<
int64_t
,
TaskNode
*>&
interceptor_id_to_node
)
{
const
std
::
unordered_map
<
int64_t
,
TaskNode
*>&
interceptor_id_to_node
)
{
interceptor_id_to_node_
=
interceptor_id_to_node
;
interceptor_id_to_node_
=
interceptor_id_to_node
;
}
}
void
SetCarrierIdToInterceptorIds
(
const
std
::
unordered_map
<
int64_t
,
std
::
unordered_set
<
int64_t
>>&
carrier_id_to_interceptor_ids
)
{
carrier_id_to_interceptor_ids_
=
carrier_id_to_interceptor_ids
;
}
std
::
string
DebugString
()
const
;
std
::
string
DebugString
()
const
;
private:
private:
DISABLE_COPY_AND_ASSIGN
(
RuntimeGraph
);
DISABLE_COPY_AND_ASSIGN
(
RuntimeGraph
);
std
::
unordered_map
<
int64_t
,
TaskNode
*>
interceptor_id_to_node_
;
std
::
unordered_map
<
int64_t
,
TaskNode
*>
interceptor_id_to_node_
;
std
::
unordered_map
<
int64_t
,
int64_t
>
interceptor_id_to_rank_
;
std
::
unordered_map
<
int64_t
,
int64_t
>
interceptor_id_to_rank_
;
std
::
unordered_map
<
int64_t
,
std
::
unordered_set
<
int64_t
>>
carrier_id_to_interceptor_ids_
;
};
};
}
// namespace distributed
}
// namespace distributed
...
...
paddle/fluid/distributed/fleet_executor/test/CMakeLists.txt
浏览文件 @
2273471d
...
@@ -13,9 +13,6 @@ cc_test(interceptor_pipeline_long_path_test SRCS interceptor_pipeline_long_path_
...
@@ -13,9 +13,6 @@ cc_test(interceptor_pipeline_long_path_test SRCS interceptor_pipeline_long_path_
set_source_files_properties
(
compute_interceptor_run_op_test.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
set_source_files_properties
(
compute_interceptor_run_op_test.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
cc_test
(
compute_interceptor_run_op_test SRCS compute_interceptor_run_op_test.cc DEPS fleet_executor
${
BRPC_DEPS
}
op_registry fill_constant_op elementwise_add_op scope device_context
)
cc_test
(
compute_interceptor_run_op_test SRCS compute_interceptor_run_op_test.cc DEPS fleet_executor
${
BRPC_DEPS
}
op_registry fill_constant_op elementwise_add_op scope device_context
)
set_source_files_properties
(
interceptor_pass_the_parcel_test.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
cc_test
(
interceptor_pass_the_parcel_test SRCS interceptor_pass_the_parcel_test.cc DEPS fleet_executor
${
BRPC_DEPS
}
)
if
(
WITH_DISTRIBUTE AND WITH_PSCORE AND
NOT
(
WITH_ASCEND OR WITH_ASCEND_CL
))
if
(
WITH_DISTRIBUTE AND WITH_PSCORE AND
NOT
(
WITH_ASCEND OR WITH_ASCEND_CL
))
set_source_files_properties
(
interceptor_ping_pong_with_brpc_test.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
set_source_files_properties
(
interceptor_ping_pong_with_brpc_test.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
cc_test
(
interceptor_ping_pong_with_brpc_test SRCS interceptor_ping_pong_with_brpc_test.cc DEPS fleet_executor
${
BRPC_DEPS
}
)
cc_test
(
interceptor_ping_pong_with_brpc_test SRCS interceptor_ping_pong_with_brpc_test.cc DEPS fleet_executor
${
BRPC_DEPS
}
)
...
...
paddle/fluid/distributed/fleet_executor/test/compute_interceptor_run_op_test.cc
浏览文件 @
2273471d
...
@@ -62,8 +62,10 @@ TEST(ComputeInterceptor, Compute) {
...
@@ -62,8 +62,10 @@ TEST(ComputeInterceptor, Compute) {
std
::
vector
<
framework
::
Scope
*>
scopes
=
{
scope
,
scope
};
std
::
vector
<
framework
::
Scope
*>
scopes
=
{
scope
,
scope
};
platform
::
Place
place
=
platform
::
CPUPlace
();
platform
::
Place
place
=
platform
::
CPUPlace
();
Carrier
*
carrier
=
GlobalMap
<
int64_t
,
Carrier
>::
Create
(
0
,
0
);
std
::
string
carrier_id
=
"0"
;
carrier
->
Init
(
0
,
{{
0
,
0
},
{
1
,
0
}},
{
0
,
1
});
Carrier
*
carrier
=
GlobalMap
<
std
::
string
,
Carrier
>::
Create
(
carrier_id
,
carrier_id
);
carrier
->
Init
(
0
,
{{
0
,
0
},
{
1
,
0
}});
auto
msg_bus
=
std
::
make_shared
<
MessageBus
>
();
auto
msg_bus
=
std
::
make_shared
<
MessageBus
>
();
msg_bus
->
Init
(
0
,
{{
0
,
"127.0.0.0:0"
}},
""
);
msg_bus
->
Init
(
0
,
{{
0
,
"127.0.0.0:0"
}},
""
);
...
...
paddle/fluid/distributed/fleet_executor/test/compute_interceptor_test.cc
浏览文件 @
2273471d
...
@@ -47,8 +47,10 @@ class StartInterceptor : public Interceptor {
...
@@ -47,8 +47,10 @@ class StartInterceptor : public Interceptor {
};
};
TEST
(
ComputeInterceptor
,
Compute
)
{
TEST
(
ComputeInterceptor
,
Compute
)
{
Carrier
*
carrier
=
GlobalMap
<
int64_t
,
Carrier
>::
Create
(
0
,
0
);
std
::
string
carrier_id
=
"0"
;
carrier
->
Init
(
0
,
{{
0
,
0
},
{
1
,
0
},
{
2
,
0
}},
{
0
,
1
,
2
});
Carrier
*
carrier
=
GlobalMap
<
std
::
string
,
Carrier
>::
Create
(
carrier_id
,
carrier_id
);
carrier
->
Init
(
0
,
{{
0
,
0
},
{
1
,
0
},
{
2
,
0
}});
auto
msg_bus
=
std
::
make_shared
<
MessageBus
>
();
auto
msg_bus
=
std
::
make_shared
<
MessageBus
>
();
msg_bus
->
Init
(
0
,
{{
0
,
"127.0.0.0:0"
}},
""
);
msg_bus
->
Init
(
0
,
{{
0
,
"127.0.0.0:0"
}},
""
);
...
...
paddle/fluid/distributed/fleet_executor/test/interceptor_pass_the_parcel_test.cc
已删除
100644 → 0
浏览文件 @
2d2609ea
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <iostream>
#include <unordered_map>
#include "gtest/gtest.h"
#include "paddle/fluid/distributed/fleet_executor/carrier.h"
#include "paddle/fluid/distributed/fleet_executor/global_map.h"
#include "paddle/fluid/distributed/fleet_executor/interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/message_bus.h"
namespace
paddle
{
namespace
distributed
{
class
ParcelInterceptor
:
public
Interceptor
{
public:
ParcelInterceptor
(
int64_t
interceptor_id
,
TaskNode
*
node
)
:
Interceptor
(
interceptor_id
,
node
)
{
RegisterMsgHandle
(
[
this
](
const
InterceptorMessage
&
msg
)
{
PassParcel
(
msg
);
});
}
void
PassParcel
(
const
InterceptorMessage
&
msg
)
{
if
(
msg
.
message_type
()
==
STOP
)
{
stop_
=
true
;
return
;
}
std
::
cout
<<
GetInterceptorId
()
<<
" recv msg, count="
<<
count_
<<
std
::
endl
;
if
(
count_
==
5
&&
interceptor_id_
==
0
)
{
InterceptorMessage
stop
;
stop
.
set_message_type
(
STOP
);
Send
(
0
,
stop
);
Send
(
1
,
stop
);
Send
(
2
,
stop
);
Send
(
3
,
stop
);
StopCarrier
();
return
;
}
++
count_
;
InterceptorMessage
new_msg
;
if
(
msg
.
dst_id
()
==
3
)
{
Send
(
0
,
new_msg
);
}
else
{
Send
(
msg
.
dst_id
()
+
1
,
new_msg
);
}
}
private:
int
count_
{
0
};
};
REGISTER_INTERCEPTOR
(
Parcel
,
ParcelInterceptor
);
TEST
(
InterceptorTest
,
PassTheParcel
)
{
auto
msg_bus
=
std
::
make_shared
<
MessageBus
>
();
Carrier
*
carrier_0
=
GlobalMap
<
int64_t
,
Carrier
>::
Create
(
0
,
0
);
carrier_0
->
Init
(
0
,
{{
0
,
0
},
{
1
,
0
},
{
2
,
0
},
{
3
,
0
}},
{
0
});
carrier_0
->
SetMsgBus
(
msg_bus
);
Carrier
*
carrier_1
=
GlobalMap
<
int64_t
,
Carrier
>::
Create
(
1
,
1
);
carrier_1
->
Init
(
0
,
{{
0
,
0
},
{
1
,
0
},
{
2
,
0
},
{
3
,
0
}},
{
1
});
carrier_1
->
SetMsgBus
(
msg_bus
);
Carrier
*
carrier_2
=
GlobalMap
<
int64_t
,
Carrier
>::
Create
(
2
,
2
);
carrier_2
->
Init
(
0
,
{{
0
,
0
},
{
1
,
0
},
{
2
,
0
},
{
3
,
0
}},
{
2
});
carrier_2
->
SetMsgBus
(
msg_bus
);
Carrier
*
carrier_3
=
GlobalMap
<
int64_t
,
Carrier
>::
Create
(
3
,
3
);
carrier_3
->
Init
(
0
,
{{
0
,
0
},
{
1
,
0
},
{
2
,
0
},
{
3
,
0
}},
{
3
});
carrier_3
->
SetMsgBus
(
msg_bus
);
msg_bus
->
Init
(
0
,
{{
0
,
"127.0.0.0:0"
}},
""
);
Interceptor
*
a
=
carrier_0
->
SetInterceptor
(
0
,
InterceptorFactory
::
Create
(
"Parcel"
,
0
,
nullptr
));
carrier_1
->
SetInterceptor
(
1
,
InterceptorFactory
::
Create
(
"Parcel"
,
1
,
nullptr
));
carrier_2
->
SetInterceptor
(
2
,
InterceptorFactory
::
Create
(
"Parcel"
,
2
,
nullptr
));
carrier_3
->
SetInterceptor
(
3
,
InterceptorFactory
::
Create
(
"Parcel"
,
3
,
nullptr
));
InterceptorMessage
msg
;
a
->
Send
(
1
,
msg
);
carrier_0
->
Wait
();
}
}
// namespace distributed
}
// namespace paddle
paddle/fluid/distributed/fleet_executor/test/interceptor_ping_pong_test.cc
浏览文件 @
2273471d
...
@@ -60,8 +60,10 @@ class PingPongInterceptor : public Interceptor {
...
@@ -60,8 +60,10 @@ class PingPongInterceptor : public Interceptor {
REGISTER_INTERCEPTOR
(
PingPong
,
PingPongInterceptor
);
REGISTER_INTERCEPTOR
(
PingPong
,
PingPongInterceptor
);
TEST
(
InterceptorTest
,
PingPong
)
{
TEST
(
InterceptorTest
,
PingPong
)
{
Carrier
*
carrier
=
GlobalMap
<
int64_t
,
Carrier
>::
Create
(
0
,
0
);
std
::
string
carrier_id
=
"0"
;
carrier
->
Init
(
0
,
{{
0
,
0
},
{
1
,
0
}},
{
0
,
1
});
Carrier
*
carrier
=
GlobalMap
<
std
::
string
,
Carrier
>::
Create
(
carrier_id
,
carrier_id
);
carrier
->
Init
(
0
,
{{
0
,
0
},
{
1
,
0
}});
auto
msg_bus
=
std
::
make_shared
<
MessageBus
>
();
auto
msg_bus
=
std
::
make_shared
<
MessageBus
>
();
msg_bus
->
Init
(
0
,
{{
0
,
"127.0.0.0:0"
}},
""
);
msg_bus
->
Init
(
0
,
{{
0
,
"127.0.0.0:0"
}},
""
);
carrier
->
SetMsgBus
(
msg_bus
);
carrier
->
SetMsgBus
(
msg_bus
);
...
...
paddle/fluid/distributed/fleet_executor/test/interceptor_ping_pong_with_brpc_test.cc
浏览文件 @
2273471d
...
@@ -106,16 +106,18 @@ TEST(InterceptorTest, PingPong) {
...
@@ -106,16 +106,18 @@ TEST(InterceptorTest, PingPong) {
std
::
cout
<<
"ip1: "
<<
ip1
<<
std
::
endl
;
std
::
cout
<<
"ip1: "
<<
ip1
<<
std
::
endl
;
std
::
unordered_map
<
int64_t
,
int64_t
>
interceptor_id_to_rank
=
{{
0
,
0
},
std
::
unordered_map
<
int64_t
,
int64_t
>
interceptor_id_to_rank
=
{{
0
,
0
},
{
1
,
1
}};
{
1
,
1
}};
std
::
string
carrier_id
=
"0"
;
int
pid
=
fork
();
int
pid
=
fork
();
if
(
pid
==
0
)
{
if
(
pid
==
0
)
{
Carrier
*
carrier
=
GlobalMap
<
int64_t
,
Carrier
>::
Create
(
0
,
0
);
Carrier
*
carrier
=
GlobalMap
<
std
::
string
,
Carrier
>::
Create
(
carrier_id
,
carrier_id
);
GlobalVal
<
std
::
string
>::
Set
(
carrier_id
);
auto
msg_bus
=
std
::
make_shared
<
MessageBus
>
();
auto
msg_bus
=
std
::
make_shared
<
MessageBus
>
();
carrier
->
SetMsgBus
(
msg_bus
);
carrier
->
SetMsgBus
(
msg_bus
);
// NOTE: need Init msg_bus after carrier SetMsgBus
// NOTE: need Init msg_bus after carrier SetMsgBus
carrier
->
Init
(
0
,
interceptor_id_to_rank
,
{
0
}
);
carrier
->
Init
(
0
,
interceptor_id_to_rank
);
msg_bus
->
Init
(
0
,
{{
0
,
ip0
},
{
1
,
ip1
}},
ip0
);
msg_bus
->
Init
(
0
,
{{
0
,
ip0
},
{
1
,
ip1
}},
ip0
);
carrier
->
SetMsgBus
(
msg_bus
);
Interceptor
*
a
=
carrier
->
SetInterceptor
(
Interceptor
*
a
=
carrier
->
SetInterceptor
(
0
,
InterceptorFactory
::
Create
(
"PingPong"
,
0
,
nullptr
));
0
,
InterceptorFactory
::
Create
(
"PingPong"
,
0
,
nullptr
));
msg_bus
->
Barrier
();
msg_bus
->
Barrier
();
...
@@ -123,10 +125,12 @@ TEST(InterceptorTest, PingPong) {
...
@@ -123,10 +125,12 @@ TEST(InterceptorTest, PingPong) {
a
->
Send
(
1
,
msg
);
a
->
Send
(
1
,
msg
);
carrier
->
Wait
();
carrier
->
Wait
();
}
else
{
}
else
{
Carrier
*
carrier
=
GlobalMap
<
int64_t
,
Carrier
>::
Create
(
0
,
0
);
Carrier
*
carrier
=
GlobalMap
<
std
::
string
,
Carrier
>::
Create
(
carrier_id
,
carrier_id
);
GlobalVal
<
std
::
string
>::
Set
(
carrier_id
);
auto
msg_bus
=
std
::
make_shared
<
MessageBus
>
();
auto
msg_bus
=
std
::
make_shared
<
MessageBus
>
();
carrier
->
SetMsgBus
(
msg_bus
);
carrier
->
SetMsgBus
(
msg_bus
);
carrier
->
Init
(
1
,
interceptor_id_to_rank
,
{
1
}
);
carrier
->
Init
(
1
,
interceptor_id_to_rank
);
msg_bus
->
Init
(
1
,
{{
0
,
ip0
},
{
1
,
ip1
}},
ip1
);
msg_bus
->
Init
(
1
,
{{
0
,
ip0
},
{
1
,
ip1
}},
ip1
);
carrier
->
SetInterceptor
(
1
,
carrier
->
SetInterceptor
(
1
,
InterceptorFactory
::
Create
(
"PingPong"
,
1
,
nullptr
));
InterceptorFactory
::
Create
(
"PingPong"
,
1
,
nullptr
));
...
...
paddle/fluid/distributed/fleet_executor/test/interceptor_pipeline_long_path_test.cc
浏览文件 @
2273471d
...
@@ -52,9 +52,10 @@ void LinkNodes(const std::vector<TaskNode*>& nodes) {
...
@@ -52,9 +52,10 @@ void LinkNodes(const std::vector<TaskNode*>& nodes) {
}
}
TEST
(
AmplifierInterceptor
,
Amplifier
)
{
TEST
(
AmplifierInterceptor
,
Amplifier
)
{
Carrier
*
carrier
=
GlobalMap
<
int64_t
,
Carrier
>::
Create
(
0
,
0
);
std
::
string
carrier_id
=
"0"
;
carrier
->
Init
(
0
,
{{
0
,
0
},
{
1
,
0
},
{
2
,
0
},
{
3
,
0
},
{
4
,
0
},
{
5
,
0
}},
Carrier
*
carrier
=
{
0
,
1
,
2
,
3
,
4
,
5
});
GlobalMap
<
std
::
string
,
Carrier
>::
Create
(
carrier_id
,
carrier_id
);
carrier
->
Init
(
0
,
{{
0
,
0
},
{
1
,
0
},
{
2
,
0
},
{
3
,
0
},
{
4
,
0
},
{
5
,
0
}});
auto
msg_bus
=
std
::
make_shared
<
MessageBus
>
();
auto
msg_bus
=
std
::
make_shared
<
MessageBus
>
();
msg_bus
->
Init
(
0
,
{{
0
,
"127.0.0.0:0"
}},
"127.0.0.0:0"
);
msg_bus
->
Init
(
0
,
{{
0
,
"127.0.0.0:0"
}},
"127.0.0.0:0"
);
carrier
->
SetMsgBus
(
msg_bus
);
carrier
->
SetMsgBus
(
msg_bus
);
...
...
paddle/fluid/distributed/fleet_executor/test/interceptor_pipeline_short_path_test.cc
浏览文件 @
2273471d
...
@@ -70,8 +70,10 @@ void LinkNodes(const std::vector<TaskNode*>& nodes,
...
@@ -70,8 +70,10 @@ void LinkNodes(const std::vector<TaskNode*>& nodes,
}
}
TEST
(
AmplifierInterceptor
,
Amplifier
)
{
TEST
(
AmplifierInterceptor
,
Amplifier
)
{
Carrier
*
carrier
=
GlobalMap
<
int64_t
,
Carrier
>::
Create
(
0
,
0
);
std
::
string
carrier_id
=
"0"
;
carrier
->
Init
(
0
,
{{
0
,
0
},
{
1
,
0
},
{
2
,
0
},
{
3
,
0
}},
{
0
,
1
,
2
,
3
});
Carrier
*
carrier
=
GlobalMap
<
std
::
string
,
Carrier
>::
Create
(
carrier_id
,
carrier_id
);
carrier
->
Init
(
0
,
{{
0
,
0
},
{
1
,
0
},
{
2
,
0
},
{
3
,
0
}});
auto
msg_bus
=
std
::
make_shared
<
MessageBus
>
();
auto
msg_bus
=
std
::
make_shared
<
MessageBus
>
();
msg_bus
->
Init
(
0
,
{{
0
,
""
}},
""
);
msg_bus
->
Init
(
0
,
{{
0
,
""
}},
""
);
carrier
->
SetMsgBus
(
msg_bus
);
carrier
->
SetMsgBus
(
msg_bus
);
...
...
python/paddle/fluid/executor.py
浏览文件 @
2273471d
...
@@ -1956,7 +1956,11 @@ class Executor(object):
...
@@ -1956,7 +1956,11 @@ class Executor(object):
return
ctx
return
ctx
def
_prepare_fleet_executor
(
self
,
program
=
None
,
scope
=
None
,
fleet_opt
=
None
):
def
_prepare_fleet_executor
(
self
,
carrier_id
=
""
,
program
=
None
,
scope
=
None
,
fleet_opt
=
None
):
from
..distributed.fleet.proto
import
fleet_executor_desc_pb2
from
..distributed.fleet.proto
import
fleet_executor_desc_pb2
assert
program
,
"Program for fleet executor should not be None"
assert
program
,
"Program for fleet executor should not be None"
assert
fleet_opt
,
"Configurations for fleet executor should not be None"
assert
fleet_opt
,
"Configurations for fleet executor should not be None"
...
@@ -2014,7 +2018,8 @@ class Executor(object):
...
@@ -2014,7 +2018,8 @@ class Executor(object):
fleet_exe
=
core
.
FleetExecutor
(
fleet_exe_desc
.
SerializeToString
())
fleet_exe
=
core
.
FleetExecutor
(
fleet_exe_desc
.
SerializeToString
())
place
=
core
.
Place
()
place
=
core
.
Place
()
place
.
set_place
(
self
.
place
)
place
.
set_place
(
self
.
place
)
fleet_exe
.
init
(
program
.
desc
,
scope
,
place
,
tasks
,
task_id_to_rank
)
fleet_exe
.
init
(
carrier_id
,
program
.
desc
,
scope
,
place
,
tasks
,
task_id_to_rank
)
return
fleet_exe
return
fleet_exe
def
_run_using_fleet_executor
(
self
,
def
_run_using_fleet_executor
(
self
,
...
@@ -2023,6 +2028,7 @@ class Executor(object):
...
@@ -2023,6 +2028,7 @@ class Executor(object):
feed_var_name
=
"feed"
,
feed_var_name
=
"feed"
,
fetch_var_name
=
"fetch"
,
fetch_var_name
=
"fetch"
,
fetch_list
=
None
):
fetch_list
=
None
):
# TODO(liyurui): Change cache strategy for multi carriers
cache_key
=
_get_strong_program_cache_key
(
program
,
feed
,
fetch_list
)
cache_key
=
_get_strong_program_cache_key
(
program
,
feed
,
fetch_list
)
cached_ctx
=
self
.
_get_ctx_cache
(
cache_key
)
cached_ctx
=
self
.
_get_ctx_cache
(
cache_key
)
cached_scope
=
self
.
_get_scope_cache
(
cache_key
)
cached_scope
=
self
.
_get_scope_cache
(
cache_key
)
...
@@ -2088,7 +2094,10 @@ class Executor(object):
...
@@ -2088,7 +2094,10 @@ class Executor(object):
fetch_task
.
set_program
(
fetch_program
)
fetch_task
.
set_program
(
fetch_program
)
cached_ctx
=
self
.
_prepare_fleet_executor
(
cached_ctx
=
self
.
_prepare_fleet_executor
(
program
=
cached_program
,
scope
=
cached_scope
,
fleet_opt
=
fleet_opt
)
cache_key
,
program
=
cached_program
,
scope
=
cached_scope
,
fleet_opt
=
fleet_opt
)
self
.
_add_ctx_cache
(
cache_key
,
cached_ctx
)
self
.
_add_ctx_cache
(
cache_key
,
cached_ctx
)
if
feed
:
if
feed
:
# NOTE: don't have to traverse programs in task nodes,
# NOTE: don't have to traverse programs in task nodes,
...
@@ -2107,7 +2116,7 @@ class Executor(object):
...
@@ -2107,7 +2116,7 @@ class Executor(object):
lr_sheduler
.
_var_name
)
lr_sheduler
.
_var_name
)
tensor
.
set
(
data
,
self
.
place
)
tensor
.
set
(
data
,
self
.
place
)
cached_ctx
.
run
()
cached_ctx
.
run
(
cache_key
)
if
fetch_list
:
if
fetch_list
:
arr
=
cached_scope
.
find_var
(
fetch_var_name
).
get_fetch_list
()
arr
=
cached_scope
.
find_var
(
fetch_var_name
).
get_fetch_list
()
tensors
=
arr
.
_move_to_list
()
tensors
=
arr
.
_move_to_list
()
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录