Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
93355cc0
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
93355cc0
编写于
7月 21, 2018
作者:
X
Xin Pan
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix control deps
上级
f6d99d1f
变更
14
隐藏空白更改
内联
并排
Showing
14 changed file
with
155 addition
and
61 deletion
+155
-61
paddle/fluid/framework/details/CMakeLists.txt
paddle/fluid/framework/details/CMakeLists.txt
+1
-1
paddle/fluid/framework/details/multi_devices_graph_builder.cc
...le/fluid/framework/details/multi_devices_graph_builder.cc
+21
-13
paddle/fluid/framework/details/multi_devices_graph_builder.h
paddle/fluid/framework/details/multi_devices_graph_builder.h
+2
-2
paddle/fluid/framework/details/rpc_op_handle.cc
paddle/fluid/framework/details/rpc_op_handle.cc
+2
-1
paddle/fluid/framework/details/ssa_graph_builder.cc
paddle/fluid/framework/details/ssa_graph_builder.cc
+31
-2
paddle/fluid/framework/details/ssa_graph_builder.h
paddle/fluid/framework/details/ssa_graph_builder.h
+9
-0
paddle/fluid/framework/details/var_handle.cc
paddle/fluid/framework/details/var_handle.cc
+1
-1
paddle/fluid/framework/ir/graph.cc
paddle/fluid/framework/ir/graph.cc
+19
-7
paddle/fluid/framework/ir/graph.h
paddle/fluid/framework/ir/graph.h
+38
-7
paddle/fluid/framework/ir/graph_helper.cc
paddle/fluid/framework/ir/graph_helper.cc
+9
-10
paddle/fluid/framework/ir/graph_helper.h
paddle/fluid/framework/ir/graph_helper.h
+1
-0
paddle/fluid/framework/ir/graph_test.cc
paddle/fluid/framework/ir/graph_test.cc
+15
-14
paddle/fluid/operators/CMakeLists.txt
paddle/fluid/operators/CMakeLists.txt
+2
-2
paddle/fluid/operators/send_recv_util.h
paddle/fluid/operators/send_recv_util.h
+4
-1
未找到文件。
paddle/fluid/framework/details/CMakeLists.txt
浏览文件 @
93355cc0
cc_library
(
var_handle SRCS var_handle.cc DEPS place framework_proto
)
cc_library
(
var_handle SRCS var_handle.cc DEPS place framework_proto
node
)
cc_library
(
op_handle_base SRCS op_handle_base.cc DEPS var_handle device_context lod_tensor
)
cc_library
(
op_handle_base SRCS op_handle_base.cc DEPS var_handle device_context lod_tensor
)
cc_library
(
scale_loss_grad_op_handle SRCS scale_loss_grad_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory
)
cc_library
(
scale_loss_grad_op_handle SRCS scale_loss_grad_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory
)
cc_library
(
fetch_op_handle SRCS fetch_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory
)
cc_library
(
fetch_op_handle SRCS fetch_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory
)
...
...
paddle/fluid/framework/details/multi_devices_graph_builder.cc
浏览文件 @
93355cc0
...
@@ -94,12 +94,11 @@ void MultiDevSSAGraphBuilder::CreateOpHandleIOs(ir::Graph *result,
...
@@ -94,12 +94,11 @@ void MultiDevSSAGraphBuilder::CreateOpHandleIOs(ir::Graph *result,
}
}
std
::
vector
<
std
::
string
>
MultiDevSSAGraphBuilder
::
FindDistTrainSendVars
(
std
::
vector
<
std
::
string
>
MultiDevSSAGraphBuilder
::
FindDistTrainSendVars
(
const
std
::
vector
<
std
::
unique_ptr
<
ir
::
Node
>
>
&
nodes
)
const
{
const
std
::
vector
<
ir
::
Node
*
>
&
nodes
)
const
{
std
::
vector
<
std
::
string
>
send_vars
;
std
::
vector
<
std
::
string
>
send_vars
;
// since parameters are all in block 0,
// since parameters are all in block 0,
// it's enough to only scan send ops in block 0
// it's enough to only scan send ops in block 0
for
(
auto
&
node
:
nodes
)
{
for
(
auto
&
node
:
nodes
)
{
if
(
node
->
NodeType
()
!=
ir
::
Node
::
Type
::
kOperation
)
continue
;
OpDesc
*
op
=
node
->
Op
();
OpDesc
*
op
=
node
->
Op
();
// TODO(Yancey1989): use a graceful method to find send op,
// TODO(Yancey1989): use a graceful method to find send op,
// instead of the the hard code string
// instead of the the hard code string
...
@@ -114,10 +113,9 @@ std::vector<std::string> MultiDevSSAGraphBuilder::FindDistTrainSendVars(
...
@@ -114,10 +113,9 @@ std::vector<std::string> MultiDevSSAGraphBuilder::FindDistTrainSendVars(
}
}
std
::
vector
<
std
::
string
>
MultiDevSSAGraphBuilder
::
FindDistTrainRecvVars
(
std
::
vector
<
std
::
string
>
MultiDevSSAGraphBuilder
::
FindDistTrainRecvVars
(
const
std
::
vector
<
std
::
unique_ptr
<
ir
::
Node
>
>
&
nodes
)
const
{
const
std
::
vector
<
ir
::
Node
*
>
&
nodes
)
const
{
std
::
vector
<
std
::
string
>
recv_vars
;
std
::
vector
<
std
::
string
>
recv_vars
;
for
(
auto
&
node
:
nodes
)
{
for
(
auto
&
node
:
nodes
)
{
if
(
node
->
NodeType
()
!=
ir
::
Node
::
Type
::
kOperation
)
continue
;
OpDesc
*
op
=
node
->
Op
();
OpDesc
*
op
=
node
->
Op
();
// TODO(Yancey1989): use a graceful method to find recv op,
// TODO(Yancey1989): use a graceful method to find recv op,
// instead of the hard code string
// instead of the hard code string
...
@@ -214,6 +212,19 @@ std::vector<ir::Node *> SortOpsAndDelayOptimizeOp(const ir::Graph &graph) {
...
@@ -214,6 +212,19 @@ std::vector<ir::Node *> SortOpsAndDelayOptimizeOp(const ir::Graph &graph) {
}
}
}
}
// Verify that no operations before optimize ops depends on optimize ops.
std
::
unordered_set
<
ir
::
Node
*>
optimize_set
(
optimize_ops
.
begin
(),
optimize_ops
.
end
());
for
(
size_t
i
=
0
;
i
<
last_backward
;
++
i
)
{
for
(
ir
::
Node
*
in
:
sorted_ret
[
i
]
->
inputs
)
{
for
(
ir
::
Node
*
pre_n
:
in
->
inputs
)
{
PADDLE_ENFORCE
(
optimize_set
.
find
(
pre_n
)
==
optimize_set
.
end
(),
"optimize operations cannot be depended by forward "
"or backward node %s -> %s"
,
pre_n
->
Name
(),
sorted_ret
[
i
]
->
Name
());
}
}
}
sorted_ret
.
insert
(
sorted_ret
.
begin
()
+
last_backward
,
optimize_ops
.
begin
(),
sorted_ret
.
insert
(
sorted_ret
.
begin
()
+
last_backward
,
optimize_ops
.
begin
(),
optimize_ops
.
end
());
optimize_ops
.
end
());
return
sorted_ret
;
return
sorted_ret
;
...
@@ -221,18 +232,16 @@ std::vector<ir::Node *> SortOpsAndDelayOptimizeOp(const ir::Graph &graph) {
...
@@ -221,18 +232,16 @@ std::vector<ir::Node *> SortOpsAndDelayOptimizeOp(const ir::Graph &graph) {
std
::
unique_ptr
<
ir
::
Graph
>
MultiDevSSAGraphBuilder
::
Apply
(
std
::
unique_ptr
<
ir
::
Graph
>
MultiDevSSAGraphBuilder
::
Apply
(
std
::
unique_ptr
<
ir
::
Graph
>
graph
)
const
{
std
::
unique_ptr
<
ir
::
Graph
>
graph
)
const
{
//
R
ebuild the graph structure.
//
Give the topology sort order and r
ebuild the graph structure.
std
::
vector
<
ir
::
Node
*>
sorted_ops
=
SortOpsAndDelayOptimizeOp
(
*
graph
);
std
::
vector
<
ir
::
Node
*>
sorted_ops
=
SortOpsAndDelayOptimizeOp
(
*
graph
);
auto
nodes
=
std
::
move
(
graph
->
nodes
);
auto
nodes
=
graph
->
ReleaseNodes
(
);
graph
->
nodes
.
clear
()
;
ir
::
Graph
&
result
=
*
graph
;
for
(
auto
&
node
:
nodes
)
{
for
(
auto
&
node
:
nodes
)
{
if
(
node
->
NodeType
()
==
ir
::
Node
::
Type
::
kVariable
)
{
if
(
node
->
NodeType
()
==
ir
::
Node
::
Type
::
kVariable
)
{
all_vars_
.
emplace
(
node
->
Name
(),
node
->
Var
());
all_vars_
.
emplace
(
node
->
Name
(),
node
->
Var
());
}
}
}
}
ir
::
Graph
&
result
=
*
graph
;
std
::
unordered_set
<
std
::
string
>
og_has_been_broadcast
;
std
::
unordered_set
<
std
::
string
>
og_has_been_broadcast
;
// We cannot invoke resize. It is a bug of GCC 4.8
// We cannot invoke resize. It is a bug of GCC 4.8
...
@@ -242,8 +251,8 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::Apply(
...
@@ -242,8 +251,8 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::Apply(
// find send/recv vars so that we can place the distributed training
// find send/recv vars so that we can place the distributed training
// realted op in the place 0
// realted op in the place 0
auto
send_vars
=
FindDistTrainSendVars
(
node
s
);
auto
send_vars
=
FindDistTrainSendVars
(
sorted_op
s
);
auto
recv_vars
=
FindDistTrainRecvVars
(
node
s
);
auto
recv_vars
=
FindDistTrainRecvVars
(
sorted_op
s
);
std
::
vector
<
std
::
unordered_set
<
std
::
string
>>
bcast_var_name_set
;
std
::
vector
<
std
::
unordered_set
<
std
::
string
>>
bcast_var_name_set
;
bcast_var_name_set
.
resize
(
places_
.
size
());
bcast_var_name_set
.
resize
(
places_
.
size
());
...
@@ -589,8 +598,7 @@ void MultiDevSSAGraphBuilder::ConnectOp(ir::Graph *result, OpHandleBase *op,
...
@@ -589,8 +598,7 @@ void MultiDevSSAGraphBuilder::ConnectOp(ir::Graph *result, OpHandleBase *op,
const
std
::
string
&
prev_op_name
)
const
{
const
std
::
string
&
prev_op_name
)
const
{
for
(
auto
&
prev_op
:
result
->
Get
<
GraphOps
>
(
"ops"
))
{
for
(
auto
&
prev_op
:
result
->
Get
<
GraphOps
>
(
"ops"
))
{
if
(
prev_op
->
Name
()
==
prev_op_name
)
{
if
(
prev_op
->
Name
()
==
prev_op_name
)
{
auto
*
dep_var
=
new
DummyVarHandle
(
auto
*
dep_var
=
new
DummyVarHandle
(
result
->
CreateControlDepVar
());
result
->
CreateEmptyNode
(
"dummy"
,
ir
::
Node
::
Type
::
kVariable
));
prev_op
->
AddOutput
(
dep_var
);
prev_op
->
AddOutput
(
dep_var
);
result
->
Get
<
GraphDepVars
>
(
"dep_vars"
).
emplace
(
dep_var
);
result
->
Get
<
GraphDepVars
>
(
"dep_vars"
).
emplace
(
dep_var
);
op
->
AddInput
(
dep_var
);
op
->
AddInput
(
dep_var
);
...
...
paddle/fluid/framework/details/multi_devices_graph_builder.h
浏览文件 @
93355cc0
...
@@ -76,10 +76,10 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
...
@@ -76,10 +76,10 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
const
std
::
vector
<
std
::
string
>
&
recv_vars
)
const
;
const
std
::
vector
<
std
::
string
>
&
recv_vars
)
const
;
std
::
vector
<
std
::
string
>
FindDistTrainSendVars
(
std
::
vector
<
std
::
string
>
FindDistTrainSendVars
(
const
std
::
vector
<
std
::
unique_ptr
<
ir
::
Node
>
>
&
nodes
)
const
;
const
std
::
vector
<
ir
::
Node
*
>
&
nodes
)
const
;
std
::
vector
<
std
::
string
>
FindDistTrainRecvVars
(
std
::
vector
<
std
::
string
>
FindDistTrainRecvVars
(
const
std
::
vector
<
std
::
unique_ptr
<
ir
::
Node
>
>
&
nodes
)
const
;
const
std
::
vector
<
ir
::
Node
*
>
&
nodes
)
const
;
void
ConnectOp
(
ir
::
Graph
*
result
,
OpHandleBase
*
op
,
void
ConnectOp
(
ir
::
Graph
*
result
,
OpHandleBase
*
op
,
const
std
::
string
&
prev_op_name
)
const
;
const
std
::
string
&
prev_op_name
)
const
;
...
...
paddle/fluid/framework/details/rpc_op_handle.cc
浏览文件 @
93355cc0
...
@@ -33,7 +33,8 @@ void RPCOpHandle::RunImpl() {
...
@@ -33,7 +33,8 @@ void RPCOpHandle::RunImpl() {
for
(
auto
*
in
:
inputs_
)
{
for
(
auto
*
in
:
inputs_
)
{
auto
&
p
=
static_cast
<
VarHandle
*>
(
in
)
->
place_
;
auto
&
p
=
static_cast
<
VarHandle
*>
(
in
)
->
place_
;
// FIXME(Yancey1989): need a better solution instead of use DebugString()
// FIXME(Yancey1989): need a better solution instead of use DebugString()
if
(
in
->
DebugString
()
==
"dummy"
)
{
// HACK
if
(
in
->
Node
()
->
Name
().
find
(
ir
::
Node
::
kControlDepVarName
)
!=
std
::
string
::
npos
)
{
// HACK
continue
;
continue
;
}
}
if
(
in
->
GeneratedOp
())
{
if
(
in
->
GeneratedOp
())
{
...
...
paddle/fluid/framework/details/ssa_graph_builder.cc
浏览文件 @
93355cc0
...
@@ -17,6 +17,36 @@
...
@@ -17,6 +17,36 @@
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
namespace
details
{
namespace
details
{
void
SSAGraphBuilder
::
PolishGraphToSupportDataHazards
(
ir
::
Graph
*
graph
)
{
for
(
auto
&
var_map
:
graph
->
Get
<
GraphVars
>
(
"vars"
))
{
for
(
auto
&
name_pair
:
var_map
)
{
if
(
name_pair
.
second
.
size
()
<=
1
)
{
continue
;
}
auto
it_new
=
name_pair
.
second
.
rbegin
();
auto
it_old
=
name_pair
.
second
.
rbegin
();
++
it_old
;
for
(;
it_old
!=
name_pair
.
second
.
rend
();
it_new
=
it_old
,
++
it_old
)
{
OpHandleBase
*
write_op
=
(
*
it_new
)
->
GeneratedOp
();
const
auto
&
read_ops
=
(
*
it_old
)
->
PendingOps
();
for
(
auto
*
read_op
:
read_ops
)
{
// Manually add a dependency var from read_op to write_op;
if
(
read_op
==
write_op
)
{
// Read Write is the same op.
continue
;
}
auto
*
dep_var
=
new
DummyVarHandle
(
graph
->
CreateControlDepVar
());
read_op
->
AddOutput
(
dep_var
);
write_op
->
AddInput
(
dep_var
);
graph
->
Get
<
GraphDepVars
>
(
"dep_vars"
).
emplace
(
dep_var
);
}
}
}
}
}
VarHandle
*
SSAGraphBuilder
::
CreateOrGetLatestVarHandle
(
VarHandle
*
SSAGraphBuilder
::
CreateOrGetLatestVarHandle
(
ir
::
Graph
*
graph
,
ir
::
Node
*
node
,
const
platform
::
Place
&
place
,
ir
::
Graph
*
graph
,
ir
::
Node
*
node
,
const
platform
::
Place
&
place
,
size_t
place_offset
)
{
size_t
place_offset
)
{
...
@@ -56,8 +86,7 @@ void SSAGraphBuilder::AddOutputToLeafOps(ir::Graph *graph) {
...
@@ -56,8 +86,7 @@ void SSAGraphBuilder::AddOutputToLeafOps(ir::Graph *graph) {
if
(
!
op
->
Outputs
().
empty
())
{
if
(
!
op
->
Outputs
().
empty
())
{
continue
;
continue
;
}
}
auto
*
dummy_leaf
=
new
DummyVarHandle
(
auto
*
dummy_leaf
=
new
DummyVarHandle
(
graph
->
CreateControlDepVar
());
graph
->
CreateEmptyNode
(
"dummy"
,
ir
::
Node
::
Type
::
kVariable
));
graph
->
Get
<
GraphDepVars
>
(
"dep_vars"
).
emplace
(
dummy_leaf
);
graph
->
Get
<
GraphDepVars
>
(
"dep_vars"
).
emplace
(
dummy_leaf
);
op
->
AddOutput
(
dummy_leaf
);
op
->
AddOutput
(
dummy_leaf
);
}
}
...
...
paddle/fluid/framework/details/ssa_graph_builder.h
浏览文件 @
93355cc0
...
@@ -57,6 +57,15 @@ class SSAGraphBuilder : public ir::Pass {
...
@@ -57,6 +57,15 @@ class SSAGraphBuilder : public ir::Pass {
DISABLE_COPY_AND_ASSIGN
(
SSAGraphBuilder
);
DISABLE_COPY_AND_ASSIGN
(
SSAGraphBuilder
);
protected:
protected:
/**
* We only handle write after read(WAR), since it should not have a write
* after write in program. If there are write after write operators, we need
* prune them.
*
* https://en.wikipedia.org/wiki/Hazard_(computer_architecture)#Write_after_read_(WAR)
*/
static
void
PolishGraphToSupportDataHazards
(
ir
::
Graph
*
graph
);
static
VarHandle
*
CreateOrGetLatestVarHandle
(
ir
::
Graph
*
graph
,
ir
::
Node
*
node
,
static
VarHandle
*
CreateOrGetLatestVarHandle
(
ir
::
Graph
*
graph
,
ir
::
Node
*
node
,
const
platform
::
Place
&
place
,
const
platform
::
Place
&
place
,
size_t
place_offset
);
size_t
place_offset
);
...
...
paddle/fluid/framework/details/var_handle.cc
浏览文件 @
93355cc0
...
@@ -26,7 +26,7 @@ std::string VarHandle::DebugString() const {
...
@@ -26,7 +26,7 @@ std::string VarHandle::DebugString() const {
return
ss
.
str
();
return
ss
.
str
();
}
}
std
::
string
DummyVarHandle
::
DebugString
()
const
{
return
"dummy"
;
}
std
::
string
DummyVarHandle
::
DebugString
()
const
{
return
node_
->
Name
()
;
}
}
// namespace details
}
// namespace details
}
// namespace framework
}
// namespace framework
}
// namespace paddle
}
// namespace paddle
paddle/fluid/framework/ir/graph.cc
浏览文件 @
93355cc0
...
@@ -34,7 +34,8 @@ Graph::Graph(const ProgramDesc &program) : program_(program) {
...
@@ -34,7 +34,8 @@ Graph::Graph(const ProgramDesc &program) : program_(program) {
std
::
map
<
std
::
string
,
std
::
vector
<
ir
::
Node
*>>
var_nodes
;
std
::
map
<
std
::
string
,
std
::
vector
<
ir
::
Node
*>>
var_nodes
;
for
(
auto
*
op
:
program
.
Block
(
0
).
AllOps
())
{
for
(
auto
*
op
:
program
.
Block
(
0
).
AllOps
())
{
ir
::
Node
*
node
=
CreateOpNode
(
op
);
ir
::
Node
*
node
=
CreateOpNode
(
op
);
// For input args, reuse the same var name if it was created before.
// Otherwise, create a new one.
for
(
auto
&
each_var_name
:
op
->
InputArgumentNames
())
{
for
(
auto
&
each_var_name
:
op
->
InputArgumentNames
())
{
ir
::
Node
*
var
=
nullptr
;
ir
::
Node
*
var
=
nullptr
;
if
(
var_nodes
.
find
(
each_var_name
)
!=
var_nodes
.
end
())
{
if
(
var_nodes
.
find
(
each_var_name
)
!=
var_nodes
.
end
())
{
...
@@ -43,16 +44,16 @@ Graph::Graph(const ProgramDesc &program) : program_(program) {
...
@@ -43,16 +44,16 @@ Graph::Graph(const ProgramDesc &program) : program_(program) {
var
=
CreateVarNode
(
all_vars
.
at
(
each_var_name
));
var
=
CreateVarNode
(
all_vars
.
at
(
each_var_name
));
var_nodes
[
each_var_name
].
push_back
(
var
);
var_nodes
[
each_var_name
].
push_back
(
var
);
}
else
{
}
else
{
//
TODO(paddle-dev): Seems some assumption doesn't hold?
//
Operation input var can be optional (dispensable). Which means
VLOG
(
3
)
<<
op
->
Type
()
// the operation doesn't really need the var at runtime. In this
<<
" input var not in all_var list: "
<<
each_var_name
;
// case, the no-existed var is ready at the beginning.
var
=
CreateEmptyNode
(
each_var_name
,
ir
::
Node
::
Type
::
kVariable
);
var
=
CreateEmptyNode
(
each_var_name
,
ir
::
Node
::
Type
::
kVariable
);
var_nodes
[
each_var_name
].
push_back
(
var
);
var_nodes
[
each_var_name
].
push_back
(
var
);
}
}
node
->
inputs
.
push_back
(
var
);
node
->
inputs
.
push_back
(
var
);
var
->
outputs
.
push_back
(
node
);
var
->
outputs
.
push_back
(
node
);
}
}
// For output args, always create a new var.
for
(
auto
&
each_var_name
:
op
->
OutputArgumentNames
())
{
for
(
auto
&
each_var_name
:
op
->
OutputArgumentNames
())
{
ir
::
Node
*
var
=
CreateVarNode
(
all_vars
.
at
(
each_var_name
));
ir
::
Node
*
var
=
CreateVarNode
(
all_vars
.
at
(
each_var_name
));
var_nodes
[
each_var_name
].
push_back
(
var
);
var_nodes
[
each_var_name
].
push_back
(
var
);
...
@@ -67,6 +68,7 @@ Graph::Graph(const ProgramDesc &program) : program_(program) {
...
@@ -67,6 +68,7 @@ Graph::Graph(const ProgramDesc &program) : program_(program) {
*
*
* https://en.wikipedia.org/wiki/Hazard_(computer_architecture)#Write_after_read_(WAR)
* https://en.wikipedia.org/wiki/Hazard_(computer_architecture)#Write_after_read_(WAR)
*/
*/
for
(
auto
&
var
:
var_nodes
)
{
for
(
auto
&
var
:
var_nodes
)
{
auto
&
versions
=
var
.
second
;
auto
&
versions
=
var
.
second
;
if
(
versions
.
size
()
<=
1
)
continue
;
if
(
versions
.
size
()
<=
1
)
continue
;
...
@@ -85,8 +87,18 @@ Graph::Graph(const ProgramDesc &program) : program_(program) {
...
@@ -85,8 +87,18 @@ Graph::Graph(const ProgramDesc &program) : program_(program) {
// Read Write is the same op.
// Read Write is the same op.
continue
;
continue
;
}
}
ir
::
Node
*
dep_var
=
CreateEmptyNode
(
ir
::
Node
::
kControlDepVarName
,
// 2 ops might have been connected via other vars.
ir
::
Node
::
Type
::
kVariable
);
bool
has_dep
=
false
;
for
(
ir
::
Node
*
r_out
:
read_op
->
outputs
)
{
for
(
ir
::
Node
*
w_in
:
write_op
->
inputs
)
{
if
(
r_out
==
w_in
)
{
has_dep
=
true
;
break
;
}
}
}
if
(
has_dep
)
continue
;
ir
::
Node
*
dep_var
=
CreateControlDepVar
();
read_op
->
outputs
.
push_back
(
dep_var
);
read_op
->
outputs
.
push_back
(
dep_var
);
dep_var
->
inputs
.
push_back
(
read_op
);
dep_var
->
inputs
.
push_back
(
read_op
);
write_op
->
inputs
.
push_back
(
dep_var
);
write_op
->
inputs
.
push_back
(
dep_var
);
...
...
paddle/fluid/framework/ir/graph.h
浏览文件 @
93355cc0
...
@@ -27,6 +27,7 @@ limitations under the License. */
...
@@ -27,6 +27,7 @@ limitations under the License. */
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
namespace
ir
{
namespace
ir
{
class
Graph
{
class
Graph
{
public:
public:
explicit
Graph
(
const
ProgramDesc
&
program
);
explicit
Graph
(
const
ProgramDesc
&
program
);
...
@@ -54,28 +55,58 @@ class Graph {
...
@@ -54,28 +55,58 @@ class Graph {
};
};
}
}
const
std
::
unordered_set
<
ir
::
Node
*>
&
Nodes
()
const
{
return
node_set_
;
}
ir
::
Node
*
CreateVarNode
(
VarDesc
*
var_desc
)
{
ir
::
Node
*
CreateVarNode
(
VarDesc
*
var_desc
)
{
nodes
.
emplace_back
(
new
ir
::
Node
(
var_desc
));
return
AddNode
(
new
ir
::
Node
(
var_desc
));
return
nodes
.
back
().
get
();
}
}
ir
::
Node
*
CreateOpNode
(
OpDesc
*
op_desc
)
{
ir
::
Node
*
CreateOpNode
(
OpDesc
*
op_desc
)
{
nodes
.
emplace_back
(
new
ir
::
Node
(
op_desc
));
return
AddNode
(
new
ir
::
Node
(
op_desc
));
return
nodes
.
back
().
get
();
}
ir
::
Node
*
CreateControlDepVar
()
{
// TODO(panyx0718): control var name should be unique.
const
std
::
string
name
=
string
::
Sprintf
(
"%s@%llu"
,
ir
::
Node
::
kControlDepVarName
,
node_set_
.
size
());
return
AddNode
(
new
ir
::
Node
(
name
,
ir
::
Node
::
Type
::
kVariable
));
}
}
ir
::
Node
*
CreateEmptyNode
(
const
std
::
string
&
name
,
ir
::
Node
::
Type
type
)
{
ir
::
Node
*
CreateEmptyNode
(
const
std
::
string
&
name
,
ir
::
Node
::
Type
type
)
{
nodes
.
emplace_back
(
new
ir
::
Node
(
name
,
type
));
return
AddNode
(
new
ir
::
Node
(
name
,
type
));
return
nodes
.
back
().
get
();
}
}
std
::
vector
<
std
::
unique_ptr
<
ir
::
Node
>>
nodes
;
std
::
vector
<
std
::
unique_ptr
<
ir
::
Node
>>
ReleaseNodes
()
{
std
::
vector
<
std
::
unique_ptr
<
ir
::
Node
>>
ret
;
for
(
auto
&
n
:
nodes_
)
{
ret
.
emplace_back
(
n
.
second
.
release
());
}
nodes_
.
clear
();
node_set_
.
clear
();
return
ret
;
}
private:
private:
// This method takes ownership of `node`.
ir
::
Node
*
AddNode
(
ir
::
Node
*
node
)
{
PADDLE_ENFORCE
(
node_set_
.
find
(
node
)
==
node_set_
.
end
());
nodes_
[
node
].
reset
(
node
);
node_set_
.
insert
(
node
);
return
node
;
}
void
RemoveNode
(
ir
::
Node
*
node
)
{
PADDLE_ENFORCE
(
node_set_
.
find
(
node
)
!=
node_set_
.
end
());
node_set_
.
erase
(
node
);
nodes_
.
erase
(
node
);
}
// NOTE: program_ shouldn't be exposed to user.
// NOTE: program_ shouldn't be exposed to user.
const
ProgramDesc
&
program_
;
const
ProgramDesc
&
program_
;
std
::
map
<
std
::
string
,
boost
::
any
>
attrs_
;
std
::
map
<
std
::
string
,
boost
::
any
>
attrs_
;
std
::
map
<
std
::
string
,
std
::
function
<
void
(
void
)
>>
attr_dels_
;
std
::
map
<
std
::
string
,
std
::
function
<
void
(
void
)
>>
attr_dels_
;
std
::
map
<
ir
::
Node
*
,
std
::
unique_ptr
<
ir
::
Node
>>
nodes_
;
std
::
unordered_set
<
ir
::
Node
*>
node_set_
;
};
};
}
// namespace ir
}
// namespace ir
}
// namespace framework
}
// namespace framework
...
...
paddle/fluid/framework/ir/graph_helper.cc
浏览文件 @
93355cc0
...
@@ -33,9 +33,8 @@ void SortHelper(
...
@@ -33,9 +33,8 @@ void SortHelper(
}
}
}
}
LOG
(
ERROR
)
<<
"topology sort insert: "
<<
node
->
Name
()
VLOG
(
3
)
<<
"topology sort insert: "
<<
node
->
Name
()
<<
reinterpret_cast
<
void
*>
(
node
)
<<
" input "
<<
reinterpret_cast
<
void
*>
(
node
)
<<
" input "
<<
node
->
inputs
.
size
();
<<
node
->
inputs
.
size
();
ret
->
push_back
(
node
);
ret
->
push_back
(
node
);
}
}
...
@@ -93,18 +92,18 @@ std::map<ir::Node *, std::unordered_set<ir::Node *>> BuildOperationAdjList(
...
@@ -93,18 +92,18 @@ std::map<ir::Node *, std::unordered_set<ir::Node *>> BuildOperationAdjList(
const
Graph
&
graph
)
{
const
Graph
&
graph
)
{
std
::
map
<
ir
::
Node
*
,
std
::
unordered_set
<
ir
::
Node
*>>
adj_list
;
std
::
map
<
ir
::
Node
*
,
std
::
unordered_set
<
ir
::
Node
*>>
adj_list
;
for
(
auto
&
n
:
graph
.
nodes
)
{
for
(
auto
&
n
:
graph
.
Nodes
()
)
{
if
(
n
->
NodeType
()
!=
ir
::
Node
::
Type
::
kOperation
)
continue
;
if
(
n
->
NodeType
()
!=
ir
::
Node
::
Type
::
kOperation
)
continue
;
if
(
adj_list
.
find
(
n
.
get
()
)
==
adj_list
.
end
())
{
if
(
adj_list
.
find
(
n
)
==
adj_list
.
end
())
{
adj_list
[
n
.
get
()
]
=
std
::
unordered_set
<
ir
::
Node
*>
();
adj_list
[
n
]
=
std
::
unordered_set
<
ir
::
Node
*>
();
}
}
for
(
auto
&
var
:
n
->
inputs
)
{
for
(
auto
&
var
:
n
->
inputs
)
{
for
(
auto
&
adj_n
:
var
->
inputs
)
{
for
(
auto
&
adj_n
:
var
->
inputs
)
{
PADDLE_ENFORCE
(
adj_n
->
NodeType
()
==
ir
::
Node
::
Type
::
kOperation
);
PADDLE_ENFORCE
(
adj_n
->
NodeType
()
==
ir
::
Node
::
Type
::
kOperation
);
adj_list
[
n
.
get
()
].
insert
(
adj_n
);
adj_list
[
n
].
insert
(
adj_n
);
LOG
(
ERROR
)
<<
"adj "
<<
adj_n
->
Name
()
<<
reinterpret_cast
<
void
*>
(
adj_n
)
VLOG
(
3
)
<<
"adj "
<<
adj_n
->
Name
()
<<
reinterpret_cast
<
void
*>
(
adj_n
)
<<
" -> "
<<
n
->
Name
()
<<
reinterpret_cast
<
void
*>
(
n
.
get
()
)
<<
" -> "
<<
n
->
Name
()
<<
reinterpret_cast
<
void
*>
(
n
)
<<
" via "
<<
var
->
Name
()
<<
reinterpret_cast
<
void
*>
(
var
);
<<
" via "
<<
var
->
Name
()
<<
reinterpret_cast
<
void
*>
(
var
);
}
}
}
}
}
}
...
...
paddle/fluid/framework/ir/graph_helper.h
浏览文件 @
93355cc0
...
@@ -30,6 +30,7 @@ std::vector<ir::Node *> TopologySortOperations(const Graph &graph);
...
@@ -30,6 +30,7 @@ std::vector<ir::Node *> TopologySortOperations(const Graph &graph);
std
::
map
<
ir
::
Node
*
,
std
::
unordered_set
<
ir
::
Node
*>>
BuildOperationAdjList
(
std
::
map
<
ir
::
Node
*
,
std
::
unordered_set
<
ir
::
Node
*>>
BuildOperationAdjList
(
const
Graph
&
graph
);
const
Graph
&
graph
);
}
// namespace ir
}
// namespace ir
}
// namespace framework
}
// namespace framework
}
// namespace paddle
}
// namespace paddle
paddle/fluid/framework/ir/graph_test.cc
浏览文件 @
93355cc0
...
@@ -94,20 +94,21 @@ TEST(GraphTest, Basic) {
...
@@ -94,20 +94,21 @@ TEST(GraphTest, Basic) {
prog
.
MutableBlock
(
0
)
->
Var
(
"test_out"
)
->
GetType
());
prog
.
MutableBlock
(
0
)
->
Var
(
"test_out"
)
->
GetType
());
std
::
unique_ptr
<
ir
::
Graph
>
g
(
new
ir
::
Graph
(
prog
));
std
::
unique_ptr
<
ir
::
Graph
>
g
(
new
ir
::
Graph
(
prog
));
ASSERT_EQ
(
g
->
nodes
[
0
]
->
Name
(),
"sum"
);
std
::
vector
<
ir
::
Node
*>
nodes
(
g
->
Nodes
().
begin
(),
g
->
Nodes
().
end
());
ASSERT_EQ
(
g
->
nodes
[
0
]
->
inputs
[
0
]
->
Name
(),
"test_a"
);
ASSERT_EQ
(
nodes
[
0
]
->
Name
(),
"sum"
);
ASSERT_EQ
(
g
->
nodes
[
0
]
->
inputs
[
1
]
->
Name
(),
"test_b"
);
ASSERT_EQ
(
nodes
[
0
]
->
inputs
[
0
]
->
Name
(),
"test_a"
);
ASSERT_EQ
(
g
->
nodes
[
0
]
->
inputs
[
2
]
->
Name
(),
"test_c"
);
ASSERT_EQ
(
nodes
[
0
]
->
inputs
[
1
]
->
Name
(),
"test_b"
);
ASSERT_EQ
(
g
->
nodes
[
0
]
->
outputs
[
0
]
->
Name
(),
"test_out"
);
ASSERT_EQ
(
nodes
[
0
]
->
inputs
[
2
]
->
Name
(),
"test_c"
);
ASSERT_EQ
(
g
->
nodes
[
1
]
->
Name
(),
"test_a"
);
ASSERT_EQ
(
nodes
[
0
]
->
outputs
[
0
]
->
Name
(),
"test_out"
);
ASSERT_EQ
(
g
->
nodes
[
1
]
->
outputs
[
0
]
->
Name
(),
"sum"
);
ASSERT_EQ
(
nodes
[
1
]
->
Name
(),
"test_a"
);
ASSERT_EQ
(
g
->
nodes
[
2
]
->
Name
(),
"test_b"
);
ASSERT_EQ
(
nodes
[
1
]
->
outputs
[
0
]
->
Name
(),
"sum"
);
ASSERT_EQ
(
g
->
nodes
[
2
]
->
outputs
[
0
]
->
Name
(),
"sum"
);
ASSERT_EQ
(
nodes
[
2
]
->
Name
(),
"test_b"
);
ASSERT_EQ
(
g
->
nodes
[
3
]
->
Name
(),
"test_c"
);
ASSERT_EQ
(
nodes
[
2
]
->
outputs
[
0
]
->
Name
(),
"sum"
);
ASSERT_EQ
(
g
->
nodes
[
3
]
->
outputs
[
0
]
->
Name
(),
"sum"
);
ASSERT_EQ
(
nodes
[
3
]
->
Name
(),
"test_c"
);
ASSERT_EQ
(
g
->
nodes
[
4
]
->
Name
(),
"test_out"
);
ASSERT_EQ
(
nodes
[
3
]
->
outputs
[
0
]
->
Name
(),
"sum"
);
ASSERT_EQ
(
g
->
nodes
[
4
]
->
inputs
[
0
]
->
Name
(),
"sum"
);
ASSERT_EQ
(
nodes
[
4
]
->
Name
(),
"test_out"
);
ASSERT_EQ
(
g
->
nodes
.
size
(),
5
);
ASSERT_EQ
(
nodes
[
4
]
->
inputs
[
0
]
->
Name
(),
"sum"
);
ASSERT_EQ
(
nodes
.
size
(),
5
);
}
}
}
// namespace framework
}
// namespace framework
}
// namespace paddle
}
// namespace paddle
paddle/fluid/operators/CMakeLists.txt
浏览文件 @
93355cc0
...
@@ -192,9 +192,9 @@ if(WITH_DISTRIBUTE)
...
@@ -192,9 +192,9 @@ if(WITH_DISTRIBUTE)
set
(
DISTRIBUTE_DEPS
""
)
set
(
DISTRIBUTE_DEPS
""
)
if
(
WITH_GRPC
)
if
(
WITH_GRPC
)
set
(
DISTRIBUTE_DEPS sendrecvop_grpc grpc++_unsecure grpc_unsecure gpr cares zlib protobuf
)
set
(
DISTRIBUTE_DEPS sendrecvop_grpc grpc++_unsecure grpc_unsecure gpr cares zlib protobuf
node
)
else
()
else
()
set
(
DISTRIBUTE_DEPS sendrecvop_brpc brpc leveldb snappystream snappy protobuf ssl crypto zlib
)
set
(
DISTRIBUTE_DEPS sendrecvop_brpc brpc leveldb snappystream snappy protobuf ssl crypto zlib
node
)
if
(
WITH_BRPC_RDMA
)
if
(
WITH_BRPC_RDMA
)
find_library
(
IBVERBS_LIBRARY NAMES ibverbs
)
find_library
(
IBVERBS_LIBRARY NAMES ibverbs
)
ADD_LIBRARY
(
ibverbs SHARED IMPORTED GLOBAL
)
ADD_LIBRARY
(
ibverbs SHARED IMPORTED GLOBAL
)
...
...
paddle/fluid/operators/send_recv_util.h
浏览文件 @
93355cc0
...
@@ -14,6 +14,7 @@ limitations under the License. */
...
@@ -14,6 +14,7 @@ limitations under the License. */
#pragma once
#pragma once
#include <string>
#include <string>
#include "paddle/fluid/framework/ir/node.h"
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
...
@@ -22,7 +23,9 @@ inline bool NeedSend(const framework::Scope& scope,
...
@@ -22,7 +23,9 @@ inline bool NeedSend(const framework::Scope& scope,
const
std
::
string
&
varname
)
{
const
std
::
string
&
varname
)
{
// dummy variable is only used in parallel executor to represent
// dummy variable is only used in parallel executor to represent
// some dependency relationship, we don't need to send/recv it.
// some dependency relationship, we don't need to send/recv it.
if
(
varname
==
"dummy"
)
return
false
;
if
(
varname
.
find
(
framework
::
ir
::
Node
::
kControlDepVarName
)
!=
std
::
string
::
npos
)
return
false
;
auto
*
var
=
scope
.
FindVar
(
varname
);
auto
*
var
=
scope
.
FindVar
(
varname
);
PADDLE_ENFORCE_NOT_NULL
(
var
,
"Can not find variable '%s' in the send side."
,
PADDLE_ENFORCE_NOT_NULL
(
var
,
"Can not find variable '%s' in the send side."
,
varname
);
varname
);
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录