Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
be0c4823
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
be0c4823
编写于
3月 25, 2019
作者:
Q
Qiao Longfei
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
update trainer_id
上级
c60f312d
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
22 addition
and
13 deletion
+22
-13
paddle/fluid/framework/details/async_ssa_graph_executor.cc
paddle/fluid/framework/details/async_ssa_graph_executor.cc
+7
-2
paddle/fluid/operators/distributed/parameter_recv.cc
paddle/fluid/operators/distributed/parameter_recv.cc
+2
-2
paddle/fluid/operators/distributed/parameter_send.cc
paddle/fluid/operators/distributed/parameter_send.cc
+1
-1
paddle/fluid/operators/distributed/rpc_common.h
paddle/fluid/operators/distributed/rpc_common.h
+5
-2
paddle/fluid/operators/distributed_ops/recv_op.cc
paddle/fluid/operators/distributed_ops/recv_op.cc
+4
-3
paddle/fluid/operators/distributed_ops/send_op.cc
paddle/fluid/operators/distributed_ops/send_op.cc
+3
-3
未找到文件。
paddle/fluid/framework/details/async_ssa_graph_executor.cc
浏览文件 @
be0c4823
...
...
@@ -60,9 +60,12 @@ void ProcessGraph(std::vector<ir::Graph *> graphs, Scope *scope) {
node
->
Op
()
->
GetNullableAttr
(
"epmap"
));
auto
height_section
=
boost
::
get
<
std
::
vector
<
int64_t
>>
(
node
->
Op
()
->
GetNullableAttr
(
"sections"
));
auto
trainer_id
=
boost
::
get
<
int
>
(
node
->
Op
()
->
GetNullableAttr
(
"trainer_id"
));
send_varname_to_ctx
[
send_var_name
]
=
operators
::
distributed
::
RpcContext
(
send_var_name
,
send_varnames
,
epmap
,
height_section
);
epmap
,
height_section
,
trainer_id
);
VLOG
(
3
)
<<
"find and init an send op: "
<<
send_varname_to_ctx
[
send_var_name
];
}
else
if
(
node
->
Name
()
==
"recv"
)
{
...
...
@@ -71,9 +74,11 @@ void ProcessGraph(std::vector<ir::Graph *> graphs, Scope *scope) {
node
->
Op
()
->
GetNullableAttr
(
"recv_varnames"
));
auto
epmap
=
boost
::
get
<
std
::
vector
<
std
::
string
>>
(
node
->
Op
()
->
GetNullableAttr
(
"epmap"
));
auto
trainer_id
=
boost
::
get
<
int
>
(
node
->
Op
()
->
GetNullableAttr
(
"trainer_id"
));
recv_varname_to_ctx
[
recv_var_name
]
=
operators
::
distributed
::
RpcContext
(
recv_var_name
,
recv_varnames
,
epmap
,
{});
epmap
,
{}
,
trainer_id
);
nodes_to_delete
.
push_back
(
node
);
VLOG
(
3
)
<<
"find and remove an recv op: "
<<
recv_varname_to_ctx
[
recv_var_name
];
...
...
paddle/fluid/operators/distributed/parameter_recv.cc
浏览文件 @
be0c4823
...
...
@@ -48,7 +48,7 @@ void ParameterRecv<T>::operator()(const RpcContext &rpc_ctx,
auto
&
cpu_ctx
=
*
pool
.
Get
(
platform
::
CPUPlace
());
distributed
::
RPCClient
*
rpc_client
=
distributed
::
RPCClient
::
GetInstance
<
RPCCLIENT_T
>
(
0
);
distributed
::
RPCClient
::
GetInstance
<
RPCCLIENT_T
>
(
rpc_ctx
.
trainer_id
);
auto
*
recv_var
=
scope
.
FindVar
(
rpc_ctx
.
var_name
);
...
...
@@ -112,7 +112,7 @@ void ParameterRecv<T>::operator()(const RpcContext &rpc_ctx,
// FIXME(qiao): use a trick to avoid the bug of recv an selected rows
for
(
auto
i
=
1
;
i
<
recv_slr
.
rows
().
size
();
++
i
)
{
auto
row_id
=
recv_slr
.
rows
()[
i
]
+
row_offset
;
PADDLE_ENFORCE_LT
(
row_id
,
recv_dims
[
1
]);
PADDLE_ENFORCE_LT
(
row_id
,
recv_dims
[
0
]);
memcpy
(
recv_tensor
->
data
<
T
>
()
+
row_id
*
width
,
recv_slr
.
value
().
data
<
T
>
()
+
i
*
width
,
sizeof
(
T
)
*
width
);
}
...
...
paddle/fluid/operators/distributed/parameter_send.cc
浏览文件 @
be0c4823
...
...
@@ -46,7 +46,7 @@ void ParameterSend<T>::operator()(const RpcContext &rpc_ctx,
auto
&
cpu_ctx
=
*
pool
.
Get
(
platform
::
CPUPlace
());
distributed
::
RPCClient
*
rpc_client
=
distributed
::
RPCClient
::
GetInstance
<
RPCCLIENT_T
>
(
0
);
distributed
::
RPCClient
::
GetInstance
<
RPCCLIENT_T
>
(
rpc_ctx
.
trainer_id
);
auto
*
send_var
=
scope
.
FindVar
(
rpc_ctx
.
var_name
);
size_t
out_num
=
rpc_ctx
.
splited_var_names
.
size
();
...
...
paddle/fluid/operators/distributed/rpc_common.h
浏览文件 @
be0c4823
...
...
@@ -27,23 +27,26 @@ struct RpcContext {
RpcContext
(
const
std
::
string
&
name
,
const
std
::
vector
<
std
::
string
>
&
names
,
const
std
::
vector
<
std
::
string
>
&
emap
,
const
std
::
vector
<
int64_t
>
&
sections
)
const
std
::
vector
<
int64_t
>
&
sections
,
int
id
)
:
var_name
(
name
),
splited_var_names
(
names
),
epmap
(
emap
),
height_sections
(
sections
)
{}
height_sections
(
sections
),
trainer_id
(
id
)
{}
RpcContext
(
const
RpcContext
&
ctx
)
{
var_name
=
ctx
.
var_name
;
splited_var_names
=
ctx
.
splited_var_names
;
epmap
=
ctx
.
epmap
;
height_sections
=
ctx
.
height_sections
;
trainer_id
=
ctx
.
trainer_id
;
}
std
::
string
var_name
;
std
::
vector
<
std
::
string
>
splited_var_names
;
std
::
vector
<
std
::
string
>
epmap
;
std
::
vector
<
int64_t
>
height_sections
;
int
trainer_id
;
};
inline
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
RpcContext
&
rpc_ctx
)
{
...
...
paddle/fluid/operators/distributed_ops/recv_op.cc
浏览文件 @
be0c4823
...
...
@@ -50,17 +50,18 @@ class RecvOp : public framework::OperatorBase {
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
auto
&
ctx
=
*
pool
.
Get
(
place
);
auto
trainer_id
=
Attr
<
int
>
(
"trainer_id"
);
distributed
::
RPCClient
*
rpc_client
=
distributed
::
RPCClient
::
GetInstance
<
RPCCLIENT_T
>
(
Attr
<
int
>
(
"trainer_id"
));
distributed
::
RPCClient
::
GetInstance
<
RPCCLIENT_T
>
(
trainer_id
);
std
::
vector
<
std
::
string
>
recv_varnames
=
Attr
<
std
::
vector
<
std
::
string
>>
(
"recv_varnames"
);
if
(
recv_varnames
.
size
()
>
0
)
{
auto
recv_functor
=
distributed
::
ParameterRecv
<
float
>
();
auto
rpc_ctx
=
distributed
::
RpcContext
(
outs
[
0
],
recv_varnames
,
epmap
,
{});
auto
rpc_ctx
=
distributed
::
RpcContext
(
outs
[
0
],
recv_varnames
,
epmap
,
{},
trainer_id
);
recv_functor
(
rpc_ctx
,
scope
);
}
else
{
if
(
with_barrier
)
{
...
...
paddle/fluid/operators/distributed_ops/send_op.cc
浏览文件 @
be0c4823
...
...
@@ -42,6 +42,7 @@ class SendOp : public framework::OperatorBase {
auto
epmap
=
Attr
<
std
::
vector
<
std
::
string
>>
(
"epmap"
);
int
sync_send
=
Attr
<
int
>
(
"sync_mode"
);
auto
trainer_id
=
Attr
<
int
>
(
"trainer_id"
);
auto
send_varnames
=
Attr
<
std
::
vector
<
std
::
string
>>
(
"send_varnames"
);
auto
height_sections
=
Attr
<
std
::
vector
<
int64_t
>>
(
"sections"
);
...
...
@@ -51,7 +52,7 @@ class SendOp : public framework::OperatorBase {
/*
auto send_functor = distributed::ParameterSend<float>();
auto rpc_ctx = distributed::RpcContext(ins[0], send_varnames, epmap,
height_sections);
height_sections
, trainer_id
);
send_functor(rpc_ctx, scope, static_cast<bool>(sync_send));
*/
VLOG
(
3
)
<<
"send "
<<
ins
[
0
];
...
...
@@ -63,8 +64,7 @@ class SendOp : public framework::OperatorBase {
auto
&
ctx
=
*
pool
.
Get
(
place
);
distributed
::
RPCClient
*
rpc_client
=
distributed
::
RPCClient
::
GetInstance
<
RPCCLIENT_T
>
(
Attr
<
int
>
(
"trainer_id"
));
distributed
::
RPCClient
::
GetInstance
<
RPCCLIENT_T
>
(
trainer_id
);
std
::
vector
<
distributed
::
VarHandlePtr
>
rets
;
for
(
size_t
i
=
0
;
i
<
ins
.
size
();
i
++
)
{
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录