Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
6d934560
P
Paddle
项目概览
PaddlePaddle
/
Paddle
1 年多 前同步成功
通知
2302
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
6d934560
编写于
4月 26, 2018
作者:
Q
Qiao Longfei
提交者:
GitHub
4月 26, 2018
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #10042 from jacquesqiao/add-async-listen-and-serv-op
listen_and_serv_op support async update
上级
f457d5da
3295f310
变更
9
隐藏空白更改
内联
并排
Showing
9 changed file
with
220 addition
and
62 deletion
+220
-62
paddle/fluid/operators/detail/grpc_server.cc
paddle/fluid/operators/detail/grpc_server.cc
+38
-22
paddle/fluid/operators/detail/grpc_server.h
paddle/fluid/operators/detail/grpc_server.h
+3
-1
paddle/fluid/operators/detail/grpc_server_test.cc
paddle/fluid/operators/detail/grpc_server_test.cc
+1
-1
paddle/fluid/operators/detail/variable_response.h
paddle/fluid/operators/detail/variable_response.h
+5
-1
paddle/fluid/operators/listen_and_serv_op.cc
paddle/fluid/operators/listen_and_serv_op.cc
+111
-2
paddle/fluid/operators/listen_and_serv_op.h
paddle/fluid/operators/listen_and_serv_op.h
+5
-0
paddle/fluid/operators/send_op.cc
paddle/fluid/operators/send_op.cc
+9
-4
paddle/fluid/operators/send_recv_op_test.cc
paddle/fluid/operators/send_recv_op_test.cc
+2
-0
python/paddle/fluid/distribute_transpiler.py
python/paddle/fluid/distribute_transpiler.py
+46
-31
未找到文件。
paddle/fluid/operators/detail/grpc_server.cc
浏览文件 @
6d934560
...
...
@@ -30,9 +30,13 @@ enum CallStatus { PROCESS = 0, FINISH };
class
RequestBase
{
public:
explicit
RequestBase
(
GrpcService
::
AsyncService
*
service
,
::
grpc
::
ServerCompletionQueue
*
cq
,
::
grpc
::
ServerCompletionQueue
*
cq
,
bool
sync_mode
,
const
platform
::
DeviceContext
*
dev_ctx
)
:
service_
(
service
),
cq_
(
cq
),
status_
(
PROCESS
),
dev_ctx_
(
dev_ctx
)
{
:
service_
(
service
),
cq_
(
cq
),
sync_mode_
(
sync_mode
),
status_
(
PROCESS
),
dev_ctx_
(
dev_ctx
)
{
PADDLE_ENFORCE
(
cq_
);
}
virtual
~
RequestBase
()
{}
...
...
@@ -49,6 +53,7 @@ class RequestBase {
::
grpc
::
ServerContext
ctx_
;
GrpcService
::
AsyncService
*
service_
;
::
grpc
::
ServerCompletionQueue
*
cq_
;
const
bool
sync_mode_
;
CallStatus
status_
;
const
platform
::
DeviceContext
*
dev_ctx_
;
};
...
...
@@ -56,11 +61,17 @@ class RequestBase {
class
RequestSend
final
:
public
RequestBase
{
public:
explicit
RequestSend
(
GrpcService
::
AsyncService
*
service
,
::
grpc
::
ServerCompletionQueue
*
cq
,
::
grpc
::
ServerCompletionQueue
*
cq
,
bool
sync_mode
,
framework
::
Scope
*
scope
,
ReceivedQueue
*
queue
,
const
platform
::
DeviceContext
*
dev_ctx
)
:
RequestBase
(
service
,
cq
,
dev_ctx
),
queue_
(
queue
),
responder_
(
&
ctx_
)
{
request_
.
reset
(
new
VariableResponse
(
scope
,
dev_ctx_
));
:
RequestBase
(
service
,
cq
,
sync_mode
,
dev_ctx
),
queue_
(
queue
),
responder_
(
&
ctx_
)
{
if
(
sync_mode_
)
{
request_
.
reset
(
new
VariableResponse
(
scope
,
dev_ctx_
,
false
));
}
else
{
request_
.
reset
(
new
VariableResponse
(
scope
,
dev_ctx_
,
true
));
}
int
method_id
=
static_cast
<
int
>
(
detail
::
GrpcMethod
::
kSendVariable
);
service_
->
RequestAsyncUnary
(
method_id
,
&
ctx_
,
request_
.
get
(),
&
responder_
,
cq_
,
cq_
,
this
);
...
...
@@ -87,11 +98,11 @@ class RequestSend final : public RequestBase {
class
RequestGet
final
:
public
RequestBase
{
public:
explicit
RequestGet
(
GrpcService
::
AsyncService
*
service
,
::
grpc
::
ServerCompletionQueue
*
cq
,
::
grpc
::
ServerCompletionQueue
*
cq
,
bool
sync_mode
,
framework
::
Scope
*
scope
,
const
platform
::
DeviceContext
*
dev_ctx
,
framework
::
BlockingQueue
<
MessageWithName
>*
queue
)
:
RequestBase
(
service
,
cq
,
dev_ctx
),
:
RequestBase
(
service
,
cq
,
sync_mode
,
dev_ctx
),
responder_
(
&
ctx_
),
scope_
(
scope
),
queue_
(
queue
)
{
...
...
@@ -134,19 +145,23 @@ class RequestGet final : public RequestBase {
class
RequestPrefetch
final
:
public
RequestBase
{
public:
explicit
RequestPrefetch
(
GrpcService
::
AsyncService
*
service
,
::
grpc
::
ServerCompletionQueue
*
cq
,
::
grpc
::
ServerCompletionQueue
*
cq
,
bool
sync_mode
,
framework
::
Scope
*
scope
,
const
platform
::
DeviceContext
*
dev_ctx
,
framework
::
Executor
*
executor
,
framework
::
ProgramDesc
*
program
,
framework
::
ExecutorPrepareContext
*
prefetch_ctx
)
:
RequestBase
(
service
,
cq
,
dev_ctx
),
:
RequestBase
(
service
,
cq
,
sync_mode
,
dev_ctx
),
responder_
(
&
ctx_
),
scope_
(
scope
),
executor_
(
executor
),
program_
(
program
),
prefetch_ctx_
(
prefetch_ctx
)
{
request_
.
reset
(
new
VariableResponse
(
scope
,
dev_ctx_
));
if
(
sync_mode_
)
{
request_
.
reset
(
new
VariableResponse
(
scope
,
dev_ctx_
,
false
));
}
else
{
request_
.
reset
(
new
VariableResponse
(
scope
,
dev_ctx_
,
true
));
}
int
method_id
=
static_cast
<
int
>
(
detail
::
GrpcMethod
::
kPrefetchVariable
);
service_
->
RequestAsyncUnary
(
method_id
,
&
ctx_
,
request_
.
get
(),
&
responder_
,
cq_
,
cq_
,
this
);
...
...
@@ -181,7 +196,6 @@ class RequestPrefetch final : public RequestBase {
framework
::
Executor
*
executor_
;
framework
::
ProgramDesc
*
program_
;
framework
::
ExecutorPrepareContext
*
prefetch_ctx_
;
int
blkid_
;
};
void
AsyncGRPCServer
::
WaitClientGet
(
int
count
)
{
...
...
@@ -254,8 +268,8 @@ void AsyncGRPCServer::TryToRegisterNewSendOne() {
VLOG
(
3
)
<<
"shutdown, do not TryToRegisterNewSendOne"
;
return
;
}
RequestSend
*
send
=
new
RequestSend
(
&
service_
,
cq_send_
.
get
(),
s
cop
e_
,
&
var_recv_queue_
,
dev_ctx_
);
RequestSend
*
send
=
new
RequestSend
(
&
service_
,
cq_send_
.
get
(),
s
ync_mod
e_
,
scope_
,
&
var_recv_queue_
,
dev_ctx_
);
VLOG
(
4
)
<<
"Create RequestSend status:"
<<
send
->
Status
();
}
...
...
@@ -265,8 +279,8 @@ void AsyncGRPCServer::TryToRegisterNewGetOne() {
VLOG
(
3
)
<<
"shutdown, do not TryToRegisterNewGetOne"
;
return
;
}
RequestGet
*
get
=
new
RequestGet
(
&
service_
,
cq_get_
.
get
(),
s
cope_
,
dev_ctx
_
,
&
var_get_queue_
);
RequestGet
*
get
=
new
RequestGet
(
&
service_
,
cq_get_
.
get
(),
s
ync_mode_
,
scope
_
,
dev_ctx_
,
&
var_get_queue_
);
VLOG
(
4
)
<<
"Create RequestGet status:"
<<
get
->
Status
();
}
...
...
@@ -277,8 +291,8 @@ void AsyncGRPCServer::TryToRegisterNewPrefetchOne() {
return
;
}
RequestPrefetch
*
prefetch
=
new
RequestPrefetch
(
&
service_
,
cq_prefetch_
.
get
(),
s
cope_
,
dev_ctx
_
,
executor_
,
program_
,
prefetch_ctx_
);
new
RequestPrefetch
(
&
service_
,
cq_prefetch_
.
get
(),
s
ync_mode_
,
scope
_
,
dev_ctx_
,
executor_
,
program_
,
prefetch_ctx_
);
VLOG
(
4
)
<<
"Create RequestPrefetch status:"
<<
prefetch
->
Status
();
}
...
...
@@ -301,9 +315,11 @@ void AsyncGRPCServer::HandleRequest(::grpc::ServerCompletionQueue* cq,
VLOG
(
3
)
<<
"HandleRequest for "
<<
cq_name
<<
" while after Next"
;
PADDLE_ENFORCE
(
tag
);
// FIXME(typhoonzero): de-couple the barriers with recv_op
if
(
!
is_shut_down_
&&
cq_name
==
"cq_get"
)
WaitCond
(
1
);
if
(
!
is_shut_down_
&&
cq_name
==
"cq_send"
)
WaitCond
(
0
);
if
(
sync_mode_
)
{
// FIXME(typhoonzero): de-couple the barriers with recv_op
if
(
!
is_shut_down_
&&
cq_name
==
"cq_get"
)
WaitCond
(
1
);
if
(
!
is_shut_down_
&&
cq_name
==
"cq_send"
)
WaitCond
(
0
);
}
RequestBase
*
base
=
reinterpret_cast
<
RequestBase
*>
(
tag
);
// reference:
...
...
@@ -320,13 +336,13 @@ void AsyncGRPCServer::HandleRequest(::grpc::ServerCompletionQueue* cq,
switch
(
base
->
Status
())
{
case
PROCESS
:
{
VLOG
(
4
)
<<
cq_name
<<
" status:"
<<
base
->
Status
();
VLOG
(
4
)
<<
cq_name
<<
"
PROCESS
status:"
<<
base
->
Status
();
TryToRegisterNewOne
();
base
->
Process
();
break
;
}
case
FINISH
:
{
VLOG
(
4
)
<<
cq_name
<<
" status:"
<<
base
->
Status
();
VLOG
(
4
)
<<
cq_name
<<
"
FINISH
status:"
<<
base
->
Status
();
delete
base
;
break
;
}
...
...
paddle/fluid/operators/detail/grpc_server.h
浏览文件 @
6d934560
...
...
@@ -44,7 +44,8 @@ class RequestBase;
class
AsyncGRPCServer
final
{
public:
explicit
AsyncGRPCServer
(
const
std
::
string
&
address
)
:
address_
(
address
)
{}
explicit
AsyncGRPCServer
(
const
std
::
string
&
address
,
bool
sync_mode
)
:
address_
(
address
),
sync_mode_
(
sync_mode
)
{}
void
RunSyncUpdate
();
...
...
@@ -95,6 +96,7 @@ class AsyncGRPCServer final {
std
::
unique_ptr
<::
grpc
::
Server
>
server_
;
std
::
string
address_
;
const
bool
sync_mode_
;
framework
::
Scope
*
scope_
;
const
platform
::
DeviceContext
*
dev_ctx_
;
...
...
paddle/fluid/operators/detail/grpc_server_test.cc
浏览文件 @
6d934560
...
...
@@ -89,7 +89,7 @@ void InitTensorsOnServer(framework::Scope* scope, platform::CPUPlace* place,
}
void
StartServer
(
const
std
::
string
&
endpoint
)
{
rpc_service_
.
reset
(
new
detail
::
AsyncGRPCServer
(
endpoint
));
rpc_service_
.
reset
(
new
detail
::
AsyncGRPCServer
(
endpoint
,
true
));
framework
::
ProgramDesc
program
;
framework
::
Scope
scope
;
platform
::
CPUPlace
place
;
...
...
paddle/fluid/operators/detail/variable_response.h
浏览文件 @
6d934560
...
...
@@ -46,7 +46,9 @@ class VariableResponse {
}
virtual
~
VariableResponse
()
{
if
(
create_scope_
)
scope_
->
DeleteScope
(
local_scope_
);
if
(
create_scope_
)
{
scope_
->
DeleteScope
(
local_scope_
);
}
}
// return:
...
...
@@ -63,6 +65,8 @@ class VariableResponse {
const
framework
::
Scope
&
GetLocalScope
()
const
{
return
*
local_scope_
;
}
framework
::
Scope
*
GetMutableLocalScope
()
const
{
return
local_scope_
;
}
inline
std
::
string
Varname
()
{
return
meta_
.
varname
();
}
inline
std
::
string
OutVarname
()
{
return
meta_
.
out_varname
();
}
...
...
paddle/fluid/operators/listen_and_serv_op.cc
浏览文件 @
6d934560
...
...
@@ -27,6 +27,38 @@ void RunServer(std::shared_ptr<detail::AsyncGRPCServer> service) {
VLOG
(
4
)
<<
"RunServer thread end"
;
}
static
void
split
(
const
std
::
string
&
str
,
char
sep
,
std
::
vector
<
std
::
string
>
*
pieces
)
{
pieces
->
clear
();
if
(
str
.
empty
())
{
return
;
}
size_t
pos
=
0
;
size_t
next
=
str
.
find
(
sep
,
pos
);
while
(
next
!=
std
::
string
::
npos
)
{
pieces
->
push_back
(
str
.
substr
(
pos
,
next
-
pos
));
pos
=
next
+
1
;
next
=
str
.
find
(
sep
,
pos
);
}
if
(
!
str
.
substr
(
pos
).
empty
())
{
pieces
->
push_back
(
str
.
substr
(
pos
));
}
}
static
void
AsyncExecuteBlock
(
framework
::
Executor
*
executor
,
framework
::
ExecutorPrepareContext
*
prepared
,
framework
::
Scope
*
scope
)
{
std
::
future
<
void
>
future
=
framework
::
Async
([
&
executor
,
&
prepared
,
&
scope
]()
{
try
{
executor
->
RunPreparedContext
(
prepared
,
scope
,
false
,
false
);
}
catch
(
std
::
exception
&
e
)
{
LOG
(
ERROR
)
<<
"run sub program error "
<<
e
.
what
();
}
});
// TODO(qiao) maybe we can remove this
future
.
wait
();
}
static
void
ParallelExecuteBlocks
(
const
std
::
vector
<
size_t
>
&
parallel_blkids
,
framework
::
Executor
*
executor
,
const
std
::
vector
<
std
::
shared_ptr
<
framework
::
ExecutorPrepareContext
>>
...
...
@@ -169,15 +201,82 @@ void ListenAndServOp::RunSyncLoop(framework::Executor *executor,
}
// while(true)
}
void
ListenAndServOp
::
RunAsyncLoop
(
framework
::
Executor
*
executor
,
framework
::
ProgramDesc
*
program
,
framework
::
Scope
*
recv_scope
,
framework
::
BlockDesc
*
prefetch_block
)
const
{
VLOG
(
3
)
<<
"RunAsyncLoop in"
;
// grad name to block id
std
::
unordered_map
<
std
::
string
,
int32_t
>
grad_to_block_id
;
std
::
unordered_map
<
int32_t
,
std
::
string
>
id_to_grad
;
auto
grad_to_block_id_str
=
Attr
<
std
::
vector
<
std
::
string
>>
(
"grad_to_block_id"
);
for
(
auto
&
grad_and_id
:
grad_to_block_id_str
)
{
std
::
vector
<
std
::
string
>
pieces
;
split
(
grad_and_id
,
':'
,
&
pieces
);
VLOG
(
3
)
<<
"after split, grad = "
<<
pieces
[
0
]
<<
", id="
<<
pieces
[
1
];
PADDLE_ENFORCE_EQ
(
pieces
.
size
(),
2
);
PADDLE_ENFORCE_EQ
(
grad_to_block_id
.
count
(
pieces
[
0
]),
0
);
int
block_id
=
std
::
stoi
(
pieces
[
1
]);
grad_to_block_id
[
pieces
[
0
]]
=
block_id
;
id_to_grad
[
block_id
]
=
pieces
[
0
];
}
size_t
num_blocks
=
program
->
Size
();
PADDLE_ENFORCE_GE
(
num_blocks
,
2
,
"server program should have at least 2 blocks"
);
std
::
vector
<
int
>
block_list
;
for
(
size_t
blkid
=
1
;
blkid
<
num_blocks
;
++
blkid
)
{
block_list
.
push_back
(
blkid
);
}
auto
optimize_prepared
=
executor
->
Prepare
(
*
program
,
block_list
);
std
::
unordered_map
<
std
::
string
,
std
::
shared_ptr
<
framework
::
ExecutorPrepareContext
>>
grad_to_prepared_ctx
;
for
(
size_t
i
=
0
;
i
<
block_list
.
size
();
++
i
)
{
grad_to_prepared_ctx
[
id_to_grad
[
block_list
[
i
]]]
=
optimize_prepared
[
i
];
}
VLOG
(
3
)
<<
"RunAsyncLoop into while"
;
bool
exit_flag
=
false
;
while
(
!
exit_flag
)
{
const
detail
::
ReceivedMessage
v
=
rpc_service_
->
Get
();
auto
recv_var_name
=
v
.
first
;
if
(
recv_var_name
==
LISTEN_TERMINATE_MESSAGE
)
{
LOG
(
INFO
)
<<
"received terminate message and exit"
;
exit_flag
=
true
;
break
;
}
else
{
VLOG
(
3
)
<<
"received grad: "
<<
recv_var_name
;
auto
var
=
v
.
second
->
GetVar
();
if
(
var
==
nullptr
)
{
LOG
(
ERROR
)
<<
"Can not find server side var: "
<<
recv_var_name
;
PADDLE_THROW
(
"Can not find server side var"
);
}
AsyncExecuteBlock
(
executor
,
grad_to_prepared_ctx
[
recv_var_name
].
get
(),
v
.
second
->
GetMutableLocalScope
());
}
if
(
exit_flag
)
{
rpc_service_
->
ShutDown
();
break
;
}
}
// while(true)
}
void
ListenAndServOp
::
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
dev_place
)
const
{
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
auto
&
dev_ctx
=
*
pool
.
Get
(
dev_place
);
framework
::
Scope
&
recv_scope
=
scope
.
NewScope
();
bool
sync_mode
=
Attr
<
bool
>
(
"sync_mode"
);
PADDLE_ENFORCE
(
!
rpc_service_
);
std
::
string
endpoint
=
Attr
<
std
::
string
>
(
"endpoint"
);
rpc_service_
.
reset
(
new
detail
::
AsyncGRPCServer
(
endpoint
));
rpc_service_
.
reset
(
new
detail
::
AsyncGRPCServer
(
endpoint
,
sync_mode
));
auto
*
optimize_block
=
Attr
<
framework
::
BlockDesc
*>
(
kOptimizeBlock
);
auto
*
prefetch_block
=
Attr
<
framework
::
BlockDesc
*>
(
kPrefetchBlock
);
...
...
@@ -202,7 +301,11 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
sleep
(
5
);
// Write to a file of server selected port for python use.
SavePort
(
rpc_service_
);
RunSyncLoop
(
&
executor
,
program
,
&
recv_scope
,
prefetch_block
);
if
(
sync_mode
)
{
RunSyncLoop
(
&
executor
,
program
,
&
recv_scope
,
prefetch_block
);
}
else
{
RunAsyncLoop
(
&
executor
,
program
,
&
recv_scope
,
prefetch_block
);
}
}
class
ListenAndServOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
...
...
@@ -221,6 +324,12 @@ from send_op and send back variables to recv_op.
"IP address to listen on."
)
.
SetDefault
(
"127.0.0.1:6164"
)
.
AddCustomChecker
([](
const
std
::
string
&
ip
)
{
return
!
ip
.
empty
();
});
AddAttr
<
std
::
vector
<
std
::
string
>>
(
"grad_to_block_id"
,
"['param1@GRAD.block0:1', 'param2@GRAD.blockn:2'] "
"a map from grad name to it's optimize block id"
)
.
SetDefault
({});
AddAttr
<
bool
>
(
"sync_mode"
,
"if works at sync_mode or not"
).
SetDefault
(
true
);
AddAttr
<
framework
::
BlockDesc
*>
(
kOptimizeBlock
,
"BlockID to run on server side."
);
AddAttr
<
framework
::
BlockDesc
*>
(
kPrefetchBlock
,
...
...
paddle/fluid/operators/listen_and_serv_op.h
浏览文件 @
6d934560
...
...
@@ -46,6 +46,11 @@ class ListenAndServOp : public framework::OperatorBase {
framework
::
Scope
*
recv_scope
,
framework
::
BlockDesc
*
prefetch_block
)
const
;
void
RunAsyncLoop
(
framework
::
Executor
*
executor
,
framework
::
ProgramDesc
*
program
,
framework
::
Scope
*
recv_scope
,
framework
::
BlockDesc
*
prefetch_block
)
const
;
void
Stop
()
override
;
void
RunImpl
(
const
framework
::
Scope
&
scope
,
...
...
paddle/fluid/operators/send_op.cc
浏览文件 @
6d934560
...
...
@@ -41,6 +41,8 @@ class SendOp : public framework::OperatorBase {
std
::
vector
<
std
::
string
>
endpoints
=
Attr
<
std
::
vector
<
std
::
string
>>
(
"endpoints"
);
bool
sync_mode
=
Attr
<
bool
>
(
"sync_mode"
);
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
auto
&
ctx
=
*
pool
.
Get
(
place
);
...
...
@@ -64,11 +66,13 @@ class SendOp : public framework::OperatorBase {
}
PADDLE_ENFORCE
(
rpc_client
->
Wait
());
for
(
auto
&
ep
:
endpoints
)
{
VLOG
(
3
)
<<
"batch barrier, ep: "
<<
ep
;
rpc_client
->
AsyncSendBatchBarrier
(
ep
);
if
(
sync_mode
)
{
for
(
auto
&
ep
:
endpoints
)
{
VLOG
(
3
)
<<
"batch barrier, ep: "
<<
ep
;
rpc_client
->
AsyncSendBatchBarrier
(
ep
);
}
PADDLE_ENFORCE
(
rpc_client
->
Wait
());
}
PADDLE_ENFORCE
(
rpc_client
->
Wait
());
if
(
outs
.
size
()
>
0
)
{
for
(
size_t
i
=
0
;
i
<
outs
.
size
();
i
++
)
{
...
...
@@ -112,6 +116,7 @@ This operator will send tensor to recv_op at the parameter server.
"Server endpoints in the order of input "
"variables for mapping"
)
.
SetDefault
({});
AddAttr
<
bool
>
(
"sync_mode"
,
"work in sync_mode or not"
).
SetDefault
(
true
);
}
};
...
...
paddle/fluid/operators/send_recv_op_test.cc
浏览文件 @
6d934560
...
...
@@ -137,6 +137,8 @@ void StartServerNet(bool is_sparse) {
attrs
.
insert
({
"GradList"
,
std
::
vector
<
std
::
string
>
({
"x1"
})});
attrs
.
insert
({
"OptimizeBlock"
,
optimize_block
});
attrs
.
insert
({
"PrefetchBlock"
,
prefetch_block
});
attrs
.
insert
({
"grad_to_block_id"
,
std
::
vector
<
std
::
string
>
({
""
})});
attrs
.
insert
({
"sync_mode"
,
true
});
listen_and_serv_op
=
f
::
OpRegistry
::
CreateOp
(
"listen_and_serv"
,
{{
"X"
,
{
"x1"
}}},
{},
attrs
);
listen_and_serv_op
->
Run
(
scope
,
place
);
...
...
python/paddle/fluid/distribute_transpiler.py
浏览文件 @
6d934560
...
...
@@ -143,7 +143,8 @@ class DistributeTranspiler:
program
=
None
,
pservers
=
"127.0.0.1:6174"
,
trainers
=
1
,
split_method
=
splitter
.
round_robin
):
split_method
=
splitter
.
round_robin
,
sync_mode
=
True
):
"""
Transpile the program to distributed data-parallelism programs.
The main_program will be transformed to use a remote parameter server
...
...
@@ -184,6 +185,9 @@ class DistributeTranspiler:
:param split_method: A function to determin how to split variables
to different servers equally.
:type split_method: function
:type sync_mode: boolean default True
:param sync_mode: if sync_mode is set True, it means that dist transpiler
will transpile the program into sync_mode pserver and trainer program.
"""
assert
(
callable
(
split_method
))
if
program
is
None
:
...
...
@@ -191,6 +195,7 @@ class DistributeTranspiler:
self
.
origin_program
=
program
self
.
trainer_num
=
trainers
self
.
optimize_ops
=
optimize_ops
self
.
sync_mode
=
sync_mode
# TODO(typhoonzero): currently trainer_id is fetched from cluster system
# like Kubernetes, we should port this to use etcd later when developing
# fluid distributed training with fault-tolerance.
...
...
@@ -295,8 +300,11 @@ class DistributeTranspiler:
inputs
=
{
"X"
:
send_inputs
},
outputs
=
{
"Out"
:
send_outputs
,
"RPCClient"
:
rpc_client_var
},
attrs
=
{
"endpoints"
:
pserver_endpoints
,
"epmap"
:
eplist
})
attrs
=
{
"endpoints"
:
pserver_endpoints
,
"epmap"
:
eplist
,
"sync_mode"
:
self
.
sync_mode
})
# step4: Concat the parameters splits together after recv.
for
varname
,
splited_var
in
param_var_mapping
.
iteritems
():
if
len
(
splited_var
)
<=
1
:
...
...
@@ -356,7 +364,7 @@ class DistributeTranspiler:
type
=
v
.
type
,
dtype
=
v
.
dtype
,
shape
=
v
.
shape
)
if
self
.
trainer_num
>
1
:
if
self
.
sync_mode
and
self
.
trainer_num
>
1
:
for
trainer_id
in
xrange
(
self
.
trainer_num
):
var
=
pserver_program
.
global_block
().
create_var
(
name
=
"%s.trainer_%d"
%
(
orig_var_name
,
trainer_id
),
...
...
@@ -402,13 +410,13 @@ class DistributeTranspiler:
for
op
in
self
.
optimize_ops
:
if
op
.
type
==
"scale"
:
for
in_name
in
op
.
input_arg_names
:
if
in_name
.
startswith
(
"beta1_pow_acc"
)
or
\
in_name
.
startswith
(
"beta2_pow_acc"
):
if
in_name
.
startswith
(
"beta1_pow_acc"
)
or
\
in_name
.
startswith
(
"beta2_pow_acc"
):
global_ops
.
append
(
op
)
def
__append_optimize_op__
(
op
,
block
):
def
__append_optimize_op__
(
op
,
block
,
grad_to_block_id
):
if
self
.
_is_opt_op
(
op
):
self
.
_append_pserver_ops
(
block
,
op
,
endpoint
,
self
.
_append_pserver_ops
(
block
,
op
,
endpoint
,
grad_to_block_id
,
default_main_program
())
else
:
self
.
_append_pserver_non_opt_ops
(
block
,
op
)
...
...
@@ -422,16 +430,16 @@ class DistributeTranspiler:
self
.
_append_pserver_non_opt_ops
(
lr_decay_block
,
op
)
# append op to the current block
grad_to_block_id
=
[]
pre_block_idx
=
pserver_program
.
num_blocks
-
1
for
idx
,
opt_op
in
enumerate
(
opt_op_on_pserver
):
per_opt_block
=
pserver_program
.
create_block
(
pre_block_idx
)
for
_
,
op
in
enumerate
(
self
.
optimize_ops
):
# optimizer is connected to itself
if
ufind
.
is_connected
(
op
,
opt_op
)
and
op
not
in
global_ops
:
__append_optimize_op__
(
op
,
per_opt_block
)
__append_optimize_op__
(
op
,
per_opt_block
,
grad_to_block_id
)
# append global ops
opt_state_block
=
None
if
global_ops
:
opt_state_block
=
pserver_program
.
create_block
(
pserver_program
.
num_blocks
-
1
)
...
...
@@ -472,7 +480,9 @@ class DistributeTranspiler:
"OptimizeBlock"
:
pserver_program
.
block
(
1
),
"endpoint"
:
endpoint
,
"Fanin"
:
self
.
trainer_num
,
"PrefetchBlock"
:
prefetch_block
"PrefetchBlock"
:
prefetch_block
,
"sync_mode"
:
self
.
sync_mode
,
"grad_to_block_id"
:
grad_to_block_id
})
pserver_program
.
sync_with_cpp
()
...
...
@@ -683,17 +693,6 @@ class DistributeTranspiler:
self
.
table_name
)],
persistable
=
False
)
# create grad vars in pserver program
table_grad_var
=
self
.
table_param_grad
[
1
]
table_grad_list
=
[
pserver_program
.
global_block
().
create_var
(
name
=
"%s.trainer_%d.pserver_%d"
%
(
table_grad_var
.
name
,
index
,
pserver_index
),
type
=
table_grad_var
.
type
,
shape
=
table_grad_var
.
shape
,
dtype
=
table_grad_var
.
dtype
)
for
index
in
range
(
self
.
trainer_num
)
]
# create table optimize block in pserver program
table_opt_op
=
[
op
for
op
in
self
.
optimize_ops
...
...
@@ -703,11 +702,24 @@ class DistributeTranspiler:
# only support sgd now
assert
table_opt_op
.
type
==
"sgd"
# append sum op for table_grad_list
table_opt_block
.
append_op
(
type
=
"sum"
,
inputs
=
{
"X"
:
table_grad_list
},
outputs
=
{
"Out"
:
[
grad_var
]})
if
self
.
sync_mode
:
# create grad vars in pserver program
table_grad_var
=
self
.
table_param_grad
[
1
]
table_grad_list
=
[
pserver_program
.
global_block
().
create_var
(
name
=
"%s.trainer_%d.pserver_%d"
%
(
table_grad_var
.
name
,
index
,
pserver_index
),
type
=
table_grad_var
.
type
,
shape
=
table_grad_var
.
shape
,
dtype
=
table_grad_var
.
dtype
)
for
index
in
range
(
self
.
trainer_num
)
]
# append sum op for table_grad_list
table_opt_block
.
append_op
(
type
=
"sum"
,
inputs
=
{
"X"
:
table_grad_list
},
outputs
=
{
"Out"
:
[
grad_var
]})
lr_var
=
pserver_program
.
global_block
().
vars
[
table_opt_op
.
input
(
"LearningRate"
)[
0
]]
...
...
@@ -746,7 +758,7 @@ class DistributeTranspiler:
for
varname
,
splited
in
block_map
.
iteritems
():
orig_var
=
program
.
global_block
().
var
(
varname
)
if
len
(
splited
)
==
1
:
if
add_trainer_suffix
:
if
self
.
sync_mode
and
add_trainer_suffix
:
new_var_name
=
"%s.trainer_%d"
%
\
(
orig_var
.
name
,
self
.
trainer_id
)
program
.
global_block
().
rename_var
(
varname
,
new_var_name
)
...
...
@@ -770,7 +782,7 @@ class DistributeTranspiler:
if
len
(
orig_shape
)
>=
2
:
splited_shape
.
extend
(
orig_shape
[
1
:])
new_var_name
=
""
if
add_trainer_suffix
:
if
self
.
sync_mode
and
add_trainer_suffix
:
new_var_name
=
"%s.block%d.trainer_%d"
%
\
(
varname
,
i
,
self
.
trainer_id
)
else
:
...
...
@@ -879,7 +891,7 @@ class DistributeTranspiler:
return
orig_var_name
def
_append_pserver_ops
(
self
,
optimize_block
,
opt_op
,
endpoint
,
origin_program
):
grad_to_block_id
,
origin_program
):
program
=
optimize_block
.
program
pserver_block
=
program
.
global_block
()
new_inputs
=
dict
()
...
...
@@ -900,7 +912,9 @@ class DistributeTranspiler:
return
merged_var
=
\
pserver_block
.
vars
[
self
.
_orig_varname
(
grad_block
.
name
)]
if
self
.
trainer_num
>
1
:
grad_to_block_id
.
append
(
merged_var
.
name
+
":"
+
str
(
optimize_block
.
idx
))
if
self
.
sync_mode
and
self
.
trainer_num
>
1
:
vars2merge
=
[]
for
i
in
xrange
(
self
.
trainer_num
):
per_trainer_name
=
"%s.trainer_%d"
%
\
...
...
@@ -918,6 +932,7 @@ class DistributeTranspiler:
inputs
=
{
"X"
:
merged_var
},
outputs
=
{
"Out"
:
merged_var
},
attrs
=
{
"scale"
:
1.0
/
float
(
self
.
trainer_num
)})
new_inputs
[
key
]
=
merged_var
elif
key
==
"Param"
:
# param is already created on global program
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录