Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
e4d7d7ae
P
Paddle
项目概览
PaddlePaddle
/
Paddle
1 年多 前同步成功
通知
2302
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
e4d7d7ae
编写于
7月 26, 2018
作者:
X
Xin Pan
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
pass refactoring
上级
142e832d
变更
17
隐藏空白更改
内联
并排
Showing
17 changed file
with
229 addition
and
81 deletion
+229
-81
paddle/fluid/framework/details/multi_devices_graph_builder.cc
...le/fluid/framework/details/multi_devices_graph_builder.cc
+30
-16
paddle/fluid/framework/details/multi_devices_graph_builder.h
paddle/fluid/framework/details/multi_devices_graph_builder.h
+3
-3
paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.h
...uid/framework/details/scope_buffered_ssa_graph_executor.h
+3
-0
paddle/fluid/framework/details/ssa_graph_builder.h
paddle/fluid/framework/details/ssa_graph_builder.h
+2
-2
paddle/fluid/framework/details/ssa_graph_builder_factory.cc
paddle/fluid/framework/details/ssa_graph_builder_factory.cc
+6
-6
paddle/fluid/framework/details/ssa_graph_builder_factory.h
paddle/fluid/framework/details/ssa_graph_builder_factory.h
+7
-7
paddle/fluid/framework/details/ssa_graph_checker.cc
paddle/fluid/framework/details/ssa_graph_checker.cc
+3
-0
paddle/fluid/framework/details/ssa_graph_checker.h
paddle/fluid/framework/details/ssa_graph_checker.h
+1
-6
paddle/fluid/framework/details/ssa_graph_executor.h
paddle/fluid/framework/details/ssa_graph_executor.h
+3
-1
paddle/fluid/framework/details/ssa_graph_printer.cc
paddle/fluid/framework/details/ssa_graph_printer.cc
+3
-0
paddle/fluid/framework/details/ssa_graph_printer.h
paddle/fluid/framework/details/ssa_graph_printer.h
+1
-6
paddle/fluid/framework/details/threaded_ssa_graph_executor.h
paddle/fluid/framework/details/threaded_ssa_graph_executor.h
+1
-0
paddle/fluid/framework/ir/graph.h
paddle/fluid/framework/ir/graph.h
+2
-0
paddle/fluid/framework/ir/pass.h
paddle/fluid/framework/ir/pass.h
+2
-0
paddle/fluid/framework/parallel_executor.cc
paddle/fluid/framework/parallel_executor.cc
+65
-33
paddle/fluid/framework/parallel_executor.h
paddle/fluid/framework/parallel_executor.h
+0
-1
paddle/fluid/operators/distributed/send_recv.proto
paddle/fluid/operators/distributed/send_recv.proto
+97
-0
未找到文件。
paddle/fluid/framework/details/multi_devices_graph_builder.cc
浏览文件 @
e4d7d7ae
...
...
@@ -244,6 +244,7 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::Apply(
result
.
Set
(
"vars"
,
new
GraphVars
(
places_
.
size
()));
result
.
Set
(
"dep_vars"
,
new
GraphDepVars
);
result
.
Set
(
"ops"
,
new
GraphOps
);
result
.
Set
(
"sharded_var_device"
,
new
ShardedVarDevice
);
// find send/recv vars so that we can place the distributed training
// realted op in the place 0
...
...
@@ -276,11 +277,12 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::Apply(
// the block.
is_forwarding
=
false
;
}
else
{
int
op_dev_id
=
GetOpDeviceID
(
node
);
int
op_dev_id
=
GetOpDeviceID
(
result
,
node
);
if
(
op_dev_id
!=
-
1
)
{
// This op only runs on one specific device.
CreateComputationalOp
(
&
result
,
node
,
op_dev_id
);
for
(
ir
::
Node
*
n
:
node
->
outputs
)
{
var_name_on_devices_
.
emplace
(
n
->
Name
(),
op_dev_id
);
graph
->
Get
<
ShardedVarDevice
>
(
"sharded_var_device"
)
.
emplace
(
n
->
Name
(),
op_dev_id
);
}
}
else
{
// This op runs on all devices, and its output may have parameter's
...
...
@@ -317,7 +319,8 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::Apply(
case
BuildStrategy
::
ReduceStrategy
::
kReduce
:
cur_device_id
=
GetAppropriateDeviceID
({
g_name
});
CreateReduceOp
(
&
result
,
g_name
,
cur_device_id
);
var_name_on_devices_
.
emplace
(
g_name
,
cur_device_id
);
graph
->
Get
<
ShardedVarDevice
>
(
"sharded_var_device"
)
.
emplace
(
g_name
,
cur_device_id
);
bcast_var_name_set
[
cur_device_id
].
emplace
(
p_name
);
break
;
case
BuildStrategy
::
ReduceStrategy
::
kAllReduce
:
...
...
@@ -499,7 +502,8 @@ bool MultiDevSSAGraphBuilder::IsParameterGradientOnce(
return
is_pg_once
;
}
int
MultiDevSSAGraphBuilder
::
GetOpDeviceID
(
ir
::
Node
*
node
)
const
{
int
MultiDevSSAGraphBuilder
::
GetOpDeviceID
(
const
ir
::
Graph
&
graph
,
ir
::
Node
*
node
)
const
{
if
(
strategy_
.
reduce_
!=
BuildStrategy
::
ReduceStrategy
::
kReduce
)
{
return
-
1
;
}
...
...
@@ -512,15 +516,17 @@ int MultiDevSSAGraphBuilder::GetOpDeviceID(ir::Node *node) const {
node
->
Op
()
->
GetAttr
(
OpProtoAndCheckerMaker
::
OpRoleVarAttrName
()));
PADDLE_ENFORCE_EQ
(
param_grad
.
size
(),
2U
);
int
dev_id
=
GetVarDeviceID
(
param_grad
[
1
]);
int
dev_id
=
GetVarDeviceID
(
graph
,
param_grad
[
1
]);
PADDLE_ENFORCE_NE
(
dev_id
,
-
1
,
"dev_id should not be -1.[%s, %s, %s]"
,
node
->
Op
()
->
Type
(),
param_grad
[
0
],
param_grad
[
1
]);
return
dev_id
;
}
int
MultiDevSSAGraphBuilder
::
GetVarDeviceID
(
const
std
::
string
&
varname
)
const
{
auto
got
=
var_name_on_devices_
.
find
(
varname
);
return
got
==
var_name_on_devices_
.
end
()
?
-
1
:
got
->
second
;
int
MultiDevSSAGraphBuilder
::
GetVarDeviceID
(
const
ir
::
Graph
&
graph
,
const
std
::
string
&
varname
)
const
{
auto
&
sharded_var_device
=
graph
.
Get
<
ShardedVarDevice
>
(
"sharded_var_device"
);
auto
got
=
sharded_var_device
.
find
(
varname
);
return
got
==
sharded_var_device
.
end
()
?
-
1
:
got
->
second
;
}
void
MultiDevSSAGraphBuilder
::
CreateScaleLossGradOp
(
ir
::
Graph
*
result
)
const
{
...
...
@@ -625,20 +631,23 @@ void MultiDevSSAGraphBuilder::CreateDistTrainOp(ir::Graph *result,
if
(
node
->
Op
()
->
Type
()
==
"split_byref"
||
node
->
Op
()
->
Type
()
==
"split_selected_rows"
)
{
// TODO(paddle-dev): getting the first var is not safe.
op_dev_id
=
GetVarDeviceID
(
input_var_names
[
0
]);
op_dev_id
=
GetVarDeviceID
(
*
result
,
input_var_names
[
0
]);
if
(
strategy_
.
reduce_
==
BuildStrategy
::
ReduceStrategy
::
kAllReduce
)
{
op_dev_id
=
GetAppropriateDeviceID
(
input_var_names
);
for
(
auto
&
varname
:
input_var_names
)
{
var_name_on_devices_
.
emplace
(
varname
,
op_dev_id
);
result
->
Get
<
ShardedVarDevice
>
(
"sharded_var_device"
)
.
emplace
(
varname
,
op_dev_id
);
}
}
for
(
auto
&
varname
:
output_var_names
)
{
var_name_on_devices_
.
emplace
(
varname
,
op_dev_id
);
result
->
Get
<
ShardedVarDevice
>
(
"sharded_var_device"
)
.
emplace
(
varname
,
op_dev_id
);
}
}
else
if
(
node
->
Op
()
->
Type
()
==
"concat"
)
{
op_dev_id
=
GetVarDeviceID
(
input_var_names
[
0
]);
op_dev_id
=
GetVarDeviceID
(
*
result
,
input_var_names
[
0
]);
for
(
auto
&
varname
:
output_var_names
)
{
var_name_on_devices_
.
emplace
(
varname
,
op_dev_id
);
result
->
Get
<
ShardedVarDevice
>
(
"sharded_var_device"
)
.
emplace
(
varname
,
op_dev_id
);
}
}
else
{
PADDLE_ENFORCE
(
...
...
@@ -663,7 +672,7 @@ void MultiDevSSAGraphBuilder::CreateRPCOp(ir::Graph *result,
int
op_dev_id
=
-
1
;
if
(
node
->
Op
()
->
Type
()
==
"send"
)
{
// TODO(paddle-dev): getting the first var is not safe.
op_dev_id
=
GetVarDeviceID
(
node
->
inputs
[
0
]
->
Name
());
op_dev_id
=
GetVarDeviceID
(
*
result
,
node
->
inputs
[
0
]
->
Name
());
PADDLE_ENFORCE
(
!
ir
::
IsControlDepVar
(
*
node
->
inputs
[
0
]),
"This hack no longer holds, please fix."
);
// the variable name which contains .block means it was splited by
...
...
@@ -678,7 +687,8 @@ void MultiDevSSAGraphBuilder::CreateRPCOp(ir::Graph *result,
}
op_dev_id
=
GetAppropriateDeviceID
(
input_var_names
);
for
(
auto
&
varname
:
input_var_names
)
{
var_name_on_devices_
.
emplace
(
varname
,
op_dev_id
);
result
->
Get
<
ShardedVarDevice
>
(
"sharded_var_device"
)
.
emplace
(
varname
,
op_dev_id
);
}
}
}
else
if
(
node
->
Op
()
->
Type
()
==
"recv"
)
{
...
...
@@ -688,7 +698,8 @@ void MultiDevSSAGraphBuilder::CreateRPCOp(ir::Graph *result,
}
op_dev_id
=
GetAppropriateDeviceID
(
output_var_names
);
for
(
auto
&
varname
:
output_var_names
)
{
var_name_on_devices_
.
emplace
(
varname
,
op_dev_id
);
result
->
Get
<
ShardedVarDevice
>
(
"sharded_var_device"
)
.
emplace
(
varname
,
op_dev_id
);
}
}
else
{
// send_barrier and fetch_barrier op can be scheduled on device 0
...
...
@@ -730,3 +741,6 @@ bool MultiDevSSAGraphBuilder::IsScaleLossOp(ir::Node *node) const {
}
// namespace details
}
// namespace framework
}
// namespace paddle
REGISTER_PASS
(
multi_device_pass
,
paddle
::
framework
::
details
::
MultiDevSSAGraphBuilder
);
paddle/fluid/framework/details/multi_devices_graph_builder.h
浏览文件 @
e4d7d7ae
...
...
@@ -34,7 +34,6 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
public:
std
::
unique_ptr
<
ir
::
Graph
>
Apply
(
std
::
unique_ptr
<
ir
::
Graph
>
graph
)
const
override
;
int
GetVarDeviceID
(
const
std
::
string
&
varname
)
const
override
;
private:
void
CreateOpHandleIOs
(
ir
::
Graph
*
result
,
ir
::
Node
*
node
,
...
...
@@ -51,6 +50,8 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
mutable
platform
::
NCCLContextMap
*
nccl_ctxs_
;
#endif
int
GetVarDeviceID
(
const
ir
::
Graph
&
graph
,
const
std
::
string
&
varname
)
const
;
bool
IsScaleLossOp
(
ir
::
Node
*
node
)
const
;
void
CreateRPCOp
(
ir
::
Graph
*
result
,
ir
::
Node
*
node
)
const
;
...
...
@@ -84,7 +85,7 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
const
std
::
string
&
og
,
std
::
unordered_set
<
std
::
string
>
*
og_has_been_broadcast
)
const
;
int
GetOpDeviceID
(
ir
::
Node
*
node
)
const
;
int
GetOpDeviceID
(
const
ir
::
Graph
&
graph
,
ir
::
Node
*
node
)
const
;
void
InsertAllReduceOp
(
ir
::
Graph
*
result
,
const
std
::
string
&
og
)
const
;
...
...
@@ -102,7 +103,6 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
private:
mutable
BuildStrategy
strategy_
;
mutable
std
::
unordered_map
<
std
::
string
,
VarDesc
*>
all_vars_
;
mutable
std
::
unordered_map
<
std
::
string
,
int
>
var_name_on_devices_
;
mutable
std
::
vector
<
int64_t
>
balance_vars_
;
void
SetCommunicationContext
(
OpHandleBase
*
op_handle
,
...
...
paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.h
浏览文件 @
e4d7d7ae
...
...
@@ -40,6 +40,9 @@ class ScopeBufferedSSAGraphExecutor : public SSAGraphExecutor {
ExecutionStrategy
strategy
,
std
::
vector
<
Scope
*>
local_scopes
,
std
::
vector
<
VariableInfo
>
var_infos
,
std
::
vector
<
platform
::
Place
>
places
,
std
::
unique_ptr
<
SSAGraphExecutor
>&&
underlying_executor
);
const
ir
::
Graph
&
Graph
()
const
{
return
underlying_executor_
->
Graph
();
}
FeedFetchList
Run
(
const
std
::
vector
<
std
::
string
>&
fetch_tensors
)
override
;
private:
...
...
paddle/fluid/framework/details/ssa_graph_builder.h
浏览文件 @
e4d7d7ae
...
...
@@ -47,13 +47,13 @@ typedef std::unordered_set<std::unique_ptr<VarHandleBase>> GraphDepVars;
// unordered.
typedef
std
::
vector
<
std
::
unique_ptr
<
OpHandleBase
>>
GraphOps
;
typedef
std
::
unordered_map
<
std
::
string
,
int
>
ShardedVarDevice
;
class
SSAGraphBuilder
:
public
ir
::
Pass
{
public:
SSAGraphBuilder
()
{}
virtual
~
SSAGraphBuilder
()
{}
virtual
int
GetVarDeviceID
(
const
std
::
string
&
var_name
)
const
=
0
;
DISABLE_COPY_AND_ASSIGN
(
SSAGraphBuilder
);
protected:
...
...
paddle/fluid/framework/details/ssa_graph_builder_factory.cc
浏览文件 @
e4d7d7ae
...
...
@@ -21,8 +21,8 @@
namespace
paddle
{
namespace
framework
{
namespace
details
{
std
::
unique_ptr
<
SSAGraphBuilder
>
SSAGraphBuilderFactory
::
Create
()
{
std
::
unique_ptr
<
SSAGraphBuilder
>
res
(
new
MultiDevSSAGraphBuilder
);
std
::
unique_ptr
<
ir
::
Pass
>
ParallelExecutorPassManager
::
Create
()
{
std
::
unique_ptr
<
ir
::
Pass
>
res
(
new
MultiDevSSAGraphBuilder
);
res
->
SetNotOwned
<
std
::
vector
<
platform
::
Place
>>
(
"places"
,
&
places_
);
res
->
SetNotOwned
<
std
::
string
>
(
"loss_var_name"
,
&
loss_var_name_
);
res
->
SetNotOwned
<
std
::
unordered_set
<
std
::
string
>>
(
"params"
,
&
param_names_
);
...
...
@@ -33,18 +33,18 @@ std::unique_ptr<SSAGraphBuilder> SSAGraphBuilderFactory::Create() {
#endif
if
(
!
strategy_
.
debug_graphviz_path_
.
empty
())
{
SSAGraphBuilder
*
previous_pass
=
res
.
release
();
ir
::
Pass
*
previous_pass
=
res
.
release
();
res
.
reset
(
new
SSAGraghBuilderWithPrinter
);
res
->
Set
<
SSAGraphBuilder
>
(
"previous_pass"
,
previous_pass
);
res
->
Set
<
ir
::
Pass
>
(
"previous_pass"
,
previous_pass
);
res
->
SetNotOwned
<
std
::
string
>
(
"debug_graphviz_path"
,
&
strategy_
.
debug_graphviz_path_
);
res
->
Set
<
GraphvizSSAGraphPrinter
>
(
"graph_printer"
,
new
GraphvizSSAGraphPrinter
);
}
SSAGraphBuilder
*
previous_pass
=
res
.
release
();
ir
::
Pass
*
previous_pass
=
res
.
release
();
res
.
reset
(
new
SSAGraghBuilderWithChecker
);
res
->
Set
<
SSAGraphBuilder
>
(
"previous_pass"
,
previous_pass
);
res
->
Set
<
ir
::
Pass
>
(
"previous_pass"
,
previous_pass
);
return
res
;
}
...
...
paddle/fluid/framework/details/ssa_graph_builder_factory.h
浏览文件 @
e4d7d7ae
...
...
@@ -29,13 +29,13 @@ namespace framework {
class
Scope
;
namespace
details
{
class
SSAGraphBuilderFactory
{
class
ParallelExecutorPassManager
{
public:
SSAGraphBuilderFactory
(
const
std
::
vector
<
platform
::
Place
>&
places
,
const
std
::
string
&
loss_var_name
,
const
std
::
unordered_set
<
std
::
string
>&
param_names
,
const
std
::
vector
<
Scope
*>&
local_scop
es
,
const
BuildStrategy
&
strategy
)
ParallelExecutorPassManager
(
const
std
::
vector
<
platform
::
Place
>&
places
,
const
std
::
string
&
loss_var_name
,
const
std
::
unordered_set
<
std
::
string
>&
param_nam
es
,
const
std
::
vector
<
Scope
*>&
local_scopes
,
const
BuildStrategy
&
strategy
)
:
places_
(
places
),
loss_var_name_
(
loss_var_name
),
param_names_
(
param_names
),
...
...
@@ -52,7 +52,7 @@ class SSAGraphBuilderFactory {
}
#endif
std
::
unique_ptr
<
SSAGraphBuilder
>
Create
();
std
::
unique_ptr
<
ir
::
Pass
>
Create
();
private:
std
::
vector
<
platform
::
Place
>
places_
;
...
...
paddle/fluid/framework/details/ssa_graph_checker.cc
浏览文件 @
e4d7d7ae
...
...
@@ -85,3 +85,6 @@ bool SSAGraghBuilderWithChecker::IsValidGraph(const ir::Graph *graph) const {
}
// namespace details
}
// namespace framework
}
// namespace paddle
REGISTER_PASS
(
multi_device_check_pass
,
paddle
::
framework
::
details
::
SSAGraghBuilderWithChecker
);
paddle/fluid/framework/details/ssa_graph_checker.h
浏览文件 @
e4d7d7ae
...
...
@@ -26,16 +26,11 @@ class SSAGraghBuilderWithChecker : public SSAGraphBuilder {
public:
std
::
unique_ptr
<
ir
::
Graph
>
Apply
(
std
::
unique_ptr
<
ir
::
Graph
>
graph
)
const
override
{
auto
new_graph
=
Get
<
SSAGraphBuilder
>
(
"previous_pass"
).
Apply
(
std
::
move
(
graph
));
auto
new_graph
=
Get
<
ir
::
Pass
>
(
"previous_pass"
).
Apply
(
std
::
move
(
graph
));
PADDLE_ENFORCE
(
IsValidGraph
(
new_graph
.
get
()));
return
new_graph
;
}
int
GetVarDeviceID
(
const
std
::
string
&
var_name
)
const
override
{
return
Get
<
SSAGraphBuilder
>
(
"previous_pass"
).
GetVarDeviceID
(
var_name
);
}
bool
IsValidGraph
(
const
ir
::
Graph
*
graph
)
const
;
};
...
...
paddle/fluid/framework/details/ssa_graph_executor.h
浏览文件 @
e4d7d7ae
...
...
@@ -32,7 +32,9 @@ class SSAGraphExecutor {
virtual
~
SSAGraphExecutor
();
virtual
FeedFetchList
Run
(
const
std
::
vector
<
std
::
string
>
&
fetch_tensors
)
=
0
;
virtual
const
ir
::
Graph
&
Graph
()
const
=
0
;
virtual
FeedFetchList
Run
(
const
std
::
vector
<
std
::
string
>&
fetch_tensors
)
=
0
;
};
}
// namespace details
}
// namespace framework
...
...
paddle/fluid/framework/details/ssa_graph_printer.cc
浏览文件 @
e4d7d7ae
...
...
@@ -81,3 +81,6 @@ void GraphvizSSAGraphPrinter::Print(const ir::Graph &graph,
}
// namespace details
}
// namespace framework
}
// namespace paddle
REGISTER_PASS
(
multi_device_print_pass
,
paddle
::
framework
::
details
::
SSAGraghBuilderWithPrinter
);
paddle/fluid/framework/details/ssa_graph_printer.h
浏览文件 @
e4d7d7ae
...
...
@@ -39,8 +39,7 @@ class SSAGraghBuilderWithPrinter : public SSAGraphBuilder {
public:
std
::
unique_ptr
<
ir
::
Graph
>
Apply
(
std
::
unique_ptr
<
ir
::
Graph
>
graph
)
const
override
{
auto
new_graph
=
Get
<
SSAGraphBuilder
>
(
"previous_pass"
).
Apply
(
std
::
move
(
graph
));
auto
new_graph
=
Get
<
ir
::
Pass
>
(
"previous_pass"
).
Apply
(
std
::
move
(
graph
));
std
::
unique_ptr
<
std
::
ostream
>
fout
(
new
std
::
ofstream
(
Get
<
std
::
string
>
(
"debug_graphviz_path"
)));
...
...
@@ -48,10 +47,6 @@ class SSAGraghBuilderWithPrinter : public SSAGraphBuilder {
Get
<
GraphvizSSAGraphPrinter
>
(
"graph_printer"
).
Print
(
*
new_graph
,
*
fout
);
return
new_graph
;
}
int
GetVarDeviceID
(
const
std
::
string
&
var_name
)
const
override
{
return
Get
<
SSAGraphBuilder
>
(
"previous_pass"
).
GetVarDeviceID
(
var_name
);
}
};
}
// namespace details
...
...
paddle/fluid/framework/details/threaded_ssa_graph_executor.h
浏览文件 @
e4d7d7ae
...
...
@@ -42,6 +42,7 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor {
const
std
::
vector
<
platform
::
Place
>
&
places
,
std
::
unique_ptr
<
ir
::
Graph
>
&&
graph
);
const
ir
::
Graph
&
Graph
()
const
{
return
*
graph_
;
}
// Run a SSAGraph by a thread pool
// Use topological sort algorithm
FeedFetchList
Run
(
const
std
::
vector
<
std
::
string
>
&
fetch_tensors
)
override
;
...
...
paddle/fluid/framework/ir/graph.h
浏览文件 @
e4d7d7ae
...
...
@@ -42,6 +42,8 @@ class Graph {
template
<
typename
AttrType
>
AttrType
&
Get
(
const
std
::
string
&
attr_name
)
const
{
PADDLE_ENFORCE
(
attrs_
.
find
(
attr_name
)
!=
attrs_
.
end
(),
"%s attr not registered for graph."
,
attr_name
);
return
*
boost
::
any_cast
<
AttrType
*>
(
attrs_
.
at
(
attr_name
));
}
...
...
paddle/fluid/framework/ir/pass.h
浏览文件 @
e4d7d7ae
...
...
@@ -44,6 +44,8 @@ class Pass {
template
<
typename
AttrType
>
AttrType
&
Get
(
const
std
::
string
&
attr_name
)
const
{
PADDLE_ENFORCE
(
attrs_
.
find
(
attr_name
)
!=
attrs_
.
end
(),
"%s attr not registered for pass."
,
attr_name
);
return
*
boost
::
any_cast
<
AttrType
*>
(
attrs_
.
at
(
attr_name
));
}
...
...
paddle/fluid/framework/parallel_executor.cc
浏览文件 @
e4d7d7ae
...
...
@@ -33,6 +33,48 @@ limitations under the License. */
namespace
paddle
{
namespace
framework
{
std
::
unique_ptr
<
ir
::
Graph
>
ApplyParallelExecutorPass
(
const
ProgramDesc
&
main_program
,
const
std
::
vector
<
platform
::
Place
>
&
places
,
const
std
::
string
&
loss_var_name
,
const
std
::
unordered_set
<
std
::
string
>
&
param_names
,
const
std
::
vector
<
Scope
*>
&
local_scopes
,
const
bool
use_cuda
,
#ifdef PADDLE_WITH_CUDA
const
BuildStrategy
&
strategy
,
platform
::
NCCLContextMap
*
nccl_ctxs
)
{
#else
const
BuildStrategy
&
strategy
)
{
#endif
details
::
ParallelExecutorPassManager
builder_factory
(
places
,
loss_var_name
,
param_names
,
local_scopes
,
strategy
);
if
(
use_cuda
)
{
#ifdef PADDLE_WITH_CUDA
builder_factory
.
SetNCCLContextMap
(
nccl_ctxs
);
#else
PADDLE_THROW
(
"Not compiled with CUDA."
);
#endif
}
std
::
unique_ptr
<
ir
::
Graph
>
graph
(
new
ir
::
Graph
(
main_program
));
if
(
!
strategy
.
debug_graphviz_path_
.
empty
())
{
auto
viz_pass
=
ir
::
PassRegistry
::
Instance
().
Get
(
"graph_viz_pass"
);
const
std
::
string
graph_path
=
string
::
Sprintf
(
"%s%s"
,
strategy
.
debug_graphviz_path_
.
c_str
(),
"_original_graph"
);
viz_pass
->
Set
<
std
::
string
>
(
"graph_viz_path"
,
new
std
::
string
(
graph_path
));
graph
=
viz_pass
->
Apply
(
std
::
move
(
graph
));
}
auto
builder
=
builder_factory
.
Create
();
graph
=
builder
->
Apply
(
std
::
move
(
graph
));
if
(
!
strategy
.
debug_graphviz_path_
.
empty
())
{
auto
viz_pass
=
ir
::
PassRegistry
::
Instance
().
Get
(
"graph_viz_pass"
);
const
std
::
string
graph_path
=
string
::
Sprintf
(
"%s%s"
,
strategy
.
debug_graphviz_path_
.
c_str
(),
"_before_exec"
);
viz_pass
->
Set
<
std
::
string
>
(
"graph_viz_path"
,
new
std
::
string
(
graph_path
));
graph
=
viz_pass
->
Apply
(
std
::
move
(
graph
));
}
return
graph
;
}
class
ParallelExecutorPrivate
{
public:
explicit
ParallelExecutorPrivate
(
const
std
::
vector
<
platform
::
Place
>
&
places
)
...
...
@@ -120,38 +162,18 @@ ParallelExecutor::ParallelExecutor(
var_infos
.
back
().
persistable_
=
var
->
Persistable
();
}
// Step 3. Convert main_program to SSA form and dependency graph. Also, insert
// ncclOp
details
::
SSAGraphBuilderFactory
builder_factory
(
member_
->
places_
,
loss_var_name
,
params
,
member_
->
local_scopes_
,
build_strategy
);
if
(
member_
->
use_cuda_
)
{
// Step 3. Convert main_program to SSA form and dependency graph. Also, insert
// ncclOp
#ifdef PADDLE_WITH_CUDA
builder_factory
.
SetNCCLContextMap
(
member_
->
nccl_ctxs_
.
get
());
std
::
unique_ptr
<
ir
::
Graph
>
graph
=
ApplyParallelExecutorPass
(
main_program
,
member_
->
places_
,
loss_var_name
,
params
,
member_
->
local_scopes_
,
member_
->
use_cuda_
,
build_strategy
,
member_
->
nccl_ctxs_
.
get
());
#else
PADDLE_THROW
(
"Not compiled with CUDA."
);
std
::
unique_ptr
<
ir
::
Graph
>
graph
=
ApplyParallelExecutorPass
(
main_program
,
member_
->
places_
,
loss_var_name
,
params
,
member_
->
local_scopes_
,
member_
->
use_cuda_
,
build_strategy
);
#endif
}
std
::
unique_ptr
<
ir
::
Graph
>
graph
(
new
ir
::
Graph
(
main_program
));
if
(
!
build_strategy
.
debug_graphviz_path_
.
empty
())
{
auto
viz_pass
=
ir
::
PassRegistry
::
Instance
().
Get
(
"graph_viz_pass"
);
const
std
::
string
graph_path
=
string
::
Sprintf
(
"%s%s"
,
build_strategy
.
debug_graphviz_path_
.
c_str
(),
"_original_graph"
);
viz_pass
->
Set
<
std
::
string
>
(
"graph_viz_path"
,
new
std
::
string
(
graph_path
));
graph
=
viz_pass
->
Apply
(
std
::
move
(
graph
));
}
builder_
=
builder_factory
.
Create
();
graph
=
builder_
->
Apply
(
std
::
move
(
graph
));
if
(
!
build_strategy
.
debug_graphviz_path_
.
empty
())
{
auto
viz_pass
=
ir
::
PassRegistry
::
Instance
().
Get
(
"graph_viz_pass"
);
const
std
::
string
graph_path
=
string
::
Sprintf
(
"%s%s"
,
build_strategy
.
debug_graphviz_path_
.
c_str
(),
"_before_exec"
);
viz_pass
->
Set
<
std
::
string
>
(
"graph_viz_path"
,
new
std
::
string
(
graph_path
));
graph
=
viz_pass
->
Apply
(
std
::
move
(
graph
));
}
member_
->
executor_
.
reset
(
new
details
::
ThreadedSSAGraphExecutor
(
exec_strategy
,
member_
->
local_scopes_
,
places
,
std
::
move
(
graph
)));
...
...
@@ -165,11 +187,18 @@ void ParallelExecutor::BCastParamsToDevices(
// the initializing bcast, all vars would be bcast from device(0),
// otherwise
// bcast from the specified device.
bool
initializing
=
builder_
.
get
()
==
nullptr
?
true
:
false
;
bool
initializing
=
member_
->
executor_
?
false
:
true
;
for
(
auto
&
var
:
vars
)
{
int
var_dev_id
=
builder_
.
get
()
==
nullptr
?
-
1
:
builder_
->
GetVarDeviceID
(
var
);
int
var_dev_id
=
-
1
;
if
(
member_
->
executor_
)
{
auto
&
sharded_var_device
=
member_
->
executor_
->
Graph
().
Get
<
details
::
ShardedVarDevice
>
(
"sharded_var_device"
);
if
(
sharded_var_device
.
find
(
var
)
!=
sharded_var_device
.
end
())
{
var_dev_id
=
sharded_var_device
.
at
(
var
);
}
}
if
(
!
initializing
&&
var_dev_id
==
-
1
)
continue
;
framework
::
Variable
*
main_var
=
nullptr
;
...
...
@@ -307,3 +336,6 @@ ParallelExecutor::~ParallelExecutor() {
}
// namespace paddle
USE_PASS
(
graph_viz_pass
);
USE_PASS
(
multi_device_pass
);
USE_PASS
(
multi_device_check_pass
);
USE_PASS
(
multi_device_print_pass
);
paddle/fluid/framework/parallel_executor.h
浏览文件 @
e4d7d7ae
...
...
@@ -70,7 +70,6 @@ class ParallelExecutor {
private:
ParallelExecutorPrivate
*
member_
;
std
::
unique_ptr
<
details
::
SSAGraphBuilder
>
builder_
;
};
}
// namespace framework
...
...
paddle/fluid/operators/distributed/send_recv.proto
0 → 100644
浏览文件 @
e4d7d7ae
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. Licensed under
the Apache License, Version 2.0 (the "License"); you may not use this file
except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
syntax
=
"proto3"
;
package
sendrecv
;
option
cc_generic_services
=
false
;
service
SendRecvService
{
// For parameter server round-robin like hashing, do not split tensors.
// Send and recv only one tensor
// TODO(typhoonzero): add streaming API
rpc
SendVariable
(
VariableMessage
)
returns
(
VoidMessage
)
{}
// Argument VariableMessage for GetVariable should only contain varname.
rpc
GetVariable
(
VariableMessage
)
returns
(
VariableMessage
)
{}
// pre-fetch variable by given variable name and Ids
rpc
PrefetchVariable
(
VariableMessage
)
returns
(
VariableMessage
)
{}
rpc
CheckpointNotify
(
VariableMessage
)
returns
(
VoidMessage
)
{}
}
// VariableMessage is serialized paddle variable message.
// It can be:
// LoDTensor
// SelectedRows
enum
VarType
{
LOD_TENSOR
=
0
;
SELECTED_ROWS
=
1
;
NCCL_ID
=
2
;
}
// NOTICE(gongwb):don't modify this proto if you are not
// not familar with how we serialize in sendrecvop_utils.h
// and deserilize it in variable_response.h.
message
VariableMessage
{
enum
Type
{
// Pod Types
BOOL
=
0
;
INT16
=
1
;
INT32
=
2
;
INT64
=
3
;
FP16
=
4
;
FP32
=
5
;
FP64
=
6
;
}
message
LodData
{
repeated
int64
lod_data
=
1
;
}
string
varname
=
1
;
// TODO(Yancey1989): reference framework::proto::VarDesc::VarType
VarType
type
=
2
;
// bool persistable is not needed for sending.
// tensor info:
Type
data_type
=
3
;
repeated
int64
dims
=
4
;
// lod details:
int64
lod_level
=
5
;
repeated
LodData
lod
=
6
;
// selected_rows height, aka. original dim0
int64
slr_height
=
7
;
// tensor data
bytes
serialized
=
8
;
// selected_rows data
bytes
rows
=
9
;
// Look up table block execution output variable name.
string
out_varname
=
10
;
// If 1, the ps server will start profiling, the ps
// server stops profiling and generates a profile to /tmp/profile_ps_*
// when profile switches from 1 to 2.
int64
profile
=
11
;
}
message
VoidMessage
{}
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录