Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
6133efd9
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
6133efd9
编写于
7月 24, 2018
作者:
Y
Yancey
提交者:
GitHub
7月 24, 2018
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #12218 from Yancey1989/rpc_complete_interface
Add rpc complete interface
上级
24bea401
fb06ed7b
变更
14
隐藏空白更改
内联
并排
Showing
14 changed file
with
91 addition
and
103 deletion
+91
-103
paddle/fluid/API.spec
paddle/fluid/API.spec
+1
-2
paddle/fluid/framework/executor.cc
paddle/fluid/framework/executor.cc
+3
-9
paddle/fluid/framework/executor.h
paddle/fluid/framework/executor.h
+3
-9
paddle/fluid/operators/distributed/CMakeLists.txt
paddle/fluid/operators/distributed/CMakeLists.txt
+1
-1
paddle/fluid/operators/distributed/grpc_client.cc
paddle/fluid/operators/distributed/grpc_client.cc
+11
-28
paddle/fluid/operators/distributed/grpc_client.h
paddle/fluid/operators/distributed/grpc_client.h
+8
-9
paddle/fluid/operators/distributed/request_handler.h
paddle/fluid/operators/distributed/request_handler.h
+0
-2
paddle/fluid/operators/distributed/request_handler_impl.cc
paddle/fluid/operators/distributed/request_handler_impl.cc
+4
-7
paddle/fluid/operators/distributed/rpc_client.h
paddle/fluid/operators/distributed/rpc_client.h
+5
-9
paddle/fluid/operators/distributed/rpc_server.cc
paddle/fluid/operators/distributed/rpc_server.cc
+6
-12
paddle/fluid/operators/distributed/rpc_server.h
paddle/fluid/operators/distributed/rpc_server.h
+2
-3
paddle/fluid/operators/distributed/rpc_server_test.cc
paddle/fluid/operators/distributed/rpc_server_test.cc
+25
-4
paddle/fluid/pybind/pybind.cc
paddle/fluid/pybind/pybind.cc
+1
-4
python/paddle/fluid/executor.py
python/paddle/fluid/executor.py
+21
-4
未找到文件。
paddle/fluid/API.spec
浏览文件 @
6133efd9
...
@@ -35,8 +35,7 @@ paddle.fluid.program_guard ArgSpec(args=[], varargs='args', keywords='kwds', def
...
@@ -35,8 +35,7 @@ paddle.fluid.program_guard ArgSpec(args=[], varargs='args', keywords='kwds', def
paddle.fluid.get_var ArgSpec(args=['name', 'program'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.get_var ArgSpec(args=['name', 'program'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.Executor.__init__ ArgSpec(args=['self', 'place'], varargs=None, keywords=None, defaults=None)
paddle.fluid.Executor.__init__ ArgSpec(args=['self', 'place'], varargs=None, keywords=None, defaults=None)
paddle.fluid.Executor.as_lodtensor ArgSpec(args=['self', 'data'], varargs=None, keywords=None, defaults=None)
paddle.fluid.Executor.as_lodtensor ArgSpec(args=['self', 'data'], varargs=None, keywords=None, defaults=None)
paddle.fluid.Executor.begin_pass ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None)
paddle.fluid.Executor.close ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None)
paddle.fluid.Executor.end_pass ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None)
paddle.fluid.Executor.run ArgSpec(args=['self', 'program', 'feed', 'fetch_list', 'feed_var_name', 'fetch_var_name', 'scope', 'return_numpy', 'use_program_cache'], varargs=None, keywords=None, defaults=(None, None, None, 'feed', 'fetch', None, True, False))
paddle.fluid.Executor.run ArgSpec(args=['self', 'program', 'feed', 'fetch_list', 'feed_var_name', 'fetch_var_name', 'scope', 'return_numpy', 'use_program_cache'], varargs=None, keywords=None, defaults=(None, None, None, 'feed', 'fetch', None, True, False))
paddle.fluid.global_scope ArgSpec(args=[], varargs=None, keywords=None, defaults=None)
paddle.fluid.global_scope ArgSpec(args=[], varargs=None, keywords=None, defaults=None)
paddle.fluid.scope_guard ArgSpec(args=[], varargs='args', keywords='kwds', defaults=None)
paddle.fluid.scope_guard ArgSpec(args=[], varargs='args', keywords='kwds', defaults=None)
...
...
paddle/fluid/framework/executor.cc
浏览文件 @
6133efd9
...
@@ -45,19 +45,13 @@ ExecutorPrepareContext::~ExecutorPrepareContext() {
...
@@ -45,19 +45,13 @@ ExecutorPrepareContext::~ExecutorPrepareContext() {
Executor
::
Executor
(
const
platform
::
Place
&
place
)
:
place_
(
place
)
{}
Executor
::
Executor
(
const
platform
::
Place
&
place
)
:
place_
(
place
)
{}
void
Executor
::
Close
()
{
#ifdef PADDLE_WITH_DISTRIBUTE
#ifdef PADDLE_WITH_DISTRIBUTE
void
Executor
::
BeginPass
()
{
::
paddle
::
operators
::
distributed
::
RPCClient
::
GetInstance
<
::
paddle
::
operators
::
distributed
::
RPCClient
::
GetInstance
<
::
paddle
::
operators
::
distributed
::
GRPCClient
>
()
::
paddle
::
operators
::
distributed
::
GRPCClient
>
()
->
SendBeginPass
();
->
SendComplete
();
}
void
Executor
::
EndPass
()
{
::
paddle
::
operators
::
distributed
::
RPCClient
::
GetInstance
<
::
paddle
::
operators
::
distributed
::
GRPCClient
>
()
->
SendEndPass
();
}
#endif
#endif
}
void
InitializeVariable
(
Variable
*
var
,
proto
::
VarType
::
Type
var_type
)
{
void
InitializeVariable
(
Variable
*
var
,
proto
::
VarType
::
Type
var_type
)
{
if
(
var_type
==
proto
::
VarType
::
LOD_TENSOR
)
{
if
(
var_type
==
proto
::
VarType
::
LOD_TENSOR
)
{
...
...
paddle/fluid/framework/executor.h
浏览文件 @
6133efd9
...
@@ -44,17 +44,11 @@ class Executor {
...
@@ -44,17 +44,11 @@ class Executor {
explicit
Executor
(
const
platform
::
Place
&
place
);
explicit
Executor
(
const
platform
::
Place
&
place
);
#ifdef PADDLE_WITH_DISTRIBUTE
/*
/*
* Sending signal to pserver to mark current pass started.
* Close this Executor.
* Calling this method will send complete messages to all pserver instances.
*/
*/
void
BeginPass
();
void
Close
();
/*
* Sending signal to pserver to mark current pass finished.
*/
void
EndPass
();
#endif
/* @Brief
/* @Brief
* Runtime evaluation of the given ProgramDesc under certain Scope
* Runtime evaluation of the given ProgramDesc under certain Scope
...
...
paddle/fluid/operators/distributed/CMakeLists.txt
浏览文件 @
6133efd9
...
@@ -18,7 +18,7 @@ if(WITH_GRPC)
...
@@ -18,7 +18,7 @@ if(WITH_GRPC)
set_source_files_properties
(
grpc_serde_test.cc rpc_server_test.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
set_source_files_properties
(
grpc_serde_test.cc rpc_server_test.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
cc_test
(
grpc_serde_test SRCS grpc_serde_test.cc
cc_test
(
grpc_serde_test SRCS grpc_serde_test.cc
DEPS grpc++_unsecure grpc_unsecure gpr cares zlib protobuf sendrecvop_grpc scope profiler math_function SERIAL
)
DEPS grpc++_unsecure grpc_unsecure gpr cares zlib protobuf sendrecvop_grpc scope profiler math_function SERIAL
)
cc_test
(
g
rpc_server_test SRCS rpc_server_test.cc
cc_test
(
rpc_server_test SRCS rpc_server_test.cc
DEPS sendrecvop_grpc grpc++_unsecure grpc_unsecure gpr cares zlib protobuf executor proto_desc lookup_table_op SERIAL
)
DEPS sendrecvop_grpc grpc++_unsecure grpc_unsecure gpr cares zlib protobuf executor proto_desc lookup_table_op SERIAL
)
return
()
return
()
endif
()
endif
()
...
...
paddle/fluid/operators/distributed/grpc_client.cc
浏览文件 @
6133efd9
...
@@ -36,20 +36,16 @@ void GRPCClient::InitEventLoop() {
...
@@ -36,20 +36,16 @@ void GRPCClient::InitEventLoop() {
client_thread_
.
reset
(
new
std
::
thread
(
std
::
bind
(
&
GRPCClient
::
Proceed
,
this
)));
client_thread_
.
reset
(
new
std
::
thread
(
std
::
bind
(
&
GRPCClient
::
Proceed
,
this
)));
}
}
void
GRPCClient
::
SendBeginPass
()
{
void
GRPCClient
::
SendComplete
()
{
for
(
auto
&
it
:
channels_
)
{
std
::
unique_lock
<
std
::
mutex
>
lk
(
completed_mutex_
);
VLOG
(
3
)
<<
"send begin pass to: "
<<
it
.
first
;
if
(
!
completed_
)
{
this
->
AsyncSendBeginPass
(
it
.
first
);
for
(
auto
&
it
:
channels_
)
{
}
VLOG
(
3
)
<<
"send complete message to "
<<
it
.
first
;
this
->
Wait
();
this
->
AsyncSendComplete
(
it
.
first
);
}
}
PADDLE_ENFORCE
(
this
->
Wait
(),
"internal grpc error"
);
void
GRPCClient
::
SendEndPass
()
{
completed_
=
true
;
for
(
auto
&
it
:
channels_
)
{
VLOG
(
3
)
<<
"send end pass to "
<<
it
.
first
;
this
->
AsyncSendEndPass
(
it
.
first
);
}
}
this
->
Wait
();
}
}
GRPCClient
::~
GRPCClient
()
{
GRPCClient
::~
GRPCClient
()
{
...
@@ -239,32 +235,19 @@ void GRPCClient::AsyncSendFetchBarrier(const std::string& ep,
...
@@ -239,32 +235,19 @@ void GRPCClient::AsyncSendFetchBarrier(const std::string& ep,
req_count_
++
;
req_count_
++
;
}
}
void
GRPCClient
::
AsyncSend
BeginPass
(
const
std
::
string
&
ep
,
int64_t
time_out
)
{
void
GRPCClient
::
AsyncSend
Complete
(
const
std
::
string
&
ep
,
int64_t
time_out
)
{
const
auto
ch
=
GetChannel
(
ep
);
const
auto
ch
=
GetChannel
(
ep
);
BatchBarrierProcessor
*
s
=
new
BatchBarrierProcessor
(
ch
);
BatchBarrierProcessor
*
s
=
new
BatchBarrierProcessor
(
ch
);
s
->
Prepare
(
time_out
);
s
->
Prepare
(
time_out
);
sendrecv
::
VariableMessage
req
;
sendrecv
::
VariableMessage
req
;
req
.
set_varname
(
BEGIN_PASS
_MESSAGE
);
req
.
set_varname
(
COMPLETE
_MESSAGE
);
auto
rpc
=
s
->
stub_
->
AsyncSendVariable
(
s
->
context_
.
get
(),
req
,
&
cq_
);
auto
rpc
=
s
->
stub_
->
AsyncSendVariable
(
s
->
context_
.
get
(),
req
,
&
cq_
);
rpc
->
Finish
(
&
s
->
reply_
,
&
s
->
status_
,
reinterpret_cast
<
void
*>
(
s
));
rpc
->
Finish
(
&
s
->
reply_
,
&
s
->
status_
,
reinterpret_cast
<
void
*>
(
s
));
req_count_
++
;
req_count_
++
;
}
}
void
GRPCClient
::
AsyncSendEndPass
(
const
std
::
string
&
ep
,
int64_t
time_out
)
{
const
auto
ch
=
GetChannel
(
ep
);
FetchBarrierProcessor
*
s
=
new
FetchBarrierProcessor
(
ch
);
s
->
Prepare
(
time_out
);
sendrecv
::
VariableMessage
req
;
req
.
set_varname
(
END_PASS_MESSAGE
);
auto
rpc
=
s
->
stub_
->
AsyncGetVariable
(
s
->
context_
.
get
(),
req
,
&
cq_
);
rpc
->
Finish
(
&
s
->
reply_
,
&
s
->
status_
,
reinterpret_cast
<
void
*>
(
s
));
req_count_
++
;
}
void
GRPCClient
::
AsyncCheckpointNotify
(
const
std
::
string
&
ep
,
void
GRPCClient
::
AsyncCheckpointNotify
(
const
std
::
string
&
ep
,
const
std
::
string
&
dir
,
const
std
::
string
&
dir
,
int64_t
time_out
)
{
int64_t
time_out
)
{
...
...
paddle/fluid/operators/distributed/grpc_client.h
浏览文件 @
6133efd9
...
@@ -174,7 +174,7 @@ class CheckpointNotifyProcessor : public BaseProcessor {
...
@@ -174,7 +174,7 @@ class CheckpointNotifyProcessor : public BaseProcessor {
class
GRPCClient
:
public
RPCClient
{
class
GRPCClient
:
public
RPCClient
{
public:
public:
GRPCClient
()
:
ok_
(
true
)
{}
GRPCClient
()
:
ok_
(
true
)
,
completed_
(
false
)
{}
virtual
~
GRPCClient
();
virtual
~
GRPCClient
();
bool
AsyncSendVar
(
const
std
::
string
&
ep
,
const
platform
::
DeviceContext
&
ctx
,
bool
AsyncSendVar
(
const
std
::
string
&
ep
,
const
platform
::
DeviceContext
&
ctx
,
...
@@ -201,17 +201,12 @@ class GRPCClient : public RPCClient {
...
@@ -201,17 +201,12 @@ class GRPCClient : public RPCClient {
void
AsyncCheckpointNotify
(
const
std
::
string
&
ep
,
const
std
::
string
&
dir
,
void
AsyncCheckpointNotify
(
const
std
::
string
&
ep
,
const
std
::
string
&
dir
,
int64_t
time_out
=
FLAGS_rpc_deadline
)
override
;
int64_t
time_out
=
FLAGS_rpc_deadline
)
override
;
void
AsyncSendBeginPass
(
const
std
::
string
&
ep
,
void
AsyncSendComplete
(
const
std
::
string
&
ep
,
int64_t
time_out
=
FLAGS_rpc_deadline
)
override
;
int64_t
time_out
=
FLAGS_rpc_deadline
)
override
;
void
AsyncSendEndPass
(
const
std
::
string
&
ep
,
int64_t
time_out
=
FLAGS_rpc_deadline
)
override
;
bool
Wait
()
override
;
bool
Wait
()
override
;
void
SendBeginPass
()
override
;
void
SendComplete
()
override
;
void
SendEndPass
()
override
;
protected:
protected:
void
InitImpl
()
override
;
void
InitImpl
()
override
;
...
@@ -238,6 +233,10 @@ class GRPCClient : public RPCClient {
...
@@ -238,6 +233,10 @@ class GRPCClient : public RPCClient {
// mutex for GetChannel thread safety
// mutex for GetChannel thread safety
std
::
mutex
chan_mutex_
;
std
::
mutex
chan_mutex_
;
DISABLE_COPY_AND_ASSIGN
(
GRPCClient
);
DISABLE_COPY_AND_ASSIGN
(
GRPCClient
);
// mutex for sending complete message only once
std
::
mutex
completed_mutex_
;
bool
completed_
;
};
};
}
// namespace distributed
}
// namespace distributed
...
...
paddle/fluid/operators/distributed/request_handler.h
浏览文件 @
6133efd9
...
@@ -43,8 +43,6 @@ constexpr char kRequestPassBarrier[] = "RequestPassBarrier";
...
@@ -43,8 +43,6 @@ constexpr char kRequestPassBarrier[] = "RequestPassBarrier";
#define BATCH_BARRIER_MESSAGE "BATCH_BARRIER@RECV"
#define BATCH_BARRIER_MESSAGE "BATCH_BARRIER@RECV"
#define FETCH_BARRIER_MESSAGE "FETCH_BARRIER@RECV"
#define FETCH_BARRIER_MESSAGE "FETCH_BARRIER@RECV"
#define COMPLETE_MESSAGE "COMPLETE@RECV"
#define COMPLETE_MESSAGE "COMPLETE@RECV"
#define BEGIN_PASS_MESSAGE "BEGIN_PASS@RECV"
#define END_PASS_MESSAGE "END_PASS@RECV"
#define CHECKPOINT_SAVE_MESSAGE "SAVE@CHECKPOINTNOTIFY"
#define CHECKPOINT_SAVE_MESSAGE "SAVE@CHECKPOINTNOTIFY"
#define CHECKPOINT_LOAD_MESSAGE "LOAD@CHECKPOINTNOTIFY"
#define CHECKPOINT_LOAD_MESSAGE "LOAD@CHECKPOINTNOTIFY"
...
...
paddle/fluid/operators/distributed/request_handler_impl.cc
浏览文件 @
6133efd9
...
@@ -55,10 +55,9 @@ bool RequestSendHandler::Handle(const std::string& varname,
...
@@ -55,10 +55,9 @@ bool RequestSendHandler::Handle(const std::string& varname,
if
(
varname
==
BATCH_BARRIER_MESSAGE
)
{
if
(
varname
==
BATCH_BARRIER_MESSAGE
)
{
VLOG
(
3
)
<<
"sync: recv BATCH_BARRIER_MESSAGE"
;
VLOG
(
3
)
<<
"sync: recv BATCH_BARRIER_MESSAGE"
;
rpc_server_
->
IncreaseBatchBarrier
(
kRequestSend
);
rpc_server_
->
IncreaseBatchBarrier
(
kRequestSend
);
}
else
if
(
varname
==
BEGIN_PASS_MESSAGE
)
{
}
else
if
(
varname
==
COMPLETE_MESSAGE
)
{
VLOG
(
3
)
<<
"sync: recv begin pass message"
;
VLOG
(
3
)
<<
"sync: recv complete message"
;
rpc_server_
->
WaitCond
(
kRequestSend
);
rpc_server_
->
Complete
();
rpc_server_
->
BeginPass
();
}
else
{
}
else
{
VLOG
(
3
)
<<
"sync: received var_name: "
<<
varname
;
VLOG
(
3
)
<<
"sync: received var_name: "
<<
varname
;
rpc_server_
->
WaitCond
(
kRequestSend
);
rpc_server_
->
WaitCond
(
kRequestSend
);
...
@@ -94,14 +93,12 @@ bool RequestGetHandler::Handle(const std::string& varname,
...
@@ -94,14 +93,12 @@ bool RequestGetHandler::Handle(const std::string& varname,
if
(
varname
==
FETCH_BARRIER_MESSAGE
)
{
if
(
varname
==
FETCH_BARRIER_MESSAGE
)
{
VLOG
(
3
)
<<
"sync: recv fetch barrier message"
;
VLOG
(
3
)
<<
"sync: recv fetch barrier message"
;
rpc_server_
->
IncreaseBatchBarrier
(
kRequestGet
);
rpc_server_
->
IncreaseBatchBarrier
(
kRequestGet
);
}
else
if
(
varname
==
END_PASS_MESSAGE
)
{
rpc_server_
->
EndPass
();
}
else
{
}
else
{
rpc_server_
->
WaitCond
(
kRequestGet
);
rpc_server_
->
WaitCond
(
kRequestGet
);
*
outvar
=
scope_
->
FindVar
(
varname
);
*
outvar
=
scope_
->
FindVar
(
varname
);
}
}
}
else
{
}
else
{
if
(
varname
!=
FETCH_BARRIER_MESSAGE
&&
varname
!=
END_PASS
_MESSAGE
)
{
if
(
varname
!=
FETCH_BARRIER_MESSAGE
&&
varname
!=
COMPLETE
_MESSAGE
)
{
*
outvar
=
scope_
->
FindVar
(
varname
);
*
outvar
=
scope_
->
FindVar
(
varname
);
}
}
}
}
...
...
paddle/fluid/operators/distributed/rpc_client.h
浏览文件 @
6133efd9
...
@@ -60,17 +60,13 @@ class RPCClient {
...
@@ -60,17 +60,13 @@ class RPCClient {
const
std
::
string
&
dir
,
const
std
::
string
&
dir
,
int64_t
time_out
=
FLAGS_rpc_deadline
)
=
0
;
int64_t
time_out
=
FLAGS_rpc_deadline
)
=
0
;
virtual
void
AsyncSend
BeginPass
(
const
std
::
string
&
ep
,
virtual
void
AsyncSend
Complete
(
const
std
::
string
&
ep
,
int64_t
time_out
=
FLAGS_rpc_deadline
)
=
0
;
int64_t
time_out
=
FLAGS_rpc_deadline
)
=
0
;
virtual
void
AsyncSendEndPass
(
const
std
::
string
&
ep
,
// Complete tells all the pserver instances that finishe the training,
int64_t
time_out
=
FLAGS_rpc_deadline
)
=
0
;
// the pserver can reduce it's barrier count, and continue to train
// BeginePass/EndPass tells all the pserver that start/end a pass, so that
// the pserver can increase/reduce it's barrier count, and continue to train
// with other trainers.
// with other trainers.
virtual
void
SendBeginPass
()
=
0
;
virtual
void
SendComplete
()
=
0
;
virtual
void
SendEndPass
()
=
0
;
virtual
bool
Wait
()
=
0
;
virtual
bool
Wait
()
=
0
;
...
...
paddle/fluid/operators/distributed/rpc_server.cc
浏览文件 @
6133efd9
...
@@ -64,18 +64,7 @@ void RPCServer::IncreaseBatchBarrier(const std::string rpc_name) {
...
@@ -64,18 +64,7 @@ void RPCServer::IncreaseBatchBarrier(const std::string rpc_name) {
}
}
}
}
void
RPCServer
::
BeginPass
()
{
void
RPCServer
::
Complete
()
{
VLOG
(
4
)
<<
"RPCServer begin increase pass barrier"
;
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
client_num_
++
;
VLOG
(
4
)
<<
"increase client_num to: "
<<
client_num_
;
}
barrier_cond_
.
notify_all
();
}
void
RPCServer
::
EndPass
()
{
VLOG
(
4
)
<<
"RPCServer begin increase pass barrier"
;
{
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
client_num_
--
;
client_num_
--
;
...
@@ -87,6 +76,11 @@ void RPCServer::EndPass() {
...
@@ -87,6 +76,11 @@ void RPCServer::EndPass() {
barrier_cond_
.
notify_all
();
barrier_cond_
.
notify_all
();
}
}
int
RPCServer
::
GetClientNum
()
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
return
client_num_
;
}
void
RPCServer
::
ResetBarrierCounter
()
{
void
RPCServer
::
ResetBarrierCounter
()
{
VLOG
(
3
)
<<
"RPCServer ResetBarrierCounter "
;
VLOG
(
3
)
<<
"RPCServer ResetBarrierCounter "
;
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
...
...
paddle/fluid/operators/distributed/rpc_server.h
浏览文件 @
6133efd9
...
@@ -44,7 +44,7 @@ class RPCServer {
...
@@ -44,7 +44,7 @@ class RPCServer {
int
GetSelectedPort
()
const
{
return
selected_port_
;
}
int
GetSelectedPort
()
const
{
return
selected_port_
;
}
int
GetClientNum
()
const
;
int
GetClientNum
();
void
SavePort
()
const
;
void
SavePort
()
const
;
...
@@ -64,8 +64,7 @@ class RPCServer {
...
@@ -64,8 +64,7 @@ class RPCServer {
void
WaitCond
(
const
std
::
string
&
rpc_name
);
void
WaitCond
(
const
std
::
string
&
rpc_name
);
void
IncreaseBatchBarrier
(
const
std
::
string
rpc_name
);
void
IncreaseBatchBarrier
(
const
std
::
string
rpc_name
);
void
BeginPass
();
void
Complete
();
void
EndPass
();
void
ResetBarrierCounter
();
void
ResetBarrierCounter
();
...
...
paddle/fluid/operators/distributed/rpc_server_test.cc
浏览文件 @
6133efd9
...
@@ -91,7 +91,7 @@ void InitTensorsOnServer(framework::Scope* scope, platform::CPUPlace* place,
...
@@ -91,7 +91,7 @@ void InitTensorsOnServer(framework::Scope* scope, platform::CPUPlace* place,
}
}
}
}
void
StartServer
()
{
void
StartServer
(
const
std
::
string
&
rpc_name
)
{
framework
::
ProgramDesc
program
;
framework
::
ProgramDesc
program
;
framework
::
Scope
scope
;
framework
::
Scope
scope
;
platform
::
CPUPlace
place
;
platform
::
CPUPlace
place
;
...
@@ -107,14 +107,14 @@ void StartServer() {
...
@@ -107,14 +107,14 @@ void StartServer() {
std
::
shared_ptr
<
framework
::
ExecutorPrepareContext
>>
std
::
shared_ptr
<
framework
::
ExecutorPrepareContext
>>
prefetch_var_name_to_prepared
;
prefetch_var_name_to_prepared
;
prefetch_var_name_to_prepared
[
in_var_name
]
=
prepared
[
0
];
prefetch_var_name_to_prepared
[
in_var_name
]
=
prepared
[
0
];
g_req_handler
->
SetProgram
(
&
program
);
g_req_handler
->
SetProgram
(
&
program
);
g_req_handler
->
SetPrefetchPreparedCtx
(
&
prefetch_var_name_to_prepared
);
g_req_handler
->
SetPrefetchPreparedCtx
(
&
prefetch_var_name_to_prepared
);
g_req_handler
->
SetDevCtx
(
&
ctx
);
g_req_handler
->
SetDevCtx
(
&
ctx
);
g_req_handler
->
SetScope
(
&
scope
);
g_req_handler
->
SetScope
(
&
scope
);
g_req_handler
->
SetExecutor
(
&
exe
);
g_req_handler
->
SetExecutor
(
&
exe
);
g_rpc_service
->
RegisterRPC
(
distributed
::
kRequestPrefetch
,
g_rpc_service
->
RegisterRPC
(
rpc_name
,
g_req_handler
.
get
());
g_req_handler
.
get
());
g_req_handler
->
SetRPCServer
(
g_rpc_service
.
get
());
g_req_handler
->
SetRPCServer
(
g_rpc_service
.
get
());
std
::
thread
server_thread
(
std
::
thread
server_thread
(
...
@@ -129,7 +129,7 @@ TEST(PREFETCH, CPU) {
...
@@ -129,7 +129,7 @@ TEST(PREFETCH, CPU) {
distributed
::
RPCClient
*
client
=
distributed
::
RPCClient
*
client
=
distributed
::
RPCClient
::
GetInstance
<
RPCCLIENT_T
>
();
distributed
::
RPCClient
::
GetInstance
<
RPCCLIENT_T
>
();
std
::
thread
server_thread
(
StartServer
);
std
::
thread
server_thread
(
StartServer
,
distributed
::
kRequestPrefetch
);
g_rpc_service
->
WaitServerReady
();
g_rpc_service
->
WaitServerReady
();
int
port
=
g_rpc_service
->
GetSelectedPort
();
int
port
=
g_rpc_service
->
GetSelectedPort
();
...
@@ -162,3 +162,24 @@ TEST(PREFETCH, CPU) {
...
@@ -162,3 +162,24 @@ TEST(PREFETCH, CPU) {
g_rpc_service
.
reset
(
nullptr
);
g_rpc_service
.
reset
(
nullptr
);
g_req_handler
.
reset
(
nullptr
);
g_req_handler
.
reset
(
nullptr
);
}
}
TEST
(
COMPLETE
,
CPU
)
{
g_req_handler
.
reset
(
new
distributed
::
RequestSendHandler
(
true
));
g_rpc_service
.
reset
(
new
RPCSERVER_T
(
"127.0.0.1:0"
,
2
));
distributed
::
RPCClient
*
client
=
distributed
::
RPCClient
::
GetInstance
<
RPCCLIENT_T
>
();
PADDLE_ENFORCE
(
client
!=
nullptr
);
std
::
thread
server_thread
(
StartServer
,
distributed
::
kRequestSend
);
g_rpc_service
->
WaitServerReady
();
int
port
=
g_rpc_service
->
GetSelectedPort
();
std
::
string
ep
=
paddle
::
string
::
Sprintf
(
"127.0.0.1:%d"
,
port
);
client
->
AsyncSendComplete
(
ep
);
client
->
Wait
();
EXPECT_EQ
(
g_rpc_service
->
GetClientNum
(),
1
);
g_rpc_service
->
ShutDown
();
server_thread
.
join
();
g_rpc_service
.
reset
(
nullptr
);
g_req_handler
.
reset
(
nullptr
);
}
paddle/fluid/pybind/pybind.cc
浏览文件 @
6133efd9
...
@@ -498,10 +498,7 @@ All parameter, weight, gradient are variables in Paddle.
...
@@ -498,10 +498,7 @@ All parameter, weight, gradient are variables in Paddle.
py
::
class_
<
framework
::
Executor
>
(
m
,
"Executor"
)
py
::
class_
<
framework
::
Executor
>
(
m
,
"Executor"
)
.
def
(
py
::
init
<
const
platform
::
Place
&>
())
.
def
(
py
::
init
<
const
platform
::
Place
&>
())
#ifdef PADDLE_WITH_DISTRIBUTE
.
def
(
"close"
,
&
Executor
::
Close
)
.
def
(
"begin_pass"
,
&
Executor
::
BeginPass
)
.
def
(
"end_pass"
,
&
Executor
::
EndPass
)
#endif
.
def
(
"run"
,
[](
Executor
&
self
,
const
ProgramDesc
&
prog
,
Scope
*
scope
,
.
def
(
"run"
,
[](
Executor
&
self
,
const
ProgramDesc
&
prog
,
Scope
*
scope
,
int
block_id
,
bool
create_local_scope
,
bool
create_vars
)
{
int
block_id
,
bool
create_local_scope
,
bool
create_vars
)
{
pybind11
::
gil_scoped_release
release
;
pybind11
::
gil_scoped_release
release
;
...
...
python/paddle/fluid/executor.py
浏览文件 @
6133efd9
...
@@ -247,6 +247,7 @@ class Executor(object):
...
@@ -247,6 +247,7 @@ class Executor(object):
p
.
set_place
(
place
)
p
.
set_place
(
place
)
self
.
executor
=
core
.
Executor
(
p
)
self
.
executor
=
core
.
Executor
(
p
)
self
.
program_caches
=
dict
()
self
.
program_caches
=
dict
()
self
.
_closed
=
False
def
as_lodtensor
(
self
,
data
):
def
as_lodtensor
(
self
,
data
):
"""
"""
...
@@ -348,11 +349,23 @@ class Executor(object):
...
@@ -348,11 +349,23 @@ class Executor(object):
]
]
return
outs
return
outs
def
begin_pass
(
self
):
def
close
(
self
):
self
.
executor
.
begin_pass
()
"""
Close this executor.
def
end_pass
(
self
):
You can no long use this executor after calling this method.
self
.
executor
.
end_pass
()
For the distributed training, this method would free the resource on PServers related to
the current Trainer.
Example:
>>> cpu = core.CPUPlace()
>>> exe = Executor(cpu)
>>> ...
>>> exe.close()
"""
if
not
self
.
_closed
:
self
.
executor
.
close
()
self
.
_closed
=
True
def
run
(
self
,
def
run
(
self
,
program
=
None
,
program
=
None
,
...
@@ -405,6 +418,10 @@ class Executor(object):
...
@@ -405,6 +418,10 @@ class Executor(object):
>>> feed={'X': x},
>>> feed={'X': x},
>>> fetch_list=[loss.name])
>>> fetch_list=[loss.name])
"""
"""
if
self
.
_closed
:
raise
RuntimeError
(
"Attempted to use a closed Executor"
)
if
feed
is
None
:
if
feed
is
None
:
feed
=
{}
feed
=
{}
if
not
isinstance
(
feed
,
dict
):
if
not
isinstance
(
feed
,
dict
):
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录