Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
6edfae42
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
6edfae42
编写于
9月 07, 2018
作者:
Y
Yancey1989
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
reset received vars on pserver
上级
f76f42c2
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
58 addition
and
26 deletion
+58
-26
paddle/fluid/operators/distributed/request_handler_impl.cc
paddle/fluid/operators/distributed/request_handler_impl.cc
+0
-13
paddle/fluid/operators/distributed/request_handler_impl.h
paddle/fluid/operators/distributed/request_handler_impl.h
+0
-5
paddle/fluid/operators/distributed/rpc_server.cc
paddle/fluid/operators/distributed/rpc_server.cc
+8
-0
paddle/fluid/operators/distributed/rpc_server.h
paddle/fluid/operators/distributed/rpc_server.h
+5
-1
paddle/fluid/operators/listen_and_serv_op.cc
paddle/fluid/operators/listen_and_serv_op.cc
+36
-6
paddle/fluid/operators/listen_and_serv_op.h
paddle/fluid/operators/listen_and_serv_op.h
+9
-1
未找到文件。
paddle/fluid/operators/distributed/request_handler_impl.cc
浏览文件 @
6edfae42
...
@@ -67,24 +67,11 @@ bool RequestSendHandler::Handle(const std::string& varname,
...
@@ -67,24 +67,11 @@ bool RequestSendHandler::Handle(const std::string& varname,
LOG
(
FATAL
)
<<
"sync: Can not find server side var: "
<<
varname
;
LOG
(
FATAL
)
<<
"sync: Can not find server side var: "
<<
varname
;
return
false
;
return
false
;
}
}
if
(
invar
->
IsType
<
framework
::
SelectedRows
>
())
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_sparse_vars_
);
sparse_vars_
.
push_back
(
invar
);
}
}
}
}
}
return
true
;
return
true
;
}
}
void
RequestSendHandler
::
ResetSparseVarRecorder
()
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_sparse_vars_
);
for
(
auto
*
var
:
sparse_vars_
)
{
var
->
GetMutable
<
framework
::
SelectedRows
>
()
->
mutable_rows
()
->
clear
();
}
sparse_vars_
.
clear
();
}
bool
RequestGetHandler
::
Handle
(
const
std
::
string
&
varname
,
bool
RequestGetHandler
::
Handle
(
const
std
::
string
&
varname
,
framework
::
Scope
*
scope
,
framework
::
Scope
*
scope
,
framework
::
Variable
*
invar
,
framework
::
Variable
*
invar
,
...
...
paddle/fluid/operators/distributed/request_handler_impl.h
浏览文件 @
6edfae42
...
@@ -41,11 +41,6 @@ class RequestSendHandler final : public RequestHandler {
...
@@ -41,11 +41,6 @@ class RequestSendHandler final : public RequestHandler {
bool
Handle
(
const
std
::
string
&
varname
,
framework
::
Scope
*
scope
,
bool
Handle
(
const
std
::
string
&
varname
,
framework
::
Scope
*
scope
,
framework
::
Variable
*
var
,
framework
::
Variable
**
outvar
,
framework
::
Variable
*
var
,
framework
::
Variable
**
outvar
,
const
std
::
string
&
out_var_name
=
""
)
override
;
const
std
::
string
&
out_var_name
=
""
)
override
;
void
ResetSparseVarRecorder
();
private:
std
::
mutex
mutex_sparse_vars_
;
std
::
vector
<
framework
::
Variable
*>
sparse_vars_
;
};
};
class
RequestGetHandler
final
:
public
RequestHandler
{
class
RequestGetHandler
final
:
public
RequestHandler
{
...
...
paddle/fluid/operators/distributed/rpc_server.cc
浏览文件 @
6edfae42
...
@@ -101,6 +101,8 @@ void RPCServer::Complete() {
...
@@ -101,6 +101,8 @@ void RPCServer::Complete() {
{
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
client_num_
--
;
client_num_
--
;
need_reset_all_vars_
=
true
;
VLOG
(
4
)
<<
"decrease client_num to: "
<<
client_num_
;
VLOG
(
4
)
<<
"decrease client_num to: "
<<
client_num_
;
if
(
cur_cond_
.
load
()
==
rpc_cond_map_
[
kRequestGet
])
{
if
(
cur_cond_
.
load
()
==
rpc_cond_map_
[
kRequestGet
])
{
barrier_counter_
[
kRequestGet
]
--
;
barrier_counter_
[
kRequestGet
]
--
;
...
@@ -109,6 +111,11 @@ void RPCServer::Complete() {
...
@@ -109,6 +111,11 @@ void RPCServer::Complete() {
barrier_cond_
.
notify_all
();
barrier_cond_
.
notify_all
();
}
}
bool
RPCServer
::
NeedResetAllVars
()
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
return
need_reset_all_vars_
;
}
int
RPCServer
::
GetClientNum
()
{
int
RPCServer
::
GetClientNum
()
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
return
client_num_
;
return
client_num_
;
...
@@ -120,6 +127,7 @@ void RPCServer::ResetBarrierCounter() {
...
@@ -120,6 +127,7 @@ void RPCServer::ResetBarrierCounter() {
for
(
auto
&
t
:
barrier_counter_
)
{
for
(
auto
&
t
:
barrier_counter_
)
{
t
.
second
=
0
;
t
.
second
=
0
;
}
}
need_reset_all_vars_
=
false
;
}
}
void
RPCServer
::
RegisterRPC
(
const
std
::
string
&
rpc_name
,
void
RPCServer
::
RegisterRPC
(
const
std
::
string
&
rpc_name
,
...
...
paddle/fluid/operators/distributed/rpc_server.h
浏览文件 @
6edfae42
...
@@ -49,7 +49,8 @@ class RPCServer {
...
@@ -49,7 +49,8 @@ class RPCServer {
bind_address_
(
address
),
bind_address_
(
address
),
exit_flag_
(
false
),
exit_flag_
(
false
),
selected_port_
(
0
),
selected_port_
(
0
),
client_num_
(
client_num
)
{}
client_num_
(
client_num
),
need_reset_all_vars_
(
false
)
{}
virtual
~
RPCServer
()
{}
virtual
~
RPCServer
()
{}
virtual
void
StartServer
()
=
0
;
virtual
void
StartServer
()
=
0
;
...
@@ -86,6 +87,8 @@ class RPCServer {
...
@@ -86,6 +87,8 @@ class RPCServer {
void
ResetBarrierCounter
();
void
ResetBarrierCounter
();
RPCServerProfiler
&
Profiler
()
{
return
profiler_
;
}
RPCServerProfiler
&
Profiler
()
{
return
profiler_
;
}
bool
NeedResetAllVars
();
protected:
protected:
virtual
void
ShutDownImpl
()
=
0
;
virtual
void
ShutDownImpl
()
=
0
;
...
@@ -104,6 +107,7 @@ class RPCServer {
...
@@ -104,6 +107,7 @@ class RPCServer {
std
::
atomic
<
int
>
exit_flag_
;
std
::
atomic
<
int
>
exit_flag_
;
int
selected_port_
;
int
selected_port_
;
int
client_num_
;
int
client_num_
;
bool
need_reset_all_vars_
;
std
::
unordered_map
<
std
::
string
,
RequestHandler
*>
rpc_call_map_
;
std
::
unordered_map
<
std
::
string
,
RequestHandler
*>
rpc_call_map_
;
std
::
unordered_map
<
std
::
string
,
int
>
rpc_thread_num_
;
std
::
unordered_map
<
std
::
string
,
int
>
rpc_thread_num_
;
...
...
paddle/fluid/operators/listen_and_serv_op.cc
浏览文件 @
6edfae42
...
@@ -22,6 +22,7 @@ limitations under the License. */
...
@@ -22,6 +22,7 @@ limitations under the License. */
#include "gflags/gflags.h"
#include "gflags/gflags.h"
#include "paddle/fluid/operators/detail/macros.h"
#include "paddle/fluid/operators/detail/macros.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/distributed/request_handler_impl.h"
#include "paddle/fluid/operators/distributed/request_handler_impl.h"
#include "paddle/fluid/operators/listen_and_serv_op.h"
#include "paddle/fluid/operators/listen_and_serv_op.h"
...
@@ -101,9 +102,10 @@ static int64_t GetTimestamp() {
...
@@ -101,9 +102,10 @@ static int64_t GetTimestamp() {
void
ListenAndServOp
::
RunSyncLoop
(
void
ListenAndServOp
::
RunSyncLoop
(
framework
::
Executor
*
executor
,
framework
::
ProgramDesc
*
program
,
framework
::
Executor
*
executor
,
framework
::
ProgramDesc
*
program
,
framework
::
Scope
*
recv_scope
,
framework
::
Scope
*
recv_scope
,
platform
::
DeviceContext
*
dev_ctx
,
const
std
::
vector
<
int
>
&
prefetch_block_id_list
,
const
std
::
vector
<
int
>
&
prefetch_block_id_list
,
const
int
checkpoint_point_block_id
)
const
{
const
int
checkpoint_point_block_id
,
const
std
::
vector
<
std
::
string
>
&
recv_varnames
)
const
{
VLOG
(
2
)
<<
"RunSyncLoop"
;
VLOG
(
2
)
<<
"RunSyncLoop"
;
size_t
num_blocks
=
program
->
Size
();
size_t
num_blocks
=
program
->
Size
();
auto
optimize_blocks
=
auto
optimize_blocks
=
...
@@ -166,8 +168,8 @@ void ListenAndServOp::RunSyncLoop(
...
@@ -166,8 +168,8 @@ void ListenAndServOp::RunSyncLoop(
VLOG
(
2
)
<<
"run all blocks spent "
<<
GetTimestamp
()
-
ts
<<
"(ms)"
;
VLOG
(
2
)
<<
"run all blocks spent "
<<
GetTimestamp
()
-
ts
<<
"(ms)"
;
// reset received sparse vars to avoid reuse it in the next mini-batch
// reset received sparse vars to avoid reuse it in the next mini-batch
dynamic_cast
<
distributed
::
RequestSendHandler
*>
(
request_send_handler_
.
get
())
ResetReceivedVars
(
recv_varnames
,
recv_scope
,
dev_ctx
,
->
ResetSparseVarRecorder
(
);
!
rpc_service_
->
NeedResetAllVars
()
);
rpc_service_
->
SetCond
(
distributed
::
kRequestGet
);
rpc_service_
->
SetCond
(
distributed
::
kRequestGet
);
rpc_service_
->
WaitBarrier
(
distributed
::
kRequestGet
);
rpc_service_
->
WaitBarrier
(
distributed
::
kRequestGet
);
...
@@ -175,6 +177,33 @@ void ListenAndServOp::RunSyncLoop(
...
@@ -175,6 +177,33 @@ void ListenAndServOp::RunSyncLoop(
}
// while(true)
}
// while(true)
}
}
void
ListenAndServOp
::
ResetReceivedVars
(
const
std
::
vector
<
std
::
string
>
&
recv_varnames
,
framework
::
Scope
*
recv_scope
,
platform
::
DeviceContext
*
dev_ctx
,
bool
only_sparse_vars
)
const
{
for
(
auto
&
varname
:
recv_varnames
)
{
auto
var
=
recv_scope
->
FindVar
(
varname
);
if
(
var
==
nullptr
)
{
VLOG
(
2
)
<<
"can not find var "
<<
varname
<<
" in received scope"
;
continue
;
}
if
(
var
->
IsType
<
framework
::
SelectedRows
>
())
{
var
->
GetMutable
<
framework
::
SelectedRows
>
()
->
mutable_rows
()
->
clear
();
}
if
(
!
only_sparse_vars
)
{
if
(
var
->
IsType
<
framework
::
LoDTensor
>
())
{
math
::
set_constant
(
*
dev_ctx
,
var
->
GetMutable
<
framework
::
LoDTensor
>
(),
static_cast
<
float
>
(
0
));
}
else
if
(
var
->
IsType
<
framework
::
Tensor
>
())
{
math
::
set_constant
(
*
dev_ctx
,
var
->
GetMutable
<
framework
::
Tensor
>
(),
static_cast
<
float
>
(
0
));
}
else
{
PADDLE_THROW
(
"received var should be in [SelectedRows, LoDTensor, Tensor]"
);
}
}
}
}
void
ListenAndServOp
::
RunAsyncLoop
(
framework
::
Executor
*
executor
,
void
ListenAndServOp
::
RunAsyncLoop
(
framework
::
Executor
*
executor
,
framework
::
ProgramDesc
*
program
,
framework
::
ProgramDesc
*
program
,
framework
::
Scope
*
recv_scope
)
const
{
framework
::
Scope
*
recv_scope
)
const
{
...
@@ -258,6 +287,7 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
...
@@ -258,6 +287,7 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
bool
sync_mode
=
Attr
<
bool
>
(
"sync_mode"
);
bool
sync_mode
=
Attr
<
bool
>
(
"sync_mode"
);
auto
fan_in
=
Attr
<
int
>
(
"Fanin"
);
auto
fan_in
=
Attr
<
int
>
(
"Fanin"
);
auto
inputs
=
Inputs
(
"X"
);
PADDLE_ENFORCE
(
!
rpc_service_
);
PADDLE_ENFORCE
(
!
rpc_service_
);
std
::
string
endpoint
=
Attr
<
std
::
string
>
(
"endpoint"
);
std
::
string
endpoint
=
Attr
<
std
::
string
>
(
"endpoint"
);
...
@@ -351,8 +381,8 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
...
@@ -351,8 +381,8 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
// Write to a file of server selected port for python use.
// Write to a file of server selected port for python use.
SavePort
();
SavePort
();
if
(
sync_mode
)
{
if
(
sync_mode
)
{
RunSyncLoop
(
&
executor
,
program
,
&
recv_scope
,
prefetch_block_id_list
,
RunSyncLoop
(
&
executor
,
program
,
&
recv_scope
,
&
dev_ctx
,
checkpoint_block_id
);
prefetch_block_id_list
,
checkpoint_block_id
,
inputs
);
}
else
{
}
else
{
RunAsyncLoop
(
&
executor
,
program
,
&
recv_scope
);
RunAsyncLoop
(
&
executor
,
program
,
&
recv_scope
);
}
}
...
...
paddle/fluid/operators/listen_and_serv_op.h
浏览文件 @
6edfae42
...
@@ -26,6 +26,7 @@ limitations under the License. */
...
@@ -26,6 +26,7 @@ limitations under the License. */
#include "paddle/fluid/framework/threadpool.h"
#include "paddle/fluid/framework/threadpool.h"
#include "paddle/fluid/operators/distributed/request_handler.h"
#include "paddle/fluid/operators/distributed/request_handler.h"
#include "paddle/fluid/operators/distributed/rpc_server.h"
#include "paddle/fluid/operators/distributed/rpc_server.h"
#include "paddle/fluid/platform/device_context.h"
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
...
@@ -48,8 +49,10 @@ class ListenAndServOp : public framework::OperatorBase {
...
@@ -48,8 +49,10 @@ class ListenAndServOp : public framework::OperatorBase {
void
RunSyncLoop
(
framework
::
Executor
*
executor
,
void
RunSyncLoop
(
framework
::
Executor
*
executor
,
framework
::
ProgramDesc
*
program
,
framework
::
ProgramDesc
*
program
,
framework
::
Scope
*
recv_scope
,
framework
::
Scope
*
recv_scope
,
platform
::
DeviceContext
*
dev_ctx
,
const
std
::
vector
<
int
>&
prefetch_block_id_list
,
const
std
::
vector
<
int
>&
prefetch_block_id_list
,
const
int
checkpoint_point_block_id
)
const
;
const
int
checkpoint_point_block_id
,
const
std
::
vector
<
std
::
string
>&
recv_varnames
)
const
;
void
RunAsyncLoop
(
framework
::
Executor
*
executor
,
void
RunAsyncLoop
(
framework
::
Executor
*
executor
,
framework
::
ProgramDesc
*
program
,
framework
::
ProgramDesc
*
program
,
...
@@ -64,6 +67,11 @@ class ListenAndServOp : public framework::OperatorBase {
...
@@ -64,6 +67,11 @@ class ListenAndServOp : public framework::OperatorBase {
void
RunImpl
(
const
framework
::
Scope
&
scope
,
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
dev_place
)
const
override
;
const
platform
::
Place
&
dev_place
)
const
override
;
void
ResetReceivedVars
(
const
std
::
vector
<
std
::
string
>&
recv_varnames
,
framework
::
Scope
*
recv_scope
,
platform
::
DeviceContext
*
dev_ctx
,
bool
only_sparse_vars
=
true
)
const
;
protected:
protected:
mutable
std
::
shared_ptr
<
distributed
::
RPCServer
>
rpc_service_
;
mutable
std
::
shared_ptr
<
distributed
::
RPCServer
>
rpc_service_
;
mutable
std
::
shared_ptr
<
distributed
::
RequestHandler
>
request_send_handler_
;
mutable
std
::
shared_ptr
<
distributed
::
RequestHandler
>
request_send_handler_
;
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录