Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
9b960330
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
9b960330
编写于
7月 11, 2018
作者:
X
Xin Pan
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
graph attrs
上级
2eeaa8d5
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
111 addition
and
86 deletion
+111
-86
paddle/fluid/framework/details/multi_devices_graph_builder.cc
...le/fluid/framework/details/multi_devices_graph_builder.cc
+43
-72
paddle/fluid/framework/details/ssa_graph_builder.cc
paddle/fluid/framework/details/ssa_graph_builder.cc
+6
-11
paddle/fluid/framework/ir/graph.h
paddle/fluid/framework/ir/graph.h
+62
-3
未找到文件。
paddle/fluid/framework/details/multi_devices_graph_builder.cc
浏览文件 @
9b960330
...
@@ -70,8 +70,7 @@ MultiDevSSAGraphBuilder::MultiDevSSAGraphBuilder(
...
@@ -70,8 +70,7 @@ MultiDevSSAGraphBuilder::MultiDevSSAGraphBuilder(
void
MultiDevSSAGraphBuilder
::
CreateOpHandleIOs
(
Graph
*
result
,
const
OpDesc
&
op
,
void
MultiDevSSAGraphBuilder
::
CreateOpHandleIOs
(
Graph
*
result
,
const
OpDesc
&
op
,
size_t
place_id
)
const
{
size_t
place_id
)
const
{
auto
p
=
places_
[
place_id
];
auto
p
=
places_
[
place_id
];
auto
*
op_handle
=
auto
*
op_handle
=
result
->
Get
<
GraphOps
>
(
"ops"
).
back
().
get
();
boost
::
any_cast
<
GraphOps
*>
(
result
->
attrs
[
"ops"
])
->
back
().
get
();
op_handle
->
SetDeviceContext
(
p
,
op_handle
->
SetDeviceContext
(
p
,
platform
::
DeviceContextPool
::
Instance
().
Get
(
p
));
platform
::
DeviceContextPool
::
Instance
().
Get
(
p
));
...
@@ -179,13 +178,9 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
...
@@ -179,13 +178,9 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
std
::
unordered_set
<
std
::
string
>
og_has_been_broadcast
;
std
::
unordered_set
<
std
::
string
>
og_has_been_broadcast
;
// We cannot invoke resize. It is a bug of GCC 4.8
// We cannot invoke resize. It is a bug of GCC 4.8
result
.
attrs
[
"vars"
]
=
new
std
::
vector
<
result
.
Set
(
"vars"
,
new
GraphVars
(
places_
.
size
()));
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
std
::
unique_ptr
<
VarHandle
>>>>
(
result
.
Set
(
"dep_vars"
,
new
GraphDepVars
);
places_
.
size
());
result
.
Set
(
"ops"
,
new
GraphOps
);
result
.
attrs
[
"dep_vars"
]
=
new
std
::
unordered_set
<
std
::
unique_ptr
<
VarHandleBase
>>
();
result
.
attrs
[
"ops"
]
=
new
std
::
vector
<
std
::
unique_ptr
<
OpHandleBase
>>
();
// find send/recv vars so that we can place the distributed training
// find send/recv vars so that we can place the distributed training
// realted op in the place 0
// realted op in the place 0
auto
send_vars
=
FindDistTrainSendVars
(
program
);
auto
send_vars
=
FindDistTrainSendVars
(
program
);
...
@@ -308,13 +303,9 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
...
@@ -308,13 +303,9 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
AddOutputToLeafOps
(
&
result
);
AddOutputToLeafOps
(
&
result
);
std
::
unique_ptr
<
SSAGraph
>
ssa_graph
(
new
SSAGraph
);
std
::
unique_ptr
<
SSAGraph
>
ssa_graph
(
new
SSAGraph
);
ssa_graph
->
vars_
=
ssa_graph
->
vars_
=
std
::
move
(
*
graph
->
Erase
<
GraphVars
>
(
"vars"
));
std
::
move
(
*
boost
::
any_cast
<
GraphVars
*>
(
graph
->
attrs
[
"vars"
]));
ssa_graph
->
ops_
=
std
::
move
(
*
graph
->
Erase
<
GraphOps
>
(
"ops"
));
ssa_graph
->
ops_
=
ssa_graph
->
dep_vars_
=
std
::
move
(
*
graph
->
Erase
<
GraphDepVars
>
(
"dep_vars"
));
std
::
move
(
*
boost
::
any_cast
<
GraphOps
*>
(
graph
->
attrs
[
"ops"
]));
ssa_graph
->
dep_vars_
=
std
::
move
(
*
boost
::
any_cast
<
GraphDepVars
*>
(
graph
->
attrs
[
"dep_vars"
]));
return
std
::
move
(
ssa_graph
);
return
std
::
move
(
ssa_graph
);
}
}
...
@@ -347,20 +338,15 @@ void MultiDevSSAGraphBuilder::CreateBroadcastOp(Graph *result,
...
@@ -347,20 +338,15 @@ void MultiDevSSAGraphBuilder::CreateBroadcastOp(Graph *result,
#else
#else
auto
*
op_handle
=
new
BroadcastOpHandle
(
local_scopes_
,
places_
);
auto
*
op_handle
=
new
BroadcastOpHandle
(
local_scopes_
,
places_
);
#endif
#endif
result
->
Get
<
GraphOps
>
(
"ops"
).
emplace_back
(
op_handle
);
boost
::
any_cast
<
GraphOps
*>
(
result
->
attrs
[
"ops"
])
->
emplace_back
(
op_handle
);
auto
*
in
=
auto
*
in
=
boost
::
any_cast
<
GraphVars
*>
(
result
->
attrs
[
"vars"
])
result
->
Get
<
GraphVars
>
(
"vars"
).
at
(
src_dev_id
).
at
(
p_name
).
back
().
get
();
->
at
(
src_dev_id
)
.
at
(
p_name
)
.
back
()
.
get
();
op_handle
->
AddInput
(
in
);
op_handle
->
AddInput
(
in
);
for
(
size_t
i
=
0
;
i
<
places_
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
places_
.
size
();
++
i
)
{
auto
&
p
=
places_
[
i
];
auto
&
p
=
places_
[
i
];
SetCommunicationContext
(
op_handle
,
p
);
SetCommunicationContext
(
op_handle
,
p
);
auto
&
vars
=
auto
&
vars
=
result
->
Get
<
GraphVars
>
(
"vars"
).
at
(
i
).
at
(
p_name
);
boost
::
any_cast
<
GraphVars
*>
(
result
->
attrs
[
"vars"
])
->
at
(
i
).
at
(
p_name
);
auto
*
out_var
=
new
VarHandle
(
vars
.
size
(),
i
,
p_name
,
p
);
auto
*
out_var
=
new
VarHandle
(
vars
.
size
(),
i
,
p_name
,
p
);
vars
.
emplace_back
(
out_var
);
vars
.
emplace_back
(
out_var
);
op_handle
->
AddOutput
(
out_var
);
op_handle
->
AddOutput
(
out_var
);
...
@@ -370,28 +356,26 @@ void MultiDevSSAGraphBuilder::CreateBroadcastOp(Graph *result,
...
@@ -370,28 +356,26 @@ void MultiDevSSAGraphBuilder::CreateBroadcastOp(Graph *result,
void
MultiDevSSAGraphBuilder
::
CreateComputationalOp
(
Graph
*
result
,
void
MultiDevSSAGraphBuilder
::
CreateComputationalOp
(
Graph
*
result
,
const
OpDesc
&
op
,
const
OpDesc
&
op
,
int
dev_id
)
const
{
int
dev_id
)
const
{
boost
::
any_cast
<
GraphOps
*>
(
result
->
attrs
[
"ops"
])
result
->
Get
<
GraphOps
>
(
"ops"
).
emplace_back
(
->
emplace_back
(
new
ComputationOpHandle
(
op
,
local_scopes_
[
dev_id
],
places_
[
dev_id
]));
new
ComputationOpHandle
(
op
,
local_scopes_
[
dev_id
],
places_
[
dev_id
]));
CreateOpHandleIOs
(
result
,
op
,
dev_id
);
CreateOpHandleIOs
(
result
,
op
,
dev_id
);
}
}
void
MultiDevSSAGraphBuilder
::
InsertAllReduceOp
(
Graph
*
result
,
void
MultiDevSSAGraphBuilder
::
InsertAllReduceOp
(
Graph
*
result
,
const
std
::
string
&
og
)
const
{
const
std
::
string
&
og
)
const
{
#ifdef PADDLE_WITH_CUDA
#ifdef PADDLE_WITH_CUDA
boost
::
any_cast
<
GraphOps
*>
(
result
->
attrs
[
"ops"
])
result
->
Get
<
GraphOps
>
(
"ops"
).
emplace_back
(
->
emplace_back
(
new
AllReduceOpHandle
(
local_scopes_
,
places_
,
nccl_ctxs_
));
new
AllReduceOpHandle
(
local_scopes_
,
places_
,
nccl_ctxs_
));
#else
#else
boost
::
any_cast
<
GraphOps
*>
(
result
->
attrs
[
"ops"
])
result
->
Get
<
GraphOps
>
(
"ops"
).
emplace_back
(
->
emplace_back
(
new
AllReduceOpHandle
(
local_scopes_
,
places_
));
new
AllReduceOpHandle
(
local_scopes_
,
places_
));
#endif
#endif
auto
*
op_handle
=
auto
*
op_handle
=
result
->
Get
<
GraphOps
>
(
"ops"
).
back
().
get
();
boost
::
any_cast
<
GraphOps
*>
(
result
->
attrs
[
"ops"
])
->
back
().
get
();
for
(
size_t
i
=
0
;
i
<
places_
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
places_
.
size
();
++
i
)
{
auto
&
p
=
places_
[
i
];
auto
&
p
=
places_
[
i
];
SetCommunicationContext
(
op_handle
,
p
);
SetCommunicationContext
(
op_handle
,
p
);
auto
&
vars
=
(
*
boost
::
any_cast
<
GraphVars
*>
(
result
->
attrs
[
"vars"
])
)[
i
][
og
];
auto
&
vars
=
result
->
Get
<
GraphVars
>
(
"vars"
)[
i
][
og
];
PADDLE_ENFORCE
(
!
vars
.
empty
());
PADDLE_ENFORCE
(
!
vars
.
empty
());
auto
&
prev_grad
=
vars
.
back
();
auto
&
prev_grad
=
vars
.
back
();
op_handle
->
AddInput
(
prev_grad
.
get
());
op_handle
->
AddInput
(
prev_grad
.
get
());
...
@@ -405,21 +389,18 @@ void MultiDevSSAGraphBuilder::InsertAllReduceOp(Graph *result,
...
@@ -405,21 +389,18 @@ void MultiDevSSAGraphBuilder::InsertAllReduceOp(Graph *result,
void
MultiDevSSAGraphBuilder
::
InsertDataBalanceOp
(
void
MultiDevSSAGraphBuilder
::
InsertDataBalanceOp
(
Graph
*
result
,
const
std
::
vector
<
std
::
string
>
&
datas
)
const
{
Graph
*
result
,
const
std
::
vector
<
std
::
string
>
&
datas
)
const
{
#ifdef PADDLE_WITH_CUDA
#ifdef PADDLE_WITH_CUDA
boost
::
any_cast
<
GraphOps
*>
(
result
->
attrs
[
"ops"
])
result
->
Get
<
GraphOps
>
(
"ops"
).
emplace_back
(
->
emplace_back
(
new
DataBalanceOpHandle
(
local_scopes_
,
places_
,
nccl_ctxs_
));
new
DataBalanceOpHandle
(
local_scopes_
,
places_
,
nccl_ctxs_
));
#else
#else
boost
::
any_cast
<
GraphOps
*>
(
result
->
attrs
[
"ops"
])
result
->
Get
<
GraphOps
>
(
"ops"
).
emplace_back
(
->
emplace_back
(
new
DataBalanceOpHandle
(
local_scopes_
,
places_
));
new
DataBalanceOpHandle
(
local_scopes_
,
places_
));
#endif
#endif
auto
*
op_handle
=
auto
*
op_handle
=
result
->
Get
<
GraphOps
>
(
"ops"
).
back
().
get
();
boost
::
any_cast
<
GraphOps
*>
(
result
->
attrs
[
"ops"
])
->
back
().
get
();
for
(
size_t
i
=
0
;
i
<
places_
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
places_
.
size
();
++
i
)
{
auto
&
p
=
places_
[
i
];
auto
&
p
=
places_
[
i
];
SetCommunicationContext
(
op_handle
,
p
);
SetCommunicationContext
(
op_handle
,
p
);
for
(
const
std
::
string
&
d_name
:
datas
)
{
for
(
const
std
::
string
&
d_name
:
datas
)
{
auto
&
vars
=
auto
&
vars
=
result
->
Get
<
GraphVars
>
(
"vars"
)[
i
][
d_name
];
(
*
boost
::
any_cast
<
GraphVars
*>
(
result
->
attrs
[
"vars"
]))[
i
][
d_name
];
PADDLE_ENFORCE
(
!
vars
.
empty
());
PADDLE_ENFORCE
(
!
vars
.
empty
());
op_handle
->
AddInput
(
vars
.
back
().
get
());
op_handle
->
AddInput
(
vars
.
back
().
get
());
auto
var
=
new
VarHandle
(
vars
.
size
(),
i
,
d_name
,
p
);
auto
var
=
new
VarHandle
(
vars
.
size
(),
i
,
d_name
,
p
);
...
@@ -480,7 +461,7 @@ void MultiDevSSAGraphBuilder::CreateScaleLossGradOp(Graph *result) const {
...
@@ -480,7 +461,7 @@ void MultiDevSSAGraphBuilder::CreateScaleLossGradOp(Graph *result) const {
auto
*
op_handle
=
auto
*
op_handle
=
new
ScaleLossGradOpHandle
(
local_scopes_
.
size
(),
local_scopes_
[
i
],
new
ScaleLossGradOpHandle
(
local_scopes_
.
size
(),
local_scopes_
[
i
],
places_
[
i
],
communication_dev_ctx
);
places_
[
i
],
communication_dev_ctx
);
boost
::
any_cast
<
GraphOps
*>
(
result
->
attrs
[
"ops"
])
->
emplace_back
(
op_handle
);
result
->
Get
<
GraphOps
>
(
"ops"
).
emplace_back
(
op_handle
);
// FIXME: Currently ScaleLossGradOp only use device_count as scale
// FIXME: Currently ScaleLossGradOp only use device_count as scale
// factor. So it does not depend on any other operators.
// factor. So it does not depend on any other operators.
...
@@ -499,8 +480,8 @@ void MultiDevSSAGraphBuilder::CreateComputationalOps(Graph *result,
...
@@ -499,8 +480,8 @@ void MultiDevSSAGraphBuilder::CreateComputationalOps(Graph *result,
for
(
size_t
scope_idx
=
0
;
scope_idx
<
num_places
;
++
scope_idx
)
{
for
(
size_t
scope_idx
=
0
;
scope_idx
<
num_places
;
++
scope_idx
)
{
auto
p
=
places_
[
scope_idx
];
auto
p
=
places_
[
scope_idx
];
auto
s
=
local_scopes_
[
scope_idx
];
auto
s
=
local_scopes_
[
scope_idx
];
boost
::
any_cast
<
GraphOps
*>
(
result
->
attrs
[
"ops"
])
result
->
Get
<
GraphOps
>
(
"ops"
).
emplace_back
(
->
emplace_back
(
new
ComputationOpHandle
(
op
,
s
,
p
));
new
ComputationOpHandle
(
op
,
s
,
p
));
CreateOpHandleIOs
(
result
,
op
,
scope_idx
);
CreateOpHandleIOs
(
result
,
op
,
scope_idx
);
}
}
}
}
...
@@ -509,25 +490,23 @@ VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(Graph *result,
...
@@ -509,25 +490,23 @@ VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(Graph *result,
const
std
::
string
&
og
,
const
std
::
string
&
og
,
int
dst_dev_id
)
const
{
int
dst_dev_id
)
const
{
#ifdef PADDLE_WITH_CUDA
#ifdef PADDLE_WITH_CUDA
boost
::
any_cast
<
GraphOps
*>
(
result
->
attrs
[
"ops"
])
result
->
Get
<
GraphOps
>
(
"ops"
).
emplace_back
(
->
emplace_back
(
new
ReduceOpHandle
(
local_scopes_
,
places_
,
nccl_ctxs_
));
new
ReduceOpHandle
(
local_scopes_
,
places_
,
nccl_ctxs_
));
#else
#else
boost
::
any_cast
<
GraphOps
*>
(
result
->
attrs
[
"ops"
])
result
->
Get
<
GraphOps
>
(
"ops"
).
emplace_back
(
->
emplace_back
(
new
ReduceOpHandle
(
local_scopes_
,
places_
));
new
ReduceOpHandle
(
local_scopes_
,
places_
));
#endif
#endif
auto
*
op_handle
=
auto
*
op_handle
=
result
->
Get
<
GraphOps
>
(
"ops"
).
back
().
get
();
boost
::
any_cast
<
GraphOps
*>
(
result
->
attrs
[
"ops"
])
->
back
().
get
();
for
(
size_t
i
=
0
;
i
<
places_
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
places_
.
size
();
++
i
)
{
auto
&
p
=
places_
[
i
];
auto
&
p
=
places_
[
i
];
SetCommunicationContext
(
op_handle
,
p
);
SetCommunicationContext
(
op_handle
,
p
);
auto
&
vars
=
(
*
boost
::
any_cast
<
GraphVars
*>
(
result
->
attrs
[
"vars"
])
)[
i
][
og
];
auto
&
vars
=
result
->
Get
<
GraphVars
>
(
"vars"
)[
i
][
og
];
PADDLE_ENFORCE
(
!
vars
.
empty
());
PADDLE_ENFORCE
(
!
vars
.
empty
());
auto
&
prev_grad
=
vars
.
back
();
auto
&
prev_grad
=
vars
.
back
();
op_handle
->
AddInput
(
prev_grad
.
get
());
op_handle
->
AddInput
(
prev_grad
.
get
());
}
}
auto
&
vars
=
auto
&
vars
=
result
->
Get
<
GraphVars
>
(
"vars"
)[
dst_dev_id
][
og
];
(
*
boost
::
any_cast
<
GraphVars
*>
(
result
->
attrs
[
"vars"
]))[
dst_dev_id
][
og
];
auto
var
=
new
VarHandle
(
vars
.
size
(),
dst_dev_id
,
og
,
places_
[
dst_dev_id
]);
auto
var
=
new
VarHandle
(
vars
.
size
(),
dst_dev_id
,
og
,
places_
[
dst_dev_id
]);
vars
.
emplace_back
(
var
);
vars
.
emplace_back
(
var
);
op_handle
->
AddOutput
(
var
);
op_handle
->
AddOutput
(
var
);
...
@@ -538,12 +517,11 @@ VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(Graph *result,
...
@@ -538,12 +517,11 @@ VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(Graph *result,
// on it.
// on it.
void
MultiDevSSAGraphBuilder
::
ConnectOp
(
Graph
*
result
,
OpHandleBase
*
op
,
void
MultiDevSSAGraphBuilder
::
ConnectOp
(
Graph
*
result
,
OpHandleBase
*
op
,
const
std
::
string
&
prev_op_name
)
const
{
const
std
::
string
&
prev_op_name
)
const
{
for
(
auto
&
prev_op
:
(
*
boost
::
any_cast
<
GraphOps
*>
(
result
->
attrs
[
"ops"
])
))
{
for
(
auto
&
prev_op
:
result
->
Get
<
GraphOps
>
(
"ops"
))
{
if
(
prev_op
->
Name
()
==
prev_op_name
)
{
if
(
prev_op
->
Name
()
==
prev_op_name
)
{
auto
*
dep_var
=
new
DummyVarHandle
();
auto
*
dep_var
=
new
DummyVarHandle
();
prev_op
->
AddOutput
(
dep_var
);
prev_op
->
AddOutput
(
dep_var
);
boost
::
any_cast
<
GraphDepVars
*>
(
result
->
attrs
[
"dep_vars"
])
result
->
Get
<
GraphDepVars
>
(
"dep_vars"
).
emplace
(
dep_var
);
->
emplace
(
dep_var
);
op
->
AddInput
(
dep_var
);
op
->
AddInput
(
dep_var
);
}
}
}
}
...
@@ -579,8 +557,7 @@ void MultiDevSSAGraphBuilder::CreateDistTrainOp(Graph *result,
...
@@ -579,8 +557,7 @@ void MultiDevSSAGraphBuilder::CreateDistTrainOp(Graph *result,
CreateComputationalOp
(
result
,
op
,
op_dev_id
);
CreateComputationalOp
(
result
,
op
,
op_dev_id
);
if
(
op
.
Type
()
==
"concat"
)
{
if
(
op
.
Type
()
==
"concat"
)
{
ConnectOp
(
result
,
ConnectOp
(
result
,
result
->
Get
<
GraphOps
>
(
"ops"
).
back
().
get
(),
boost
::
any_cast
<
GraphOps
*>
(
result
->
attrs
[
"ops"
])
->
back
().
get
(),
"fetch_barrier"
);
"fetch_barrier"
);
}
}
}
}
...
@@ -615,22 +592,16 @@ void MultiDevSSAGraphBuilder::CreateRPCOp(Graph *result,
...
@@ -615,22 +592,16 @@ void MultiDevSSAGraphBuilder::CreateRPCOp(Graph *result,
PADDLE_ENFORCE
(
op_dev_id
!=
-
1
,
"can not find the right place for rpc op: %s"
,
PADDLE_ENFORCE
(
op_dev_id
!=
-
1
,
"can not find the right place for rpc op: %s"
,
op
.
Type
());
op
.
Type
());
boost
::
any_cast
<
GraphOps
*>
(
result
->
attrs
[
"ops"
])
result
->
Get
<
GraphOps
>
(
"ops"
).
emplace_back
(
new
RPCOpHandle
(
->
emplace_back
(
new
RPCOpHandle
(
op
,
local_scopes_
[
op_dev_id
],
op
.
Type
(),
op
,
local_scopes_
[
op_dev_id
],
op
.
Type
(),
places_
[
op_dev_id
]));
places_
[
op_dev_id
]));
if
(
op
.
Type
()
==
"send_barrier"
)
{
if
(
op
.
Type
()
==
"send_barrier"
)
{
ConnectOp
(
result
,
ConnectOp
(
result
,
result
->
Get
<
GraphOps
>
(
"ops"
).
back
().
get
(),
"send"
);
boost
::
any_cast
<
GraphOps
*>
(
result
->
attrs
[
"ops"
])
->
back
().
get
(),
"send"
);
}
else
if
(
op
.
Type
()
==
"recv"
)
{
}
else
if
(
op
.
Type
()
==
"recv"
)
{
ConnectOp
(
result
,
ConnectOp
(
result
,
result
->
Get
<
GraphOps
>
(
"ops"
).
back
().
get
(),
boost
::
any_cast
<
GraphOps
*>
(
result
->
attrs
[
"ops"
])
->
back
().
get
(),
"send_barrier"
);
"send_barrier"
);
}
else
if
(
op
.
Type
()
==
"fetch_barrier"
)
{
}
else
if
(
op
.
Type
()
==
"fetch_barrier"
)
{
ConnectOp
(
result
,
ConnectOp
(
result
,
result
->
Get
<
GraphOps
>
(
"ops"
).
back
().
get
(),
"recv"
);
boost
::
any_cast
<
GraphOps
*>
(
result
->
attrs
[
"ops"
])
->
back
().
get
(),
"recv"
);
}
else
if
(
op
.
Type
()
==
"send"
)
{
}
else
if
(
op
.
Type
()
==
"send"
)
{
// do nothing
// do nothing
}
else
{
}
else
{
...
...
paddle/fluid/framework/details/ssa_graph_builder.cc
浏览文件 @
9b960330
...
@@ -18,7 +18,7 @@ namespace paddle {
...
@@ -18,7 +18,7 @@ namespace paddle {
namespace
framework
{
namespace
framework
{
namespace
details
{
namespace
details
{
void
SSAGraphBuilder
::
PolishGraphToSupportDataHazards
(
Graph
*
graph
)
{
void
SSAGraphBuilder
::
PolishGraphToSupportDataHazards
(
Graph
*
graph
)
{
for
(
auto
&
var_map
:
*
boost
::
any_cast
<
GraphVars
*>
(
graph
->
attrs
[
"vars"
]
))
{
for
(
auto
&
var_map
:
graph
->
Get
<
GraphVars
>
(
"vars"
))
{
for
(
auto
&
name_pair
:
var_map
)
{
for
(
auto
&
name_pair
:
var_map
)
{
if
(
name_pair
.
second
.
size
()
<=
1
)
{
if
(
name_pair
.
second
.
size
()
<=
1
)
{
continue
;
continue
;
...
@@ -40,8 +40,7 @@ void SSAGraphBuilder::PolishGraphToSupportDataHazards(Graph *graph) {
...
@@ -40,8 +40,7 @@ void SSAGraphBuilder::PolishGraphToSupportDataHazards(Graph *graph) {
auto
*
dep_var
=
new
DummyVarHandle
();
auto
*
dep_var
=
new
DummyVarHandle
();
read_op
->
AddOutput
(
dep_var
);
read_op
->
AddOutput
(
dep_var
);
write_op
->
AddInput
(
dep_var
);
write_op
->
AddInput
(
dep_var
);
boost
::
any_cast
<
GraphDepVars
*>
(
graph
->
attrs
[
"dep_vars"
])
graph
->
Get
<
GraphDepVars
>
(
"dep_vars"
).
emplace
(
dep_var
);
->
emplace
(
dep_var
);
}
}
}
}
}
}
...
@@ -51,8 +50,7 @@ void SSAGraphBuilder::PolishGraphToSupportDataHazards(Graph *graph) {
...
@@ -51,8 +50,7 @@ void SSAGraphBuilder::PolishGraphToSupportDataHazards(Graph *graph) {
VarHandle
*
SSAGraphBuilder
::
CreateOrGetLatestVarHandle
(
VarHandle
*
SSAGraphBuilder
::
CreateOrGetLatestVarHandle
(
Graph
*
graph
,
const
std
::
string
&
each_var_name
,
Graph
*
graph
,
const
std
::
string
&
each_var_name
,
const
platform
::
Place
&
place
,
size_t
place_offset
)
{
const
platform
::
Place
&
place
,
size_t
place_offset
)
{
auto
&
var_holders
=
auto
&
var_holders
=
graph
->
Get
<
GraphVars
>
(
"vars"
)[
place_offset
];
(
*
boost
::
any_cast
<
GraphVars
*>
(
graph
->
attrs
[
"vars"
]))[
place_offset
];
auto
&
var_holder
=
var_holders
[
each_var_name
];
auto
&
var_holder
=
var_holders
[
each_var_name
];
VarHandle
*
var
=
nullptr
;
VarHandle
*
var
=
nullptr
;
if
(
var_holder
.
empty
())
{
if
(
var_holder
.
empty
())
{
...
@@ -68,9 +66,7 @@ void SSAGraphBuilder::CreateOpOutput(Graph *graph, OpHandleBase *op_handle,
...
@@ -68,9 +66,7 @@ void SSAGraphBuilder::CreateOpOutput(Graph *graph, OpHandleBase *op_handle,
const
std
::
string
&
each_var_name
,
const
std
::
string
&
each_var_name
,
const
platform
::
Place
&
place
,
const
platform
::
Place
&
place
,
size_t
place_offset
)
{
size_t
place_offset
)
{
auto
&
vars
=
auto
&
vars
=
graph
->
Get
<
GraphVars
>
(
"vars"
)[
place_offset
][
each_var_name
];
(
*
boost
::
any_cast
<
GraphVars
*>
(
graph
->
attrs
[
"vars"
]))[
place_offset
]
[
each_var_name
];
size_t
version
=
vars
.
size
();
size_t
version
=
vars
.
size
();
auto
var
=
new
VarHandle
(
version
,
place_offset
,
each_var_name
,
place
);
auto
var
=
new
VarHandle
(
version
,
place_offset
,
each_var_name
,
place
);
vars
.
emplace_back
(
var
);
vars
.
emplace_back
(
var
);
...
@@ -78,15 +74,14 @@ void SSAGraphBuilder::CreateOpOutput(Graph *graph, OpHandleBase *op_handle,
...
@@ -78,15 +74,14 @@ void SSAGraphBuilder::CreateOpOutput(Graph *graph, OpHandleBase *op_handle,
}
}
void
SSAGraphBuilder
::
AddOutputToLeafOps
(
Graph
*
graph
)
{
void
SSAGraphBuilder
::
AddOutputToLeafOps
(
Graph
*
graph
)
{
GraphOps
&
all_ops
=
*
boost
::
any_cast
<
GraphOps
*>
(
graph
->
attrs
[
"ops"
]
);
GraphOps
&
all_ops
=
graph
->
Get
<
GraphOps
>
(
"ops"
);
for
(
auto
&
op
:
all_ops
)
{
for
(
auto
&
op
:
all_ops
)
{
if
(
!
op
->
Outputs
().
empty
())
{
if
(
!
op
->
Outputs
().
empty
())
{
continue
;
continue
;
}
}
auto
*
dummy_leaf
=
new
DummyVarHandle
();
auto
*
dummy_leaf
=
new
DummyVarHandle
();
boost
::
any_cast
<
GraphDepVars
*>
(
graph
->
attrs
[
"dep_vars"
])
graph
->
Get
<
GraphDepVars
>
(
"dep_vars"
).
emplace
(
dummy_leaf
);
->
emplace
(
dummy_leaf
);
op
->
AddOutput
(
dummy_leaf
);
op
->
AddOutput
(
dummy_leaf
);
}
}
}
}
...
...
paddle/fluid/framework/ir/graph.h
浏览文件 @
9b960330
...
@@ -20,18 +20,77 @@ limitations under the License. */
...
@@ -20,18 +20,77 @@ limitations under the License. */
#include <vector>
#include <vector>
#include "paddle/fluid/framework/ir/node.h"
#include "paddle/fluid/framework/ir/node.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/variant.h"
#include "paddle/fluid/platform/variant.h"
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
class
Graph
;
template
<
typename
AttrType
>
struct
AnyAttr
{
public:
explicit
AnyAttr
(
AttrType
*
attr
)
:
attr_
(
attr
)
{}
AttrType
&
Get
()
{
return
*
boost
::
any_cast
<
AttrType
*>
(
attr_
);
}
private:
friend
Graph
;
AttrType
*
Release
()
{
released_
=
true
;
return
boost
::
any_cast
<
AttrType
*>
(
attr_
);
}
void
Delete
()
{
if
(
!
released_
)
{
delete
boost
::
any_cast
<
AttrType
*>
(
attr_
);
}
}
bool
released_
=
false
;
boost
::
any
attr_
;
};
class
Graph
{
class
Graph
{
public:
public:
std
::
map
<
std
::
string
,
boost
::
any
>
attrs
;
virtual
~
Graph
()
{
for
(
auto
&
attr
:
attrs
)
{
attr_dels
[
attr
.
first
]();
}
attrs
.
clear
();
attr_dels
.
clear
();
}
template
<
typename
AttrType
>
AttrType
&
Get
(
const
std
::
string
&
attr_name
)
{
return
boost
::
any_cast
<
AnyAttr
<
AttrType
>>
(
attrs
[
attr_name
]).
Get
();
}
template
<
typename
AttrType
>
void
Set
(
const
std
::
string
&
attr_name
,
AttrType
*
attr
)
{
AnyAttr
<
AttrType
>
any_attr
=
AnyAttr
<
AttrType
>
(
attr
);
attrs
[
attr_name
]
=
any_attr
;
attr_dels
[
attr_name
]
=
[
&
any_attr
]()
{
any_attr
.
Delete
();
};
}
std
::
vector
<
Node
*>
inputs
;
template
<
typename
AttrType
>
std
::
vector
<
Node
*>
outputs
;
AttrType
*
Erase
(
const
std
::
string
&
attr_name
)
{
AnyAttr
<
AttrType
>
attr_type
=
boost
::
any_cast
<
AnyAttr
<
AttrType
>>
(
attrs
[
attr_name
]);
attrs
.
erase
(
attr_name
);
attr_dels
.
erase
(
attr_name
);
return
attr_type
.
Release
();
}
std
::
vector
<
Node
*>
inputs
;
std
::
vector
<
Node
*>
outputs
;
std
::
vector
<
std
::
unique_ptr
<
Node
>>
nodes
;
std
::
vector
<
std
::
unique_ptr
<
Node
>>
nodes
;
std
::
map
<
std
::
string
,
boost
::
any
>
attrs
;
std
::
map
<
std
::
string
,
std
::
function
<
void
(
void
)
>>
attr_dels
;
private:
};
};
}
// namespace framework
}
// namespace framework
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录