Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
6d934560
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
6d934560
编写于
4月 26, 2018
作者:
Q
Qiao Longfei
提交者:
GitHub
4月 26, 2018
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #10042 from jacquesqiao/add-async-listen-and-serv-op
listen_and_serv_op support async update
上级
f457d5da
3295f310
变更
9
显示空白变更内容
内联
并排
Showing
9 changed file
with
220 addition
and
62 deletion
+220
-62
paddle/fluid/operators/detail/grpc_server.cc
paddle/fluid/operators/detail/grpc_server.cc
+38
-22
paddle/fluid/operators/detail/grpc_server.h
paddle/fluid/operators/detail/grpc_server.h
+3
-1
paddle/fluid/operators/detail/grpc_server_test.cc
paddle/fluid/operators/detail/grpc_server_test.cc
+1
-1
paddle/fluid/operators/detail/variable_response.h
paddle/fluid/operators/detail/variable_response.h
+5
-1
paddle/fluid/operators/listen_and_serv_op.cc
paddle/fluid/operators/listen_and_serv_op.cc
+111
-2
paddle/fluid/operators/listen_and_serv_op.h
paddle/fluid/operators/listen_and_serv_op.h
+5
-0
paddle/fluid/operators/send_op.cc
paddle/fluid/operators/send_op.cc
+9
-4
paddle/fluid/operators/send_recv_op_test.cc
paddle/fluid/operators/send_recv_op_test.cc
+2
-0
python/paddle/fluid/distribute_transpiler.py
python/paddle/fluid/distribute_transpiler.py
+46
-31
未找到文件。
paddle/fluid/operators/detail/grpc_server.cc
浏览文件 @
6d934560
...
...
@@ -30,9 +30,13 @@ enum CallStatus { PROCESS = 0, FINISH };
class
RequestBase
{
public:
explicit
RequestBase
(
GrpcService
::
AsyncService
*
service
,
::
grpc
::
ServerCompletionQueue
*
cq
,
::
grpc
::
ServerCompletionQueue
*
cq
,
bool
sync_mode
,
const
platform
::
DeviceContext
*
dev_ctx
)
:
service_
(
service
),
cq_
(
cq
),
status_
(
PROCESS
),
dev_ctx_
(
dev_ctx
)
{
:
service_
(
service
),
cq_
(
cq
),
sync_mode_
(
sync_mode
),
status_
(
PROCESS
),
dev_ctx_
(
dev_ctx
)
{
PADDLE_ENFORCE
(
cq_
);
}
virtual
~
RequestBase
()
{}
...
...
@@ -49,6 +53,7 @@ class RequestBase {
::
grpc
::
ServerContext
ctx_
;
GrpcService
::
AsyncService
*
service_
;
::
grpc
::
ServerCompletionQueue
*
cq_
;
const
bool
sync_mode_
;
CallStatus
status_
;
const
platform
::
DeviceContext
*
dev_ctx_
;
};
...
...
@@ -56,11 +61,17 @@ class RequestBase {
class
RequestSend
final
:
public
RequestBase
{
public:
explicit
RequestSend
(
GrpcService
::
AsyncService
*
service
,
::
grpc
::
ServerCompletionQueue
*
cq
,
::
grpc
::
ServerCompletionQueue
*
cq
,
bool
sync_mode
,
framework
::
Scope
*
scope
,
ReceivedQueue
*
queue
,
const
platform
::
DeviceContext
*
dev_ctx
)
:
RequestBase
(
service
,
cq
,
dev_ctx
),
queue_
(
queue
),
responder_
(
&
ctx_
)
{
request_
.
reset
(
new
VariableResponse
(
scope
,
dev_ctx_
));
:
RequestBase
(
service
,
cq
,
sync_mode
,
dev_ctx
),
queue_
(
queue
),
responder_
(
&
ctx_
)
{
if
(
sync_mode_
)
{
request_
.
reset
(
new
VariableResponse
(
scope
,
dev_ctx_
,
false
));
}
else
{
request_
.
reset
(
new
VariableResponse
(
scope
,
dev_ctx_
,
true
));
}
int
method_id
=
static_cast
<
int
>
(
detail
::
GrpcMethod
::
kSendVariable
);
service_
->
RequestAsyncUnary
(
method_id
,
&
ctx_
,
request_
.
get
(),
&
responder_
,
cq_
,
cq_
,
this
);
...
...
@@ -87,11 +98,11 @@ class RequestSend final : public RequestBase {
class
RequestGet
final
:
public
RequestBase
{
public:
explicit
RequestGet
(
GrpcService
::
AsyncService
*
service
,
::
grpc
::
ServerCompletionQueue
*
cq
,
::
grpc
::
ServerCompletionQueue
*
cq
,
bool
sync_mode
,
framework
::
Scope
*
scope
,
const
platform
::
DeviceContext
*
dev_ctx
,
framework
::
BlockingQueue
<
MessageWithName
>*
queue
)
:
RequestBase
(
service
,
cq
,
dev_ctx
),
:
RequestBase
(
service
,
cq
,
sync_mode
,
dev_ctx
),
responder_
(
&
ctx_
),
scope_
(
scope
),
queue_
(
queue
)
{
...
...
@@ -134,19 +145,23 @@ class RequestGet final : public RequestBase {
class
RequestPrefetch
final
:
public
RequestBase
{
public:
explicit
RequestPrefetch
(
GrpcService
::
AsyncService
*
service
,
::
grpc
::
ServerCompletionQueue
*
cq
,
::
grpc
::
ServerCompletionQueue
*
cq
,
bool
sync_mode
,
framework
::
Scope
*
scope
,
const
platform
::
DeviceContext
*
dev_ctx
,
framework
::
Executor
*
executor
,
framework
::
ProgramDesc
*
program
,
framework
::
ExecutorPrepareContext
*
prefetch_ctx
)
:
RequestBase
(
service
,
cq
,
dev_ctx
),
:
RequestBase
(
service
,
cq
,
sync_mode
,
dev_ctx
),
responder_
(
&
ctx_
),
scope_
(
scope
),
executor_
(
executor
),
program_
(
program
),
prefetch_ctx_
(
prefetch_ctx
)
{
request_
.
reset
(
new
VariableResponse
(
scope
,
dev_ctx_
));
if
(
sync_mode_
)
{
request_
.
reset
(
new
VariableResponse
(
scope
,
dev_ctx_
,
false
));
}
else
{
request_
.
reset
(
new
VariableResponse
(
scope
,
dev_ctx_
,
true
));
}
int
method_id
=
static_cast
<
int
>
(
detail
::
GrpcMethod
::
kPrefetchVariable
);
service_
->
RequestAsyncUnary
(
method_id
,
&
ctx_
,
request_
.
get
(),
&
responder_
,
cq_
,
cq_
,
this
);
...
...
@@ -181,7 +196,6 @@ class RequestPrefetch final : public RequestBase {
framework
::
Executor
*
executor_
;
framework
::
ProgramDesc
*
program_
;
framework
::
ExecutorPrepareContext
*
prefetch_ctx_
;
int
blkid_
;
};
void
AsyncGRPCServer
::
WaitClientGet
(
int
count
)
{
...
...
@@ -254,8 +268,8 @@ void AsyncGRPCServer::TryToRegisterNewSendOne() {
VLOG
(
3
)
<<
"shutdown, do not TryToRegisterNewSendOne"
;
return
;
}
RequestSend
*
send
=
new
RequestSend
(
&
service_
,
cq_send_
.
get
(),
s
cop
e_
,
&
var_recv_queue_
,
dev_ctx_
);
RequestSend
*
send
=
new
RequestSend
(
&
service_
,
cq_send_
.
get
(),
s
ync_mod
e_
,
scope_
,
&
var_recv_queue_
,
dev_ctx_
);
VLOG
(
4
)
<<
"Create RequestSend status:"
<<
send
->
Status
();
}
...
...
@@ -265,8 +279,8 @@ void AsyncGRPCServer::TryToRegisterNewGetOne() {
VLOG
(
3
)
<<
"shutdown, do not TryToRegisterNewGetOne"
;
return
;
}
RequestGet
*
get
=
new
RequestGet
(
&
service_
,
cq_get_
.
get
(),
s
cope_
,
dev_ctx
_
,
&
var_get_queue_
);
RequestGet
*
get
=
new
RequestGet
(
&
service_
,
cq_get_
.
get
(),
s
ync_mode_
,
scope
_
,
dev_ctx_
,
&
var_get_queue_
);
VLOG
(
4
)
<<
"Create RequestGet status:"
<<
get
->
Status
();
}
...
...
@@ -277,8 +291,8 @@ void AsyncGRPCServer::TryToRegisterNewPrefetchOne() {
return
;
}
RequestPrefetch
*
prefetch
=
new
RequestPrefetch
(
&
service_
,
cq_prefetch_
.
get
(),
s
cope_
,
dev_ctx
_
,
executor_
,
program_
,
prefetch_ctx_
);
new
RequestPrefetch
(
&
service_
,
cq_prefetch_
.
get
(),
s
ync_mode_
,
scope
_
,
dev_ctx_
,
executor_
,
program_
,
prefetch_ctx_
);
VLOG
(
4
)
<<
"Create RequestPrefetch status:"
<<
prefetch
->
Status
();
}
...
...
@@ -301,9 +315,11 @@ void AsyncGRPCServer::HandleRequest(::grpc::ServerCompletionQueue* cq,
VLOG
(
3
)
<<
"HandleRequest for "
<<
cq_name
<<
" while after Next"
;
PADDLE_ENFORCE
(
tag
);
if
(
sync_mode_
)
{
// FIXME(typhoonzero): de-couple the barriers with recv_op
if
(
!
is_shut_down_
&&
cq_name
==
"cq_get"
)
WaitCond
(
1
);
if
(
!
is_shut_down_
&&
cq_name
==
"cq_send"
)
WaitCond
(
0
);
}
RequestBase
*
base
=
reinterpret_cast
<
RequestBase
*>
(
tag
);
// reference:
...
...
@@ -320,13 +336,13 @@ void AsyncGRPCServer::HandleRequest(::grpc::ServerCompletionQueue* cq,
switch
(
base
->
Status
())
{
case
PROCESS
:
{
VLOG
(
4
)
<<
cq_name
<<
" status:"
<<
base
->
Status
();
VLOG
(
4
)
<<
cq_name
<<
"
PROCESS
status:"
<<
base
->
Status
();
TryToRegisterNewOne
();
base
->
Process
();
break
;
}
case
FINISH
:
{
VLOG
(
4
)
<<
cq_name
<<
" status:"
<<
base
->
Status
();
VLOG
(
4
)
<<
cq_name
<<
"
FINISH
status:"
<<
base
->
Status
();
delete
base
;
break
;
}
...
...
paddle/fluid/operators/detail/grpc_server.h
浏览文件 @
6d934560
...
...
@@ -44,7 +44,8 @@ class RequestBase;
class
AsyncGRPCServer
final
{
public:
explicit
AsyncGRPCServer
(
const
std
::
string
&
address
)
:
address_
(
address
)
{}
explicit
AsyncGRPCServer
(
const
std
::
string
&
address
,
bool
sync_mode
)
:
address_
(
address
),
sync_mode_
(
sync_mode
)
{}
void
RunSyncUpdate
();
...
...
@@ -95,6 +96,7 @@ class AsyncGRPCServer final {
std
::
unique_ptr
<::
grpc
::
Server
>
server_
;
std
::
string
address_
;
const
bool
sync_mode_
;
framework
::
Scope
*
scope_
;
const
platform
::
DeviceContext
*
dev_ctx_
;
...
...
paddle/fluid/operators/detail/grpc_server_test.cc
浏览文件 @
6d934560
...
...
@@ -89,7 +89,7 @@ void InitTensorsOnServer(framework::Scope* scope, platform::CPUPlace* place,
}
void
StartServer
(
const
std
::
string
&
endpoint
)
{
rpc_service_
.
reset
(
new
detail
::
AsyncGRPCServer
(
endpoint
));
rpc_service_
.
reset
(
new
detail
::
AsyncGRPCServer
(
endpoint
,
true
));
framework
::
ProgramDesc
program
;
framework
::
Scope
scope
;
platform
::
CPUPlace
place
;
...
...
paddle/fluid/operators/detail/variable_response.h
浏览文件 @
6d934560
...
...
@@ -46,7 +46,9 @@ class VariableResponse {
}
virtual
~
VariableResponse
()
{
if
(
create_scope_
)
scope_
->
DeleteScope
(
local_scope_
);
if
(
create_scope_
)
{
scope_
->
DeleteScope
(
local_scope_
);
}
}
// return:
...
...
@@ -63,6 +65,8 @@ class VariableResponse {
const
framework
::
Scope
&
GetLocalScope
()
const
{
return
*
local_scope_
;
}
framework
::
Scope
*
GetMutableLocalScope
()
const
{
return
local_scope_
;
}
inline
std
::
string
Varname
()
{
return
meta_
.
varname
();
}
inline
std
::
string
OutVarname
()
{
return
meta_
.
out_varname
();
}
...
...
paddle/fluid/operators/listen_and_serv_op.cc
浏览文件 @
6d934560
...
...
@@ -27,6 +27,38 @@ void RunServer(std::shared_ptr<detail::AsyncGRPCServer> service) {
VLOG
(
4
)
<<
"RunServer thread end"
;
}
static
void
split
(
const
std
::
string
&
str
,
char
sep
,
std
::
vector
<
std
::
string
>
*
pieces
)
{
pieces
->
clear
();
if
(
str
.
empty
())
{
return
;
}
size_t
pos
=
0
;
size_t
next
=
str
.
find
(
sep
,
pos
);
while
(
next
!=
std
::
string
::
npos
)
{
pieces
->
push_back
(
str
.
substr
(
pos
,
next
-
pos
));
pos
=
next
+
1
;
next
=
str
.
find
(
sep
,
pos
);
}
if
(
!
str
.
substr
(
pos
).
empty
())
{
pieces
->
push_back
(
str
.
substr
(
pos
));
}
}
static
void
AsyncExecuteBlock
(
framework
::
Executor
*
executor
,
framework
::
ExecutorPrepareContext
*
prepared
,
framework
::
Scope
*
scope
)
{
std
::
future
<
void
>
future
=
framework
::
Async
([
&
executor
,
&
prepared
,
&
scope
]()
{
try
{
executor
->
RunPreparedContext
(
prepared
,
scope
,
false
,
false
);
}
catch
(
std
::
exception
&
e
)
{
LOG
(
ERROR
)
<<
"run sub program error "
<<
e
.
what
();
}
});
// TODO(qiao) maybe we can remove this
future
.
wait
();
}
static
void
ParallelExecuteBlocks
(
const
std
::
vector
<
size_t
>
&
parallel_blkids
,
framework
::
Executor
*
executor
,
const
std
::
vector
<
std
::
shared_ptr
<
framework
::
ExecutorPrepareContext
>>
...
...
@@ -169,15 +201,82 @@ void ListenAndServOp::RunSyncLoop(framework::Executor *executor,
}
// while(true)
}
void
ListenAndServOp
::
RunAsyncLoop
(
framework
::
Executor
*
executor
,
framework
::
ProgramDesc
*
program
,
framework
::
Scope
*
recv_scope
,
framework
::
BlockDesc
*
prefetch_block
)
const
{
VLOG
(
3
)
<<
"RunAsyncLoop in"
;
// grad name to block id
std
::
unordered_map
<
std
::
string
,
int32_t
>
grad_to_block_id
;
std
::
unordered_map
<
int32_t
,
std
::
string
>
id_to_grad
;
auto
grad_to_block_id_str
=
Attr
<
std
::
vector
<
std
::
string
>>
(
"grad_to_block_id"
);
for
(
auto
&
grad_and_id
:
grad_to_block_id_str
)
{
std
::
vector
<
std
::
string
>
pieces
;
split
(
grad_and_id
,
':'
,
&
pieces
);
VLOG
(
3
)
<<
"after split, grad = "
<<
pieces
[
0
]
<<
", id="
<<
pieces
[
1
];
PADDLE_ENFORCE_EQ
(
pieces
.
size
(),
2
);
PADDLE_ENFORCE_EQ
(
grad_to_block_id
.
count
(
pieces
[
0
]),
0
);
int
block_id
=
std
::
stoi
(
pieces
[
1
]);
grad_to_block_id
[
pieces
[
0
]]
=
block_id
;
id_to_grad
[
block_id
]
=
pieces
[
0
];
}
size_t
num_blocks
=
program
->
Size
();
PADDLE_ENFORCE_GE
(
num_blocks
,
2
,
"server program should have at least 2 blocks"
);
std
::
vector
<
int
>
block_list
;
for
(
size_t
blkid
=
1
;
blkid
<
num_blocks
;
++
blkid
)
{
block_list
.
push_back
(
blkid
);
}
auto
optimize_prepared
=
executor
->
Prepare
(
*
program
,
block_list
);
std
::
unordered_map
<
std
::
string
,
std
::
shared_ptr
<
framework
::
ExecutorPrepareContext
>>
grad_to_prepared_ctx
;
for
(
size_t
i
=
0
;
i
<
block_list
.
size
();
++
i
)
{
grad_to_prepared_ctx
[
id_to_grad
[
block_list
[
i
]]]
=
optimize_prepared
[
i
];
}
VLOG
(
3
)
<<
"RunAsyncLoop into while"
;
bool
exit_flag
=
false
;
while
(
!
exit_flag
)
{
const
detail
::
ReceivedMessage
v
=
rpc_service_
->
Get
();
auto
recv_var_name
=
v
.
first
;
if
(
recv_var_name
==
LISTEN_TERMINATE_MESSAGE
)
{
LOG
(
INFO
)
<<
"received terminate message and exit"
;
exit_flag
=
true
;
break
;
}
else
{
VLOG
(
3
)
<<
"received grad: "
<<
recv_var_name
;
auto
var
=
v
.
second
->
GetVar
();
if
(
var
==
nullptr
)
{
LOG
(
ERROR
)
<<
"Can not find server side var: "
<<
recv_var_name
;
PADDLE_THROW
(
"Can not find server side var"
);
}
AsyncExecuteBlock
(
executor
,
grad_to_prepared_ctx
[
recv_var_name
].
get
(),
v
.
second
->
GetMutableLocalScope
());
}
if
(
exit_flag
)
{
rpc_service_
->
ShutDown
();
break
;
}
}
// while(true)
}
void
ListenAndServOp
::
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
dev_place
)
const
{
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
auto
&
dev_ctx
=
*
pool
.
Get
(
dev_place
);
framework
::
Scope
&
recv_scope
=
scope
.
NewScope
();
bool
sync_mode
=
Attr
<
bool
>
(
"sync_mode"
);
PADDLE_ENFORCE
(
!
rpc_service_
);
std
::
string
endpoint
=
Attr
<
std
::
string
>
(
"endpoint"
);
rpc_service_
.
reset
(
new
detail
::
AsyncGRPCServer
(
endpoint
));
rpc_service_
.
reset
(
new
detail
::
AsyncGRPCServer
(
endpoint
,
sync_mode
));
auto
*
optimize_block
=
Attr
<
framework
::
BlockDesc
*>
(
kOptimizeBlock
);
auto
*
prefetch_block
=
Attr
<
framework
::
BlockDesc
*>
(
kPrefetchBlock
);
...
...
@@ -202,7 +301,11 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
sleep
(
5
);
// Write to a file of server selected port for python use.
SavePort
(
rpc_service_
);
if
(
sync_mode
)
{
RunSyncLoop
(
&
executor
,
program
,
&
recv_scope
,
prefetch_block
);
}
else
{
RunAsyncLoop
(
&
executor
,
program
,
&
recv_scope
,
prefetch_block
);
}
}
class
ListenAndServOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
...
...
@@ -221,6 +324,12 @@ from send_op and send back variables to recv_op.
"IP address to listen on."
)
.
SetDefault
(
"127.0.0.1:6164"
)
.
AddCustomChecker
([](
const
std
::
string
&
ip
)
{
return
!
ip
.
empty
();
});
AddAttr
<
std
::
vector
<
std
::
string
>>
(
"grad_to_block_id"
,
"['param1@GRAD.block0:1', 'param2@GRAD.blockn:2'] "
"a map from grad name to it's optimize block id"
)
.
SetDefault
({});
AddAttr
<
bool
>
(
"sync_mode"
,
"if works at sync_mode or not"
).
SetDefault
(
true
);
AddAttr
<
framework
::
BlockDesc
*>
(
kOptimizeBlock
,
"BlockID to run on server side."
);
AddAttr
<
framework
::
BlockDesc
*>
(
kPrefetchBlock
,
...
...
paddle/fluid/operators/listen_and_serv_op.h
浏览文件 @
6d934560
...
...
@@ -46,6 +46,11 @@ class ListenAndServOp : public framework::OperatorBase {
framework
::
Scope
*
recv_scope
,
framework
::
BlockDesc
*
prefetch_block
)
const
;
void
RunAsyncLoop
(
framework
::
Executor
*
executor
,
framework
::
ProgramDesc
*
program
,
framework
::
Scope
*
recv_scope
,
framework
::
BlockDesc
*
prefetch_block
)
const
;
void
Stop
()
override
;
void
RunImpl
(
const
framework
::
Scope
&
scope
,
...
...
paddle/fluid/operators/send_op.cc
浏览文件 @
6d934560
...
...
@@ -41,6 +41,8 @@ class SendOp : public framework::OperatorBase {
std
::
vector
<
std
::
string
>
endpoints
=
Attr
<
std
::
vector
<
std
::
string
>>
(
"endpoints"
);
bool
sync_mode
=
Attr
<
bool
>
(
"sync_mode"
);
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
auto
&
ctx
=
*
pool
.
Get
(
place
);
...
...
@@ -64,11 +66,13 @@ class SendOp : public framework::OperatorBase {
}
PADDLE_ENFORCE
(
rpc_client
->
Wait
());
if
(
sync_mode
)
{
for
(
auto
&
ep
:
endpoints
)
{
VLOG
(
3
)
<<
"batch barrier, ep: "
<<
ep
;
rpc_client
->
AsyncSendBatchBarrier
(
ep
);
}
PADDLE_ENFORCE
(
rpc_client
->
Wait
());
}
if
(
outs
.
size
()
>
0
)
{
for
(
size_t
i
=
0
;
i
<
outs
.
size
();
i
++
)
{
...
...
@@ -112,6 +116,7 @@ This operator will send tensor to recv_op at the parameter server.
"Server endpoints in the order of input "
"variables for mapping"
)
.
SetDefault
({});
AddAttr
<
bool
>
(
"sync_mode"
,
"work in sync_mode or not"
).
SetDefault
(
true
);
}
};
...
...
paddle/fluid/operators/send_recv_op_test.cc
浏览文件 @
6d934560
...
...
@@ -137,6 +137,8 @@ void StartServerNet(bool is_sparse) {
attrs
.
insert
({
"GradList"
,
std
::
vector
<
std
::
string
>
({
"x1"
})});
attrs
.
insert
({
"OptimizeBlock"
,
optimize_block
});
attrs
.
insert
({
"PrefetchBlock"
,
prefetch_block
});
attrs
.
insert
({
"grad_to_block_id"
,
std
::
vector
<
std
::
string
>
({
""
})});
attrs
.
insert
({
"sync_mode"
,
true
});
listen_and_serv_op
=
f
::
OpRegistry
::
CreateOp
(
"listen_and_serv"
,
{{
"X"
,
{
"x1"
}}},
{},
attrs
);
listen_and_serv_op
->
Run
(
scope
,
place
);
...
...
python/paddle/fluid/distribute_transpiler.py
浏览文件 @
6d934560
...
...
@@ -143,7 +143,8 @@ class DistributeTranspiler:
program
=
None
,
pservers
=
"127.0.0.1:6174"
,
trainers
=
1
,
split_method
=
splitter
.
round_robin
):
split_method
=
splitter
.
round_robin
,
sync_mode
=
True
):
"""
Transpile the program to distributed data-parallelism programs.
The main_program will be transformed to use a remote parameter server
...
...
@@ -184,6 +185,9 @@ class DistributeTranspiler:
:param split_method: A function to determin how to split variables
to different servers equally.
:type split_method: function
:type sync_mode: boolean default True
:param sync_mode: if sync_mode is set True, it means that dist transpiler
will transpile the program into sync_mode pserver and trainer program.
"""
assert
(
callable
(
split_method
))
if
program
is
None
:
...
...
@@ -191,6 +195,7 @@ class DistributeTranspiler:
self
.
origin_program
=
program
self
.
trainer_num
=
trainers
self
.
optimize_ops
=
optimize_ops
self
.
sync_mode
=
sync_mode
# TODO(typhoonzero): currently trainer_id is fetched from cluster system
# like Kubernetes, we should port this to use etcd later when developing
# fluid distributed training with fault-tolerance.
...
...
@@ -295,8 +300,11 @@ class DistributeTranspiler:
inputs
=
{
"X"
:
send_inputs
},
outputs
=
{
"Out"
:
send_outputs
,
"RPCClient"
:
rpc_client_var
},
attrs
=
{
"endpoints"
:
pserver_endpoints
,
"epmap"
:
eplist
})
attrs
=
{
"endpoints"
:
pserver_endpoints
,
"epmap"
:
eplist
,
"sync_mode"
:
self
.
sync_mode
})
# step4: Concat the parameters splits together after recv.
for
varname
,
splited_var
in
param_var_mapping
.
iteritems
():
if
len
(
splited_var
)
<=
1
:
...
...
@@ -356,7 +364,7 @@ class DistributeTranspiler:
type
=
v
.
type
,
dtype
=
v
.
dtype
,
shape
=
v
.
shape
)
if
self
.
trainer_num
>
1
:
if
self
.
sync_mode
and
self
.
trainer_num
>
1
:
for
trainer_id
in
xrange
(
self
.
trainer_num
):
var
=
pserver_program
.
global_block
().
create_var
(
name
=
"%s.trainer_%d"
%
(
orig_var_name
,
trainer_id
),
...
...
@@ -402,13 +410,13 @@ class DistributeTranspiler:
for
op
in
self
.
optimize_ops
:
if
op
.
type
==
"scale"
:
for
in_name
in
op
.
input_arg_names
:
if
in_name
.
startswith
(
"beta1_pow_acc"
)
or
\
if
in_name
.
startswith
(
"beta1_pow_acc"
)
or
\
in_name
.
startswith
(
"beta2_pow_acc"
):
global_ops
.
append
(
op
)
def
__append_optimize_op__
(
op
,
block
):
def
__append_optimize_op__
(
op
,
block
,
grad_to_block_id
):
if
self
.
_is_opt_op
(
op
):
self
.
_append_pserver_ops
(
block
,
op
,
endpoint
,
self
.
_append_pserver_ops
(
block
,
op
,
endpoint
,
grad_to_block_id
,
default_main_program
())
else
:
self
.
_append_pserver_non_opt_ops
(
block
,
op
)
...
...
@@ -422,16 +430,16 @@ class DistributeTranspiler:
self
.
_append_pserver_non_opt_ops
(
lr_decay_block
,
op
)
# append op to the current block
grad_to_block_id
=
[]
pre_block_idx
=
pserver_program
.
num_blocks
-
1
for
idx
,
opt_op
in
enumerate
(
opt_op_on_pserver
):
per_opt_block
=
pserver_program
.
create_block
(
pre_block_idx
)
for
_
,
op
in
enumerate
(
self
.
optimize_ops
):
# optimizer is connected to itself
if
ufind
.
is_connected
(
op
,
opt_op
)
and
op
not
in
global_ops
:
__append_optimize_op__
(
op
,
per_opt_block
)
__append_optimize_op__
(
op
,
per_opt_block
,
grad_to_block_id
)
# append global ops
opt_state_block
=
None
if
global_ops
:
opt_state_block
=
pserver_program
.
create_block
(
pserver_program
.
num_blocks
-
1
)
...
...
@@ -472,7 +480,9 @@ class DistributeTranspiler:
"OptimizeBlock"
:
pserver_program
.
block
(
1
),
"endpoint"
:
endpoint
,
"Fanin"
:
self
.
trainer_num
,
"PrefetchBlock"
:
prefetch_block
"PrefetchBlock"
:
prefetch_block
,
"sync_mode"
:
self
.
sync_mode
,
"grad_to_block_id"
:
grad_to_block_id
})
pserver_program
.
sync_with_cpp
()
...
...
@@ -683,6 +693,16 @@ class DistributeTranspiler:
self
.
table_name
)],
persistable
=
False
)
# create table optimize block in pserver program
table_opt_op
=
[
op
for
op
in
self
.
optimize_ops
if
op
.
input
(
"Param"
)[
0
]
==
self
.
table_name
][
0
]
table_opt_block
=
pserver_program
.
create_block
(
pre_block_idx
)
# only support sgd now
assert
table_opt_op
.
type
==
"sgd"
if
self
.
sync_mode
:
# create grad vars in pserver program
table_grad_var
=
self
.
table_param_grad
[
1
]
table_grad_list
=
[
...
...
@@ -691,18 +711,10 @@ class DistributeTranspiler:
(
table_grad_var
.
name
,
index
,
pserver_index
),
type
=
table_grad_var
.
type
,
shape
=
table_grad_var
.
shape
,
dtype
=
table_grad_var
.
dtype
)
for
index
in
range
(
self
.
trainer_num
)
dtype
=
table_grad_var
.
dtype
)
for
index
in
range
(
self
.
trainer_num
)
]
# create table optimize block in pserver program
table_opt_op
=
[
op
for
op
in
self
.
optimize_ops
if
op
.
input
(
"Param"
)[
0
]
==
self
.
table_name
][
0
]
table_opt_block
=
pserver_program
.
create_block
(
pre_block_idx
)
# only support sgd now
assert
table_opt_op
.
type
==
"sgd"
# append sum op for table_grad_list
table_opt_block
.
append_op
(
type
=
"sum"
,
...
...
@@ -746,7 +758,7 @@ class DistributeTranspiler:
for
varname
,
splited
in
block_map
.
iteritems
():
orig_var
=
program
.
global_block
().
var
(
varname
)
if
len
(
splited
)
==
1
:
if
add_trainer_suffix
:
if
self
.
sync_mode
and
add_trainer_suffix
:
new_var_name
=
"%s.trainer_%d"
%
\
(
orig_var
.
name
,
self
.
trainer_id
)
program
.
global_block
().
rename_var
(
varname
,
new_var_name
)
...
...
@@ -770,7 +782,7 @@ class DistributeTranspiler:
if
len
(
orig_shape
)
>=
2
:
splited_shape
.
extend
(
orig_shape
[
1
:])
new_var_name
=
""
if
add_trainer_suffix
:
if
self
.
sync_mode
and
add_trainer_suffix
:
new_var_name
=
"%s.block%d.trainer_%d"
%
\
(
varname
,
i
,
self
.
trainer_id
)
else
:
...
...
@@ -879,7 +891,7 @@ class DistributeTranspiler:
return
orig_var_name
def
_append_pserver_ops
(
self
,
optimize_block
,
opt_op
,
endpoint
,
origin_program
):
grad_to_block_id
,
origin_program
):
program
=
optimize_block
.
program
pserver_block
=
program
.
global_block
()
new_inputs
=
dict
()
...
...
@@ -900,7 +912,9 @@ class DistributeTranspiler:
return
merged_var
=
\
pserver_block
.
vars
[
self
.
_orig_varname
(
grad_block
.
name
)]
if
self
.
trainer_num
>
1
:
grad_to_block_id
.
append
(
merged_var
.
name
+
":"
+
str
(
optimize_block
.
idx
))
if
self
.
sync_mode
and
self
.
trainer_num
>
1
:
vars2merge
=
[]
for
i
in
xrange
(
self
.
trainer_num
):
per_trainer_name
=
"%s.trainer_%d"
%
\
...
...
@@ -918,6 +932,7 @@ class DistributeTranspiler:
inputs
=
{
"X"
:
merged_var
},
outputs
=
{
"Out"
:
merged_var
},
attrs
=
{
"scale"
:
1.0
/
float
(
self
.
trainer_num
)})
new_inputs
[
key
]
=
merged_var
elif
key
==
"Param"
:
# param is already created on global program
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录