Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
84367cf8
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
84367cf8
编写于
2月 10, 2019
作者:
Q
Qiao Longfei
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
support async mode in dist mode parallel executor
上级
e72637dd
变更
2
显示空白变更内容
内联
并排
Showing
2 changed file
with
36 addition
and
11 deletion
+36
-11
paddle/fluid/framework/details/multi_devices_graph_pass.cc
paddle/fluid/framework/details/multi_devices_graph_pass.cc
+30
-5
paddle/fluid/framework/details/multi_devices_graph_pass.h
paddle/fluid/framework/details/multi_devices_graph_pass.h
+6
-6
未找到文件。
paddle/fluid/framework/details/multi_devices_graph_pass.cc
浏览文件 @
84367cf8
...
@@ -167,6 +167,10 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilderBase::ApplyImpl(
...
@@ -167,6 +167,10 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilderBase::ApplyImpl(
bool
is_forwarding
=
true
;
bool
is_forwarding
=
true
;
bool
insert_collection_ops
=
NeedCollectiveOps
();
bool
insert_collection_ops
=
NeedCollectiveOps
();
if
(
strategy_
.
async_mode_
)
{
// async mode did not need to merge gradient
insert_collection_ops
=
false
;
}
for
(
ir
::
Node
*
node
:
sorted_ops
)
{
for
(
ir
::
Node
*
node
:
sorted_ops
)
{
if
(
DealWithSpecialOp
(
&
result
,
node
))
{
if
(
DealWithSpecialOp
(
&
result
,
node
))
{
...
@@ -192,8 +196,22 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilderBase::ApplyImpl(
...
@@ -192,8 +196,22 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilderBase::ApplyImpl(
static_cast
<
bool
>
(
boost
::
get
<
int
>
(
node
->
Op
()
->
GetAttr
(
static_cast
<
bool
>
(
boost
::
get
<
int
>
(
node
->
Op
()
->
GetAttr
(
OpProtoAndCheckerMaker
::
OpRoleAttrName
()))
&
OpProtoAndCheckerMaker
::
OpRoleAttrName
()))
&
static_cast
<
int
>
(
OpRole
::
kBackward
));
static_cast
<
int
>
(
OpRole
::
kBackward
));
// optimize op is already processed in DealWithSpecialOp,
// here we only consider backward op
if
(
!
is_bk_op
)
continue
;
if
(
!
is_bk_op
)
continue
;
/*
* the op that will generate the gradient of on parameter will have
one attr op_role_var
* to record the parameter and gradient, like:
attrs {
name: "op_role_var"
type: STRINGS
strings: "fc_1.b_0"
strings: "fc_1.b_0@GRAD"
}
*/
// Currently, we assume that once gradient is generated, it can be
// Currently, we assume that once gradient is generated, it can be
// broadcast, and each gradient is only broadcast once.
// broadcast, and each gradient is only broadcast once.
auto
backward_vars
=
auto
backward_vars
=
...
@@ -204,7 +222,7 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilderBase::ApplyImpl(
...
@@ -204,7 +222,7 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilderBase::ApplyImpl(
for
(
size_t
i
=
0
;
i
<
backward_vars
.
size
();
i
+=
2
)
{
for
(
size_t
i
=
0
;
i
<
backward_vars
.
size
();
i
+=
2
)
{
auto
&
p_name
=
backward_vars
[
i
];
auto
&
p_name
=
backward_vars
[
i
];
auto
&
g_name
=
backward_vars
[
i
+
1
];
auto
&
g_name
=
backward_vars
[
i
+
1
];
VLOG
(
10
)
<<
"Bcast "
<<
g_name
<<
" for parameter "
<<
p_name
;
VLOG
(
3
)
<<
"Bcast "
<<
g_name
<<
" for parameter "
<<
p_name
;
InsertCollectiveOp
(
&
result
,
p_name
,
g_name
);
InsertCollectiveOp
(
&
result
,
p_name
,
g_name
);
}
}
...
@@ -385,7 +403,7 @@ void MultiDevSSAGraphBuilderBase::CreateFusedBroadcastOp(
...
@@ -385,7 +403,7 @@ void MultiDevSSAGraphBuilderBase::CreateFusedBroadcastOp(
void
MultiDevSSAGraphBuilderBase
::
CreateComputationalOp
(
ir
::
Graph
*
result
,
void
MultiDevSSAGraphBuilderBase
::
CreateComputationalOp
(
ir
::
Graph
*
result
,
ir
::
Node
*
node
,
ir
::
Node
*
node
,
in
t
dev_id
)
const
{
size_
t
dev_id
)
const
{
result
->
Get
<
GraphOps
>
(
kGraphOps
).
emplace_back
(
result
->
Get
<
GraphOps
>
(
kGraphOps
).
emplace_back
(
new
ComputationOpHandle
(
result
->
CreateOpNode
(
node
->
Op
()),
new
ComputationOpHandle
(
result
->
CreateOpNode
(
node
->
Op
()),
local_scopes_
[
dev_id
],
places_
[
dev_id
],
dev_id
));
local_scopes_
[
dev_id
],
places_
[
dev_id
],
dev_id
));
...
@@ -454,9 +472,8 @@ void MultiDevSSAGraphBuilderBase::CreateComputationalOps(
...
@@ -454,9 +472,8 @@ void MultiDevSSAGraphBuilderBase::CreateComputationalOps(
}
}
}
}
VarHandle
*
MultiDevSSAGraphBuilderBase
::
CreateReduceOp
(
ir
::
Graph
*
result
,
VarHandle
*
MultiDevSSAGraphBuilderBase
::
CreateReduceOp
(
const
std
::
string
&
og
,
ir
::
Graph
*
result
,
const
std
::
string
&
og
,
size_t
dst_dev_id
)
const
{
int
dst_dev_id
)
const
{
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
result
->
Get
<
GraphOps
>
(
kGraphOps
).
emplace_back
(
new
ReduceOpHandle
(
result
->
Get
<
GraphOps
>
(
kGraphOps
).
emplace_back
(
new
ReduceOpHandle
(
result
->
CreateEmptyNode
(
"reduce"
,
ir
::
Node
::
Type
::
kOperation
),
result
->
CreateEmptyNode
(
"reduce"
,
ir
::
Node
::
Type
::
kOperation
),
...
@@ -720,6 +737,10 @@ bool DistSSAGraphBuilder::DealWithSpecialOp(ir::Graph *result,
...
@@ -720,6 +737,10 @@ bool DistSSAGraphBuilder::DealWithSpecialOp(ir::Graph *result,
ir
::
Node
*
node
)
const
{
ir
::
Node
*
node
)
const
{
bool
insert_op
=
false
;
bool
insert_op
=
false
;
if
(
OpHaveRole
(
*
node
,
OpRole
::
kRPC
))
{
if
(
OpHaveRole
(
*
node
,
OpRole
::
kRPC
))
{
// in async_mode, each graph will send it's own gradient.
if
(
strategy_
.
async_mode_
&&
node
->
Op
()
->
Type
()
==
"send"
)
{
return
false
;
}
int
op_dev_id
=
CreateRPCOp
(
result
,
node
);
int
op_dev_id
=
CreateRPCOp
(
result
,
node
);
PADDLE_ENFORCE
(
op_dev_id
!=
-
1
,
PADDLE_ENFORCE
(
op_dev_id
!=
-
1
,
"Can not schedule the RPC operator to the right place."
);
"Can not schedule the RPC operator to the right place."
);
...
@@ -737,6 +758,8 @@ bool DistSSAGraphBuilder::DealWithSpecialOp(ir::Graph *result,
...
@@ -737,6 +758,8 @@ bool DistSSAGraphBuilder::DealWithSpecialOp(ir::Graph *result,
}
else
if
(
OpHaveRole
(
*
node
,
OpRole
::
kDist
))
{
}
else
if
(
OpHaveRole
(
*
node
,
OpRole
::
kDist
))
{
int
op_dev_id
=
CreateDistTrainOp
(
result
,
node
);
int
op_dev_id
=
CreateDistTrainOp
(
result
,
node
);
if
(
node
->
Op
()
->
Type
()
==
"concat"
)
{
if
(
node
->
Op
()
->
Type
()
==
"concat"
)
{
// the input(block of parameter) of concat is on different device,
// the output(parameter) will on one device.
auto
origin_param_name
=
node
->
Op
()
->
OutputArgumentNames
()[
0
];
auto
origin_param_name
=
node
->
Op
()
->
OutputArgumentNames
()[
0
];
bcast_var_name_set_
[
op_dev_id
].
emplace
(
origin_param_name
);
bcast_var_name_set_
[
op_dev_id
].
emplace
(
origin_param_name
);
}
}
...
@@ -744,6 +767,7 @@ bool DistSSAGraphBuilder::DealWithSpecialOp(ir::Graph *result,
...
@@ -744,6 +767,7 @@ bool DistSSAGraphBuilder::DealWithSpecialOp(ir::Graph *result,
}
else
{
}
else
{
int
op_dev_id
=
GetOpDeviceID
(
node
);
int
op_dev_id
=
GetOpDeviceID
(
node
);
if
(
op_dev_id
!=
-
1
)
{
// This op only runs on one specific device.
if
(
op_dev_id
!=
-
1
)
{
// This op only runs on one specific device.
// optimize op will be processed here.
CreateComputationalOp
(
result
,
node
,
op_dev_id
);
CreateComputationalOp
(
result
,
node
,
op_dev_id
);
for
(
ir
::
Node
*
n
:
node
->
outputs
)
{
for
(
ir
::
Node
*
n
:
node
->
outputs
)
{
sharded_var_device_
.
emplace
(
n
->
Name
(),
op_dev_id
);
sharded_var_device_
.
emplace
(
n
->
Name
(),
op_dev_id
);
...
@@ -905,6 +929,7 @@ int DistSSAGraphBuilder::CreateDistTrainOp(ir::Graph *result,
...
@@ -905,6 +929,7 @@ int DistSSAGraphBuilder::CreateDistTrainOp(ir::Graph *result,
void
DistSSAGraphBuilder
::
InsertCollectiveOp
(
ir
::
Graph
*
result
,
void
DistSSAGraphBuilder
::
InsertCollectiveOp
(
ir
::
Graph
*
result
,
const
std
::
string
&
p_name
,
const
std
::
string
&
p_name
,
const
std
::
string
&
g_name
)
const
{
const
std
::
string
&
g_name
)
const
{
// collective gradient to each device
size_t
cur_device_id
=
0
;
size_t
cur_device_id
=
0
;
switch
(
strategy_
.
reduce_
)
{
switch
(
strategy_
.
reduce_
)
{
case
BuildStrategy
::
ReduceStrategy
::
kReduce
:
case
BuildStrategy
::
ReduceStrategy
::
kReduce
:
...
...
paddle/fluid/framework/details/multi_devices_graph_pass.h
浏览文件 @
84367cf8
...
@@ -68,10 +68,10 @@ class MultiDevSSAGraphBuilderBase : public ir::Pass {
...
@@ -68,10 +68,10 @@ class MultiDevSSAGraphBuilderBase : public ir::Pass {
proto
::
VarType
::
Type
dtype
)
const
;
proto
::
VarType
::
Type
dtype
)
const
;
VarHandle
*
CreateReduceOp
(
ir
::
Graph
*
result
,
const
std
::
string
&
og
,
VarHandle
*
CreateReduceOp
(
ir
::
Graph
*
result
,
const
std
::
string
&
og
,
in
t
dst_dev_id
)
const
;
size_
t
dst_dev_id
)
const
;
void
CreateComputationalOp
(
ir
::
Graph
*
result
,
ir
::
Node
*
node
,
void
CreateComputationalOp
(
ir
::
Graph
*
result
,
ir
::
Node
*
node
,
in
t
dev_id
)
const
;
size_
t
dev_id
)
const
;
bool
IsSparseGradient
(
const
std
::
string
&
og
)
const
;
bool
IsSparseGradient
(
const
std
::
string
&
og
)
const
;
...
@@ -118,16 +118,16 @@ class AllReduceSSAGraphBuilder : public MultiDevSSAGraphBuilderBase {
...
@@ -118,16 +118,16 @@ class AllReduceSSAGraphBuilder : public MultiDevSSAGraphBuilderBase {
class
AsyncSSAGraphBuilder
:
public
MultiDevSSAGraphBuilderBase
{
class
AsyncSSAGraphBuilder
:
public
MultiDevSSAGraphBuilderBase
{
protected:
protected:
v
irtual
v
oid
InsertCollectiveOp
(
ir
::
Graph
*
result
,
const
std
::
string
&
p_name
,
void
InsertCollectiveOp
(
ir
::
Graph
*
result
,
const
std
::
string
&
p_name
,
const
std
::
string
&
g_name
)
const
{}
const
std
::
string
&
g_name
)
const
override
{}
bool
NeedCollectiveOps
()
const
override
{
return
false
;
}
bool
NeedCollectiveOps
()
const
override
{
return
false
;
}
virtual
bool
DealWithSpecialOp
(
ir
::
Graph
*
result
,
ir
::
Node
*
node
)
const
{
bool
DealWithSpecialOp
(
ir
::
Graph
*
result
,
ir
::
Node
*
node
)
const
override
{
return
false
;
return
false
;
}
}
v
irtual
void
InsertPostprocessOps
(
ir
::
Graph
*
result
)
const
{}
v
oid
InsertPostprocessOps
(
ir
::
Graph
*
result
)
const
override
{}
};
};
class
BalanceVarSSAGraphBuilder
:
public
MultiDevSSAGraphBuilderBase
{
class
BalanceVarSSAGraphBuilder
:
public
MultiDevSSAGraphBuilderBase
{
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录