Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
2eeaa8d5
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
2eeaa8d5
编写于
7月 11, 2018
作者:
X
Xin Pan
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Graph in ParallelExecutor Builder
上级
7781297c
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
151 addition
and
76 deletion
+151
-76
paddle/fluid/framework/details/multi_devices_graph_builder.cc
...le/fluid/framework/details/multi_devices_graph_builder.cc
+86
-48
paddle/fluid/framework/details/multi_devices_graph_builder.h
paddle/fluid/framework/details/multi_devices_graph_builder.h
+12
-12
paddle/fluid/framework/details/ssa_graph_builder.cc
paddle/fluid/framework/details/ssa_graph_builder.cc
+17
-10
paddle/fluid/framework/details/ssa_graph_builder.h
paddle/fluid/framework/details/ssa_graph_builder.h
+13
-4
paddle/fluid/framework/ir/graph.h
paddle/fluid/framework/ir/graph.h
+1
-1
paddle/fluid/framework/ir/pass.h
paddle/fluid/framework/ir/pass.h
+22
-1
未找到文件。
paddle/fluid/framework/details/multi_devices_graph_builder.cc
浏览文件 @
2eeaa8d5
...
...
@@ -25,6 +25,7 @@
#include "paddle/fluid/framework/details/reduce_op_handle.h"
#include "paddle/fluid/framework/details/rpc_op_handle.h"
#include "paddle/fluid/framework/details/scale_loss_grad_op_handle.h"
#include "paddle/fluid/framework/ir/node.h"
#include "paddle/fluid/framework/op_info.h"
#include "paddle/fluid/framework/scope.h"
...
...
@@ -66,11 +67,11 @@ MultiDevSSAGraphBuilder::MultiDevSSAGraphBuilder(
}
}
void
MultiDevSSAGraphBuilder
::
CreateOpHandleIOs
(
SSAGraph
*
result
,
const
OpDesc
&
op
,
void
MultiDevSSAGraphBuilder
::
CreateOpHandleIOs
(
Graph
*
result
,
const
OpDesc
&
op
,
size_t
place_id
)
const
{
auto
p
=
places_
[
place_id
];
auto
*
op_handle
=
result
->
ops_
.
back
().
get
();
auto
*
op_handle
=
boost
::
any_cast
<
GraphOps
*>
(
result
->
attrs
[
"ops"
])
->
back
().
get
();
op_handle
->
SetDeviceContext
(
p
,
platform
::
DeviceContextPool
::
Instance
().
Get
(
p
));
...
...
@@ -169,18 +170,21 @@ size_t MultiDevSSAGraphBuilder::GetAppropriateDeviceID(
std
::
unique_ptr
<
SSAGraph
>
MultiDevSSAGraphBuilder
::
Build
(
const
ProgramDesc
&
program
)
const
{
std
::
unique_ptr
<
Graph
>
graph
(
new
Graph
);
for
(
auto
*
var
:
program
.
Block
(
0
).
AllVars
())
{
all_vars_
.
emplace
(
var
->
Name
(),
var
);
}
auto
graph
=
new
SSAGraph
();
SSAGraph
&
result
=
*
graph
;
Graph
&
result
=
*
graph
;
std
::
unordered_set
<
std
::
string
>
og_has_been_broadcast
;
// We cannot invoke resize. It is a bug of GCC 4.8
result
.
vars_
=
std
::
vector
<
result
.
attrs
[
"vars"
]
=
new
std
::
vector
<
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
std
::
unique_ptr
<
VarHandle
>>>>
(
places_
.
size
());
result
.
attrs
[
"dep_vars"
]
=
new
std
::
unordered_set
<
std
::
unique_ptr
<
VarHandleBase
>>
();
result
.
attrs
[
"ops"
]
=
new
std
::
vector
<
std
::
unique_ptr
<
OpHandleBase
>>
();
// find send/recv vars so that we can place the distributed training
// realted op in the place 0
...
...
@@ -303,7 +307,15 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
*/
AddOutputToLeafOps
(
&
result
);
return
std
::
unique_ptr
<
SSAGraph
>
(
graph
);
std
::
unique_ptr
<
SSAGraph
>
ssa_graph
(
new
SSAGraph
);
ssa_graph
->
vars_
=
std
::
move
(
*
boost
::
any_cast
<
GraphVars
*>
(
graph
->
attrs
[
"vars"
]));
ssa_graph
->
ops_
=
std
::
move
(
*
boost
::
any_cast
<
GraphOps
*>
(
graph
->
attrs
[
"ops"
]));
ssa_graph
->
dep_vars_
=
std
::
move
(
*
boost
::
any_cast
<
GraphDepVars
*>
(
graph
->
attrs
[
"dep_vars"
]));
return
std
::
move
(
ssa_graph
);
}
bool
MultiDevSSAGraphBuilder
::
IsSparseGradient
(
const
std
::
string
&
og
)
const
{
...
...
@@ -327,7 +339,7 @@ void MultiDevSSAGraphBuilder::SetCommunicationContext(
#endif
}
void
MultiDevSSAGraphBuilder
::
CreateBroadcastOp
(
SSA
Graph
*
result
,
void
MultiDevSSAGraphBuilder
::
CreateBroadcastOp
(
Graph
*
result
,
const
std
::
string
&
p_name
,
size_t
src_dev_id
)
const
{
#ifdef PADDLE_WITH_CUDA
...
...
@@ -336,42 +348,50 @@ void MultiDevSSAGraphBuilder::CreateBroadcastOp(SSAGraph *result,
auto
*
op_handle
=
new
BroadcastOpHandle
(
local_scopes_
,
places_
);
#endif
result
->
ops_
.
emplace_back
(
op_handle
);
auto
*
in
=
result
->
vars_
.
at
(
src_dev_id
).
at
(
p_name
).
back
().
get
();
boost
::
any_cast
<
GraphOps
*>
(
result
->
attrs
[
"ops"
])
->
emplace_back
(
op_handle
);
auto
*
in
=
boost
::
any_cast
<
GraphVars
*>
(
result
->
attrs
[
"vars"
])
->
at
(
src_dev_id
)
.
at
(
p_name
)
.
back
()
.
get
();
op_handle
->
AddInput
(
in
);
for
(
size_t
i
=
0
;
i
<
places_
.
size
();
++
i
)
{
auto
&
p
=
places_
[
i
];
SetCommunicationContext
(
op_handle
,
p
);
auto
&
vars
=
result
->
vars_
.
at
(
i
).
at
(
p_name
);
auto
&
vars
=
boost
::
any_cast
<
GraphVars
*>
(
result
->
attrs
[
"vars"
])
->
at
(
i
).
at
(
p_name
);
auto
*
out_var
=
new
VarHandle
(
vars
.
size
(),
i
,
p_name
,
p
);
vars
.
emplace_back
(
out_var
);
op_handle
->
AddOutput
(
out_var
);
}
}
void
MultiDevSSAGraphBuilder
::
CreateComputationalOp
(
SSA
Graph
*
result
,
void
MultiDevSSAGraphBuilder
::
CreateComputationalOp
(
Graph
*
result
,
const
OpDesc
&
op
,
int
dev_id
)
const
{
result
->
ops_
.
emplace_back
(
new
ComputationOpHandle
(
op
,
local_scopes_
[
dev_id
],
places_
[
dev_id
]));
boost
::
any_cast
<
GraphOps
*>
(
result
->
attrs
[
"ops"
])
->
emplace_back
(
new
ComputationOpHandle
(
op
,
local_scopes_
[
dev_id
],
places_
[
dev_id
]));
CreateOpHandleIOs
(
result
,
op
,
dev_id
);
}
void
MultiDevSSAGraphBuilder
::
InsertAllReduceOp
(
SSA
Graph
*
result
,
void
MultiDevSSAGraphBuilder
::
InsertAllReduceOp
(
Graph
*
result
,
const
std
::
string
&
og
)
const
{
#ifdef PADDLE_WITH_CUDA
result
->
ops_
.
emplace_back
(
new
AllReduceOpHandle
(
local_scopes_
,
places_
,
nccl_ctxs_
));
boost
::
any_cast
<
GraphOps
*>
(
result
->
attrs
[
"ops"
])
->
emplace_back
(
new
AllReduceOpHandle
(
local_scopes_
,
places_
,
nccl_ctxs_
));
#else
result
->
ops_
.
emplace_back
(
new
AllReduceOpHandle
(
local_scopes_
,
places_
));
boost
::
any_cast
<
GraphOps
*>
(
result
->
attrs
[
"ops"
])
->
emplace_back
(
new
AllReduceOpHandle
(
local_scopes_
,
places_
));
#endif
auto
*
op_handle
=
result
->
ops_
.
back
().
get
();
auto
*
op_handle
=
boost
::
any_cast
<
GraphOps
*>
(
result
->
attrs
[
"ops"
])
->
back
().
get
();
for
(
size_t
i
=
0
;
i
<
places_
.
size
();
++
i
)
{
auto
&
p
=
places_
[
i
];
SetCommunicationContext
(
op_handle
,
p
);
auto
&
vars
=
result
->
vars_
[
i
][
og
];
auto
&
vars
=
(
*
boost
::
any_cast
<
GraphVars
*>
(
result
->
attrs
[
"vars"
]))
[
i
][
og
];
PADDLE_ENFORCE
(
!
vars
.
empty
());
auto
&
prev_grad
=
vars
.
back
();
op_handle
->
AddInput
(
prev_grad
.
get
());
...
...
@@ -383,19 +403,23 @@ void MultiDevSSAGraphBuilder::InsertAllReduceOp(SSAGraph *result,
}
void
MultiDevSSAGraphBuilder
::
InsertDataBalanceOp
(
SSA
Graph
*
result
,
const
std
::
vector
<
std
::
string
>
&
datas
)
const
{
Graph
*
result
,
const
std
::
vector
<
std
::
string
>
&
datas
)
const
{
#ifdef PADDLE_WITH_CUDA
result
->
ops_
.
emplace_back
(
new
DataBalanceOpHandle
(
local_scopes_
,
places_
,
nccl_ctxs_
));
boost
::
any_cast
<
GraphOps
*>
(
result
->
attrs
[
"ops"
])
->
emplace_back
(
new
DataBalanceOpHandle
(
local_scopes_
,
places_
,
nccl_ctxs_
));
#else
result
->
ops_
.
emplace_back
(
new
DataBalanceOpHandle
(
local_scopes_
,
places_
));
boost
::
any_cast
<
GraphOps
*>
(
result
->
attrs
[
"ops"
])
->
emplace_back
(
new
DataBalanceOpHandle
(
local_scopes_
,
places_
));
#endif
auto
*
op_handle
=
result
->
ops_
.
back
().
get
();
auto
*
op_handle
=
boost
::
any_cast
<
GraphOps
*>
(
result
->
attrs
[
"ops"
])
->
back
().
get
();
for
(
size_t
i
=
0
;
i
<
places_
.
size
();
++
i
)
{
auto
&
p
=
places_
[
i
];
SetCommunicationContext
(
op_handle
,
p
);
for
(
const
std
::
string
&
d_name
:
datas
)
{
auto
&
vars
=
result
->
vars_
[
i
][
d_name
];
auto
&
vars
=
(
*
boost
::
any_cast
<
GraphVars
*>
(
result
->
attrs
[
"vars"
]))[
i
][
d_name
];
PADDLE_ENFORCE
(
!
vars
.
empty
());
op_handle
->
AddInput
(
vars
.
back
().
get
());
auto
var
=
new
VarHandle
(
vars
.
size
(),
i
,
d_name
,
p
);
...
...
@@ -441,7 +465,7 @@ int MultiDevSSAGraphBuilder::GetVarDeviceID(const std::string &varname) const {
return
got
==
var_name_on_devices_
.
end
()
?
-
1
:
got
->
second
;
}
void
MultiDevSSAGraphBuilder
::
CreateScaleLossGradOp
(
SSA
Graph
*
result
)
const
{
void
MultiDevSSAGraphBuilder
::
CreateScaleLossGradOp
(
Graph
*
result
)
const
{
for
(
size_t
i
=
0
;
i
<
places_
.
size
();
++
i
)
{
// Insert ScaleCost OpHandle
#ifdef PADDLE_WITH_CUDA
...
...
@@ -456,7 +480,7 @@ void MultiDevSSAGraphBuilder::CreateScaleLossGradOp(SSAGraph *result) const {
auto
*
op_handle
=
new
ScaleLossGradOpHandle
(
local_scopes_
.
size
(),
local_scopes_
[
i
],
places_
[
i
],
communication_dev_ctx
);
result
->
ops_
.
emplace_back
(
op_handle
);
boost
::
any_cast
<
GraphOps
*>
(
result
->
attrs
[
"ops"
])
->
emplace_back
(
op_handle
);
// FIXME: Currently ScaleLossGradOp only use device_count as scale
// factor. So it does not depend on any other operators.
...
...
@@ -469,37 +493,41 @@ void MultiDevSSAGraphBuilder::CreateScaleLossGradOp(SSAGraph *result) const {
}
}
void
MultiDevSSAGraphBuilder
::
CreateComputationalOps
(
SSA
Graph
*
result
,
void
MultiDevSSAGraphBuilder
::
CreateComputationalOps
(
Graph
*
result
,
const
OpDesc
&
op
,
size_t
num_places
)
const
{
for
(
size_t
scope_idx
=
0
;
scope_idx
<
num_places
;
++
scope_idx
)
{
auto
p
=
places_
[
scope_idx
];
auto
s
=
local_scopes_
[
scope_idx
];
result
->
ops_
.
emplace_back
(
new
ComputationOpHandle
(
op
,
s
,
p
));
boost
::
any_cast
<
GraphOps
*>
(
result
->
attrs
[
"ops"
])
->
emplace_back
(
new
ComputationOpHandle
(
op
,
s
,
p
));
CreateOpHandleIOs
(
result
,
op
,
scope_idx
);
}
}
VarHandle
*
MultiDevSSAGraphBuilder
::
CreateReduceOp
(
SSA
Graph
*
result
,
VarHandle
*
MultiDevSSAGraphBuilder
::
CreateReduceOp
(
Graph
*
result
,
const
std
::
string
&
og
,
int
dst_dev_id
)
const
{
#ifdef PADDLE_WITH_CUDA
result
->
ops_
.
emplace_back
(
new
ReduceOpHandle
(
local_scopes_
,
places_
,
nccl_ctxs_
));
boost
::
any_cast
<
GraphOps
*>
(
result
->
attrs
[
"ops"
])
->
emplace_back
(
new
ReduceOpHandle
(
local_scopes_
,
places_
,
nccl_ctxs_
));
#else
result
->
ops_
.
emplace_back
(
new
ReduceOpHandle
(
local_scopes_
,
places_
));
boost
::
any_cast
<
GraphOps
*>
(
result
->
attrs
[
"ops"
])
->
emplace_back
(
new
ReduceOpHandle
(
local_scopes_
,
places_
));
#endif
auto
*
op_handle
=
result
->
ops_
.
back
().
get
();
auto
*
op_handle
=
boost
::
any_cast
<
GraphOps
*>
(
result
->
attrs
[
"ops"
])
->
back
().
get
();
for
(
size_t
i
=
0
;
i
<
places_
.
size
();
++
i
)
{
auto
&
p
=
places_
[
i
];
SetCommunicationContext
(
op_handle
,
p
);
auto
&
vars
=
result
->
vars_
[
i
][
og
];
auto
&
vars
=
(
*
boost
::
any_cast
<
GraphVars
*>
(
result
->
attrs
[
"vars"
]))
[
i
][
og
];
PADDLE_ENFORCE
(
!
vars
.
empty
());
auto
&
prev_grad
=
vars
.
back
();
op_handle
->
AddInput
(
prev_grad
.
get
());
}
auto
&
vars
=
result
->
vars_
[
dst_dev_id
][
og
];
auto
&
vars
=
(
*
boost
::
any_cast
<
GraphVars
*>
(
result
->
attrs
[
"vars"
]))[
dst_dev_id
][
og
];
auto
var
=
new
VarHandle
(
vars
.
size
(),
dst_dev_id
,
og
,
places_
[
dst_dev_id
]);
vars
.
emplace_back
(
var
);
op_handle
->
AddOutput
(
var
);
...
...
@@ -508,19 +536,20 @@ VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(SSAGraph *result,
// Find the first occurence of `prev_op_name` and make current `op` depend
// on it.
void
MultiDevSSAGraphBuilder
::
ConnectOp
(
SSA
Graph
*
result
,
OpHandleBase
*
op
,
void
MultiDevSSAGraphBuilder
::
ConnectOp
(
Graph
*
result
,
OpHandleBase
*
op
,
const
std
::
string
&
prev_op_name
)
const
{
for
(
auto
&
prev_op
:
result
->
ops_
)
{
for
(
auto
&
prev_op
:
(
*
boost
::
any_cast
<
GraphOps
*>
(
result
->
attrs
[
"ops"
]))
)
{
if
(
prev_op
->
Name
()
==
prev_op_name
)
{
auto
*
dep_var
=
new
DummyVarHandle
();
prev_op
->
AddOutput
(
dep_var
);
result
->
dep_vars_
.
emplace
(
dep_var
);
boost
::
any_cast
<
GraphDepVars
*>
(
result
->
attrs
[
"dep_vars"
])
->
emplace
(
dep_var
);
op
->
AddInput
(
dep_var
);
}
}
}
void
MultiDevSSAGraphBuilder
::
CreateDistTrainOp
(
SSA
Graph
*
result
,
void
MultiDevSSAGraphBuilder
::
CreateDistTrainOp
(
Graph
*
result
,
const
OpDesc
&
op
)
const
{
int
op_dev_id
=
-
1
;
if
(
op
.
Type
()
==
"split_byref"
||
op
.
Type
()
==
"split_selected_rows"
)
{
...
...
@@ -550,12 +579,14 @@ void MultiDevSSAGraphBuilder::CreateDistTrainOp(SSAGraph *result,
CreateComputationalOp
(
result
,
op
,
op_dev_id
);
if
(
op
.
Type
()
==
"concat"
)
{
ConnectOp
(
result
,
result
->
ops_
.
back
().
get
(),
"fetch_barrier"
);
ConnectOp
(
result
,
boost
::
any_cast
<
GraphOps
*>
(
result
->
attrs
[
"ops"
])
->
back
().
get
(),
"fetch_barrier"
);
}
}
// Create RPC related op handles that connects its in ops and out ops.
void
MultiDevSSAGraphBuilder
::
CreateRPCOp
(
SSA
Graph
*
result
,
void
MultiDevSSAGraphBuilder
::
CreateRPCOp
(
Graph
*
result
,
const
OpDesc
&
op
)
const
{
int
op_dev_id
=
-
1
;
if
(
op
.
Type
()
==
"send"
)
{
...
...
@@ -584,15 +615,22 @@ void MultiDevSSAGraphBuilder::CreateRPCOp(SSAGraph *result,
PADDLE_ENFORCE
(
op_dev_id
!=
-
1
,
"can not find the right place for rpc op: %s"
,
op
.
Type
());
result
->
ops_
.
emplace_back
(
new
RPCOpHandle
(
op
,
local_scopes_
[
op_dev_id
],
op
.
Type
(),
places_
[
op_dev_id
]));
boost
::
any_cast
<
GraphOps
*>
(
result
->
attrs
[
"ops"
])
->
emplace_back
(
new
RPCOpHandle
(
op
,
local_scopes_
[
op_dev_id
],
op
.
Type
(),
places_
[
op_dev_id
]));
if
(
op
.
Type
()
==
"send_barrier"
)
{
ConnectOp
(
result
,
result
->
ops_
.
back
().
get
(),
"send"
);
ConnectOp
(
result
,
boost
::
any_cast
<
GraphOps
*>
(
result
->
attrs
[
"ops"
])
->
back
().
get
(),
"send"
);
}
else
if
(
op
.
Type
()
==
"recv"
)
{
ConnectOp
(
result
,
result
->
ops_
.
back
().
get
(),
"send_barrier"
);
ConnectOp
(
result
,
boost
::
any_cast
<
GraphOps
*>
(
result
->
attrs
[
"ops"
])
->
back
().
get
(),
"send_barrier"
);
}
else
if
(
op
.
Type
()
==
"fetch_barrier"
)
{
ConnectOp
(
result
,
result
->
ops_
.
back
().
get
(),
"recv"
);
ConnectOp
(
result
,
boost
::
any_cast
<
GraphOps
*>
(
result
->
attrs
[
"ops"
])
->
back
().
get
(),
"recv"
);
}
else
if
(
op
.
Type
()
==
"send"
)
{
// do nothing
}
else
{
...
...
paddle/fluid/framework/details/multi_devices_graph_builder.h
浏览文件 @
2eeaa8d5
...
...
@@ -19,6 +19,7 @@
#include "paddle/fluid/framework/details/build_strategy.h"
#include "paddle/fluid/framework/details/ssa_graph_builder.h"
#include "paddle/fluid/framework/ir/graph.h"
namespace
paddle
{
namespace
platform
{
...
...
@@ -50,7 +51,7 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
int
GetVarDeviceID
(
const
std
::
string
&
varname
)
const
override
;
private:
void
CreateOpHandleIOs
(
SSA
Graph
*
result
,
const
OpDesc
&
op
,
void
CreateOpHandleIOs
(
Graph
*
result
,
const
OpDesc
&
op
,
size_t
device_id
)
const
;
private:
...
...
@@ -65,8 +66,8 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
bool
IsScaleLossOp
(
const
OpDesc
&
op
)
const
;
void
CreateRPCOp
(
SSA
Graph
*
result
,
const
OpDesc
&
op
)
const
;
void
CreateDistTrainOp
(
SSA
Graph
*
result
,
const
OpDesc
&
op
)
const
;
void
CreateRPCOp
(
Graph
*
result
,
const
OpDesc
&
op
)
const
;
void
CreateDistTrainOp
(
Graph
*
result
,
const
OpDesc
&
op
)
const
;
/**
* Is this operator as the end-point operator before/after send operator.
...
...
@@ -81,17 +82,16 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
std
::
vector
<
std
::
string
>
FindDistTrainRecvVars
(
const
ProgramDesc
&
program
)
const
;
void
ConnectOp
(
SSA
Graph
*
result
,
OpHandleBase
*
op
,
void
ConnectOp
(
Graph
*
result
,
OpHandleBase
*
op
,
const
std
::
string
&
prev_op_name
)
const
;
void
CreateComputationalOps
(
SSA
Graph
*
result
,
const
OpDesc
&
op
,
void
CreateComputationalOps
(
Graph
*
result
,
const
OpDesc
&
op
,
size_t
num_places
)
const
;
void
CreateScaleLossGradOp
(
SSA
Graph
*
result
)
const
;
VarHandle
*
CreateReduceOp
(
SSA
Graph
*
result
,
const
std
::
string
&
og
,
void
CreateScaleLossGradOp
(
Graph
*
result
)
const
;
VarHandle
*
CreateReduceOp
(
Graph
*
result
,
const
std
::
string
&
og
,
int
dst_dev_id
)
const
;
void
CreateComputationalOp
(
SSAGraph
*
result
,
const
OpDesc
&
op
,
int
dev_id
)
const
;
void
CreateComputationalOp
(
Graph
*
result
,
const
OpDesc
&
op
,
int
dev_id
)
const
;
bool
IsParameterGradientOnce
(
const
std
::
string
&
og
,
...
...
@@ -99,12 +99,12 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
int
GetOpDeviceID
(
const
OpDesc
&
op
)
const
;
void
InsertAllReduceOp
(
SSA
Graph
*
result
,
const
std
::
string
&
og
)
const
;
void
InsertAllReduceOp
(
Graph
*
result
,
const
std
::
string
&
og
)
const
;
void
InsertDataBalanceOp
(
SSA
Graph
*
result
,
void
InsertDataBalanceOp
(
Graph
*
result
,
const
std
::
vector
<
std
::
string
>
&
datas
)
const
;
void
CreateBroadcastOp
(
SSA
Graph
*
result
,
const
std
::
string
&
p_name
,
void
CreateBroadcastOp
(
Graph
*
result
,
const
std
::
string
&
p_name
,
size_t
src_dev_id
)
const
;
bool
IsSparseGradient
(
const
std
::
string
&
og
)
const
;
...
...
paddle/fluid/framework/details/ssa_graph_builder.cc
浏览文件 @
2eeaa8d5
...
...
@@ -17,8 +17,8 @@
namespace
paddle
{
namespace
framework
{
namespace
details
{
void
SSAGraphBuilder
::
PolishGraphToSupportDataHazards
(
SSA
Graph
*
graph
)
{
for
(
auto
&
var_map
:
graph
->
vars_
)
{
void
SSAGraphBuilder
::
PolishGraphToSupportDataHazards
(
Graph
*
graph
)
{
for
(
auto
&
var_map
:
*
boost
::
any_cast
<
GraphVars
*>
(
graph
->
attrs
[
"vars"
])
)
{
for
(
auto
&
name_pair
:
var_map
)
{
if
(
name_pair
.
second
.
size
()
<=
1
)
{
continue
;
...
...
@@ -40,7 +40,8 @@ void SSAGraphBuilder::PolishGraphToSupportDataHazards(SSAGraph *graph) {
auto
*
dep_var
=
new
DummyVarHandle
();
read_op
->
AddOutput
(
dep_var
);
write_op
->
AddInput
(
dep_var
);
graph
->
dep_vars_
.
emplace
(
dep_var
);
boost
::
any_cast
<
GraphDepVars
*>
(
graph
->
attrs
[
"dep_vars"
])
->
emplace
(
dep_var
);
}
}
}
...
...
@@ -48,9 +49,10 @@ void SSAGraphBuilder::PolishGraphToSupportDataHazards(SSAGraph *graph) {
}
VarHandle
*
SSAGraphBuilder
::
CreateOrGetLatestVarHandle
(
SSA
Graph
*
graph
,
const
std
::
string
&
each_var_name
,
Graph
*
graph
,
const
std
::
string
&
each_var_name
,
const
platform
::
Place
&
place
,
size_t
place_offset
)
{
auto
&
var_holders
=
graph
->
vars_
[
place_offset
];
auto
&
var_holders
=
(
*
boost
::
any_cast
<
GraphVars
*>
(
graph
->
attrs
[
"vars"
]))[
place_offset
];
auto
&
var_holder
=
var_holders
[
each_var_name
];
VarHandle
*
var
=
nullptr
;
if
(
var_holder
.
empty
())
{
...
...
@@ -62,24 +64,29 @@ VarHandle *SSAGraphBuilder::CreateOrGetLatestVarHandle(
return
var
;
}
void
SSAGraphBuilder
::
CreateOpOutput
(
SSA
Graph
*
graph
,
OpHandleBase
*
op_handle
,
void
SSAGraphBuilder
::
CreateOpOutput
(
Graph
*
graph
,
OpHandleBase
*
op_handle
,
const
std
::
string
&
each_var_name
,
const
platform
::
Place
&
place
,
size_t
place_offset
)
{
auto
&
vars
=
graph
->
vars_
[
place_offset
][
each_var_name
];
auto
&
vars
=
(
*
boost
::
any_cast
<
GraphVars
*>
(
graph
->
attrs
[
"vars"
]))[
place_offset
]
[
each_var_name
];
size_t
version
=
vars
.
size
();
auto
var
=
new
VarHandle
(
version
,
place_offset
,
each_var_name
,
place
);
vars
.
emplace_back
(
var
);
op_handle
->
AddOutput
(
var
);
}
void
SSAGraphBuilder
::
AddOutputToLeafOps
(
SSAGraph
*
graph
)
{
for
(
auto
&
op
:
graph
->
ops_
)
{
void
SSAGraphBuilder
::
AddOutputToLeafOps
(
Graph
*
graph
)
{
GraphOps
&
all_ops
=
*
boost
::
any_cast
<
GraphOps
*>
(
graph
->
attrs
[
"ops"
]);
for
(
auto
&
op
:
all_ops
)
{
if
(
!
op
->
Outputs
().
empty
())
{
continue
;
}
auto
*
dummy_leaf
=
new
DummyVarHandle
();
graph
->
dep_vars_
.
emplace
(
dummy_leaf
);
boost
::
any_cast
<
GraphDepVars
*>
(
graph
->
attrs
[
"dep_vars"
])
->
emplace
(
dummy_leaf
);
op
->
AddOutput
(
dummy_leaf
);
}
}
...
...
paddle/fluid/framework/details/ssa_graph_builder.h
浏览文件 @
2eeaa8d5
...
...
@@ -16,15 +16,24 @@
#include <memory>
#include <string>
#include <vector>
#include "paddle/fluid/framework/details/ssa_graph.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/framework/ir/graph.h"
namespace
paddle
{
namespace
framework
{
namespace
details
{
typedef
std
::
vector
<
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
std
::
unique_ptr
<
VarHandle
>>>>
GraphVars
;
typedef
std
::
unordered_set
<
std
::
unique_ptr
<
VarHandleBase
>>
GraphDepVars
;
typedef
std
::
vector
<
std
::
unique_ptr
<
OpHandleBase
>>
GraphOps
;
class
SSAGraphBuilder
{
public:
SSAGraphBuilder
()
{}
...
...
@@ -42,20 +51,20 @@ class SSAGraphBuilder {
*
* https://en.wikipedia.org/wiki/Hazard_(computer_architecture)#Write_after_read_(WAR)
*/
static
void
PolishGraphToSupportDataHazards
(
SSA
Graph
*
graph
);
static
void
PolishGraphToSupportDataHazards
(
Graph
*
graph
);
static
VarHandle
*
CreateOrGetLatestVarHandle
(
SSA
Graph
*
graph
,
static
VarHandle
*
CreateOrGetLatestVarHandle
(
Graph
*
graph
,
const
std
::
string
&
each_var_name
,
const
platform
::
Place
&
place
,
size_t
place_offset
);
// Add an output variable (each_var_name, place, place_offset) to op_handle,
// which belongs to graph
static
void
CreateOpOutput
(
SSA
Graph
*
graph
,
OpHandleBase
*
op_handle
,
static
void
CreateOpOutput
(
Graph
*
graph
,
OpHandleBase
*
op_handle
,
const
std
::
string
&
each_var_name
,
const
platform
::
Place
&
place
,
size_t
place_offset
);
static
void
AddOutputToLeafOps
(
SSA
Graph
*
graph
);
static
void
AddOutputToLeafOps
(
Graph
*
graph
);
};
}
// namespace details
}
// namespace framework
...
...
paddle/fluid/framework/ir/graph.h
浏览文件 @
2eeaa8d5
...
...
@@ -27,7 +27,7 @@ namespace framework {
class
Graph
{
public:
std
::
map
<
std
::
string
,
std
::
vector
<
boost
::
any
>
>
attrs
;
std
::
map
<
std
::
string
,
boost
::
any
>
attrs
;
std
::
vector
<
Node
*>
inputs
;
std
::
vector
<
Node
*>
outputs
;
...
...
paddle/fluid/framework/ir/pass.h
浏览文件 @
2eeaa8d5
...
...
@@ -14,6 +14,27 @@ limitations under the License. */
#pragma once
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/node.h"
#include "paddle/fluid/framework/program_desc.h"
namespace
paddle
{
namespace
framework
{}
// namespace framework
namespace
framework
{
class
Pass
{
public:
Pass
()
=
default
;
virtual
~
Pass
()
{}
virtual
std
::
unique_ptr
<
Graph
>
Apply
(
std
::
unique_ptr
<
Graph
>
graph
)
{
return
std
::
move
(
graph
);
}
};
std
::
unique_ptr
<
Graph
>
ProgramToGraph
(
const
ProgramDesc
&
program
)
{
std
::
unique_ptr
<
Graph
>
g
(
new
Graph
);
return
std
::
move
(
g
);
}
}
// namespace framework
}
// namespace paddle
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录