Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
93401c98
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
93401c98
编写于
6月 06, 2018
作者:
Y
Yancey1989
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
overlap rpc op memcpy in distributed training
上级
df87e63b
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
82 addition
and
20 deletion
+82
-20
paddle/fluid/framework/details/multi_devices_graph_builder.cc
...le/fluid/framework/details/multi_devices_graph_builder.cc
+49
-9
paddle/fluid/framework/details/multi_devices_graph_builder.h
paddle/fluid/framework/details/multi_devices_graph_builder.h
+12
-2
paddle/fluid/framework/parallel_executor.cc
paddle/fluid/framework/parallel_executor.cc
+18
-9
paddle/fluid/framework/parallel_executor.h
paddle/fluid/framework/parallel_executor.h
+3
-0
未找到文件。
paddle/fluid/framework/details/multi_devices_graph_builder.cc
浏览文件 @
93401c98
...
...
@@ -191,15 +191,54 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
};
bool
is_forwarding
=
true
;
std
::
unordered_map
<
std
::
string
,
int
>
rpc_var_device_mapping
;
int
rpc_op_device_id
=
0
;
auto
schedule_rpc_op
=
[
&
]()
->
void
{
rpc_op_device_id
++
;
if
(
rpc_op_device_id
>=
static_cast
<
int
>
(
places_
.
size
()))
{
rpc_op_device_id
=
0
;
}
};
for
(
auto
*
op
:
program
.
Block
(
0
).
AllOps
())
{
if
(
boost
::
get
<
int
>
(
op
->
GetAttr
(
OpProtoAndCheckerMaker
::
OpRoleAttrName
()))
==
static_cast
<
int
>
(
OpRole
::
kRPC
))
{
// append rpc op if program is distributed trainer main program.
// always use the first device
CreateRPCOp
(
&
result
,
*
op
);
if
(
op
->
Type
()
==
"send_vars"
)
{
auto
got
=
remote_vars_devices_
.
find
(
op
->
InputArgumentNames
()[
0
]);
if
(
got
==
remote_vars_devices_
.
end
())
{
schedule_rpc_op
();
}
else
{
rpc_op_device_id
=
got
->
second
;
}
CreateRPCOp
(
&
result
,
*
op
,
rpc_op_device_id
);
}
else
if
(
op
->
Type
()
==
"recv"
)
{
schedule_rpc_op
();
for
(
auto
&
varname
:
op
->
OutputArgumentNames
())
{
remote_vars_devices_
.
insert
({
varname
,
rpc_op_device_id
});
}
CreateRPCOp
(
&
result
,
*
op
,
rpc_op_device_id
);
}
else
{
CreateRPCOp
(
&
result
,
*
op
,
0
);
}
}
else
if
(
IsDistTrainOp
(
*
op
,
send_vars
,
recv_vars
))
{
CreateDistTrainOp
(
&
result
,
*
op
);
if
(
op
->
Type
()
==
"split_byref"
)
{
schedule_rpc_op
();
for
(
auto
&
varname
:
op
->
OutputArgumentNames
())
{
remote_vars_devices_
.
insert
({
varname
,
rpc_op_device_id
});
}
CreateDistTrainOp
(
&
result
,
*
op
,
rpc_op_device_id
);
}
if
(
op
->
Type
()
==
"oncat"
)
{
auto
got
=
remote_vars_devices_
.
find
(
op
->
InputArgumentNames
()[
0
]);
PADDLE_ENFORCE_NE
(
got
!=
remote_vars_devices_
.
end
(),
"can not find right place to concat received var."
);
CreateDistTrainOp
(
&
result
,
*
op
,
got
->
second
);
}
else
{
CreateDistTrainOp
(
&
result
,
*
op
,
0
);
}
}
else
if
(
IsScaleLossOp
(
*
op
))
{
// user can customize loss@grad if not use_default_grad_scale_
if
(
strategy_
.
gradient_scale_
!=
...
...
@@ -464,17 +503,18 @@ void MultiDevSSAGraphBuilder::ConnectOp(SSAGraph *result, OpHandleBase *op,
}
void
MultiDevSSAGraphBuilder
::
CreateDistTrainOp
(
SSAGraph
*
result
,
const
OpDesc
&
op
)
const
{
CreateComputationalOp
(
result
,
op
,
0
);
const
OpDesc
&
op
,
int
place_id
)
const
{
CreateComputationalOp
(
result
,
op
,
place_id
);
if
(
op
.
Type
()
==
"concat"
)
{
ConnectOp
(
result
,
result
->
ops_
.
back
().
get
(),
"fetch_barrier"
);
}
}
void
MultiDevSSAGraphBuilder
::
CreateRPCOp
(
SSAGraph
*
result
,
const
OpDesc
&
op
)
const
{
auto
&
p
=
places_
[
0
];
auto
*
s
=
local_scopes_
[
0
];
void
MultiDevSSAGraphBuilder
::
CreateRPCOp
(
SSAGraph
*
result
,
const
OpDesc
&
op
,
int
place_id
)
const
{
auto
&
p
=
places_
[
place_id
];
auto
*
s
=
local_scopes_
[
place_id
];
result
->
ops_
.
emplace_back
(
new
RPCOpHandle
(
op
,
s
,
p
,
op
.
Type
()));
if
(
op
.
Type
()
==
"send_barrier"
)
{
...
...
@@ -493,7 +533,7 @@ void MultiDevSSAGraphBuilder::CreateRPCOp(SSAGraph *result,
// TODO(Yancey1989): schedule rpc op on different place may
// increate throughput
CreateOpHandleIOs
(
result
,
op
,
0
);
CreateOpHandleIOs
(
result
,
op
,
place_id
);
}
bool
MultiDevSSAGraphBuilder
::
IsScaleLossOp
(
const
OpDesc
&
op
)
const
{
...
...
paddle/fluid/framework/details/multi_devices_graph_builder.h
浏览文件 @
93401c98
...
...
@@ -48,6 +48,14 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
std
::
unique_ptr
<
SSAGraph
>
Build
(
const
ProgramDesc
&
program
)
const
override
;
int
GetRemoteVarDevice
(
const
std
::
string
&
var_name
)
const
{
auto
got
=
remote_vars_devices_
.
find
(
var_name
);
if
(
got
!=
remote_vars_devices_
.
end
())
{
return
got
->
second
;
}
return
-
1
;
}
private:
void
CreateOpHandleIOs
(
SSAGraph
*
result
,
const
OpDesc
&
op
,
size_t
place_id
)
const
;
...
...
@@ -64,8 +72,9 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
bool
IsScaleLossOp
(
const
OpDesc
&
op
)
const
;
void
CreateRPCOp
(
SSAGraph
*
result
,
const
OpDesc
&
op
)
const
;
void
CreateDistTrainOp
(
SSAGraph
*
result
,
const
OpDesc
&
op
)
const
;
void
CreateRPCOp
(
SSAGraph
*
result
,
const
OpDesc
&
op
,
int
place_id
)
const
;
void
CreateDistTrainOp
(
SSAGraph
*
result
,
const
OpDesc
&
op
,
int
place_id
)
const
;
/**
* Is this operator as the end-point operator before/after send operator.
...
...
@@ -111,6 +120,7 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
private:
BuildStrategy
strategy_
;
mutable
std
::
unordered_map
<
std
::
string
,
int
>
remote_vars_devices_
;
};
}
// namespace details
}
// namespace framework
...
...
paddle/fluid/framework/parallel_executor.cc
浏览文件 @
93401c98
...
...
@@ -22,7 +22,6 @@ limitations under the License. */
#include "paddle/fluid/platform/nccl_helper.h"
#endif
#include "paddle/fluid/framework/details/multi_devices_graph_builder.h"
#include "paddle/fluid/framework/details/threaded_ssa_graph_executor.h"
#include "paddle/fluid/platform/profiler.h"
...
...
@@ -97,15 +96,17 @@ ParallelExecutor::ParallelExecutor(
// Step 2. Convert main_program to SSA form and dependency graph. Also, insert
// ncclOp
#ifdef PADDLE_WITH_CUDA
details
::
MultiDevSSAGraphBuilder
b
uilder
(
builder_
.
reset
(
new
details
::
MultiDevSSAGraphB
uilder
(
member_
->
places_
,
loss_var_name
,
params
,
member_
->
local_scopes_
,
member_
->
nccl_ctxs_
.
get
(),
build_strategy
);
member_
->
nccl_ctxs_
.
get
(),
build_strategy
));
#else
details
::
MultiDevSSAGraphBuilder
builder
(
member_
->
places_
,
loss_var_name
,
params
,
member_
->
local_scopes_
,
build_strategy
);
builder_
.
reset
(
new
details
::
MultiDevSSAGraphBuilder
(
member_
->
places_
,
loss_var_name
,
params
,
member_
->
local_scope_
,
build_strategy
));
#endif
auto
graph
=
builder
.
Build
(
main_program
);
auto
graph
=
builder
_
.
get
()
->
Build
(
main_program
);
member_
->
executor_
.
reset
(
new
details
::
ThreadedSSAGraphExecutor
(
exec_strategy
,
member_
->
local_scopes_
,
places
,
std
::
move
(
graph
)));
...
...
@@ -146,8 +147,16 @@ void ParallelExecutor::BCastParamsToGPUs(
buffer
=
t
->
mutable_data
(
place
,
main_tensor
.
type
());
}
auto
&
nccl_ctx
=
member_
->
nccl_ctxs_
->
at
(
place
);
platform
::
dynload
::
ncclBcast
(
buffer
,
numel
,
data_type
,
0
,
nccl_ctx
.
comm_
,
nccl_ctx
.
stream
());
if
(
builder_
.
get
()
!=
nullptr
&&
builder_
->
GetRemoteVarDevice
(
var
)
!=
-
1
)
{
int
place_id
=
builder_
->
GetRemoteVarDevice
(
var
);
platform
::
dynload
::
ncclBcast
(
buffer
,
numel
,
data_type
,
place_id
,
nccl_ctx
.
comm_
,
nccl_ctx
.
stream
());
}
else
{
platform
::
dynload
::
ncclBcast
(
buffer
,
numel
,
data_type
,
0
,
nccl_ctx
.
comm_
,
nccl_ctx
.
stream
());
}
}
}
else
{
platform
::
CPUPlace
cpu
;
...
...
paddle/fluid/framework/parallel_executor.h
浏览文件 @
93401c98
...
...
@@ -19,12 +19,14 @@ limitations under the License. */
#include <unordered_set>
#include <vector>
#include "paddle/fluid/framework/details/execution_strategy.h"
#include "paddle/fluid/framework/details/multi_devices_graph_builder.h"
#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/op_info.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/platform/device_context.h"
namespace
paddle
{
namespace
framework
{
...
...
@@ -68,6 +70,7 @@ class ParallelExecutor {
private:
ParallelExecutorPrivate
*
member_
;
std
::
unique_ptr
<
details
::
MultiDevSSAGraphBuilder
>
builder_
;
};
}
// namespace framework
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录