Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
103c9bb3
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
103c9bb3
编写于
3月 24, 2019
作者:
Q
Qiao Longfei
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
update rpc_client
上级
b7661d7e
变更
9
显示空白变更内容
内联
并排
Showing
9 changed file
with
116 addition
and
56 deletion
+116
-56
paddle/fluid/operators/distributed/async_sparse_param_update_recorder.h
...perators/distributed/async_sparse_param_update_recorder.h
+35
-13
paddle/fluid/operators/distributed/brpc/brpc_client.cc
paddle/fluid/operators/distributed/brpc/brpc_client.cc
+2
-1
paddle/fluid/operators/distributed/brpc/brpc_client.h
paddle/fluid/operators/distributed/brpc/brpc_client.h
+6
-7
paddle/fluid/operators/distributed/grpc/grpc_client.cc
paddle/fluid/operators/distributed/grpc/grpc_client.cc
+31
-25
paddle/fluid/operators/distributed/grpc/grpc_client.h
paddle/fluid/operators/distributed/grpc/grpc_client.h
+3
-1
paddle/fluid/operators/distributed/grpc/grpc_server.cc
paddle/fluid/operators/distributed/grpc/grpc_server.cc
+4
-1
paddle/fluid/operators/distributed/parameter_recv.cc
paddle/fluid/operators/distributed/parameter_recv.cc
+18
-5
paddle/fluid/operators/distributed/request_handler_impl.cc
paddle/fluid/operators/distributed/request_handler_impl.cc
+16
-3
paddle/fluid/operators/distributed/rpc_client.h
paddle/fluid/operators/distributed/rpc_client.h
+1
-0
未找到文件。
paddle/fluid/operators/distributed/async_sparse_param_update_recorder.h
浏览文件 @
103c9bb3
...
@@ -37,7 +37,16 @@ class ConcurrentSet {
...
@@ -37,7 +37,16 @@ class ConcurrentSet {
~
ConcurrentSet
()
{}
~
ConcurrentSet
()
{}
std
::
future
<
void
>
Update
(
const
std
::
vector
<
int64_t
>&
rows
)
{
std
::
future
<
void
>
Update
(
const
std
::
vector
<
int64_t
>&
rows
)
{
auto
task
=
[
this
,
&
rows
]
{
auto
task
=
[
this
,
rows
]
{
if
(
VLOG_IS_ON
(
3
))
{
std
::
ostringstream
sstream
;
sstream
<<
"["
;
for
(
auto
&
id
:
rows
)
{
sstream
<<
id
<<
", "
;
}
sstream
<<
"]"
;
VLOG
(
3
)
<<
"update ids -> "
<<
sstream
.
str
();
}
for
(
auto
row
:
rows
)
{
for
(
auto
row
:
rows
)
{
set_
.
insert
(
row
);
set_
.
insert
(
row
);
}
}
...
@@ -46,9 +55,21 @@ class ConcurrentSet {
...
@@ -46,9 +55,21 @@ class ConcurrentSet {
}
}
std
::
future
<
void
>
GetAndClear
(
std
::
vector
<
int64_t
>*
result
)
{
std
::
future
<
void
>
GetAndClear
(
std
::
vector
<
int64_t
>*
result
)
{
auto
task
=
[
this
,
result
]
{
auto
task
=
[
this
,
&
result
]
{
result
->
clear
();
result
->
clear
();
result
->
insert
(
result
->
end
(),
set_
.
begin
(),
set_
.
end
());
for
(
auto
&
id
:
set_
)
{
result
->
push_back
(
id
);
}
if
(
VLOG_IS_ON
(
3
))
{
std
::
ostringstream
sstream
;
sstream
<<
"["
;
for
(
auto
&
id
:
*
result
)
{
sstream
<<
id
<<
", "
;
}
sstream
<<
"]"
;
VLOG
(
3
)
<<
"result ids size: "
<<
result
->
size
()
<<
" "
<<
sstream
.
str
();
}
set_
.
clear
();
set_
.
clear
();
};
};
return
pool_
->
enqueue
(
std
::
move
(
task
));
return
pool_
->
enqueue
(
std
::
move
(
task
));
...
@@ -67,6 +88,7 @@ class AsyncSparseParamUpdateRecorder {
...
@@ -67,6 +88,7 @@ class AsyncSparseParamUpdateRecorder {
int
trainer_num
,
int
trainer_num
,
const
std
::
unordered_map
<
std
::
string
,
std
::
string
>&
grad_to_param
)
const
std
::
unordered_map
<
std
::
string
,
std
::
string
>&
grad_to_param
)
:
trainer_num_
(
trainer_num
),
grad_to_param_
(
grad_to_param
)
{
:
trainer_num_
(
trainer_num
),
grad_to_param_
(
grad_to_param
)
{
if
(
VLOG_IS_ON
(
3
))
{
std
::
ostringstream
sstream
;
std
::
ostringstream
sstream
;
sstream
<<
"["
;
sstream
<<
"["
;
for
(
auto
&
item
:
grad_to_param
)
{
for
(
auto
&
item
:
grad_to_param
)
{
...
@@ -74,7 +96,8 @@ class AsyncSparseParamUpdateRecorder {
...
@@ -74,7 +96,8 @@ class AsyncSparseParamUpdateRecorder {
}
}
sstream
<<
"]"
;
sstream
<<
"]"
;
VLOG
(
3
)
<<
"trainer_num: "
<<
trainer_num
VLOG
(
3
)
<<
"trainer_num: "
<<
trainer_num
<<
"grad_to_param_: "
<<
sstream
.
str
();
<<
" grad_to_param_: "
<<
sstream
.
str
();
}
for
(
auto
&
iter
:
grad_to_param
)
{
for
(
auto
&
iter
:
grad_to_param
)
{
param_to_grad_
[
iter
.
second
]
=
iter
.
first
;
param_to_grad_
[
iter
.
second
]
=
iter
.
first
;
auto
&
param_name
=
iter
.
second
;
auto
&
param_name
=
iter
.
second
;
...
@@ -103,13 +126,12 @@ class AsyncSparseParamUpdateRecorder {
...
@@ -103,13 +126,12 @@ class AsyncSparseParamUpdateRecorder {
void
GetAndClear
(
const
std
::
string
&
param_name
,
int
trainer_id
,
void
GetAndClear
(
const
std
::
string
&
param_name
,
int
trainer_id
,
std
::
vector
<
int64_t
>*
result
)
{
std
::
vector
<
int64_t
>*
result
)
{
VLOG
(
3
)
<<
"GetAndClear param: "
<<
param_name
<<
" for trainer: "
<<
trainer_id
;
PADDLE_ENFORCE_LT
(
trainer_id
,
trainer_num_
);
PADDLE_ENFORCE_LT
(
trainer_id
,
trainer_num_
);
param_to_updated_rows_
.
at
(
param_name
)[
trainer_id
]
param_to_updated_rows_
.
at
(
param_name
)[
trainer_id
]
->
GetAndClear
(
result
)
->
GetAndClear
(
result
)
.
wait
();
.
wait
();
VLOG
(
3
)
<<
"GetAndClear param: "
<<
param_name
<<
" for trainer: "
<<
trainer_id
<<
" with size: "
<<
result
->
size
();
}
}
bool
HasParam
(
const
std
::
string
&
param_name
)
{
bool
HasParam
(
const
std
::
string
&
param_name
)
{
...
...
paddle/fluid/operators/distributed/brpc/brpc_client.cc
浏览文件 @
103c9bb3
...
@@ -234,9 +234,10 @@ VarHandlePtr BRPCClient::AsyncGetVar(const std::string& ep,
...
@@ -234,9 +234,10 @@ VarHandlePtr BRPCClient::AsyncGetVar(const std::string& ep,
const
framework
::
Scope
&
scope
,
const
framework
::
Scope
&
scope
,
const
std
::
string
&
var_name
,
const
std
::
string
&
var_name
,
const
std
::
string
&
out_var_name
,
const
std
::
string
&
out_var_name
,
const
std
::
string
&
table_name
,
int64_t
time_out
)
{
int64_t
time_out
)
{
return
_AsyncGetVar
(
ep
,
ctx
,
scope
,
var_name
,
out_var_name
,
kGetRPC
,
return
_AsyncGetVar
(
ep
,
ctx
,
scope
,
var_name
,
out_var_name
,
kGetRPC
,
time_out
);
t
able_name
t
ime_out
);
}
}
VarHandlePtr
BRPCClient
::
AsyncPrefetchVar
(
const
std
::
string
&
ep
,
VarHandlePtr
BRPCClient
::
AsyncPrefetchVar
(
const
std
::
string
&
ep
,
...
...
paddle/fluid/operators/distributed/brpc/brpc_client.h
浏览文件 @
103c9bb3
...
@@ -66,6 +66,7 @@ class BRPCClient : public RPCClient {
...
@@ -66,6 +66,7 @@ class BRPCClient : public RPCClient {
const
framework
::
Scope
&
scope
,
const
framework
::
Scope
&
scope
,
const
std
::
string
&
var_name
,
const
std
::
string
&
var_name
,
const
std
::
string
&
out_var_name
,
const
std
::
string
&
out_var_name
,
const
std
::
string
&
table_name
=
""
,
int64_t
time_out
=
FLAGS_rpc_deadline
)
override
;
int64_t
time_out
=
FLAGS_rpc_deadline
)
override
;
VarHandlePtr
AsyncGetMonomerBarrier
(
VarHandlePtr
AsyncGetMonomerBarrier
(
...
@@ -107,13 +108,11 @@ class BRPCClient : public RPCClient {
...
@@ -107,13 +108,11 @@ class BRPCClient : public RPCClient {
void
SendComplete
()
override
;
void
SendComplete
()
override
;
private:
private:
VarHandlePtr
_AsyncGetVar
(
const
std
::
string
&
ep
,
VarHandlePtr
_AsyncGetVar
(
const
platform
::
DeviceContext
&
ctx
,
const
std
::
string
&
ep
,
const
platform
::
DeviceContext
&
ctx
,
const
framework
::
Scope
&
scope
,
const
framework
::
Scope
&
scope
,
const
std
::
string
&
var_name
,
const
std
::
string
&
var_name
,
const
std
::
string
&
out_var_name
,
const
std
::
string
&
method_name
,
const
std
::
string
&
out_var_name
,
const
std
::
string
&
table_name
,
int64_t
time_out
=
FLAGS_rpc_deadline
);
const
std
::
string
&
method_name
,
int64_t
time_out
=
FLAGS_rpc_deadline
);
void
Proceed
();
void
Proceed
();
ChannelQueuePtr
GetChannel
(
const
std
::
string
&
ep
);
ChannelQueuePtr
GetChannel
(
const
std
::
string
&
ep
);
...
...
paddle/fluid/operators/distributed/grpc/grpc_client.cc
浏览文件 @
103c9bb3
...
@@ -128,9 +128,11 @@ VarHandlePtr GRPCClient::AsyncGetVar(const std::string& ep,
...
@@ -128,9 +128,11 @@ VarHandlePtr GRPCClient::AsyncGetVar(const std::string& ep,
const
framework
::
Scope
&
scope
,
const
framework
::
Scope
&
scope
,
const
std
::
string
&
var_name
,
const
std
::
string
&
var_name
,
const
std
::
string
&
out_varname
,
const
std
::
string
&
out_varname
,
const
std
::
string
&
table_name
,
int64_t
time_out
)
{
int64_t
time_out
)
{
return
_AsyncGetVar
(
ep
,
ctx
,
scope
,
kGetRPC
,
var_name
,
out_varname
,
return
_AsyncGetVar
(
ep
,
ctx
,
scope
,
kGetRPC
,
var_name
,
out_varname
,
"/sendrecv.SendRecvService/GetVariable"
,
time_out
);
"/sendrecv.SendRecvService/GetVariable"
,
table_name
,
time_out
);
}
}
VarHandlePtr
GRPCClient
::
AsyncGetVarNoBarrier
(
VarHandlePtr
GRPCClient
::
AsyncGetVarNoBarrier
(
...
@@ -142,7 +144,7 @@ VarHandlePtr GRPCClient::AsyncGetVarNoBarrier(
...
@@ -142,7 +144,7 @@ VarHandlePtr GRPCClient::AsyncGetVarNoBarrier(
return
_AsyncGetVar
(
return
_AsyncGetVar
(
ep
,
ctx
,
scope
,
kGetNoBarrierRPC
,
var_name_no_barrier
,
out_varname
,
ep
,
ctx
,
scope
,
kGetNoBarrierRPC
,
var_name_no_barrier
,
out_varname
,
"/sendrecv.SendRecvService/GetVariableNoBarrier"
,
time_out
);
"/sendrecv.SendRecvService/GetVariableNoBarrier"
,
""
,
time_out
);
}
}
VarHandlePtr
GRPCClient
::
AsyncGetMonomerVariable
(
VarHandlePtr
GRPCClient
::
AsyncGetMonomerVariable
(
...
@@ -150,18 +152,21 @@ VarHandlePtr GRPCClient::AsyncGetMonomerVariable(
...
@@ -150,18 +152,21 @@ VarHandlePtr GRPCClient::AsyncGetMonomerVariable(
const
framework
::
Scope
&
scope
,
const
std
::
string
&
var_name
,
const
framework
::
Scope
&
scope
,
const
std
::
string
&
var_name
,
int64_t
time_out
)
{
int64_t
time_out
)
{
return
_AsyncGetVar
(
ep
,
ctx
,
scope
,
kGetMonomerRPC
,
var_name
,
var_name
,
return
_AsyncGetVar
(
ep
,
ctx
,
scope
,
kGetMonomerRPC
,
var_name
,
var_name
,
"/sendrecv.SendRecvService/GetMonomerVariable"
,
time_out
);
"/sendrecv.SendRecvService/GetMonomerVariable"
,
""
,
time_out
);
}
}
VarHandlePtr
GRPCClient
::
_AsyncGetVar
(
VarHandlePtr
GRPCClient
::
_AsyncGetVar
(
const
std
::
string
&
ep
,
const
platform
::
DeviceContext
&
ctx
,
const
std
::
string
&
ep
,
const
platform
::
DeviceContext
&
ctx
,
const
framework
::
Scope
&
scope
,
const
std
::
string
&
method
,
const
framework
::
Scope
&
scope
,
const
std
::
string
&
method
,
const
std
::
string
&
var_name
,
const
std
::
string
&
out_varname
,
const
std
::
string
&
var_name
,
const
std
::
string
&
out_varname
,
const
std
::
string
&
rpc_path
,
int64_t
time_out
)
{
const
std
::
string
&
rpc_path
,
const
std
::
string
&
table_name
,
int64_t
time_out
)
{
const
platform
::
DeviceContext
*
p_ctx
=
&
ctx
;
const
platform
::
DeviceContext
*
p_ctx
=
&
ctx
;
const
std
::
string
ep_val
=
ep
;
const
std
::
string
ep_val
=
ep
;
const
std
::
string
var_name_val
=
var_name
;
const
std
::
string
var_name_val
=
var_name
;
const
std
::
string
out_varname_val
=
out_varname
;
const
std
::
string
out_varname_val
=
out_varname
;
const
std
::
string
table_name_val
=
table_name
;
const
framework
::
Scope
*
p_scope
=
&
scope
;
const
framework
::
Scope
*
p_scope
=
&
scope
;
const
auto
ch
=
GetChannel
(
ep_val
);
const
auto
ch
=
GetChannel
(
ep_val
);
GetProcessor
*
s
=
new
GetProcessor
(
ch
);
GetProcessor
*
s
=
new
GetProcessor
(
ch
);
...
@@ -169,13 +174,14 @@ VarHandlePtr GRPCClient::_AsyncGetVar(
...
@@ -169,13 +174,14 @@ VarHandlePtr GRPCClient::_AsyncGetVar(
VarHandlePtr
h
(
new
VarHandle
(
ep
,
method
,
out_varname_val
,
p_ctx
,
p_scope
));
VarHandlePtr
h
(
new
VarHandle
(
ep
,
method
,
out_varname_val
,
p_ctx
,
p_scope
));
s
->
Prepare
(
h
,
time_out
);
s
->
Prepare
(
h
,
time_out
);
framework
::
AsyncIO
(
framework
::
AsyncIO
(
[
var_name_val
,
out_varname_val
,
table_name_val
,
s
,
method
,
[
var_name_val
,
out_varname_val
,
s
,
method
,
p_ctx
,
h
,
rpc_path
,
this
]
{
p_ctx
,
h
,
rpc_path
,
this
]
{
// prepare input
// prepare input
sendrecv
::
VariableMessage
req
;
sendrecv
::
VariableMessage
req
;
req
.
set_varname
(
var_name_val
);
req
.
set_varname
(
var_name_val
);
req
.
set_out_varname
(
out_varname_val
);
req
.
set_out_varname
(
out_varname_val
);
req
.
set_trainer_id
(
trainer_id_
);
req
.
set_trainer_id
(
trainer_id_
);
req
.
set_table_name
(
table_name_val
);
::
grpc
::
ByteBuffer
buf
;
::
grpc
::
ByteBuffer
buf
;
RequestToByteBuffer
<
sendrecv
::
VariableMessage
>
(
req
,
&
buf
);
RequestToByteBuffer
<
sendrecv
::
VariableMessage
>
(
req
,
&
buf
);
...
...
paddle/fluid/operators/distributed/grpc/grpc_client.h
浏览文件 @
103c9bb3
...
@@ -187,6 +187,7 @@ class GRPCClient : public RPCClient {
...
@@ -187,6 +187,7 @@ class GRPCClient : public RPCClient {
const
framework
::
Scope
&
scope
,
const
framework
::
Scope
&
scope
,
const
std
::
string
&
var_name
,
const
std
::
string
&
var_name
,
const
std
::
string
&
out_varname
,
const
std
::
string
&
out_varname
,
const
std
::
string
&
table_name
=
""
,
int64_t
time_out
=
FLAGS_rpc_deadline
)
override
;
int64_t
time_out
=
FLAGS_rpc_deadline
)
override
;
VarHandlePtr
AsyncGetVarNoBarrier
(
VarHandlePtr
AsyncGetVarNoBarrier
(
...
@@ -239,7 +240,8 @@ class GRPCClient : public RPCClient {
...
@@ -239,7 +240,8 @@ class GRPCClient : public RPCClient {
const
std
::
string
&
ep
,
const
platform
::
DeviceContext
&
ctx
,
const
std
::
string
&
ep
,
const
platform
::
DeviceContext
&
ctx
,
const
framework
::
Scope
&
scope
,
const
std
::
string
&
method
,
const
framework
::
Scope
&
scope
,
const
std
::
string
&
method
,
const
std
::
string
&
var_name
,
const
std
::
string
&
out_varname
,
const
std
::
string
&
var_name
,
const
std
::
string
&
out_varname
,
const
std
::
string
&
rpc_path
,
int64_t
time_out
=
FLAGS_rpc_deadline
);
const
std
::
string
&
rpc_path
,
const
std
::
string
&
table_name
=
""
,
int64_t
time_out
=
FLAGS_rpc_deadline
);
private:
private:
grpc
::
CompletionQueue
cq_
;
grpc
::
CompletionQueue
cq_
;
...
...
paddle/fluid/operators/distributed/grpc/grpc_server.cc
浏览文件 @
103c9bb3
...
@@ -136,6 +136,7 @@ class RequestGet final : public RequestBase {
...
@@ -136,6 +136,7 @@ class RequestGet final : public RequestBase {
// proc request.
// proc request.
std
::
string
varname
=
request_
.
varname
();
std
::
string
varname
=
request_
.
varname
();
std
::
string
out_varname
=
request_
.
out_varname
();
std
::
string
out_varname
=
request_
.
out_varname
();
std
::
string
table_name
=
request_
.
table_name
();
int
trainer_id
=
request_
.
trainer_id
();
int
trainer_id
=
request_
.
trainer_id
();
VLOG
(
4
)
<<
"RequestGet "
<<
out_varname
<<
" from "
<<
varname
;
VLOG
(
4
)
<<
"RequestGet "
<<
out_varname
<<
" from "
<<
varname
;
...
@@ -146,12 +147,14 @@ class RequestGet final : public RequestBase {
...
@@ -146,12 +147,14 @@ class RequestGet final : public RequestBase {
auto
*
tmp_scope
=
scope
->
NewTmpScope
();
auto
*
tmp_scope
=
scope
->
NewTmpScope
();
request_handler_
->
Handle
(
varname
,
tmp_scope
,
invar
,
&
outvar
,
trainer_id
,
request_handler_
->
Handle
(
varname
,
tmp_scope
,
invar
,
&
outvar
,
trainer_id
,
out_varname
);
out_varname
,
table_name
);
VLOG
(
1
)
<<
"before SerializeToByteBuffer"
;
if
(
outvar
)
{
if
(
outvar
)
{
SerializeToByteBuffer
(
out_varname
,
outvar
,
*
request_handler_
->
dev_ctx
(),
SerializeToByteBuffer
(
out_varname
,
outvar
,
*
request_handler_
->
dev_ctx
(),
&
reply_
);
&
reply_
);
}
}
VLOG
(
1
)
<<
"after SerializeToByteBuffer"
;
delete
tmp_scope
;
delete
tmp_scope
;
Finish
(
reply_
,
&
responder_
);
Finish
(
reply_
,
&
responder_
);
}
}
...
...
paddle/fluid/operators/distributed/parameter_recv.cc
浏览文件 @
103c9bb3
...
@@ -41,7 +41,7 @@ using DDim = framework::DDim;
...
@@ -41,7 +41,7 @@ using DDim = framework::DDim;
template
<
typename
T
>
template
<
typename
T
>
void
ParameterRecv
<
T
>::
operator
()(
const
RpcContext
&
rpc_ctx
,
void
ParameterRecv
<
T
>::
operator
()(
const
RpcContext
&
rpc_ctx
,
const
framework
::
Scope
&
scope
)
{
const
framework
::
Scope
&
scope
)
{
VLOG
(
3
)
<<
"ParameterRecv in
"
;
VLOG
(
3
)
<<
"ParameterRecv in
"
<<
rpc_ctx
.
var_name
;
framework
::
Scope
*
local_scope
=
scope
.
NewTmpScope
();
framework
::
Scope
*
local_scope
=
scope
.
NewTmpScope
();
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
...
@@ -61,7 +61,7 @@ void ParameterRecv<T>::operator()(const RpcContext &rpc_ctx,
...
@@ -61,7 +61,7 @@ void ParameterRecv<T>::operator()(const RpcContext &rpc_ctx,
VLOG
(
3
)
<<
"recv "
<<
recv_var_name
<<
" from "
<<
rpc_ctx
.
epmap
[
i
];
VLOG
(
3
)
<<
"recv "
<<
recv_var_name
<<
" from "
<<
rpc_ctx
.
epmap
[
i
];
rets
.
push_back
(
rpc_client
->
AsyncGetVar
(
rpc_ctx
.
epmap
[
i
],
cpu_ctx
,
rets
.
push_back
(
rpc_client
->
AsyncGetVar
(
rpc_ctx
.
epmap
[
i
],
cpu_ctx
,
*
local_scope
,
recv_var_name
,
*
local_scope
,
recv_var_name
,
recv_var_name
));
recv_var_name
,
recv_var_name
));
}
}
for
(
size_t
i
=
0
;
i
<
rets
.
size
();
i
++
)
{
for
(
size_t
i
=
0
;
i
<
rets
.
size
();
i
++
)
{
PADDLE_ENFORCE
(
rets
[
i
]
->
Wait
(),
"internal error in RPCClient"
);
PADDLE_ENFORCE
(
rets
[
i
]
->
Wait
(),
"internal error in RPCClient"
);
...
@@ -73,6 +73,7 @@ void ParameterRecv<T>::operator()(const RpcContext &rpc_ctx,
...
@@ -73,6 +73,7 @@ void ParameterRecv<T>::operator()(const RpcContext &rpc_ctx,
// concat recved tensor into one var
// concat recved tensor into one var
{
{
size_t
output_offset
=
0
;
size_t
output_offset
=
0
;
size_t
row_offset
=
0
;
framework
::
Tensor
*
recv_tensor
=
framework
::
Tensor
*
recv_tensor
=
recv_var
->
GetMutable
<
framework
::
LoDTensor
>
();
recv_var
->
GetMutable
<
framework
::
LoDTensor
>
();
auto
dev_ctx
=
paddle
::
platform
::
CPUDeviceContext
();
auto
dev_ctx
=
paddle
::
platform
::
CPUDeviceContext
();
...
@@ -92,16 +93,28 @@ void ParameterRecv<T>::operator()(const RpcContext &rpc_ctx,
...
@@ -92,16 +93,28 @@ void ParameterRecv<T>::operator()(const RpcContext &rpc_ctx,
auto
&
recv_slr
=
recv_var
->
Get
<
framework
::
SelectedRows
>
();
auto
&
recv_slr
=
recv_var
->
Get
<
framework
::
SelectedRows
>
();
auto
&
recv_dims
=
recv_tensor
->
dims
();
auto
&
recv_dims
=
recv_tensor
->
dims
();
int64_t
width
=
recv_dims
[
1
];
int64_t
width
=
recv_dims
[
1
];
PADDLE_ENFORCE_EQ
(
recv_slr
.
height
(),
recv_dims
[
0
])
;
recv_numel
+=
recv_slr
.
height
()
*
width
;
PADDLE_ENFORCE_EQ
(
recv_slr
.
value
().
dims
()[
1
],
width
);
PADDLE_ENFORCE_EQ
(
recv_slr
.
value
().
dims
()[
1
],
width
);
PADDLE_ENFORCE_EQ
(
recv_slr
.
value
().
dims
()[
0
],
recv_slr
.
rows
().
size
());
PADDLE_ENFORCE_EQ
(
recv_slr
.
value
().
dims
()[
0
],
recv_slr
.
rows
().
size
());
VLOG
(
3
)
<<
"recv slr "
<<
recv_var_name
<<
" dims "
VLOG
(
3
)
<<
"recv slr "
<<
recv_var_name
<<
" dims "
<<
recv_slr
.
value
().
dims
();
<<
recv_slr
.
value
().
dims
();
if
(
VLOG_IS_ON
(
3
))
{
std
::
ostringstream
sstream
;
sstream
<<
"["
;
for
(
auto
&
row_id
:
recv_slr
.
rows
())
{
sstream
<<
row_id
<<
", "
;
}
sstream
<<
"]"
;
VLOG
(
3
)
<<
"recv_slr size: "
<<
recv_slr
.
rows
().
size
()
<<
" "
<<
sstream
.
str
();
}
for
(
auto
i
=
0
;
i
<
recv_slr
.
rows
().
size
();
++
i
)
{
for
(
auto
i
=
0
;
i
<
recv_slr
.
rows
().
size
();
++
i
)
{
auto
row_id
=
recv_slr
.
rows
()[
i
];
auto
row_id
=
recv_slr
.
rows
()[
i
]
+
row_offset
;
PADDLE_ENFORCE_LT
(
row_id
,
recv_dims
[
1
]);
memcpy
(
recv_tensor
->
data
<
T
>
()
+
row_id
*
width
,
memcpy
(
recv_tensor
->
data
<
T
>
()
+
row_id
*
width
,
recv_slr
.
value
().
data
<
T
>
()
+
i
*
width
,
sizeof
(
T
)
*
width
);
recv_slr
.
value
().
data
<
T
>
()
+
i
*
width
,
sizeof
(
T
)
*
width
);
}
}
row_offset
+=
recv_slr
.
height
();
}
else
{
}
else
{
PADDLE_THROW
(
"unsupported recieved var type"
);
PADDLE_THROW
(
"unsupported recieved var type"
);
}
}
...
@@ -110,7 +123,7 @@ void ParameterRecv<T>::operator()(const RpcContext &rpc_ctx,
...
@@ -110,7 +123,7 @@ void ParameterRecv<T>::operator()(const RpcContext &rpc_ctx,
}
}
delete
local_scope
;
delete
local_scope
;
VLOG
(
3
)
<<
"ParameterRecv out"
;
VLOG
(
3
)
<<
"ParameterRecv out"
<<
rpc_ctx
.
var_name
;
}
}
template
struct
ParameterRecv
<
float
>;
template
struct
ParameterRecv
<
float
>;
...
...
paddle/fluid/operators/distributed/request_handler_impl.cc
浏览文件 @
103c9bb3
...
@@ -89,8 +89,9 @@ bool RequestGetHandler::Handle(const std::string& varname,
...
@@ -89,8 +89,9 @@ bool RequestGetHandler::Handle(const std::string& varname,
const
int
trainer_id
,
const
int
trainer_id
,
const
std
::
string
&
out_var_name
,
const
std
::
string
&
out_var_name
,
const
std
::
string
&
table_name
)
{
const
std
::
string
&
table_name
)
{
VLOG
(
4
)
<<
"RequestGetHandler:"
<<
varname
VLOG
(
3
)
<<
"RequestGetHandler:"
<<
varname
<<
" out_var_name: "
<<
out_var_name
;
<<
" out_var_name: "
<<
out_var_name
<<
" trainer_id: "
<<
trainer_id
<<
" table_name: "
<<
table_name
;
if
(
sync_mode_
)
{
if
(
sync_mode_
)
{
if
(
varname
==
FETCH_BARRIER_MESSAGE
)
{
if
(
varname
==
FETCH_BARRIER_MESSAGE
)
{
...
@@ -115,10 +116,21 @@ bool RequestGetHandler::Handle(const std::string& varname,
...
@@ -115,10 +116,21 @@ bool RequestGetHandler::Handle(const std::string& varname,
VLOG
(
3
)
<<
"copying "
<<
varname
<<
" to "
<<
param_bak_name
;
VLOG
(
3
)
<<
"copying "
<<
varname
<<
" to "
<<
param_bak_name
;
framework
::
TensorCopy
(
t_orig
,
dev_ctx_
->
GetPlace
(),
t
);
framework
::
TensorCopy
(
t_orig
,
dev_ctx_
->
GetPlace
(),
t
);
}
}
if
(
AsyncSparseParamUpdateRecorder
::
GetInstance
()
->
HasParam
(
varname
))
{
if
(
AsyncSparseParamUpdateRecorder
::
GetInstance
()
->
HasParam
(
varname
)
&&
!
table_name
.
empty
())
{
std
::
vector
<
int64_t
>
updated_rows
;
std
::
vector
<
int64_t
>
updated_rows
;
AsyncSparseParamUpdateRecorder
::
GetInstance
()
->
GetAndClear
(
AsyncSparseParamUpdateRecorder
::
GetInstance
()
->
GetAndClear
(
varname
,
trainer_id
,
&
updated_rows
);
varname
,
trainer_id
,
&
updated_rows
);
if
(
VLOG_IS_ON
(
3
))
{
std
::
ostringstream
sstream
;
sstream
<<
"["
;
for
(
auto
&
row_id
:
updated_rows
)
{
sstream
<<
row_id
<<
", "
;
}
sstream
<<
"]"
;
VLOG
(
3
)
<<
"updated_rows size: "
<<
updated_rows
.
size
()
<<
" "
<<
sstream
.
str
();
}
auto
&
origin_tensor
=
auto
&
origin_tensor
=
scope_
->
FindVar
(
varname
)
->
Get
<
framework
::
LoDTensor
>
();
scope_
->
FindVar
(
varname
)
->
Get
<
framework
::
LoDTensor
>
();
auto
*
origin_tensor_data
=
origin_tensor
.
data
<
float
>
();
auto
*
origin_tensor_data
=
origin_tensor
.
data
<
float
>
();
...
@@ -133,6 +145,7 @@ bool RequestGetHandler::Handle(const std::string& varname,
...
@@ -133,6 +145,7 @@ bool RequestGetHandler::Handle(const std::string& varname,
out_dims
,
origin_tensor
.
place
());
out_dims
,
origin_tensor
.
place
());
auto
width
=
dims
[
1
];
auto
width
=
dims
[
1
];
for
(
auto
i
=
0
;
i
<
updated_rows
.
size
();
++
i
)
{
for
(
auto
i
=
0
;
i
<
updated_rows
.
size
();
++
i
)
{
PADDLE_ENFORCE_LT
(
updated_rows
[
i
],
dims
[
0
]);
memcpy
(
data
+
i
*
width
,
origin_tensor_data
+
updated_rows
[
i
]
*
width
,
memcpy
(
data
+
i
*
width
,
origin_tensor_data
+
updated_rows
[
i
]
*
width
,
sizeof
(
float
)
*
width
);
sizeof
(
float
)
*
width
);
}
}
...
...
paddle/fluid/operators/distributed/rpc_client.h
浏览文件 @
103c9bb3
...
@@ -44,6 +44,7 @@ class RPCClient {
...
@@ -44,6 +44,7 @@ class RPCClient {
const
framework
::
Scope
&
scope
,
const
framework
::
Scope
&
scope
,
const
std
::
string
&
var_name
,
const
std
::
string
&
var_name
,
const
std
::
string
&
out_varname
,
const
std
::
string
&
out_varname
,
const
std
::
string
&
table_name
=
""
,
int64_t
time_out
=
FLAGS_rpc_deadline
)
=
0
;
int64_t
time_out
=
FLAGS_rpc_deadline
)
=
0
;
virtual
VarHandlePtr
AsyncGetVarNoBarrier
(
virtual
VarHandlePtr
AsyncGetVarNoBarrier
(
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录