Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
d117bbc3
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
d117bbc3
编写于
9月 10, 2018
作者:
Y
Yan Xu
提交者:
GitHub
9月 10, 2018
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #13291 from Yancey1989/reset_vars_on_pserver
reset received vars on pserver
上级
2fd1bf2e
5558784c
变更
6
显示空白变更内容
内联
并排
Showing
6 changed file
with
91 addition
and
25 deletion
+91
-25
paddle/fluid/operators/distributed/request_handler_impl.cc
paddle/fluid/operators/distributed/request_handler_impl.cc
+0
-13
paddle/fluid/operators/distributed/request_handler_impl.h
paddle/fluid/operators/distributed/request_handler_impl.h
+0
-5
paddle/fluid/operators/distributed/rpc_server.cc
paddle/fluid/operators/distributed/rpc_server.cc
+8
-0
paddle/fluid/operators/distributed/rpc_server.h
paddle/fluid/operators/distributed/rpc_server.h
+5
-1
paddle/fluid/operators/listen_and_serv_op.cc
paddle/fluid/operators/listen_and_serv_op.cc
+67
-6
paddle/fluid/operators/listen_and_serv_op.h
paddle/fluid/operators/listen_and_serv_op.h
+11
-0
未找到文件。
paddle/fluid/operators/distributed/request_handler_impl.cc
浏览文件 @
d117bbc3
...
@@ -67,24 +67,11 @@ bool RequestSendHandler::Handle(const std::string& varname,
...
@@ -67,24 +67,11 @@ bool RequestSendHandler::Handle(const std::string& varname,
LOG
(
FATAL
)
<<
"sync: Can not find server side var: "
<<
varname
;
LOG
(
FATAL
)
<<
"sync: Can not find server side var: "
<<
varname
;
return
false
;
return
false
;
}
}
if
(
invar
->
IsType
<
framework
::
SelectedRows
>
())
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_sparse_vars_
);
sparse_vars_
.
push_back
(
invar
);
}
}
}
}
}
return
true
;
return
true
;
}
}
void
RequestSendHandler
::
ResetSparseVarRecorder
()
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_sparse_vars_
);
for
(
auto
*
var
:
sparse_vars_
)
{
var
->
GetMutable
<
framework
::
SelectedRows
>
()
->
mutable_rows
()
->
clear
();
}
sparse_vars_
.
clear
();
}
bool
RequestGetHandler
::
Handle
(
const
std
::
string
&
varname
,
bool
RequestGetHandler
::
Handle
(
const
std
::
string
&
varname
,
framework
::
Scope
*
scope
,
framework
::
Scope
*
scope
,
framework
::
Variable
*
invar
,
framework
::
Variable
*
invar
,
...
...
paddle/fluid/operators/distributed/request_handler_impl.h
浏览文件 @
d117bbc3
...
@@ -41,11 +41,6 @@ class RequestSendHandler final : public RequestHandler {
...
@@ -41,11 +41,6 @@ class RequestSendHandler final : public RequestHandler {
bool
Handle
(
const
std
::
string
&
varname
,
framework
::
Scope
*
scope
,
bool
Handle
(
const
std
::
string
&
varname
,
framework
::
Scope
*
scope
,
framework
::
Variable
*
var
,
framework
::
Variable
**
outvar
,
framework
::
Variable
*
var
,
framework
::
Variable
**
outvar
,
const
std
::
string
&
out_var_name
=
""
)
override
;
const
std
::
string
&
out_var_name
=
""
)
override
;
void
ResetSparseVarRecorder
();
private:
std
::
mutex
mutex_sparse_vars_
;
std
::
vector
<
framework
::
Variable
*>
sparse_vars_
;
};
};
class
RequestGetHandler
final
:
public
RequestHandler
{
class
RequestGetHandler
final
:
public
RequestHandler
{
...
...
paddle/fluid/operators/distributed/rpc_server.cc
浏览文件 @
d117bbc3
...
@@ -101,6 +101,8 @@ void RPCServer::Complete() {
...
@@ -101,6 +101,8 @@ void RPCServer::Complete() {
{
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
client_num_
--
;
client_num_
--
;
need_reset_all_vars_
=
true
;
VLOG
(
4
)
<<
"decrease client_num to: "
<<
client_num_
;
VLOG
(
4
)
<<
"decrease client_num to: "
<<
client_num_
;
if
(
cur_cond_
.
load
()
==
rpc_cond_map_
[
kRequestGet
])
{
if
(
cur_cond_
.
load
()
==
rpc_cond_map_
[
kRequestGet
])
{
barrier_counter_
[
kRequestGet
]
--
;
barrier_counter_
[
kRequestGet
]
--
;
...
@@ -109,6 +111,11 @@ void RPCServer::Complete() {
...
@@ -109,6 +111,11 @@ void RPCServer::Complete() {
barrier_cond_
.
notify_all
();
barrier_cond_
.
notify_all
();
}
}
bool
RPCServer
::
NeedResetAllVars
()
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
return
need_reset_all_vars_
;
}
int
RPCServer
::
GetClientNum
()
{
int
RPCServer
::
GetClientNum
()
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
return
client_num_
;
return
client_num_
;
...
@@ -120,6 +127,7 @@ void RPCServer::ResetBarrierCounter() {
...
@@ -120,6 +127,7 @@ void RPCServer::ResetBarrierCounter() {
for
(
auto
&
t
:
barrier_counter_
)
{
for
(
auto
&
t
:
barrier_counter_
)
{
t
.
second
=
0
;
t
.
second
=
0
;
}
}
need_reset_all_vars_
=
false
;
}
}
void
RPCServer
::
RegisterRPC
(
const
std
::
string
&
rpc_name
,
void
RPCServer
::
RegisterRPC
(
const
std
::
string
&
rpc_name
,
...
...
paddle/fluid/operators/distributed/rpc_server.h
浏览文件 @
d117bbc3
...
@@ -49,7 +49,8 @@ class RPCServer {
...
@@ -49,7 +49,8 @@ class RPCServer {
bind_address_
(
address
),
bind_address_
(
address
),
exit_flag_
(
false
),
exit_flag_
(
false
),
selected_port_
(
0
),
selected_port_
(
0
),
client_num_
(
client_num
)
{}
client_num_
(
client_num
),
need_reset_all_vars_
(
false
)
{}
virtual
~
RPCServer
()
{}
virtual
~
RPCServer
()
{}
virtual
void
StartServer
()
=
0
;
virtual
void
StartServer
()
=
0
;
...
@@ -86,6 +87,8 @@ class RPCServer {
...
@@ -86,6 +87,8 @@ class RPCServer {
void
ResetBarrierCounter
();
void
ResetBarrierCounter
();
RPCServerProfiler
&
Profiler
()
{
return
profiler_
;
}
RPCServerProfiler
&
Profiler
()
{
return
profiler_
;
}
bool
NeedResetAllVars
();
protected:
protected:
virtual
void
ShutDownImpl
()
=
0
;
virtual
void
ShutDownImpl
()
=
0
;
...
@@ -104,6 +107,7 @@ class RPCServer {
...
@@ -104,6 +107,7 @@ class RPCServer {
std
::
atomic
<
int
>
exit_flag_
;
std
::
atomic
<
int
>
exit_flag_
;
int
selected_port_
;
int
selected_port_
;
int
client_num_
;
int
client_num_
;
bool
need_reset_all_vars_
;
std
::
unordered_map
<
std
::
string
,
RequestHandler
*>
rpc_call_map_
;
std
::
unordered_map
<
std
::
string
,
RequestHandler
*>
rpc_call_map_
;
std
::
unordered_map
<
std
::
string
,
int
>
rpc_thread_num_
;
std
::
unordered_map
<
std
::
string
,
int
>
rpc_thread_num_
;
...
...
paddle/fluid/operators/listen_and_serv_op.cc
浏览文件 @
d117bbc3
...
@@ -22,6 +22,7 @@ limitations under the License. */
...
@@ -22,6 +22,7 @@ limitations under the License. */
#include "gflags/gflags.h"
#include "gflags/gflags.h"
#include "paddle/fluid/operators/detail/macros.h"
#include "paddle/fluid/operators/detail/macros.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/distributed/request_handler_impl.h"
#include "paddle/fluid/operators/distributed/request_handler_impl.h"
#include "paddle/fluid/operators/listen_and_serv_op.h"
#include "paddle/fluid/operators/listen_and_serv_op.h"
...
@@ -101,7 +102,7 @@ static int64_t GetTimestamp() {
...
@@ -101,7 +102,7 @@ static int64_t GetTimestamp() {
void
ListenAndServOp
::
RunSyncLoop
(
void
ListenAndServOp
::
RunSyncLoop
(
framework
::
Executor
*
executor
,
framework
::
ProgramDesc
*
program
,
framework
::
Executor
*
executor
,
framework
::
ProgramDesc
*
program
,
framework
::
Scope
*
recv_scope
,
framework
::
Scope
*
recv_scope
,
platform
::
DeviceContext
*
dev_ctx
,
const
std
::
vector
<
int
>
&
prefetch_block_id_list
,
const
std
::
vector
<
int
>
&
prefetch_block_id_list
,
const
int
checkpoint_point_block_id
)
const
{
const
int
checkpoint_point_block_id
)
const
{
VLOG
(
2
)
<<
"RunSyncLoop"
;
VLOG
(
2
)
<<
"RunSyncLoop"
;
...
@@ -128,6 +129,7 @@ void ListenAndServOp::RunSyncLoop(
...
@@ -128,6 +129,7 @@ void ListenAndServOp::RunSyncLoop(
rpc_service_
->
SetCond
(
distributed
::
kRequestGet
);
rpc_service_
->
SetCond
(
distributed
::
kRequestGet
);
rpc_service_
->
WaitBarrier
(
distributed
::
kRequestGet
);
rpc_service_
->
WaitBarrier
(
distributed
::
kRequestGet
);
rpc_service_
->
ResetBarrierCounter
();
rpc_service_
->
ResetBarrierCounter
();
while
(
true
)
{
while
(
true
)
{
rpc_service_
->
Profiler
().
OneStep
();
rpc_service_
->
Profiler
().
OneStep
();
// Get from multiple trainers, we don't care about the order in which
// Get from multiple trainers, we don't care about the order in which
...
@@ -165,9 +167,7 @@ void ListenAndServOp::RunSyncLoop(
...
@@ -165,9 +167,7 @@ void ListenAndServOp::RunSyncLoop(
recv_scope
);
recv_scope
);
VLOG
(
2
)
<<
"run all blocks spent "
<<
GetTimestamp
()
-
ts
<<
"(ms)"
;
VLOG
(
2
)
<<
"run all blocks spent "
<<
GetTimestamp
()
-
ts
<<
"(ms)"
;
// reset received sparse vars to avoid reuse it in the next mini-batch
ResetReceivedVars
(
recv_scope
,
dev_ctx
,
rpc_service_
->
NeedResetAllVars
());
dynamic_cast
<
distributed
::
RequestSendHandler
*>
(
request_send_handler_
.
get
())
->
ResetSparseVarRecorder
();
rpc_service_
->
SetCond
(
distributed
::
kRequestGet
);
rpc_service_
->
SetCond
(
distributed
::
kRequestGet
);
rpc_service_
->
WaitBarrier
(
distributed
::
kRequestGet
);
rpc_service_
->
WaitBarrier
(
distributed
::
kRequestGet
);
...
@@ -175,6 +175,42 @@ void ListenAndServOp::RunSyncLoop(
...
@@ -175,6 +175,42 @@ void ListenAndServOp::RunSyncLoop(
}
// while(true)
}
// while(true)
}
}
void
ListenAndServOp
::
ResetReceivedVars
(
framework
::
Scope
*
recv_scope
,
platform
::
DeviceContext
*
dev_ctx
,
bool
reset_all
)
const
{
for
(
auto
&
varname
:
sparse_vars_
)
{
auto
var
=
recv_scope
->
FindVar
(
varname
);
if
(
var
==
nullptr
)
{
VLOG
(
2
)
<<
"can not find var "
<<
varname
<<
" in received scope"
;
continue
;
}
if
(
var
->
IsType
<
framework
::
SelectedRows
>
())
{
VLOG
(
3
)
<<
"reset sparse var: "
<<
varname
;
var
->
GetMutable
<
framework
::
SelectedRows
>
()
->
mutable_rows
()
->
clear
();
}
else
{
PADDLE_THROW
(
"The type of sparse var should be SelectedRows"
);
}
}
if
(
UNLIKELY
(
reset_all
))
{
for
(
auto
&
varname
:
dense_vars_
)
{
auto
var
=
recv_scope
->
FindVar
(
varname
);
if
(
var
==
nullptr
)
{
VLOG
(
2
)
<<
"can not find var "
<<
varname
<<
" in received scope"
;
continue
;
}
if
(
var
->
IsType
<
framework
::
LoDTensor
>
())
{
math
::
set_constant
(
*
dev_ctx
,
var
->
GetMutable
<
framework
::
LoDTensor
>
(),
static_cast
<
float
>
(
0
));
}
else
if
(
var
->
IsType
<
framework
::
Tensor
>
())
{
math
::
set_constant
(
*
dev_ctx
,
var
->
GetMutable
<
framework
::
Tensor
>
(),
static_cast
<
float
>
(
0
));
}
else
{
PADDLE_THROW
(
"The type of dense var should be in [LoDTensor, Tensor]"
);
}
}
}
}
void
ListenAndServOp
::
RunAsyncLoop
(
framework
::
Executor
*
executor
,
void
ListenAndServOp
::
RunAsyncLoop
(
framework
::
Executor
*
executor
,
framework
::
ProgramDesc
*
program
,
framework
::
ProgramDesc
*
program
,
framework
::
Scope
*
recv_scope
)
const
{
framework
::
Scope
*
recv_scope
)
const
{
...
@@ -248,6 +284,25 @@ static void FillRequestCtx(
...
@@ -248,6 +284,25 @@ static void FillRequestCtx(
h
->
SetCheckpointNotifyPreparedCtx
(
checkpoint_ctx
);
h
->
SetCheckpointNotifyPreparedCtx
(
checkpoint_ctx
);
}
}
void
ListenAndServOp
::
CacheVarsType
(
const
std
::
vector
<
std
::
string
>
&
varnames
,
const
framework
::
Scope
&
scope
)
const
{
for
(
const
auto
&
varname
:
varnames
)
{
auto
var
=
scope
.
FindVar
(
varname
);
PADDLE_ENFORCE
(
var
!=
nullptr
,
"Received var should be initialized in the received scope."
);
if
(
var
->
IsType
<
framework
::
SelectedRows
>
())
{
sparse_vars_
.
push_back
(
varname
);
}
else
if
(
var
->
IsType
<
framework
::
LoDTensor
>
()
||
var
->
IsType
<
framework
::
Tensor
>
())
{
dense_vars_
.
push_back
(
varname
);
}
else
{
PADDLE_THROW
(
"The type of received var should be in [SelectedRows, LoDTensor, "
"Tensor]."
);
}
}
}
void
ListenAndServOp
::
RunImpl
(
const
framework
::
Scope
&
scope
,
void
ListenAndServOp
::
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
dev_place
)
const
{
const
platform
::
Place
&
dev_place
)
const
{
// Mark this as PS that it should decide profiling by listening from trainer.
// Mark this as PS that it should decide profiling by listening from trainer.
...
@@ -258,6 +313,7 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
...
@@ -258,6 +313,7 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
bool
sync_mode
=
Attr
<
bool
>
(
"sync_mode"
);
bool
sync_mode
=
Attr
<
bool
>
(
"sync_mode"
);
auto
fan_in
=
Attr
<
int
>
(
"Fanin"
);
auto
fan_in
=
Attr
<
int
>
(
"Fanin"
);
auto
inputs
=
Inputs
(
"X"
);
PADDLE_ENFORCE
(
!
rpc_service_
);
PADDLE_ENFORCE
(
!
rpc_service_
);
std
::
string
endpoint
=
Attr
<
std
::
string
>
(
"endpoint"
);
std
::
string
endpoint
=
Attr
<
std
::
string
>
(
"endpoint"
);
...
@@ -348,11 +404,16 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
...
@@ -348,11 +404,16 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
signal
(
SIGINT
,
SignalHandler
::
StopAndExit
);
signal
(
SIGINT
,
SignalHandler
::
StopAndExit
);
signal
(
SIGTERM
,
SignalHandler
::
StopAndExit
);
signal
(
SIGTERM
,
SignalHandler
::
StopAndExit
);
// Cache the type of the received vars as `sparse_vars_` and `dense_vars_`
// so that we can reset them at the end of each iteration.
// NOTE: only used in sync update
CacheVarsType
(
inputs
,
recv_scope
);
// Write to a file of server selected port for python use.
// Write to a file of server selected port for python use.
SavePort
();
SavePort
();
if
(
sync_mode
)
{
if
(
sync_mode
)
{
RunSyncLoop
(
&
executor
,
program
,
&
recv_scope
,
prefetch_block_id_list
,
RunSyncLoop
(
&
executor
,
program
,
&
recv_scope
,
&
dev_ctx
,
checkpoint_block_id
);
prefetch_block_id_list
,
checkpoint_block_id
);
}
else
{
}
else
{
RunAsyncLoop
(
&
executor
,
program
,
&
recv_scope
);
RunAsyncLoop
(
&
executor
,
program
,
&
recv_scope
);
}
}
...
...
paddle/fluid/operators/listen_and_serv_op.h
浏览文件 @
d117bbc3
...
@@ -26,6 +26,7 @@ limitations under the License. */
...
@@ -26,6 +26,7 @@ limitations under the License. */
#include "paddle/fluid/framework/threadpool.h"
#include "paddle/fluid/framework/threadpool.h"
#include "paddle/fluid/operators/distributed/request_handler.h"
#include "paddle/fluid/operators/distributed/request_handler.h"
#include "paddle/fluid/operators/distributed/rpc_server.h"
#include "paddle/fluid/operators/distributed/rpc_server.h"
#include "paddle/fluid/platform/device_context.h"
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
...
@@ -48,6 +49,7 @@ class ListenAndServOp : public framework::OperatorBase {
...
@@ -48,6 +49,7 @@ class ListenAndServOp : public framework::OperatorBase {
void
RunSyncLoop
(
framework
::
Executor
*
executor
,
void
RunSyncLoop
(
framework
::
Executor
*
executor
,
framework
::
ProgramDesc
*
program
,
framework
::
ProgramDesc
*
program
,
framework
::
Scope
*
recv_scope
,
framework
::
Scope
*
recv_scope
,
platform
::
DeviceContext
*
dev_ctx
,
const
std
::
vector
<
int
>&
prefetch_block_id_list
,
const
std
::
vector
<
int
>&
prefetch_block_id_list
,
const
int
checkpoint_point_block_id
)
const
;
const
int
checkpoint_point_block_id
)
const
;
...
@@ -64,6 +66,13 @@ class ListenAndServOp : public framework::OperatorBase {
...
@@ -64,6 +66,13 @@ class ListenAndServOp : public framework::OperatorBase {
void
RunImpl
(
const
framework
::
Scope
&
scope
,
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
dev_place
)
const
override
;
const
platform
::
Place
&
dev_place
)
const
override
;
void
ResetReceivedVars
(
framework
::
Scope
*
recv_scope
,
platform
::
DeviceContext
*
dev_ctx
,
bool
reset_all
=
false
)
const
;
void
CacheVarsType
(
const
std
::
vector
<
std
::
string
>&
varnames
,
const
framework
::
Scope
&
scope
)
const
;
protected:
protected:
mutable
std
::
shared_ptr
<
distributed
::
RPCServer
>
rpc_service_
;
mutable
std
::
shared_ptr
<
distributed
::
RPCServer
>
rpc_service_
;
mutable
std
::
shared_ptr
<
distributed
::
RequestHandler
>
request_send_handler_
;
mutable
std
::
shared_ptr
<
distributed
::
RequestHandler
>
request_send_handler_
;
...
@@ -74,6 +83,8 @@ class ListenAndServOp : public framework::OperatorBase {
...
@@ -74,6 +83,8 @@ class ListenAndServOp : public framework::OperatorBase {
request_checkpoint_handler_
;
request_checkpoint_handler_
;
mutable
std
::
shared_ptr
<
std
::
thread
>
server_thread_
;
mutable
std
::
shared_ptr
<
std
::
thread
>
server_thread_
;
mutable
std
::
vector
<
std
::
string
>
sparse_vars_
;
mutable
std
::
vector
<
std
::
string
>
dense_vars_
;
};
};
class
SignalHandler
{
class
SignalHandler
{
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录