Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
ca088f92
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
ca088f92
编写于
11月 19, 2021
作者:
Y
Yuang Liu
提交者:
GitHub
11月 19, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[fleet_executor] Parse pipeline config (#37319)
上级
f11e843a
变更
8
显示空白变更内容
内联
并排
Showing
8 changed file
with
68 addition
and
18 deletion
+68
-18
paddle/fluid/distributed/fleet_executor/fleet_executor_desc.proto
...luid/distributed/fleet_executor/fleet_executor_desc.proto
+2
-0
paddle/fluid/distributed/fleet_executor/interceptor.cc
paddle/fluid/distributed/fleet_executor/interceptor.cc
+3
-1
paddle/fluid/distributed/fleet_executor/interceptor.h
paddle/fluid/distributed/fleet_executor/interceptor.h
+3
-0
paddle/fluid/distributed/fleet_executor/runtime_graph.cc
paddle/fluid/distributed/fleet_executor/runtime_graph.cc
+19
-4
paddle/fluid/distributed/fleet_executor/task_node.cc
paddle/fluid/distributed/fleet_executor/task_node.cc
+23
-8
paddle/fluid/distributed/fleet_executor/task_node.h
paddle/fluid/distributed/fleet_executor/task_node.h
+11
-4
python/paddle/fluid/executor.py
python/paddle/fluid/executor.py
+2
-0
python/paddle/fluid/tests/unittests/test_fleet_executor_multi_devices.py
...luid/tests/unittests/test_fleet_executor_multi_devices.py
+5
-1
未找到文件。
paddle/fluid/distributed/fleet_executor/fleet_executor_desc.proto
浏览文件 @
ca088f92
...
@@ -27,4 +27,6 @@ message FleetExecutorDesc {
...
@@ -27,4 +27,6 @@ message FleetExecutorDesc {
optional
int32
dp_degree
=
4
[
default
=
1
];
optional
int32
dp_degree
=
4
[
default
=
1
];
optional
int32
mp_degree
=
5
[
default
=
1
];
optional
int32
mp_degree
=
5
[
default
=
1
];
optional
int32
pp_degree
=
6
[
default
=
1
];
optional
int32
pp_degree
=
6
[
default
=
1
];
optional
int64
num_micro_batches
=
7
[
default
=
1
];
optional
int64
num_slots
=
8
[
default
=
1
];
}
}
paddle/fluid/distributed/fleet_executor/interceptor.cc
浏览文件 @
ca088f92
...
@@ -14,6 +14,7 @@
...
@@ -14,6 +14,7 @@
#include "paddle/fluid/distributed/fleet_executor/interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/message_bus.h"
#include "paddle/fluid/distributed/fleet_executor/message_bus.h"
#include "paddle/fluid/distributed/fleet_executor/task_node.h"
namespace
paddle
{
namespace
paddle
{
namespace
distributed
{
namespace
distributed
{
...
@@ -21,7 +22,8 @@ namespace distributed {
...
@@ -21,7 +22,8 @@ namespace distributed {
Interceptor
::
Interceptor
(
int64_t
interceptor_id
,
TaskNode
*
node
)
Interceptor
::
Interceptor
(
int64_t
interceptor_id
,
TaskNode
*
node
)
:
interceptor_id_
(
interceptor_id
),
node_
(
node
)
{
:
interceptor_id_
(
interceptor_id
),
node_
(
node
)
{
interceptor_thread_
=
std
::
thread
([
this
]()
{
interceptor_thread_
=
std
::
thread
([
this
]()
{
VLOG
(
3
)
<<
"Start pooling local mailbox's thread."
;
VLOG
(
3
)
<<
"Interceptor "
<<
interceptor_id_
<<
" starts the thread pooling it's local mailbox."
;
PoolTheMailbox
();
PoolTheMailbox
();
});
});
}
}
...
...
paddle/fluid/distributed/fleet_executor/interceptor.h
浏览文件 @
ca088f92
...
@@ -96,6 +96,9 @@ class Interceptor {
...
@@ -96,6 +96,9 @@ class Interceptor {
// local mailbox, written by FetchRemoteMailbox()
// local mailbox, written by FetchRemoteMailbox()
// read by PoolTheMailbox()
// read by PoolTheMailbox()
std
::
queue
<
InterceptorMessage
>
local_mailbox_
;
std
::
queue
<
InterceptorMessage
>
local_mailbox_
;
int64_t
already_run_times_
{
0
};
int64_t
used_slot_nums_
{
0
};
};
};
class
InterceptorFactory
{
class
InterceptorFactory
{
...
...
paddle/fluid/distributed/fleet_executor/runtime_graph.cc
浏览文件 @
ca088f92
...
@@ -136,16 +136,31 @@ void RuntimeGraph::SplitProgramBasedFunctionality(const ProgramDesc& program) {
...
@@ -136,16 +136,31 @@ void RuntimeGraph::SplitProgramBasedFunctionality(const ProgramDesc& program) {
role_to_ops
.
at
(
new_op_role_id
).
emplace_back
(
op
.
get
());
role_to_ops
.
at
(
new_op_role_id
).
emplace_back
(
op
.
get
());
}
}
int64_t
cur_rank
=
exe_desc_
.
cur_rank
();
int64_t
cur_rank
=
exe_desc_
.
cur_rank
();
DistCoordSys
coord_sys
(
exe_desc_
.
dp_degree
(),
exe_desc_
.
pp_degree
(),
exe_desc_
.
mp_degree
());
const
auto
&
coord
=
coord_sys
.
RankToCoord
(
cur_rank
);
int
pipeline_stage
=
coord
.
pp_idx
;
int64_t
num_pipeline_stages
=
exe_desc_
.
pp_degree
();
// TODO(fleet_executor dev): start up steps should be a config `num_slots`
int64_t
start_up_steps
=
num_pipeline_stages
-
pipeline_stage
-
1
;
int64_t
num_micro_batches
=
exe_desc_
.
num_micro_batches
();
int64_t
task_id
=
cur_rank
*
functionality_order
.
size
();
int64_t
task_id
=
cur_rank
*
functionality_order
.
size
();
for
(
std
::
size_t
i
=
0
;
i
<
functionality_order
.
size
();
++
i
)
{
for
(
std
::
size_t
i
=
0
;
i
<
functionality_order
.
size
();
++
i
)
{
OpRole
role
=
functionality_order
[
i
];
OpRole
role
=
functionality_order
[
i
];
int64_t
role_id
=
static_cast
<
int64_t
>
(
role
);
int64_t
role_id
=
static_cast
<
int64_t
>
(
role
);
int64_t
max_run_times
=
num_micro_batches
;
int64_t
max_slot_nums
=
start_up_steps
;
if
(
IsLRSched
(
role_id
)
||
IsOptimize
(
role_id
))
{
max_run_times
=
1
;
max_slot_nums
=
1
;
}
if
(
role_to_ops
.
find
(
role_id
)
==
role_to_ops
.
end
())
{
if
(
role_to_ops
.
find
(
role_id
)
==
role_to_ops
.
end
())
{
task_nodes_
.
emplace_back
(
task_nodes_
.
emplace_back
(
TaskNode
::
CreateEmptyTaskNode
(
TaskNode
::
CreateEmptyTaskNode
(
role_id
,
cur_rank
,
task_id
));
role_id
,
cur_rank
,
task_id
,
max_run_times
,
max_slot_nums
));
}
else
{
}
else
{
task_nodes_
.
emplace_back
(
TaskNode
::
CreateTaskNode
(
task_nodes_
.
emplace_back
(
role_id
,
role_to_ops
.
at
(
role_id
),
cur_rank
,
task_id
));
TaskNode
::
CreateTaskNode
(
role_id
,
role_to_ops
.
at
(
role_id
),
cur_rank
,
task_id
,
max_run_times
,
max_slot_nums
));
}
}
++
task_id
;
++
task_id
;
}
}
...
...
paddle/fluid/distributed/fleet_executor/task_node.cc
浏览文件 @
ca088f92
...
@@ -22,22 +22,37 @@ using OperatorBase = TaskNode::OperatorBase;
...
@@ -22,22 +22,37 @@ using OperatorBase = TaskNode::OperatorBase;
}
}
TaskNode
::
TaskNode
(
int64_t
role
,
const
std
::
vector
<
OperatorBase
*>&
ops
,
TaskNode
::
TaskNode
(
int64_t
role
,
const
std
::
vector
<
OperatorBase
*>&
ops
,
int64_t
rank
,
int64_t
task_id
)
int64_t
rank
,
int64_t
task_id
,
int64_t
max_run_times
,
:
ops_
(
ops
),
role_
(
role
),
rank_
(
rank
),
task_id_
(
task_id
)
{}
int64_t
max_slot_nums
)
:
ops_
(
ops
),
role_
(
role
),
rank_
(
rank
),
task_id_
(
task_id
),
max_run_times_
(
max_run_times
),
max_slot_nums_
(
max_slot_nums
)
{}
TaskNode
::
TaskNode
(
int64_t
role
,
int64_t
rank
,
int64_t
task_id
)
TaskNode
::
TaskNode
(
int64_t
role
,
int64_t
rank
,
int64_t
task_id
,
:
role_
(
role
),
rank_
(
rank
),
task_id_
(
task_id
)
{}
int64_t
max_run_times
,
int64_t
max_slot_nums
)
:
role_
(
role
),
rank_
(
rank
),
task_id_
(
task_id
),
max_run_times_
(
max_run_times
),
max_slot_nums_
(
max_slot_nums
)
{}
std
::
unique_ptr
<
TaskNode
>
TaskNode
::
CreateEmptyTaskNode
(
int64_t
role
,
std
::
unique_ptr
<
TaskNode
>
TaskNode
::
CreateEmptyTaskNode
(
int64_t
role
,
int64_t
rank
,
int64_t
rank
,
int64_t
task_id
)
{
int64_t
task_id
,
return
std
::
make_unique
<
TaskNode
>
(
role
,
rank
,
task_id
);
int64_t
max_run_times
,
int64_t
max_slot_nums
)
{
return
std
::
make_unique
<
TaskNode
>
(
role
,
rank
,
task_id
,
max_run_times
,
max_slot_nums
);
}
}
std
::
unique_ptr
<
TaskNode
>
TaskNode
::
CreateTaskNode
(
std
::
unique_ptr
<
TaskNode
>
TaskNode
::
CreateTaskNode
(
int64_t
role
,
const
std
::
vector
<
OperatorBase
*>&
ops
,
int64_t
rank
,
int64_t
role
,
const
std
::
vector
<
OperatorBase
*>&
ops
,
int64_t
rank
,
int64_t
task_id
)
{
int64_t
task_id
,
int64_t
max_run_times
,
int64_t
max_slot_nums
)
{
return
std
::
make_unique
<
TaskNode
>
(
role
,
ops
,
rank
,
task_id
);
return
std
::
make_unique
<
TaskNode
>
(
role
,
ops
,
rank
,
task_id
,
max_run_times
,
max_slot_nums
);
}
}
void
TaskNode
::
AddUpstreamTask
(
int64_t
task_id
)
{
upstream_
.
insert
(
task_id
);
}
void
TaskNode
::
AddUpstreamTask
(
int64_t
task_id
)
{
upstream_
.
insert
(
task_id
);
}
...
...
paddle/fluid/distributed/fleet_executor/task_node.h
浏览文件 @
ca088f92
...
@@ -28,23 +28,28 @@ namespace distributed {
...
@@ -28,23 +28,28 @@ namespace distributed {
class
TaskNode
final
{
class
TaskNode
final
{
public:
public:
using
OperatorBase
=
paddle
::
framework
::
OperatorBase
;
using
OperatorBase
=
paddle
::
framework
::
OperatorBase
;
TaskNode
(
int64_t
role
,
int64_t
rank
,
int64_t
task_id
);
TaskNode
(
int64_t
role
,
int64_t
rank
,
int64_t
task_id
,
int64_t
max_run_times
,
int64_t
max_slot_nums
);
TaskNode
(
int64_t
role
,
const
std
::
vector
<
OperatorBase
*>&
ops
,
int64_t
rank
,
TaskNode
(
int64_t
role
,
const
std
::
vector
<
OperatorBase
*>&
ops
,
int64_t
rank
,
int64_t
task_id
);
int64_t
task_id
,
int64_t
max_run_times
,
int64_t
max_slot_nums
);
~
TaskNode
()
=
default
;
~
TaskNode
()
=
default
;
int64_t
rank
()
const
{
return
rank_
;
}
int64_t
rank
()
const
{
return
rank_
;
}
int64_t
task_id
()
const
{
return
task_id_
;
}
int64_t
task_id
()
const
{
return
task_id_
;
}
int64_t
role
()
const
{
return
role_
;
}
int64_t
role
()
const
{
return
role_
;
}
int64_t
max_run_times
()
const
{
return
max_run_times_
;
}
int64_t
max_slot_nums
()
const
{
return
max_slot_nums_
;
}
const
std
::
unordered_set
<
int64_t
>&
upstream
()
const
{
return
upstream_
;
}
const
std
::
unordered_set
<
int64_t
>&
upstream
()
const
{
return
upstream_
;
}
const
std
::
unordered_set
<
int64_t
>&
downstream
()
const
{
return
downstream_
;
}
const
std
::
unordered_set
<
int64_t
>&
downstream
()
const
{
return
downstream_
;
}
void
AddUpstreamTask
(
int64_t
task_id
);
void
AddUpstreamTask
(
int64_t
task_id
);
void
AddDownstreamTask
(
int64_t
task_id
);
void
AddDownstreamTask
(
int64_t
task_id
);
static
std
::
unique_ptr
<
TaskNode
>
CreateEmptyTaskNode
(
int64_t
role
,
static
std
::
unique_ptr
<
TaskNode
>
CreateEmptyTaskNode
(
int64_t
role
,
int64_t
rank
,
int64_t
rank
,
int64_t
task_id
);
int64_t
task_id
,
int64_t
max_run_times
,
int64_t
max_slot_nums
);
static
std
::
unique_ptr
<
TaskNode
>
CreateTaskNode
(
static
std
::
unique_ptr
<
TaskNode
>
CreateTaskNode
(
int64_t
role
,
const
std
::
vector
<
OperatorBase
*>&
ops
,
int64_t
rank
,
int64_t
role
,
const
std
::
vector
<
OperatorBase
*>&
ops
,
int64_t
rank
,
int64_t
task_id
);
int64_t
task_id
,
int64_t
max_run_times
,
int64_t
max_slot_nums
);
private:
private:
DISABLE_COPY_AND_ASSIGN
(
TaskNode
);
DISABLE_COPY_AND_ASSIGN
(
TaskNode
);
...
@@ -55,6 +60,8 @@ class TaskNode final {
...
@@ -55,6 +60,8 @@ class TaskNode final {
int64_t
role_
;
int64_t
role_
;
int64_t
rank_
;
int64_t
rank_
;
int64_t
task_id_
;
int64_t
task_id_
;
int64_t
max_run_times_
;
int64_t
max_slot_nums_
;
};
};
}
// namespace distributed
}
// namespace distributed
...
...
python/paddle/fluid/executor.py
浏览文件 @
ca088f92
...
@@ -1981,6 +1981,8 @@ class Executor(object):
...
@@ -1981,6 +1981,8 @@ class Executor(object):
fleet_exe_desc
.
dp_degree
=
fleet_opt
[
"dist_strategy"
][
"dp_degree"
]
fleet_exe_desc
.
dp_degree
=
fleet_opt
[
"dist_strategy"
][
"dp_degree"
]
fleet_exe_desc
.
mp_degree
=
fleet_opt
[
"dist_strategy"
][
"mp_degree"
]
fleet_exe_desc
.
mp_degree
=
fleet_opt
[
"dist_strategy"
][
"mp_degree"
]
fleet_exe_desc
.
pp_degree
=
fleet_opt
[
"dist_strategy"
][
"pp_degree"
]
fleet_exe_desc
.
pp_degree
=
fleet_opt
[
"dist_strategy"
][
"pp_degree"
]
if
"num_micro_batches"
in
fleet_opt
:
fleet_exe_desc
.
num_micro_batches
=
fleet_opt
[
"num_micro_batches"
]
num_of_gpu
=
fleet_exe_desc
.
dp_degree
*
fleet_exe_desc
.
mp_degree
*
fleet_exe_desc
.
pp_degree
num_of_gpu
=
fleet_exe_desc
.
dp_degree
*
fleet_exe_desc
.
mp_degree
*
fleet_exe_desc
.
pp_degree
assert
nrank
==
num_of_gpu
,
"The number of rank is not equal to the number of gpu."
assert
nrank
==
num_of_gpu
,
"The number of rank is not equal to the number of gpu."
fleet_exe
=
core
.
FleetExecutor
(
fleet_exe_desc
.
SerializeToString
())
fleet_exe
=
core
.
FleetExecutor
(
fleet_exe_desc
.
SerializeToString
())
...
...
python/paddle/fluid/tests/unittests/test_fleet_executor_multi_devices.py
浏览文件 @
ca088f92
...
@@ -43,7 +43,11 @@ class TestFleetExecutor(unittest.TestCase):
...
@@ -43,7 +43,11 @@ class TestFleetExecutor(unittest.TestCase):
"mp_degree"
:
2
,
"mp_degree"
:
2
,
"pp_degree"
:
2
"pp_degree"
:
2
}
}
fleet_opt
=
{
"dist_strategy"
:
strategy
.
sharding_configs
}
strategy
.
pipeline_configs
=
{
"accumulate_steps"
:
8
}
fleet_opt
=
{
"dist_strategy"
:
strategy
.
sharding_configs
,
"num_micro_batches"
:
strategy
.
pipeline_configs
[
"accumulate_steps"
]
}
if
fluid
.
is_compiled_with_cuda
():
if
fluid
.
is_compiled_with_cuda
():
self
.
run_fleet_executor
(
fluid
.
CUDAPlace
(
0
),
fleet_opt
)
self
.
run_fleet_executor
(
fluid
.
CUDAPlace
(
0
),
fleet_opt
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录