Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
ff599b92
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
ff599b92
编写于
5月 04, 2018
作者:
C
chengduoZH
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
use Reduce and Broadcast
上级
0441c2cc
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
13 addition
and
59 deletion
+13
-59
paddle/fluid/framework/details/multi_devices_graph_builder.cc
...le/fluid/framework/details/multi_devices_graph_builder.cc
+10
-52
paddle/fluid/framework/details/multi_devices_graph_builder.h
paddle/fluid/framework/details/multi_devices_graph_builder.h
+3
-7
未找到文件。
paddle/fluid/framework/details/multi_devices_graph_builder.cc
浏览文件 @
ff599b92
...
...
@@ -111,6 +111,7 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
for
(
auto
*
var
:
program
.
Block
(
0
).
AllVars
())
{
var_types
[
var
->
Name
()]
=
var
->
GetType
();
}
auto
graph
=
new
SSAGraph
();
SSAGraph
&
result
=
*
graph
;
std
::
unordered_set
<
std
::
string
>
og_has_been_broadcast
;
...
...
@@ -120,13 +121,6 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
std
::
unique_ptr
<
VarHandle
>>>>
(
places_
.
size
());
size_t
cur_dev_id
=
0
;
std
::
vector
<
std
::
unordered_set
<
std
::
string
>>
sparse_var_name_on_devices
;
std
::
vector
<
std
::
unordered_set
<
std
::
string
>>
bcast_sparse_var_name_set
;
sparse_var_name_on_devices
.
resize
(
places_
.
size
());
bcast_sparse_var_name_set
.
resize
(
places_
.
size
());
// Find "send" op first for split is in front of send.
OpDesc
*
send_op
=
GetSendOpDesc
(
program
);
...
...
@@ -145,27 +139,15 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
}
is_forwarding
=
false
;
}
else
{
int
op_dev_id
=
GetOpDeviceID
(
sparse_var_name_on_devices
,
*
op
);
if
(
op_dev_id
==
-
1
)
{
// var on all device
CreateComputationalOps
(
&
result
,
*
op
,
places_
.
size
());
}
else
{
CreateComputationalOp
(
&
result
,
*
op
,
op_dev_id
);
for
(
auto
&
var_name
:
op
->
OutputArgumentNames
())
{
sparse_var_name_on_devices
[
op_dev_id
].
emplace
(
var_name
);
}
}
CreateComputationalOps
(
&
result
,
*
op
,
places_
.
size
());
if
(
!
is_forwarding
&&
places_
.
size
()
>
1
)
{
// Currently, we assume that once gradient is generated, it can be
// broadcast, and each gradient is only broadcast once.
for
(
auto
&
og
:
op
->
OutputArgumentNames
())
{
if
(
IsParameterGradientOnce
(
og
,
&
og_has_been_broadcast
))
{
if
(
IsSparseGradient
(
var_types
,
og
))
{
CreateReduceOp
(
&
result
,
cur_dev_id
,
og
);
sparse_var_name_on_devices
[
cur_dev_id
].
emplace
(
og
);
bcast_sparse_var_name_set
[
cur_dev_id
].
emplace
(
og
.
substr
(
0
,
og
.
size
()
-
strlen
(
kGradVarSuffix
)));
cur_dev_id
=
(
cur_dev_id
+
1
)
%
places_
.
size
();
CreateReduceOp
(
&
result
,
og
,
0
);
CreateBroadcastOp
(
&
result
,
og
,
0
);
}
else
{
InsertNCCLAllReduceOp
(
&
result
,
og
);
}
...
...
@@ -175,14 +157,6 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
}
}
// Insert BCast Ops
for
(
size_t
dev_id
=
0
;
dev_id
<
bcast_sparse_var_name_set
.
size
();
++
dev_id
)
{
auto
&
to_bcast_set
=
bcast_sparse_var_name_set
[
dev_id
];
for
(
auto
&
bcast_name
:
to_bcast_set
)
{
CreateBroadcastOp
(
&
result
,
bcast_name
,
dev_id
);
}
}
/*
Dependency graph has been constructed. However, there are still data
harzaeds need to be handled.
...
...
@@ -213,26 +187,9 @@ bool MultiDevSSAGraphBuilder::IsSparseGradient(
return
false
;
}
int
MultiDevSSAGraphBuilder
::
GetOpDeviceID
(
const
std
::
vector
<
std
::
unordered_set
<
std
::
string
>>
&
sparse_var_name_on_devices
,
const
OpDesc
&
op
)
const
{
int
var_dev_id
=
-
1
;
for
(
auto
&
var_name
:
op
.
InputArgumentNames
())
{
if
(
var_dev_id
!=
-
1
)
break
;
for
(
size_t
i
=
0
;
i
<
sparse_var_name_on_devices
.
size
();
++
i
)
{
if
(
sparse_var_name_on_devices
[
i
].
count
(
var_name
))
{
var_dev_id
=
static_cast
<
int
>
(
i
);
break
;
}
}
}
return
var_dev_id
;
}
void
MultiDevSSAGraphBuilder
::
CreateBroadcastOp
(
SSAGraph
*
result
,
const
std
::
string
&
p_name
,
size_t
dev_id
)
const
{
size_t
src_
dev_id
)
const
{
#ifdef PADDLE_WITH_CUDA
auto
*
op_handle
=
new
BroadcastOpHandle
(
local_scopes_
,
places_
,
nccl_ctxs_
);
#else
...
...
@@ -240,11 +197,11 @@ void MultiDevSSAGraphBuilder::CreateBroadcastOp(SSAGraph *result,
#endif
result
->
ops_
.
emplace_back
(
op_handle
);
auto
*
in
=
result
->
vars_
.
at
(
dev_id
).
at
(
p_name
).
back
().
get
();
auto
*
in
=
result
->
vars_
.
at
(
src_
dev_id
).
at
(
p_name
).
back
().
get
();
op_handle
->
AddInput
(
in
);
for
(
size_t
i
=
0
;
i
<
places_
.
size
();
++
i
)
{
auto
&
vars
=
result
->
vars_
.
at
(
dev_id
).
at
(
p_name
);
auto
&
vars
=
result
->
vars_
.
at
(
i
).
at
(
p_name
);
auto
&
p
=
places_
[
i
];
auto
*
out_var
=
new
VarHandle
(
vars
.
size
(),
i
,
p_name
,
p
);
vars
.
emplace_back
(
out_var
);
...
...
@@ -345,8 +302,9 @@ void MultiDevSSAGraphBuilder::CreateComputationalOps(SSAGraph *result,
}
}
VarHandle
*
MultiDevSSAGraphBuilder
::
CreateReduceOp
(
SSAGraph
*
result
,
int
dst_dev_id
,
const
std
::
string
&
og
)
const
{
VarHandle
*
MultiDevSSAGraphBuilder
::
CreateReduceOp
(
SSAGraph
*
result
,
const
std
::
string
&
og
,
int
dst_dev_id
)
const
{
#ifdef PADDLE_WITH_CUDA
result
->
ops_
.
emplace_back
(
new
ReduceOpHandle
(
local_scopes_
,
places_
,
nccl_ctxs_
));
...
...
paddle/fluid/framework/details/multi_devices_graph_builder.h
浏览文件 @
ff599b92
...
...
@@ -75,8 +75,8 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
size_t
num_places
)
const
;
void
CreateScaleLossGradOp
(
SSAGraph
*
result
)
const
;
VarHandle
*
CreateReduceOp
(
SSAGraph
*
result
,
int
dst_dev_id
,
const
std
::
string
&
og
)
const
;
VarHandle
*
CreateReduceOp
(
SSAGraph
*
result
,
const
std
::
string
&
og
,
int
dst_dev_id
)
const
;
void
CreateComputationalOp
(
SSAGraph
*
result
,
const
OpDesc
&
op
,
int
dev_id
)
const
;
...
...
@@ -87,11 +87,7 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
void
InsertNCCLAllReduceOp
(
SSAGraph
*
result
,
const
std
::
string
&
og
)
const
;
void
CreateBroadcastOp
(
SSAGraph
*
result
,
const
std
::
string
&
p_name
,
size_t
dev_id
)
const
;
int
GetOpDeviceID
(
const
std
::
vector
<
std
::
unordered_set
<
std
::
string
>>
&
var_name_on_devices
,
const
OpDesc
&
op
)
const
;
size_t
src_dev_id
)
const
;
/**
* Get send op in the global block of program.
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录