Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
af83e79a
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
af83e79a
编写于
11月 19, 2021
作者:
L
LiYuRio
提交者:
GitHub
11月 19, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Fix runtime graph on gpt, add debug message (#37361)
上级
edc3496f
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
52 addition
and
27 deletion
+52
-27
paddle/fluid/distributed/fleet_executor/fleet_executor.cc
paddle/fluid/distributed/fleet_executor/fleet_executor.cc
+1
-0
paddle/fluid/distributed/fleet_executor/fleet_executor.h
paddle/fluid/distributed/fleet_executor/fleet_executor.h
+1
-0
paddle/fluid/distributed/fleet_executor/runtime_graph.cc
paddle/fluid/distributed/fleet_executor/runtime_graph.cc
+27
-17
paddle/fluid/distributed/fleet_executor/runtime_graph.h
paddle/fluid/distributed/fleet_executor/runtime_graph.h
+2
-0
paddle/fluid/distributed/fleet_executor/task_node.cc
paddle/fluid/distributed/fleet_executor/task_node.cc
+14
-4
paddle/fluid/distributed/fleet_executor/task_node.h
paddle/fluid/distributed/fleet_executor/task_node.h
+7
-6
未找到文件。
paddle/fluid/distributed/fleet_executor/fleet_executor.cc
浏览文件 @
af83e79a
...
...
@@ -35,6 +35,7 @@ FleetExecutor::~FleetExecutor() {
void
FleetExecutor
::
Init
(
const
paddle
::
framework
::
ProgramDesc
&
program_desc
)
{
runtime_graph_
=
std
::
make_unique
<
RuntimeGraph
>
(
program_desc
,
exe_desc_
);
VLOG
(
5
)
<<
runtime_graph_
->
DebugString
();
InitCarrier
();
InitMessageBus
();
}
...
...
paddle/fluid/distributed/fleet_executor/fleet_executor.h
浏览文件 @
af83e79a
...
...
@@ -14,6 +14,7 @@
#pragma once
#include <memory>
#include <string>
#include "paddle/fluid/distributed/fleet_executor/fleet_executor_desc.pb.h"
#include "paddle/fluid/platform/macros.h"
...
...
paddle/fluid/distributed/fleet_executor/runtime_graph.cc
浏览文件 @
af83e79a
...
...
@@ -27,24 +27,24 @@ using OpRole = paddle::framework::OpRole;
using
OpRegistry
=
paddle
::
framework
::
OpRegistry
;
using
ProgramDesc
=
paddle
::
framework
::
ProgramDesc
;
bool
IsForward
(
int
64
_t
op_role
)
{
return
(
op_role
==
static_cast
<
int
64
_t
>
(
OpRole
::
kForward
))
||
(
op_role
==
(
static_cast
<
int
64
_t
>
(
OpRole
::
kForward
)
|
static_cast
<
int
64
_t
>
(
OpRole
::
kLoss
)));
bool
IsForward
(
int
32
_t
op_role
)
{
return
(
op_role
==
static_cast
<
int
32
_t
>
(
OpRole
::
kForward
))
||
(
op_role
==
(
static_cast
<
int
32
_t
>
(
OpRole
::
kForward
)
|
static_cast
<
int
32
_t
>
(
OpRole
::
kLoss
)));
}
bool
IsLRSched
(
int
64
_t
op_role
)
{
return
op_role
==
static_cast
<
int
64
_t
>
(
OpRole
::
kLRSched
);
bool
IsLRSched
(
int
32
_t
op_role
)
{
return
op_role
==
static_cast
<
int
32
_t
>
(
OpRole
::
kLRSched
);
}
bool
IsBackward
(
int
64
_t
op_role
)
{
return
(
op_role
==
static_cast
<
int
64
_t
>
(
OpRole
::
kBackward
))
||
(
op_role
==
(
static_cast
<
int
64
_t
>
(
OpRole
::
kBackward
)
|
static_cast
<
int
64
_t
>
(
OpRole
::
kLoss
)));
bool
IsBackward
(
int
32
_t
op_role
)
{
return
(
op_role
==
static_cast
<
int
32
_t
>
(
OpRole
::
kBackward
))
||
(
op_role
==
(
static_cast
<
int
32
_t
>
(
OpRole
::
kBackward
)
|
static_cast
<
int
32
_t
>
(
OpRole
::
kLoss
)));
}
bool
IsOptimize
(
int
64
_t
op_role
)
{
return
op_role
==
static_cast
<
int
64
_t
>
(
OpRole
::
kOptimize
);
bool
IsOptimize
(
int
32
_t
op_role
)
{
return
op_role
==
static_cast
<
int
32
_t
>
(
OpRole
::
kOptimize
);
}
struct
DistCoord
{
...
...
@@ -112,9 +112,9 @@ void RuntimeGraph::SplitProgramBasedFunctionality(const ProgramDesc& program) {
for
(
const
auto
&
op_desc
:
program
.
Block
(
0
).
AllOps
())
{
ops_
.
emplace_back
(
OpRegistry
::
CreateOp
(
*
op_desc
));
}
std
::
unordered_map
<
int
64
_t
,
std
::
vector
<
OperatorBase
*>>
role_to_ops
;
std
::
unordered_map
<
int
32
_t
,
std
::
vector
<
OperatorBase
*>>
role_to_ops
;
for
(
const
auto
&
op
:
ops_
)
{
int
64_t
op_role
=
op
->
Attr
<
int64
_t
>
(
"op_role"
);
int
32_t
op_role
=
op
->
Attr
<
int32
_t
>
(
"op_role"
);
OpRole
new_op_role
;
if
(
IsLRSched
(
op_role
))
{
new_op_role
=
OpRole
::
kLRSched
;
...
...
@@ -129,7 +129,7 @@ void RuntimeGraph::SplitProgramBasedFunctionality(const ProgramDesc& program) {
"The op %s is None of LRSched, Forward, Backward or Optimize."
,
op
->
Type
()));
}
int
64_t
new_op_role_id
=
static_cast
<
int64
_t
>
(
new_op_role
);
int
32_t
new_op_role_id
=
static_cast
<
int32
_t
>
(
new_op_role
);
if
(
role_to_ops
.
find
(
new_op_role_id
)
==
role_to_ops
.
end
())
{
role_to_ops
.
insert
({
new_op_role_id
,
{}});
}
...
...
@@ -147,7 +147,7 @@ void RuntimeGraph::SplitProgramBasedFunctionality(const ProgramDesc& program) {
int64_t
task_id
=
cur_rank
*
functionality_order
.
size
();
for
(
std
::
size_t
i
=
0
;
i
<
functionality_order
.
size
();
++
i
)
{
OpRole
role
=
functionality_order
[
i
];
int
64
_t
role_id
=
static_cast
<
int64_t
>
(
role
);
int
32
_t
role_id
=
static_cast
<
int64_t
>
(
role
);
int64_t
max_run_times
=
num_micro_batches
;
int64_t
max_slot_nums
=
start_up_steps
;
if
(
IsLRSched
(
role_id
)
||
IsOptimize
(
role_id
))
{
...
...
@@ -225,12 +225,22 @@ void RuntimeGraph::FakeRuntimeInfo() {
int64_t
nrank
=
exe_desc_
.
cluster_info
().
size
();
int32_t
num_of_functionality
=
functionality_order
.
size
();
for
(
int64_t
i
=
0
;
i
<
nrank
;
++
i
)
{
for
(
int
64
_t
j
=
0
;
j
<
num_of_functionality
;
++
j
)
{
for
(
int
32
_t
j
=
0
;
j
<
num_of_functionality
;
++
j
)
{
int64_t
intercepter_id
=
i
*
num_of_functionality
+
j
;
intercepter_id_to_rank_
.
insert
({
intercepter_id
,
i
});
}
}
}
std
::
string
RuntimeGraph
::
DebugString
()
const
{
std
::
ostringstream
os
;
os
<<
"
\n
Runtime Graph Debug:
\n
"
;
for
(
const
auto
&
task
:
task_nodes_
)
{
os
<<
task
->
DebugString
();
os
<<
"
\n
"
;
}
return
os
.
str
();
}
}
// namespace distributed
}
// namespace paddle
paddle/fluid/distributed/fleet_executor/runtime_graph.h
浏览文件 @
af83e79a
...
...
@@ -14,6 +14,7 @@
#pragma once
#include <memory>
#include <string>
#include <unordered_map>
#include <vector>
#include "paddle/fluid/distributed/fleet_executor/fleet_executor_desc.pb.h"
...
...
@@ -43,6 +44,7 @@ class RuntimeGraph final {
const
std
::
unordered_map
<
int64_t
,
int64_t
>&
intercepter_id_to_rank
()
const
{
return
intercepter_id_to_rank_
;
}
std
::
string
DebugString
()
const
;
private:
DISABLE_COPY_AND_ASSIGN
(
RuntimeGraph
);
...
...
paddle/fluid/distributed/fleet_executor/task_node.cc
浏览文件 @
af83e79a
...
...
@@ -21,7 +21,7 @@ namespace {
using
OperatorBase
=
TaskNode
::
OperatorBase
;
}
TaskNode
::
TaskNode
(
int
64
_t
role
,
const
std
::
vector
<
OperatorBase
*>&
ops
,
TaskNode
::
TaskNode
(
int
32
_t
role
,
const
std
::
vector
<
OperatorBase
*>&
ops
,
int64_t
rank
,
int64_t
task_id
,
int64_t
max_run_times
,
int64_t
max_slot_nums
)
:
ops_
(
ops
),
...
...
@@ -31,7 +31,7 @@ TaskNode::TaskNode(int64_t role, const std::vector<OperatorBase*>& ops,
max_run_times_
(
max_run_times
),
max_slot_nums_
(
max_slot_nums
)
{}
TaskNode
::
TaskNode
(
int
64
_t
role
,
int64_t
rank
,
int64_t
task_id
,
TaskNode
::
TaskNode
(
int
32
_t
role
,
int64_t
rank
,
int64_t
task_id
,
int64_t
max_run_times
,
int64_t
max_slot_nums
)
:
role_
(
role
),
rank_
(
rank
),
...
...
@@ -39,7 +39,7 @@ TaskNode::TaskNode(int64_t role, int64_t rank, int64_t task_id,
max_run_times_
(
max_run_times
),
max_slot_nums_
(
max_slot_nums
)
{}
std
::
unique_ptr
<
TaskNode
>
TaskNode
::
CreateEmptyTaskNode
(
int
64
_t
role
,
std
::
unique_ptr
<
TaskNode
>
TaskNode
::
CreateEmptyTaskNode
(
int
32
_t
role
,
int64_t
rank
,
int64_t
task_id
,
int64_t
max_run_times
,
...
...
@@ -49,7 +49,7 @@ std::unique_ptr<TaskNode> TaskNode::CreateEmptyTaskNode(int64_t role,
}
std
::
unique_ptr
<
TaskNode
>
TaskNode
::
CreateTaskNode
(
int
64
_t
role
,
const
std
::
vector
<
OperatorBase
*>&
ops
,
int64_t
rank
,
int
32
_t
role
,
const
std
::
vector
<
OperatorBase
*>&
ops
,
int64_t
rank
,
int64_t
task_id
,
int64_t
max_run_times
,
int64_t
max_slot_nums
)
{
return
std
::
make_unique
<
TaskNode
>
(
role
,
ops
,
rank
,
task_id
,
max_run_times
,
max_slot_nums
);
...
...
@@ -60,5 +60,15 @@ void TaskNode::AddUpstreamTask(int64_t task_id) { upstream_.insert(task_id); }
void
TaskNode
::
AddDownstreamTask
(
int64_t
task_id
)
{
downstream_
.
insert
(
task_id
);
}
std
::
string
TaskNode
::
DebugString
()
const
{
std
::
ostringstream
os
;
os
<<
"role: "
<<
role_
<<
", task_id: "
<<
task_id_
<<
"
\n
"
;
for
(
std
::
size_t
i
=
0
;
i
<
ops_
.
size
();
++
i
)
{
os
<<
ops_
[
i
]
->
Type
()
<<
" "
;
}
os
<<
"
\n
"
;
return
os
.
str
();
}
}
// namespace distributed
}
// namespace paddle
paddle/fluid/distributed/fleet_executor/task_node.h
浏览文件 @
af83e79a
...
...
@@ -28,27 +28,28 @@ namespace distributed {
class
TaskNode
final
{
public:
using
OperatorBase
=
paddle
::
framework
::
OperatorBase
;
TaskNode
(
int
64
_t
role
,
int64_t
rank
,
int64_t
task_id
,
int64_t
max_run_times
,
TaskNode
(
int
32
_t
role
,
int64_t
rank
,
int64_t
task_id
,
int64_t
max_run_times
,
int64_t
max_slot_nums
);
TaskNode
(
int
64
_t
role
,
const
std
::
vector
<
OperatorBase
*>&
ops
,
int64_t
rank
,
TaskNode
(
int
32
_t
role
,
const
std
::
vector
<
OperatorBase
*>&
ops
,
int64_t
rank
,
int64_t
task_id
,
int64_t
max_run_times
,
int64_t
max_slot_nums
);
~
TaskNode
()
=
default
;
int64_t
rank
()
const
{
return
rank_
;
}
int64_t
task_id
()
const
{
return
task_id_
;
}
int
64
_t
role
()
const
{
return
role_
;
}
int
32
_t
role
()
const
{
return
role_
;
}
int64_t
max_run_times
()
const
{
return
max_run_times_
;
}
int64_t
max_slot_nums
()
const
{
return
max_slot_nums_
;
}
const
std
::
unordered_set
<
int64_t
>&
upstream
()
const
{
return
upstream_
;
}
const
std
::
unordered_set
<
int64_t
>&
downstream
()
const
{
return
downstream_
;
}
void
AddUpstreamTask
(
int64_t
task_id
);
void
AddDownstreamTask
(
int64_t
task_id
);
static
std
::
unique_ptr
<
TaskNode
>
CreateEmptyTaskNode
(
int64_t
role
,
std
::
string
DebugString
()
const
;
static
std
::
unique_ptr
<
TaskNode
>
CreateEmptyTaskNode
(
int32_t
role
,
int64_t
rank
,
int64_t
task_id
,
int64_t
max_run_times
,
int64_t
max_slot_nums
);
static
std
::
unique_ptr
<
TaskNode
>
CreateTaskNode
(
int
64
_t
role
,
const
std
::
vector
<
OperatorBase
*>&
ops
,
int64_t
rank
,
int
32
_t
role
,
const
std
::
vector
<
OperatorBase
*>&
ops
,
int64_t
rank
,
int64_t
task_id
,
int64_t
max_run_times
,
int64_t
max_slot_nums
);
private:
...
...
@@ -57,7 +58,7 @@ class TaskNode final {
std
::
vector
<
OperatorBase
*>
ops_
;
std
::
unordered_set
<
int64_t
>
upstream_
;
std
::
unordered_set
<
int64_t
>
downstream_
;
int
64
_t
role_
;
int
32
_t
role_
;
int64_t
rank_
;
int64_t
task_id_
;
int64_t
max_run_times_
;
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录