Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
8b50ad80
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2299
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
8b50ad80
编写于
1月 23, 2019
作者:
T
tangwei12
提交者:
GitHub
1月 23, 2019
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
checkpoint at distributed training (#14854)
checkpoint for distributed training.
上级
07dc5a15
变更
21
显示空白变更内容
内联
并排
Showing
21 changed file
with
1122 addition
and
280 deletion
+1122
-280
paddle/fluid/operators/distributed/grpc/grpc_client.cc
paddle/fluid/operators/distributed/grpc/grpc_client.cc
+52
-37
paddle/fluid/operators/distributed/grpc/grpc_client.h
paddle/fluid/operators/distributed/grpc/grpc_client.h
+12
-5
paddle/fluid/operators/distributed/grpc/grpc_server.cc
paddle/fluid/operators/distributed/grpc/grpc_server.cc
+55
-4
paddle/fluid/operators/distributed/grpc/grpc_service.h
paddle/fluid/operators/distributed/grpc/grpc_service.h
+3
-0
paddle/fluid/operators/distributed/request_handler.h
paddle/fluid/operators/distributed/request_handler.h
+13
-0
paddle/fluid/operators/distributed/request_handler_impl.cc
paddle/fluid/operators/distributed/request_handler_impl.cc
+29
-1
paddle/fluid/operators/distributed/request_handler_impl.h
paddle/fluid/operators/distributed/request_handler_impl.h
+10
-0
paddle/fluid/operators/distributed/rpc_client.h
paddle/fluid/operators/distributed/rpc_client.h
+7
-0
paddle/fluid/operators/distributed/send_recv.proto.in
paddle/fluid/operators/distributed/send_recv.proto.in
+18
-0
paddle/fluid/operators/distributed_ops/listen_and_serv_op.cc
paddle/fluid/operators/distributed_ops/listen_and_serv_op.cc
+5
-0
paddle/fluid/operators/distributed_ops/listen_and_serv_op.h
paddle/fluid/operators/distributed_ops/listen_and_serv_op.h
+2
-1
paddle/fluid/operators/distributed_ops/recv_op.cc
paddle/fluid/operators/distributed_ops/recv_op.cc
+47
-16
paddle/fluid/platform/mkldnn_reuse.h
paddle/fluid/platform/mkldnn_reuse.h
+2
-2
python/paddle/fluid/framework.py
python/paddle/fluid/framework.py
+12
-3
python/paddle/fluid/io.py
python/paddle/fluid/io.py
+322
-132
python/paddle/fluid/tests/unittests/dist_save_load.py
python/paddle/fluid/tests/unittests/dist_save_load.py
+43
-14
python/paddle/fluid/tests/unittests/dist_simnet_bow.py
python/paddle/fluid/tests/unittests/dist_simnet_bow.py
+9
-4
python/paddle/fluid/tests/unittests/test_dist_base.py
python/paddle/fluid/tests/unittests/test_dist_base.py
+4
-2
python/paddle/fluid/tests/unittests/test_dist_save_load.py
python/paddle/fluid/tests/unittests/test_dist_save_load.py
+71
-2
python/paddle/fluid/tests/unittests/test_dist_transpiler.py
python/paddle/fluid/tests/unittests/test_dist_transpiler.py
+34
-15
python/paddle/fluid/transpiler/distribute_transpiler.py
python/paddle/fluid/transpiler/distribute_transpiler.py
+372
-42
未找到文件。
paddle/fluid/operators/distributed/grpc/grpc_client.cc
浏览文件 @
8b50ad80
...
...
@@ -74,7 +74,7 @@ VarHandlePtr GRPCClient::AsyncSendVar(const std::string& ep,
const
framework
::
Scope
*
p_scope
=
&
scope
;
const
auto
ch
=
GetChannel
(
ep_val
);
SendProcessor
*
s
=
new
SendProcessor
(
ch
);
const
std
::
string
method
=
"SendRPC"
;
const
std
::
string
method
=
kSendRPC
;
VarHandlePtr
h
(
new
VarHandle
(
ep
,
method
,
var_name_val
,
p_ctx
,
p_scope
));
s
->
Prepare
(
h
,
time_out
);
...
...
@@ -107,7 +107,7 @@ VarHandlePtr GRPCClient::AsyncSendVar(const std::string& ep,
void
ProcGetResponse
(
const
VarHandle
&
var_h
,
const
::
grpc
::
ByteBuffer
&
ret_msg
)
{
VLOG
(
100
)
<<
"ProcGetResponse"
;
VLOG
(
4
)
<<
"ProcGetResponse"
;
framework
::
Variable
*
outvar
=
nullptr
;
// get response's trainer_id is not used
int
trainer_id
;
...
...
@@ -127,39 +127,54 @@ VarHandlePtr GRPCClient::AsyncGetVar(const std::string& ep,
const
platform
::
DeviceContext
&
ctx
,
const
framework
::
Scope
&
scope
,
const
std
::
string
&
var_name
,
const
std
::
string
&
out_varname
,
int64_t
time_out
)
{
return
_AsyncGetVar
(
ep
,
ctx
,
scope
,
var_
name
,
return
_AsyncGetVar
(
ep
,
ctx
,
scope
,
kGetRPC
,
var_name
,
out_var
name
,
"/sendrecv.SendRecvService/GetVariable"
,
time_out
);
}
VarHandlePtr
GRPCClient
::
AsyncGetVarNoBarrier
(
const
std
::
string
&
ep
,
const
platform
::
DeviceContext
&
ctx
,
const
framework
::
Scope
&
scope
,
const
std
::
string
&
var_name
,
const
std
::
string
&
out_varname
,
int64_t
time_out
)
{
std
::
string
var_name_no_barrier
=
string
::
Sprintf
(
"%s%s"
,
var_name
,
WITHOUT_BARRIER_MESSAGE
);
return
_AsyncGetVar
(
ep
,
ctx
,
scope
,
kGetNoBarrierRPC
,
var_name_no_barrier
,
out_varname
,
"/sendrecv.SendRecvService/GetVariableNoBarrier"
,
time_out
);
}
VarHandlePtr
GRPCClient
::
AsyncGetMonomerVariable
(
const
std
::
string
&
ep
,
const
platform
::
DeviceContext
&
ctx
,
const
framework
::
Scope
&
scope
,
const
std
::
string
&
var_name
,
int64_t
time_out
)
{
return
_AsyncGetVar
(
ep
,
ctx
,
scope
,
var_name
,
return
_AsyncGetVar
(
ep
,
ctx
,
scope
,
kGetMonomerRPC
,
var_name
,
var_name
,
"/sendrecv.SendRecvService/GetMonomerVariable"
,
time_out
);
}
VarHandlePtr
GRPCClient
::
_AsyncGetVar
(
const
std
::
string
&
ep
,
const
platform
::
DeviceContext
&
ctx
,
const
framework
::
Scope
&
scope
,
const
std
::
string
&
var_name
,
const
std
::
string
&
rpc_path
,
int64_t
time_out
)
{
VarHandlePtr
GRPCClient
::
_AsyncGetVar
(
const
std
::
string
&
ep
,
const
platform
::
DeviceContext
&
ctx
,
const
framework
::
Scope
&
scope
,
const
std
::
string
&
method
,
const
std
::
string
&
var_name
,
const
std
::
string
&
out_varname
,
const
std
::
string
&
rpc_path
,
int64_t
time_out
)
{
const
platform
::
DeviceContext
*
p_ctx
=
&
ctx
;
const
std
::
string
ep_val
=
ep
;
const
std
::
string
var_name_val
=
var_name
;
const
std
::
string
out_varname_val
=
out_varname
;
const
framework
::
Scope
*
p_scope
=
&
scope
;
const
auto
ch
=
GetChannel
(
ep_val
);
GetProcessor
*
s
=
new
GetProcessor
(
ch
);
const
std
::
string
method
=
"GetRPC"
;
VarHandlePtr
h
(
new
VarHandle
(
ep
,
method
,
var_
name_val
,
p_ctx
,
p_scope
));
VarHandlePtr
h
(
new
VarHandle
(
ep
,
method
,
out_var
name_val
,
p_ctx
,
p_scope
));
s
->
Prepare
(
h
,
time_out
);
framework
::
AsyncIO
([
var_name_val
,
s
,
method
,
p_ctx
,
h
,
rpc_path
,
this
]
{
framework
::
AsyncIO
(
[
var_name_val
,
out_varname_val
,
s
,
method
,
p_ctx
,
h
,
rpc_path
,
this
]
{
// prepare input
sendrecv
::
VariableMessage
req
;
req
.
set_varname
(
var_name_val
);
req
.
set_out_varname
(
out_varname_val
);
req
.
set_trainer_id
(
trainer_id_
);
::
grpc
::
ByteBuffer
buf
;
RequestToByteBuffer
<
sendrecv
::
VariableMessage
>
(
req
,
&
buf
);
...
...
@@ -202,7 +217,7 @@ VarHandlePtr GRPCClient::AsyncPrefetchVar(const std::string& ep,
const
auto
ch
=
GetChannel
(
ep_val
);
GetProcessor
*
s
=
new
GetProcessor
(
ch
);
const
std
::
string
method
=
"PrefetchRPC"
;
const
std
::
string
method
=
kPrefetchRPC
;
VarHandlePtr
h
(
new
VarHandle
(
ep
,
method
,
out_var_name_val
,
p_ctx
,
p_scope
));
s
->
Prepare
(
h
,
time_out
);
...
...
@@ -242,7 +257,7 @@ VarHandlePtr GRPCClient::AsyncSendBatchBarrier(const std::string& ep,
const
auto
ch
=
GetChannel
(
ep
);
BatchBarrierProcessor
*
s
=
new
BatchBarrierProcessor
(
ch
);
const
std
::
string
method
=
"BatchBarrierRPC"
;
const
std
::
string
method
=
kBatchBarrierRPC
;
VarHandlePtr
h
(
new
VarHandle
(
ep
,
method
,
BATCH_BARRIER_MESSAGE
,
nullptr
,
nullptr
));
s
->
Prepare
(
h
,
time_out
);
...
...
@@ -267,7 +282,7 @@ VarHandlePtr GRPCClient::AsyncSendFetchBarrier(const std::string& ep,
int64_t
time_out
)
{
const
auto
ch
=
GetChannel
(
ep
);
FetchBarrierProcessor
*
s
=
new
FetchBarrierProcessor
(
ch
);
const
std
::
string
method
=
"FetchBarrierRPC"
;
const
std
::
string
method
=
kFetchBarrierRPC
;
VarHandlePtr
h
(
new
VarHandle
(
ep
,
method
,
FETCH_BARRIER_MESSAGE
,
nullptr
,
nullptr
));
s
->
Prepare
(
h
,
time_out
);
...
...
@@ -293,7 +308,7 @@ VarHandlePtr GRPCClient::AsyncGetMonomerBarrier(const std::string& ep,
int64_t
time_out
)
{
const
auto
ch
=
GetChannel
(
ep
);
BatchBarrierProcessor
*
s
=
new
BatchBarrierProcessor
(
ch
);
const
std
::
string
method
=
"SendMonomerFetchBarrierRPC"
;
const
std
::
string
method
=
kSendMonomerFetchBarrierRPC
;
VarHandlePtr
h
(
new
VarHandle
(
ep
,
method
,
var_name
,
nullptr
,
nullptr
));
s
->
Prepare
(
h
,
time_out
);
...
...
@@ -320,7 +335,7 @@ VarHandlePtr GRPCClient::AsyncSendComplete(const std::string& ep,
const
auto
ch
=
GetChannel
(
ep
);
BatchBarrierProcessor
*
s
=
new
BatchBarrierProcessor
(
ch
);
const
std
::
string
method
=
"SendCompleteRPC"
;
const
std
::
string
method
=
kSendCompleteRPC
;
VarHandlePtr
h
(
new
VarHandle
(
ep
,
method
,
COMPLETE_MESSAGE
,
nullptr
,
nullptr
));
s
->
Prepare
(
h
,
time_out
);
...
...
@@ -347,7 +362,7 @@ VarHandlePtr GRPCClient::AsyncCheckpointNotify(const std::string& ep,
CheckpointNotifyProcessor
*
s
=
new
CheckpointNotifyProcessor
(
ch
);
const
std
::
string
method
=
"CheckPointNotifyRPC"
;
const
std
::
string
method
=
kCheckPointNotifyRPC
;
VarHandlePtr
h
(
new
VarHandle
(
ep
,
method
,
CHECKPOINT_SAVE_MESSAGE
,
nullptr
,
nullptr
));
...
...
paddle/fluid/operators/distributed/grpc/grpc_client.h
浏览文件 @
8b50ad80
...
...
@@ -186,6 +186,13 @@ class GRPCClient : public RPCClient {
const
platform
::
DeviceContext
&
ctx
,
const
framework
::
Scope
&
scope
,
const
std
::
string
&
var_name
,
const
std
::
string
&
out_varname
,
int64_t
time_out
=
FLAGS_rpc_deadline
)
override
;
VarHandlePtr
AsyncGetVarNoBarrier
(
const
std
::
string
&
ep
,
const
platform
::
DeviceContext
&
ctx
,
const
framework
::
Scope
&
scope
,
const
std
::
string
&
var_name
,
const
std
::
string
&
out_varname
,
int64_t
time_out
=
FLAGS_rpc_deadline
)
override
;
VarHandlePtr
AsyncGetMonomerVariable
(
...
...
@@ -228,11 +235,11 @@ class GRPCClient : public RPCClient {
void
Proceed
();
std
::
shared_ptr
<
grpc
::
Channel
>
GetChannel
(
const
std
::
string
&
ep
);
VarHandlePtr
_AsyncGetVar
(
const
std
::
string
&
ep
,
const
platform
::
DeviceContext
&
ctx
,
const
framework
::
Scope
&
scope
,
const
std
::
string
&
var_name
,
const
std
::
string
&
rpc
,
int64_t
time_out
);
VarHandlePtr
_AsyncGetVar
(
const
std
::
string
&
ep
,
const
platform
::
DeviceContext
&
ctx
,
const
framework
::
Scope
&
scope
,
const
std
::
string
&
method
,
const
std
::
string
&
var_name
,
const
std
::
string
&
out_varname
,
const
std
::
string
&
rpc_path
,
int64_t
time_out
=
FLAGS_rpc_deadline
);
private:
grpc
::
CompletionQueue
cq_
;
...
...
paddle/fluid/operators/distributed/grpc/grpc_server.cc
浏览文件 @
8b50ad80
...
...
@@ -136,17 +136,65 @@ class RequestGet final : public RequestBase {
void
Process
()
override
{
// proc request.
std
::
string
varname
=
request_
.
varname
();
std
::
string
out_varname
=
request_
.
out_varname
();
int
trainer_id
=
request_
.
trainer_id
();
VLOG
(
4
)
<<
"RequestGet "
<<
varname
;
VLOG
(
4
)
<<
"RequestGet "
<<
out_varname
<<
" from "
<<
varname
;
auto
scope
=
request_handler_
->
scope
();
auto
invar
=
scope
->
FindVar
(
varname
)
;
framework
::
Variable
*
invar
=
nullptr
;
framework
::
Variable
*
outvar
=
nullptr
;
request_handler_
->
Handle
(
varname
,
scope
,
invar
,
&
outvar
,
trainer_id
);
request_handler_
->
Handle
(
varname
,
scope
,
invar
,
&
outvar
,
trainer_id
,
out_varname
);
if
(
outvar
)
{
SerializeToByteBuffer
(
varname
,
outvar
,
*
request_handler_
->
dev_ctx
(),
SerializeToByteBuffer
(
out_varname
,
outvar
,
*
request_handler_
->
dev_ctx
(),
&
reply_
);
}
Finish
(
reply_
,
&
responder_
);
}
protected:
sendrecv
::
VariableMessage
request_
;
::
grpc
::
ByteBuffer
reply_
;
ServerAsyncResponseWriter
<::
grpc
::
ByteBuffer
>
responder_
;
};
class
RequestGetNoBarrier
final
:
public
RequestBase
{
public:
explicit
RequestGetNoBarrier
(
GrpcService
::
AsyncService
*
service
,
::
grpc
::
ServerCompletionQueue
*
cq
,
RequestHandler
*
request_handler
,
int
req_id
)
:
RequestBase
(
service
,
cq
,
request_handler
,
req_id
),
responder_
(
&
ctx_
)
{
auto
method_id
=
static_cast
<
int
>
(
distributed
::
GrpcMethod
::
kGetVariableNoBarrier
);
service_
->
RequestAsyncUnary
(
method_id
,
&
ctx_
,
&
request_
,
&
responder_
,
cq_
,
cq_
,
reinterpret_cast
<
void
*>
(
static_cast
<
intptr_t
>
(
req_id
)));
}
virtual
~
RequestGetNoBarrier
()
{}
std
::
string
GetReqName
()
override
{
return
request_
.
varname
();
}
void
Process
()
override
{
// proc request.
std
::
string
varname
=
request_
.
varname
();
std
::
string
out_varname
=
request_
.
out_varname
();
int
trainer_id
=
request_
.
trainer_id
();
VLOG
(
4
)
<<
"RequestGetNoBarrier "
<<
out_varname
<<
" from "
<<
varname
;
auto
scope
=
request_handler_
->
scope
();
framework
::
Variable
*
invar
=
nullptr
;
framework
::
Variable
*
outvar
=
nullptr
;
request_handler_
->
Handle
(
varname
,
scope
,
invar
,
&
outvar
,
trainer_id
,
out_varname
);
if
(
outvar
)
{
SerializeToByteBuffer
(
out_varname
,
outvar
,
*
request_handler_
->
dev_ctx
(),
&
reply_
);
}
Finish
(
reply_
,
&
responder_
);
...
...
@@ -460,6 +508,9 @@ void AsyncGRPCServer::TryToRegisterNewOne(const std::string& rpc_name,
b
=
new
RequestSend
(
&
service_
,
cq
.
get
(),
handler
,
req_id
);
}
else
if
(
rpc_name
==
kRequestGet
)
{
b
=
new
RequestGet
(
&
service_
,
cq
.
get
(),
handler
,
req_id
);
}
else
if
(
rpc_name
==
kRequestGetNoBarrier
)
{
b
=
new
RequestGetNoBarrier
(
&
service_
,
cq
.
get
(),
handler
,
req_id
);
}
else
if
(
rpc_name
==
kRequestGetMonomerVariable
)
{
b
=
new
RequestGetMonomerVariable
(
&
service_
,
cq
.
get
(),
handler
,
req_id
,
this
);
...
...
paddle/fluid/operators/distributed/grpc/grpc_service.h
浏览文件 @
8b50ad80
...
...
@@ -81,6 +81,7 @@ enum class GrpcMethod {
kGetVariable
,
kPrefetchVariable
,
kCheckpointNotify
,
kGetVariableNoBarrier
,
kGetMonomerVariable
,
kGetMonomerBarrier
,
};
...
...
@@ -94,6 +95,8 @@ inline const char* GrpcMethodName(GrpcMethod id) {
return
"/sendrecv.SendRecvService/SendVariable"
;
case
GrpcMethod
::
kGetVariable
:
return
"/sendrecv.SendRecvService/GetVariable"
;
case
GrpcMethod
::
kGetVariableNoBarrier
:
return
"/sendrecv.SendRecvService/GetVariableNoBarrier"
;
case
GrpcMethod
::
kGetMonomerVariable
:
return
"/sendrecv.SendRecvService/GetMonomerVariable"
;
case
GrpcMethod
::
kGetMonomerBarrier
:
...
...
paddle/fluid/operators/distributed/request_handler.h
浏览文件 @
8b50ad80
...
...
@@ -42,11 +42,24 @@ constexpr char kRequestGetMonomerBarrier[] = "RequestGetMonomerBarrier";
constexpr
char
kRequestPrefetch
[]
=
"RequestPrefetch"
;
constexpr
char
kRequestCheckpoint
[]
=
"RequestCheckpoint"
;
constexpr
char
kRequestPassBarrier
[]
=
"RequestPassBarrier"
;
constexpr
char
kRequestGetNoBarrier
[]
=
"GetVariableNoBarrier"
;
constexpr
char
kSendRPC
[]
=
"SendRPC"
;
constexpr
char
kGetRPC
[]
=
"GetRPC"
;
constexpr
char
kGetNoBarrierRPC
[]
=
"GetNoBarrierRPC"
;
constexpr
char
kGetMonomerRPC
[]
=
"GetMonomerRPC"
;
constexpr
char
kPrefetchRPC
[]
=
"PrefetchRPC"
;
constexpr
char
kBatchBarrierRPC
[]
=
"BatchBarrierRPC"
;
constexpr
char
kFetchBarrierRPC
[]
=
"FetchBarrierRPC"
;
constexpr
char
kSendMonomerFetchBarrierRPC
[]
=
"SendMonomerFetchBarrierRPC"
;
constexpr
char
kSendCompleteRPC
[]
=
"SendCompleteRPC"
;
constexpr
char
kCheckPointNotifyRPC
[]
=
"CheckPointNotifyRPC"
;
#define LISTEN_TERMINATE_MESSAGE "TERMINATE@RECV"
#define BATCH_BARRIER_MESSAGE "BATCH_BARRIER@RECV"
#define FETCH_BARRIER_MESSAGE "FETCH_BARRIER@RECV"
#define COMPLETE_MESSAGE "COMPLETE@RECV"
#define WITHOUT_BARRIER_MESSAGE "@WITHOUT_BARRIER@RECV"
#define CHECKPOINT_SAVE_MESSAGE "SAVE@CHECKPOINTNOTIFY"
#define CHECKPOINT_LOAD_MESSAGE "LOAD@CHECKPOINTNOTIFY"
...
...
paddle/fluid/operators/distributed/request_handler_impl.cc
浏览文件 @
8b50ad80
...
...
@@ -23,6 +23,7 @@
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/framework/variable_helper.h"
#include "paddle/fluid/operators/distributed/rpc_server.h"
#include "paddle/fluid/string/piece.h"
#include "paddle/fluid/string/printf.h"
namespace
paddle
{
...
...
@@ -81,7 +82,8 @@ bool RequestGetHandler::Handle(const std::string& varname,
const
int
trainer_id
,
const
std
::
string
&
out_var_name
,
const
std
::
string
&
table_name
)
{
VLOG
(
4
)
<<
"RequestGetHandler:"
<<
varname
;
VLOG
(
4
)
<<
"RequestGetHandler:"
<<
varname
<<
" out_var_name: "
<<
out_var_name
;
if
(
sync_mode_
)
{
if
(
varname
==
FETCH_BARRIER_MESSAGE
)
{
...
...
@@ -112,6 +114,32 @@ bool RequestGetHandler::Handle(const std::string& varname,
return
true
;
}
bool
RequestGetNoBarrierHandler
::
Handle
(
const
std
::
string
&
varname
,
framework
::
Scope
*
scope
,
framework
::
Variable
*
invar
,
framework
::
Variable
**
outvar
,
const
int
trainer_id
,
const
std
::
string
&
out_var_name
,
const
std
::
string
&
table_name
)
{
VLOG
(
4
)
<<
"RequestGetNoBarrierHandler:"
<<
varname
<<
" out_var_name: "
<<
out_var_name
;
// get var from pserver immediately without barriers
string
::
Piece
without_barrier_piece
(
WITHOUT_BARRIER_MESSAGE
);
string
::
Piece
var_name_piece
=
string
::
Piece
(
varname
);
if
(
string
::
Contains
(
var_name_piece
,
without_barrier_piece
))
{
var_name_piece
=
string
::
TrimSuffix
(
var_name_piece
,
without_barrier_piece
);
VLOG
(
4
)
<<
"Get var "
<<
var_name_piece
<<
" with "
<<
WITHOUT_BARRIER_MESSAGE
;
*
outvar
=
scope_
->
FindVar
(
var_name_piece
.
ToString
());
return
true
;
}
else
{
PADDLE_THROW
(
"GetNoBarrier must contain %s"
,
WITHOUT_BARRIER_MESSAGE
);
}
return
true
;
}
bool
RequestPrefetchHandler
::
Handle
(
const
std
::
string
&
varname
,
framework
::
Scope
*
scope
,
framework
::
Variable
*
invar
,
...
...
paddle/fluid/operators/distributed/request_handler_impl.h
浏览文件 @
8b50ad80
...
...
@@ -67,6 +67,16 @@ class RequestGetHandler final : public RequestHandler {
bool
enable_dc_asgd_
;
};
class
RequestGetNoBarrierHandler
final
:
public
RequestHandler
{
public:
RequestGetNoBarrierHandler
()
:
RequestHandler
(
false
)
{}
virtual
~
RequestGetNoBarrierHandler
()
{}
bool
Handle
(
const
std
::
string
&
varname
,
framework
::
Scope
*
scope
,
framework
::
Variable
*
var
,
framework
::
Variable
**
outvar
,
const
int
trainer_id
,
const
std
::
string
&
out_var_name
=
""
,
const
std
::
string
&
table_name
=
""
)
override
;
};
static
inline
void
BuildVar
(
const
std
::
string
&
param_name
,
std
::
initializer_list
<
const
char
*>
arguments
,
paddle
::
framework
::
proto
::
OpDesc
::
Var
*
var
)
{
...
...
paddle/fluid/operators/distributed/rpc_client.h
浏览文件 @
8b50ad80
...
...
@@ -43,6 +43,13 @@ class RPCClient {
const
platform
::
DeviceContext
&
ctx
,
const
framework
::
Scope
&
scope
,
const
std
::
string
&
var_name
,
const
std
::
string
&
out_varname
,
int64_t
time_out
=
FLAGS_rpc_deadline
)
=
0
;
virtual
VarHandlePtr
AsyncGetVarNoBarrier
(
const
std
::
string
&
ep
,
const
platform
::
DeviceContext
&
ctx
,
const
framework
::
Scope
&
scope
,
const
std
::
string
&
var_name
,
const
std
::
string
&
out_varname
,
int64_t
time_out
=
FLAGS_rpc_deadline
)
=
0
;
virtual
VarHandlePtr
AsyncGetMonomerVariable
(
...
...
paddle/fluid/operators/distributed/send_recv.proto.in
浏览文件 @
8b50ad80
...
...
@@ -17,8 +17,14 @@ package sendrecv;
option
cc_generic_services
=
@
cc_generic_services
@;
service
SendRecvService
{
//
For
parameter
server
round
-
robin
like
hashing
,
do
not
split
tensors
.
//
Send
and
recv
only
one
tensor
//
TODO
(
typhoonzero
):
add
streaming
API
rpc
SendVariable
(
VariableMessage
)
returns
(
VoidMessage
)
{}
//
Argument
VariableMessage
for
GetVariable
should
only
contain
varname
.
rpc
GetVariable
(
VariableMessage
)
returns
(
VariableMessage
)
{}
rpc
GetVariableNoBarrier
(
VariableMessage
)
returns
(
VariableMessage
)
{}
//
pre
-
fetch
variable
by
given
variable
name
and
Ids
rpc
PrefetchVariable
(
VariableMessage
)
returns
(
VariableMessage
)
{}
rpc
CheckpointNotify
(
VariableMessage
)
returns
(
VoidMessage
)
{}
...
...
@@ -27,12 +33,17 @@ service SendRecvService {
rpc
GetMonomerBarrier
(
VariableMessage
)
returns
(
VoidMessage
)
{}
}
//
It
can
be
:
LoDTensor
、
SelectedRows
or
NCCL_ID
enum
VarType
{
LOD_TENSOR
=
0
;
SELECTED_ROWS
=
1
;
NCCL_ID
=
2
;
}
//
VariableMessage
is
serialized
paddle
variable
message
.
//
NOTICE
(
gongwb
):
don
't modify this proto if you are not
// not familar with how we serialize in sendrecvop_utils.h
// and deserilize it in variable_response.h.
message VariableMessage {
enum Type {
// Pod Types
...
...
@@ -49,14 +60,21 @@ message VariableMessage {
string varname = 1;
// TODO(Yancey1989): reference framework::proto::VarDesc::VarType
VarType type = 2;
// bool persistable is not needed for sending.
// tensor info:
Type data_type = 3;
repeated int64 dims = 4;
// lod details:
int64 lod_level = 5;
repeated LodData lod = 6;
// selected_rows height, aka. original dim0
int64 slr_height = 7;
// tensor data
bytes serialized = 8;
// selected_rows data
bytes rows = 9;
// Look up table block execution output variable name.
string out_varname = 10;
// If 1, the ps server will start profiling, the ps
// server stops profiling and generates a profile to /tmp/profile_ps_*
...
...
paddle/fluid/operators/distributed_ops/listen_and_serv_op.cc
浏览文件 @
8b50ad80
...
...
@@ -347,6 +347,8 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
new
distributed
::
RequestPrefetchHandler
(
sync_mode
));
request_checkpoint_handler_
.
reset
(
new
distributed
::
RequestCheckpointHandler
(
sync_mode
,
checkpoint_block_id
));
request_get_no_barrier_handler_
.
reset
(
new
distributed
::
RequestGetNoBarrierHandler
());
rpc_service_
->
RegisterRPC
(
distributed
::
kRequestSend
,
request_send_handler_
.
get
(),
...
...
@@ -359,6 +361,8 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
FLAGS_rpc_prefetch_thread_num
);
rpc_service_
->
RegisterRPC
(
distributed
::
kRequestCheckpoint
,
request_checkpoint_handler_
.
get
());
rpc_service_
->
RegisterRPC
(
distributed
::
kRequestGetNoBarrier
,
request_get_no_barrier_handler_
.
get
());
auto
optimize_blocks
=
Attr
<
std
::
vector
<
framework
::
BlockDesc
*>>
(
kOptimizeBlocks
);
...
...
@@ -413,6 +417,7 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
f
(
request_get_handler_
.
get
());
f
(
request_prefetch_handler_
.
get
());
f
(
request_checkpoint_handler_
.
get
());
f
(
request_get_no_barrier_handler_
.
get
());
// start the server listening after all member initialized.
server_thread_
.
reset
(
new
std
::
thread
(
RunServer
,
rpc_service_
));
...
...
paddle/fluid/operators/distributed_ops/listen_and_serv_op.h
浏览文件 @
8b50ad80
...
...
@@ -55,7 +55,6 @@ class ListenAndServOp : public framework::OperatorBase {
const
framework
::
VariableNameMap
&
inputs
,
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
AttributeMap
&
attrs
);
virtual
~
ListenAndServOp
();
void
RunSyncLoop
(
framework
::
Executor
*
executor
,
...
...
@@ -89,6 +88,8 @@ class ListenAndServOp : public framework::OperatorBase {
mutable
std
::
shared_ptr
<
distributed
::
RPCServer
>
rpc_service_
;
mutable
std
::
shared_ptr
<
distributed
::
RequestHandler
>
request_send_handler_
;
mutable
std
::
shared_ptr
<
distributed
::
RequestHandler
>
request_get_handler_
;
mutable
std
::
shared_ptr
<
distributed
::
RequestHandler
>
request_get_no_barrier_handler_
;
mutable
std
::
shared_ptr
<
distributed
::
RequestHandler
>
request_prefetch_handler_
;
mutable
std
::
shared_ptr
<
distributed
::
RequestHandler
>
...
...
paddle/fluid/operators/distributed_ops/recv_op.cc
浏览文件 @
8b50ad80
...
...
@@ -27,34 +27,54 @@ namespace operators {
class
RecvOp
:
public
framework
::
OperatorBase
{
public:
RecvOp
(
const
std
::
string
&
type
,
const
framework
::
VariableNameMap
&
inputs
,
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
AttributeMap
&
attrs
)
RecvOp
(
const
std
::
string
&
type
,
const
framework
::
VariableNameMap
&
inputs
,
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
AttributeMap
&
attrs
)
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{
auto
outs
=
Outputs
(
"Out"
);
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{
std
::
vector
<
std
::
string
>
epmap
=
Attr
<
std
::
vector
<
std
::
string
>>
(
"epmap"
);
std
::
vector
<
std
::
string
>
varnames
=
Attr
<
std
::
vector
<
std
::
string
>>
(
"varnames"
);
int
sync_mode
=
Attr
<
int
>
(
"sync_mode"
);
auto
outs
=
Outputs
(
"Out"
);
bool
with_barrier
=
Attr
<
bool
>
(
"with_barrier"
);
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
auto
&
ctx
=
*
pool
.
Get
(
place
);
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
auto
&
ctx
=
*
pool
.
Get
(
place
);
distributed
::
RPCClient
*
rpc_client
=
distributed
::
RPCClient
*
rpc_client
=
distributed
::
RPCClient
::
GetInstance
<
RPCCLIENT_T
>
(
Attr
<
int
>
(
"trainer_id"
));
if
(
with_barrier
)
{
std
::
vector
<
distributed
::
VarHandlePtr
>
rets
;
for
(
size_t
i
=
0
;
i
<
outs
.
size
();
i
++
)
{
VLOG
(
3
)
<<
"getting "
<<
outs
[
i
]
<<
" from "
<<
epmap
[
i
];
rets
.
push_back
(
rpc_client
->
AsyncGetVar
(
epmap
[
i
],
ctx
,
scope
,
outs
[
i
]));
std
::
string
varname
=
varnames
.
size
()
==
0
?
outs
[
i
]
:
varnames
[
i
];
VLOG
(
4
)
<<
"recv "
<<
outs
[
i
]
<<
" from "
<<
epmap
[
i
]
<<
" with "
<<
varname
<<
" and with AsyncGetVar"
;
rets
.
push_back
(
rpc_client
->
AsyncGetVar
(
epmap
[
i
],
ctx
,
scope
,
varname
,
outs
[
i
]));
}
if
(
sync_mode
)
{
for
(
size_t
i
=
0
;
i
<
rets
.
size
();
i
++
)
{
PADDLE_ENFORCE
(
rets
[
i
]
->
Wait
(),
"internal error in RPCClient"
);
}
}
}
else
{
std
::
vector
<
distributed
::
VarHandlePtr
>
rets
;
for
(
size_t
i
=
0
;
i
<
outs
.
size
();
i
++
)
{
std
::
string
varname
=
varnames
.
size
()
==
0
?
outs
[
i
]
:
varnames
[
i
];
VLOG
(
4
)
<<
"recv "
<<
outs
[
i
]
<<
" from "
<<
epmap
[
i
]
<<
" with "
<<
varname
<<
" and with AsyncGetVarNoBarrier"
;
rets
.
push_back
(
rpc_client
->
AsyncGetVarNoBarrier
(
epmap
[
i
],
ctx
,
scope
,
varname
,
outs
[
i
]));
}
for
(
size_t
i
=
0
;
i
<
rets
.
size
();
i
++
)
{
PADDLE_ENFORCE
(
rets
[
i
]
->
Wait
(),
"internal error in RPCClient"
);
}
}
}
};
...
...
@@ -79,12 +99,23 @@ This operator can get variables from server side.
"(int, default 0)"
"sync recv or async recv."
)
.
SetDefault
(
0
);
AddAttr
<
bool
>
(
"with_barrier"
,
"(bool, default True) if with_barrier=False, will use "
"AsyncGetVarNoBarrier get variable from pserver immediately"
)
.
SetDefault
(
true
);
AddAttr
<
std
::
vector
<
std
::
string
>>
(
"varnames"
,
"(string vector, default {}) "
"sometimes we need to put received var in another name "
"for example: we need var named 'moment_1@127.0.0.1:1001', "
"and it real name on parameter server is 'moment_1'. "
)
.
SetDefault
({});
}
};
class
RecvOpShapeInference
:
public
framework
::
InferShapeBase
{
public:
void
operator
()(
framework
::
InferShapeContext
*
ctx
)
const
override
{}
void
operator
()(
framework
::
InferShapeContext
*
ctx
)
const
override
{}
};
}
// namespace operators
...
...
paddle/fluid/platform/mkldnn_reuse.h
浏览文件 @
8b50ad80
python/paddle/fluid/framework.py
浏览文件 @
8b50ad80
...
...
@@ -1696,12 +1696,20 @@ class Program(object):
self
.
_current_role
=
core
.
op_proto_and_checker_maker
.
OpRole
.
Forward
self
.
_op_role_var
=
[]
# for distribute
# for distribute training
# _is_distributed = True if under distributed training
self
.
_is_distributed
=
False
# _is_chief = True if the trainer is the first one, usually No.0
self
.
_is_chief
=
False
self
.
_slice_vars_and_attrs
=
[]
# _parameters_on_pservers records all the parameters distributed on parameter servers.
self
.
_parameters_on_pservers
=
None
# _endpoints is a list about parameter servers ip:port, such as ["ip:port","ip:port"]
self
.
_endpoints
=
[]
# if current role is parameter server, the _ps_endpoint is its "ip:port"
self
.
_ps_endpoint
=
None
# trainers_endpoints, it is used for distribution.
self
.
_trainers_endpoints
=
[]
# the distributed lookup table names
self
.
_distributed_lookup_table
=
None
@
property
...
...
@@ -2232,8 +2240,9 @@ class Program(object):
"Program"
)
self
.
_is_distributed
=
other
.
_is_distributed
self
.
_is_chief
=
other
.
_is_chief
self
.
_
slice_vars_and_attrs
=
other
.
_slice_vars_and_att
rs
self
.
_
parameters_on_pservers
=
other
.
_parameters_on_pserve
rs
self
.
_endpoints
=
other
.
_endpoints
self
.
_ps_endpoint
=
other
.
_ps_endpoint
self
.
_distributed_lookup_table
=
other
.
_distributed_lookup_table
def
_copy_data_info_from
(
self
,
other
):
...
...
python/paddle/fluid/io.py
浏览文件 @
8b50ad80
...
...
@@ -19,6 +19,7 @@ import errno
import
time
import
shutil
import
six
from
functools
import
reduce
from
paddle.fluid.executor
import
Executor
from
paddle.fluid.evaluator
import
Evaluator
...
...
@@ -183,8 +184,6 @@ def save_vars(executor,
# NOTE: don't save the variable which type is RAW
if
each_var
.
type
==
core
.
VarDesc
.
VarType
.
RAW
:
continue
if
each_var
.
name
==
main_program
.
_distributed_lookup_table
:
continue
new_var
=
_clone_var_in_block_
(
save_block
,
each_var
)
if
filename
is
None
:
save_block
.
append_op
(
...
...
@@ -206,16 +205,6 @@ def save_vars(executor,
outputs
=
{},
attrs
=
{
'file_path'
:
os
.
path
.
join
(
dirname
,
filename
)})
# if there is lookup table, the trainer 0 will notify all pserver to save.
if
main_program
.
_is_distributed
and
main_program
.
_is_chief
and
main_program
.
_distributed_lookup_table
:
lookup_table_filename
=
os
.
path
.
join
(
dirname
,
"__lookup_table__"
)
attrs
=
{}
attrs
[
'epmap'
]
=
main_program
.
_endpoints
attrs
[
'dir'
]
=
lookup_table_filename
attrs
[
'lookup_table'
]
=
main_program
.
_distributed_lookup_table
save_block
.
append_op
(
type
=
'checkpoint_notify'
,
inputs
=
{},
outputs
=
{},
attrs
=
attrs
)
executor
.
run
(
save_program
)
...
...
@@ -267,6 +256,186 @@ def save_params(executor, dirname, main_program=None, filename=None):
filename
=
filename
)
def
_save_distributed_persistables
(
executor
,
dirname
,
main_program
):
"""
save_persistables for distributed training.
the method will do things listed below:
1.save part of persistable variables on trainer.
2.receive "remote prefetch variables" from parameter servers and merge them.
3.save "distributed lookup table" on parameter servers.
4.receive "optimizer variables" from parameter servers and merge them.
Args:
executor(Executor): The executor to run for saving parameters.
dirname(str): The saving directory path.
main_program(Program): The program whose parameters will be
saved. the main_program must be the trainer_program
get after transpiler.
Returns:
None
Examples:
.. code-block:: python
exe = fluid.Executor(fluid.CPUPlace())
param_path = "./my_paddle_model"
t = distribute_transpiler.DistributeTranspiler()
t.transpile(...)
train_program = t.get_trainer_program()
_save_distributed_persistables(executor=exe, dirname=param_path, main_program=train_program)
"""
def
__save_remote_params
(
executor
,
dirname
,
remote_params_map
):
"""
recive params on pserver through rpc.
if the params are be sliced, will concat them to one, then save it.
"""
if
not
remote_params_map
:
return
prog
=
Program
()
block
=
prog
.
global_block
()
# recv optimize vars from pserver
for
name
,
remote_params
in
remote_params_map
.
items
():
origin_var
=
None
is_slice
=
False
slice_vars
=
[
0
]
*
len
(
remote_params
)
slice_var_names
=
[
""
]
*
len
(
remote_params
)
endpoints
=
[
""
]
*
len
(
remote_params
)
for
idx
,
optimizer
in
enumerate
(
remote_params
):
origin
=
optimizer
.
origin
slice
=
optimizer
.
slice
is_slice
=
optimizer
.
is_slice
block_id
=
optimizer
.
block_id
endpoint
=
optimizer
.
endpoint
if
idx
==
0
:
origin_var
=
block
.
create_var
(
name
=
origin
.
name
,
type
=
origin
.
type
,
shape
=
origin
.
shape
,
dtype
=
origin
.
dtype
,
persistable
=
True
)
slice_var
=
block
.
create_var
(
name
=
"{}.slice.{}"
.
format
(
slice
.
name
,
idx
),
type
=
slice
.
type
,
shape
=
slice
.
shape
,
dtype
=
slice
.
dtype
,
persistable
=
True
)
index
=
block_id
if
is_slice
else
idx
slice_vars
[
index
]
=
slice_var
slice_var_names
[
index
]
=
slice
.
name
endpoints
[
index
]
=
endpoint
if
is_slice
:
block
.
append_op
(
type
=
'recv'
,
inputs
=
{
"X"
:
[]},
outputs
=
{
"Out"
:
slice_vars
},
attrs
=
{
"epmap"
:
endpoints
,
"with_barrier"
:
False
,
"varnames"
:
slice_var_names
,
"sync_mode"
:
True
})
block
.
append_op
(
type
=
'concat'
,
inputs
=
{
'X'
:
slice_vars
},
outputs
=
{
'Out'
:
origin_var
},
attrs
=
{})
else
:
block
.
append_op
(
type
=
'recv'
,
inputs
=
{
"X"
:
[]},
outputs
=
{
"Out"
:
[
origin_var
]},
attrs
=
{
"epmap"
:
endpoints
[:
1
],
"with_barrier"
:
False
,
"varnames"
:
slice_var_names
,
"sync_mode"
:
True
})
block
.
append_op
(
type
=
'save'
,
inputs
=
{
'X'
:
[
origin_var
]},
outputs
=
{},
attrs
=
{
'file_path'
:
os
.
path
.
join
(
dirname
,
origin_var
.
name
)})
block
.
append_op
(
type
=
'delete_var'
,
inputs
=
{
'X'
:
slice_vars
})
executor
.
run
(
prog
)
def
__save_distributed_lookup_tables
(
executor
,
dirname
,
distributed_lookup_table
,
endpoints
):
"""
because the distributed lookup table may too huge to merge and save at one place,
it will be saved at parameter server independent respectively.
the save directory is dirname/"__lookup_table__".
"""
prog
=
Program
()
block
=
prog
.
global_block
()
# if there is lookup table, the trainer 0 will notify all pserver to save.
lookup_table_filename
=
os
.
path
.
join
(
dirname
,
"__lookup_table__"
)
attrs
=
{}
attrs
[
'epmap'
]
=
endpoints
attrs
[
'dir'
]
=
lookup_table_filename
attrs
[
'lookup_table'
]
=
distributed_lookup_table
block
.
append_op
(
type
=
'checkpoint_notify'
,
inputs
=
{},
outputs
=
{},
attrs
=
attrs
)
executor
.
run
(
prog
)
def
__exclude_vars
(
exclude_var_names
=
[]):
def
is_valid
(
var
):
if
var
.
name
in
exclude_var_names
:
return
False
if
var
.
desc
.
type
()
==
core
.
VarDesc
.
VarType
.
FEED_MINIBATCH
or
\
var
.
desc
.
type
()
==
core
.
VarDesc
.
VarType
.
FETCH_LIST
or
\
var
.
desc
.
type
()
==
core
.
VarDesc
.
VarType
.
READER
:
return
False
return
var
.
persistable
return
is_valid
if
not
isinstance
(
main_program
,
Program
):
raise
ValueError
(
"'main_program' should be an instance of Program."
)
if
not
main_program
.
_is_distributed
:
raise
ValueError
(
"'_save_distributed_persistables' just be designed for distributed training."
)
remote_params_map
=
main_program
.
_parameters_on_pservers
.
get_distributed_vars_by_vtypes
(
[
"Optimizer"
,
"RemotePrefetch"
],
groupby
=
True
)
exclude_var_names
=
[]
if
remote_params_map
:
exclude_var_names
.
extend
(
remote_params_map
.
keys
())
if
main_program
.
_distributed_lookup_table
:
if
isinstance
(
main_program
.
_distributed_lookup_table
,
list
):
exclude_var_names
.
extend
(
main_program
.
_distributed_lookup_table
)
else
:
exclude_var_names
.
append
(
main_program
.
_distributed_lookup_table
)
local_vars
=
list
(
filter
(
__exclude_vars
(
exclude_var_names
),
main_program
.
list_vars
()))
save_vars
(
executor
,
main_program
=
main_program
,
dirname
=
dirname
,
vars
=
local_vars
)
if
main_program
.
_is_chief
:
if
remote_params_map
:
__save_remote_params
(
executor
,
dirname
,
remote_params_map
)
if
main_program
.
_distributed_lookup_table
:
__save_distributed_lookup_tables
(
executor
,
dirname
,
main_program
.
_distributed_lookup_table
,
main_program
.
_endpoints
)
def
save_persistables
(
executor
,
dirname
,
main_program
=
None
,
filename
=
None
):
"""
This function filters out all variables with `persistable==True` from the
...
...
@@ -301,6 +470,12 @@ def save_persistables(executor, dirname, main_program=None, filename=None):
fluid.io.save_persistables(executor=exe, dirname=param_path,
main_program=None)
"""
if
main_program
and
main_program
.
_is_distributed
:
_save_distributed_persistables
(
executor
,
dirname
=
dirname
,
main_program
=
main_program
)
else
:
save_vars
(
executor
,
dirname
=
dirname
,
...
...
@@ -402,17 +577,11 @@ def load_vars(executor,
if
not
isinstance
(
main_program
,
Program
):
raise
TypeError
(
"program should be as Program type or None"
)
load_slice_vars
=
[]
for
each_var
in
main_program
.
_slice_vars_and_attrs
:
load_slice_vars
.
append
(
each_var
[
2
].
name
)
load_var_map
=
{}
for
each_var
in
vars
:
assert
isinstance
(
each_var
,
Variable
)
if
each_var
.
type
==
core
.
VarDesc
.
VarType
.
RAW
:
continue
if
each_var
.
name
in
load_slice_vars
:
continue
new_var
=
_clone_var_in_block_
(
load_block
,
each_var
)
if
filename
is
None
:
load_block
.
append_op
(
...
...
@@ -435,10 +604,6 @@ def load_vars(executor,
attrs
=
{
'file_path'
:
os
.
path
.
join
(
dirname
,
filename
)})
executor
.
run
(
load_prog
)
# load slice vars on pserver, if have it.
_load_slice_up_vars
(
executor
,
dirname
,
main_program
.
_slice_vars_and_attrs
)
def
load_params
(
executor
,
dirname
,
main_program
=
None
,
filename
=
None
):
"""
...
...
@@ -521,6 +686,11 @@ def load_persistables(executor, dirname, main_program=None, filename=None):
fluid.io.load_persistables(executor=exe, dirname=param_path,
main_program=None)
"""
if
main_program
and
main_program
.
_is_distributed
:
_load_distributed_persistables
(
executor
,
dirname
=
dirname
,
main_program
=
main_program
)
else
:
load_vars
(
executor
,
dirname
=
dirname
,
...
...
@@ -529,6 +699,123 @@ def load_persistables(executor, dirname, main_program=None, filename=None):
filename
=
filename
)
def
_load_distributed_persistables
(
executor
,
dirname
,
main_program
=
None
):
"""
customized load_persistables for distributed training.
it should be used on parameter server,
Args:
executor(Executor): The executor to run for saving parameters.
dirname(str): The load directory path.
main_program(Program): The program whose parameters will be
loaded. the main_program must be the pserver_program
get after transpiler.
Returns:
None
Examples:
.. code-block:: python
exe = fluid.Executor(fluid.CPUPlace())
param_path = "./my_paddle_model"
t = distribute_transpiler.DistributeTranspiler()
t.transpile(...)
pserver_prog = t.get_pserver_program(...)
_load_distributed_persistables(executor=exe, dirname=param_path, main_program=pserver_prog)
"""
def
__is_distributed_part_var
(
varname
):
trainer_idx
=
varname
.
find
(
".trainer_"
)
block_idx
=
varname
.
find
(
".block"
)
return
trainer_idx
or
block_idx
def
__load_persistable_vars
(
executor
,
dirname
,
need_load_vars
):
load_prog
=
Program
()
load_block
=
load_prog
.
global_block
()
need_delete_vars
=
[]
for
param
in
need_load_vars
:
origin_var
=
param
.
origin
slice_var
=
param
.
slice
is_slice
=
param
.
is_slice
offset
=
param
.
offset
if
is_slice
:
origin
=
load_block
.
create_var
(
name
=
"{}.load"
.
format
(
origin_var
.
name
),
type
=
origin_var
.
type
,
shape
=
origin_var
.
shape
,
dtype
=
origin_var
.
dtype
,
persistable
=
True
)
load_block
.
append_op
(
type
=
'load'
,
inputs
=
{},
outputs
=
{
'Out'
:
[
origin
]},
attrs
=
{
'file_path'
:
os
.
path
.
join
(
dirname
,
origin_var
.
name
)
})
slice
=
load_block
.
create_var
(
name
=
slice_var
.
name
,
type
=
slice_var
.
type
,
shape
=
slice_var
.
shape
,
dtype
=
slice_var
.
dtype
,
persistable
=
True
)
dim1_flatten
=
reduce
(
lambda
x
,
y
:
x
*
y
,
slice
.
shape
[
1
:])
start
=
int
(
offset
/
dim1_flatten
)
end
=
int
(
offset
/
dim1_flatten
+
slice
.
shape
[
0
])
load_block
.
append_op
(
type
=
"slice"
,
inputs
=
{
'Input'
:
origin
},
outputs
=
{
'Out'
:
slice
},
attrs
=
{
'axes'
:
[
0
],
'starts'
:
[
start
],
'ends'
:
[
end
]})
need_delete_vars
.
append
(
origin
)
else
:
origin
=
load_block
.
create_var
(
name
=
"{}"
.
format
(
origin_var
.
name
),
type
=
origin_var
.
type
,
shape
=
origin_var
.
shape
,
dtype
=
origin_var
.
dtype
,
persistable
=
True
)
load_block
.
append_op
(
type
=
'load'
,
inputs
=
{},
outputs
=
{
'Out'
:
[
origin
]},
attrs
=
{
'file_path'
:
os
.
path
.
join
(
dirname
,
origin_var
.
name
)
})
load_block
.
append_op
(
type
=
'delete_var'
,
inputs
=
{
'X'
:
need_delete_vars
},
)
executor
.
run
(
load_prog
)
if
not
isinstance
(
main_program
,
Program
):
raise
ValueError
(
"'main_program' should be an instance of Program."
)
if
not
main_program
.
_is_distributed
:
raise
ValueError
(
"'_load_distributed_persistables' just be designed for distributed training."
)
if
not
main_program
.
_ps_endpoint
:
raise
ValueError
(
"'_load_distributed_persistables' need current_endpoint set in DistributeTranspiler.transpile"
)
need_load_vars
=
main_program
.
_parameters_on_pservers
.
get_distributed_vars_by_ep
(
main_program
.
_ps_endpoint
)
__load_persistable_vars
(
executor
,
dirname
,
need_load_vars
)
def
prepend_feed_ops
(
inference_program
,
feed_target_names
,
feed_holder_name
=
'feed'
):
...
...
@@ -795,52 +1082,6 @@ def load_inference_model(dirname,
return
[
program
,
feed_target_names
,
fetch_targets
]
def
_save_lookup_tables_by_notify
(
executor
,
dirname
,
lookup_table
,
pserver_endpoints
):
"""
This function will send checkpoint notify message from Trainer 0
to all the pservers.
The checkpoint notify message contains lookup table name,
the absolute path on pserver to save lookup_table.
Args:
executor(Executor): The executor to run for send checkpoint notify.
dirname(str): The folder where to save.
lookup_table(string): the lookup table name, when use distribute
lookup table, we can get lookup table name by DistributeTranspiler.
table_name
ps_endpoint_list(list): the parameter server ip:port list.
when use distribute lookup table, we can get ps_endpoint_list by
distribute arguments.
Return:
None
Examples:
.. code-block:: python
exe = fluid.Executor(fluid.CPUPlace())
param_path = "./my_paddle_model"
table_name = "share_w"
ps_endpoints = ["127.0.0.1:6000","127.0.0.1:6001"]
_save_pserver_vars_by_notify(executor=exe,
dirname=param_path, lookup_table=table_name,
pserver_endpoints=ps_endpoints)
"""
pserver_notify_program
=
Program
()
pserver_notify_block
=
pserver_notify_program
.
global_block
()
attrs
=
{}
attrs
[
'epmap'
]
=
pserver_endpoints
attrs
[
'dir'
]
=
dirname
attrs
[
'lookup_table'
]
=
lookup_table
pserver_notify_block
.
append_op
(
type
=
'checkpoint_notify'
,
inputs
=
{},
outputs
=
{},
attrs
=
attrs
)
executor
.
run
(
pserver_notify_program
)
def
_endpoints_replacement
(
program
,
endpoints
):
ENDPOINT_MAP
=
"epmap"
for
op
in
program
.
global_block
().
ops
:
...
...
@@ -911,54 +1152,3 @@ def get_parameter_value_by_name(name, executor, program=None):
program
=
default_main_program
()
var
=
program
.
global_block
().
var
(
name
)
return
get_parameter_value
(
var
,
executor
)
def
_load_slice_up_vars
(
executor
,
dirname
,
slice_vars_and_attrs
):
if
not
slice_vars_and_attrs
:
return
load_prog
=
Program
()
load_block
=
load_prog
.
global_block
()
need_delete_vars
=
[]
for
var_tuple
in
slice_vars_and_attrs
:
orig_var
=
var_tuple
[
0
]
start
=
var_tuple
[
1
]
slice_var
=
var_tuple
[
2
]
end
=
start
+
slice_var
.
shape
[
0
]
orig_var_name
=
orig_var
.
name
orig_var
.
name
=
"{}.origin"
.
format
(
orig_var_name
)
clone_orig_var
=
load_block
.
create_var
(
name
=
orig_var
.
name
,
type
=
orig_var
.
type
,
shape
=
orig_var
.
shape
,
dtype
=
orig_var
.
dtype
,
persistable
=
True
)
clone_slice_var
=
load_block
.
create_var
(
name
=
slice_var
.
name
,
type
=
slice_var
.
type
,
shape
=
slice_var
.
shape
,
dtype
=
slice_var
.
dtype
,
persistable
=
True
)
load_block
.
append_op
(
type
=
'load'
,
inputs
=
{},
outputs
=
{
'Out'
:
[
clone_orig_var
]},
attrs
=
{
'file_path'
:
os
.
path
.
join
(
dirname
,
orig_var_name
)})
load_block
.
append_op
(
type
=
"slice"
,
inputs
=
{
'Input'
:
clone_orig_var
},
outputs
=
{
'Out'
:
clone_slice_var
},
attrs
=
{
'axes'
:
[
0
],
'starts'
:
[
start
],
'ends'
:
[
end
]})
need_delete_vars
.
append
(
clone_orig_var
)
load_block
.
append_op
(
type
=
'delete_var'
,
inputs
=
{
'X'
:
need_delete_vars
},
)
executor
.
run
(
load_prog
)
python/paddle/fluid/tests/unittests/dist_save_load.py
浏览文件 @
8b50ad80
...
...
@@ -80,7 +80,8 @@ class TestDistSaveLoad2x2(TestDistSimnetBow2x2):
# NOTE: pserver should not call memory optimize
t
=
self
.
get_transpiler
(
args
.
trainer_id
,
fluid
.
default_main_program
(),
args
.
endpoints
,
args
.
trainers
,
args
.
sync_mode
)
args
.
trainers
,
args
.
sync_mode
,
False
,
args
.
current_endpoint
)
pserver_prog
=
t
.
get_pserver_program
(
args
.
current_endpoint
)
startup_prog
=
t
.
get_startup_program
(
args
.
current_endpoint
,
pserver_prog
)
...
...
@@ -93,7 +94,8 @@ class TestDistSaveLoad2x2(TestDistSimnetBow2x2):
exe
.
run
(
startup_prog
)
if
need_load
and
model_dir
:
self
.
_load_persistable_vars
(
exe
,
model_dir
,
startup_prog
)
fluid
.
io
.
load_persistables
(
exe
,
model_dir
,
pserver_prog
)
exe
.
run
(
pserver_prog
)
def
run_trainer
(
self
,
args
):
...
...
@@ -158,7 +160,9 @@ class TestDistSaveLoad2x2(TestDistSimnetBow2x2):
need_save
=
bool
(
int
(
os
.
getenv
(
"SAVE"
,
"0"
)))
model_dir
=
os
.
getenv
(
"MODEL_DIR"
,
""
)
save_mode
=
os
.
getenv
(
"SAVE_MODE"
,
""
)
if
save_mode
==
"LOCAL"
:
if
need_save
:
for
_
in
six
.
moves
.
xrange
(
RUN_STEP
):
loss
,
=
exe
.
run
(
fetch_list
=
[
avg_cost
.
name
],
...
...
@@ -166,12 +170,37 @@ class TestDistSaveLoad2x2(TestDistSimnetBow2x2):
if
need_save
and
model_dir
:
io
.
save_persistables
(
startup_exe
,
model_dir
,
trainer_prog
)
var
=
np
.
array
(
fluid
.
global_scope
().
find_var
(
'__fc_b__'
).
get_tensor
())
var
=
np
.
array
(
fluid
.
global_scope
().
find_var
(
'__fc_b__'
).
get_tensor
(
))
if
six
.
PY2
:
print
(
pickle
.
dumps
(
np
.
ravel
(
var
).
tolist
()))
else
:
sys
.
stdout
.
buffer
.
write
(
pickle
.
dumps
(
np
.
ravel
(
var
).
tolist
()))
elif
save_mode
==
"DIST"
:
skip_steps
=
int
(
os
.
getenv
(
"SKIP_STEPS"
))
loss
=
None
if
need_save
:
for
idx
in
six
.
moves
.
xrange
(
8
):
loss
,
=
exe
.
run
(
fetch_list
=
[
avg_cost
.
name
],
feed
=
feeder
.
feed
(
get_data
()))
if
need_save
and
model_dir
and
idx
==
skip_steps
and
args
.
trainer_id
==
0
:
io
.
save_persistables
(
startup_exe
,
model_dir
,
trainer_prog
)
else
:
for
idx
in
six
.
moves
.
xrange
(
8
):
data
=
get_data
()
if
idx
<=
skip_steps
:
continue
loss
,
=
exe
.
run
(
fetch_list
=
[
avg_cost
.
name
],
feed
=
feeder
.
feed
(
data
))
if
six
.
PY2
:
print
(
pickle
.
dumps
(
loss
.
tolist
()))
else
:
sys
.
stdout
.
buffer
.
write
(
pickle
.
dumps
(
loss
.
tolist
()))
else
:
raise
Exception
(
"save_mode must be LOCAL or DIST"
)
if
__name__
==
"__main__"
:
paddle
.
dataset
.
common
.
download
(
DATA_URL
,
'simnet'
,
DATA_MD5
,
"train"
)
...
...
python/paddle/fluid/tests/unittests/dist_simnet_bow.py
浏览文件 @
8b50ad80
...
...
@@ -75,8 +75,12 @@ def get_loss(cos_q_pt, cos_q_nt):
return
avg_cost
def
get_optimizer
():
# SGD optimizer
def
get_optimizer
(
op
=
"sgd"
):
if
op
.
upper
()
==
"sgd"
.
upper
():
optimizer
=
fluid
.
optimizer
.
SGD
(
learning_rate
=
base_lr
)
elif
op
.
upper
()
==
"adam"
.
upper
():
optimizer
=
fluid
.
optimizer
.
Adam
(
learning_rate
=
base_lr
)
else
:
optimizer
=
fluid
.
optimizer
.
SGD
(
learning_rate
=
base_lr
)
return
optimizer
...
...
@@ -237,7 +241,8 @@ class TestDistSimnetBow2x2(TestDistRunnerBase):
inference_program
=
fluid
.
default_main_program
().
clone
()
# Optimization
opt
=
get_optimizer
()
opt
=
os
.
getenv
(
'OPTIMIZER'
,
'sgd'
)
opt
=
get_optimizer
(
opt
)
opt
.
minimize
(
avg_cost
)
# Reader
...
...
python/paddle/fluid/tests/unittests/test_dist_base.py
浏览文件 @
8b50ad80
...
...
@@ -43,7 +43,8 @@ class TestDistRunnerBase(object):
pserver_endpoints
,
trainers
,
sync_mode
,
dc_asgd
=
False
):
dc_asgd
=
False
,
current_endpoint
=
None
):
# NOTE: import fluid until runtime, or else forking processes will cause error.
config
=
fluid
.
DistributeTranspilerConfig
()
config
.
enable_dc_asgd
=
dc_asgd
...
...
@@ -53,7 +54,8 @@ class TestDistRunnerBase(object):
program
=
main_program
,
pservers
=
pserver_endpoints
,
trainers
=
trainers
,
sync_mode
=
sync_mode
)
sync_mode
=
sync_mode
,
current_endpoint
=
current_endpoint
)
return
t
def
run_pserver
(
self
,
args
):
...
...
python/paddle/fluid/tests/unittests/test_dist_save_load.py
浏览文件 @
8b50ad80
...
...
@@ -33,7 +33,6 @@ class TestDistSaveLoadDense2x2(TestDistBase):
delta
=
1e-3
,
check_error_log
=
False
,
need_envs
=
{}):
required_envs
=
{
"PATH"
:
os
.
getenv
(
"PATH"
,
""
),
"PYTHONPATH"
:
os
.
getenv
(
"PYTHONPATH"
,
""
),
...
...
@@ -77,7 +76,77 @@ class TestDistSaveLoadDense2x2(TestDistBase):
need_envs
=
{
"IS_DISTRIBUTED"
:
'0'
,
"IS_SPARSE"
:
'0'
,
'IS_SELF_CONTAINED_LR'
:
'1'
'IS_SELF_CONTAINED_LR'
:
'1'
,
'SAVE_MODE'
:
'LOCAL'
,
}
self
.
check_with_place
(
"dist_save_load.py"
,
delta
=
0
,
check_error_log
=
False
,
need_envs
=
need_envs
)
class
TestDistSaveLoadWithPServerStateDense2x2
(
TestDistBase
):
def
_setup_config
(
self
):
self
.
_sync_mode
=
True
self
.
_enforce_place
=
"CPU"
def
check_with_place
(
self
,
model_file
,
delta
=
1e-3
,
check_error_log
=
False
,
need_envs
=
{}):
required_envs
=
{
"PATH"
:
os
.
getenv
(
"PATH"
,
""
),
"PYTHONPATH"
:
os
.
getenv
(
"PYTHONPATH"
,
""
),
"LD_LIBRARY_PATH"
:
os
.
getenv
(
"LD_LIBRARY_PATH"
,
""
),
"http_proxy"
:
""
}
required_envs
.
update
(
need_envs
)
if
check_error_log
:
required_envs
[
"GLOG_v"
]
=
"3"
required_envs
[
"GLOG_logtostderr"
]
=
"1"
model_dir
=
tempfile
.
mkdtemp
()
save_env
=
{}
save_env
[
"SAVE_MODE"
]
=
"DIST"
save_env
[
"SAVE"
]
=
"1"
save_env
[
"MODEL_DIR"
]
=
model_dir
save_env
.
update
(
required_envs
)
tr0_var_1
,
tr1_var_1
=
self
.
_run_cluster
(
model_file
,
save_env
,
check_error_log
)
load_env
=
{}
load_env
[
"LOAD"
]
=
"1"
load_env
[
"MODEL_DIR"
]
=
model_dir
load_env
.
update
(
required_envs
)
tr0_var_2
,
tr1_var_2
=
self
.
_run_cluster
(
model_file
,
load_env
,
check_error_log
)
shutil
.
rmtree
(
model_dir
)
train0_1_np
=
np
.
array
(
tr0_var_1
)
train1_1_np
=
np
.
array
(
tr1_var_1
)
train0_2_np
=
np
.
array
(
tr0_var_2
)
train1_2_np
=
np
.
array
(
tr1_var_2
)
self
.
assertAlmostEqual
(
train0_1_np
.
all
(),
train0_2_np
.
all
(),
delta
=
delta
)
self
.
assertAlmostEqual
(
train1_1_np
.
all
(),
train1_2_np
.
all
(),
delta
=
delta
)
def
test_dist
(
self
):
need_envs
=
{
"IS_DISTRIBUTED"
:
'0'
,
"IS_SPARSE"
:
'0'
,
'IS_SELF_CONTAINED_LR'
:
'1'
,
'SAVE_MODE'
:
'DIST'
,
'OPTIMIZER'
:
'ADAM'
,
'SKIP_STEPS'
:
str
(
np
.
random
.
randint
(
2
,
6
))
}
self
.
check_with_place
(
"dist_save_load.py"
,
...
...
python/paddle/fluid/tests/unittests/test_dist_transpiler.py
浏览文件 @
8b50ad80
...
...
@@ -741,21 +741,40 @@ class TestLoadSliceVar(TranspilerTest):
pserver
,
_
=
self
.
get_pserver
(
self
.
pserver1_ep
)
pserver2
,
_
=
self
.
get_pserver
(
self
.
pserver2_ep
)
self
.
assertTrue
(
pserver
.
_slice_vars_and_attrs
)
self
.
assertTrue
(
pserver2
.
_slice_vars_and_attrs
)
for
idx
in
six
.
moves
.
xrange
(
len
(
pserver
.
_slice_vars_and_attrs
)):
self
.
assertEqual
(
pserver
.
_slice_vars_and_attrs
[
idx
][
0
],
pserver2
.
_slice_vars_and_attrs
[
idx
][
0
])
total_numel
=
six
.
moves
.
reduce
(
lambda
x
,
y
:
x
*
y
,
pserver
.
_slice_vars_and_attrs
[
idx
][
0
].
shape
)
self
.
assertEqual
(
total_numel
,
six
.
moves
.
reduce
(
lambda
x
,
y
:
x
*
y
,
pserver
.
_slice_vars_and_attrs
[
idx
][
2
].
shape
)
+
six
.
moves
.
reduce
(
lambda
x
,
y
:
x
*
y
,
pserver2
.
_slice_vars_and_attrs
[
idx
][
2
].
shape
))
vars_ps1
=
pserver
.
_parameters_on_pservers
.
get_distributed_vars_by_ep
(
self
.
pserver1_ep
)
vars_ps2
=
pserver
.
_parameters_on_pservers
.
get_distributed_vars_by_ep
(
self
.
pserver2_ep
)
self
.
assertTrue
(
vars_ps1
)
self
.
assertTrue
(
vars_ps2
)
for
idx
in
six
.
moves
.
xrange
(
len
(
vars_ps1
)):
total_numel
=
0
ps1_numel
,
ps2_numel
=
0
,
0
ps1_var
=
vars_ps1
[
idx
]
if
not
ps1_var
.
is_slice
:
total_numel
=
six
.
moves
.
reduce
(
lambda
x
,
y
:
x
*
y
,
vars_ps1
[
idx
].
origin
.
shape
)
ps1_numel
=
six
.
moves
.
reduce
(
lambda
x
,
y
:
x
*
y
,
vars_ps1
[
idx
].
slice
.
shape
)
else
:
ps2_var
=
None
for
var
in
vars_ps2
:
if
var
.
origin
.
name
==
ps1_var
.
origin
.
name
:
ps2_var
=
var
break
total_numel
=
six
.
moves
.
reduce
(
lambda
x
,
y
:
x
*
y
,
ps1_var
.
origin
.
shape
)
ps1_numel
=
six
.
moves
.
reduce
(
lambda
x
,
y
:
x
*
y
,
ps1_var
.
slice
.
shape
)
ps2_numel
=
six
.
moves
.
reduce
(
lambda
x
,
y
:
x
*
y
,
ps2_var
.
slice
.
shape
)
self
.
assertEqual
(
total_numel
,
ps1_numel
+
ps2_numel
)
class
TestNCCL2Transpile
(
TranspilerTest
):
...
...
python/paddle/fluid/transpiler/distribute_transpiler.py
浏览文件 @
8b50ad80
...
...
@@ -39,7 +39,7 @@ from .ps_dispatcher import RoundRobin, PSDispatcher
from
..
import
core
,
framework
,
unique_name
from
..framework
import
Program
,
default_main_program
,
\
default_startup_program
,
Block
,
\
Parameter
,
grad_var_name
Parameter
,
Variable
,
grad_var_name
from
.details
import
*
from
..distribute_lookup_table
import
find_distributed_lookup_table
from
functools
import
reduce
...
...
@@ -62,6 +62,260 @@ def log(*args):
print
(
args
)
class
VarStruct
(
object
):
"""
record part properties of a Variable in python.
"""
def
__init__
(
self
,
name
,
shape
,
dtype
,
type
,
lod_level
,
persistable
):
self
.
name
=
name
self
.
shape
=
shape
self
.
dtype
=
dtype
self
.
type
=
type
self
.
lod_level
=
lod_level
self
.
persistable
=
persistable
class
VarDistributed
(
object
):
"""
a class to record the var distributed on parameter servers.
the class will record the relationship between origin var and slice var.
the slice var's properties, such as type/shape/offset/endpoint.
"""
def
__init__
(
self
,
origin_var
,
slice_var
,
is_slice
=
None
,
block_id
=
None
,
offset
=
None
,
vtype
=
None
,
endpoint
=
None
):
"""
Args:
origin_var(Variable|VarStruct): origin var properties
slice_var(Variable|VarStruct): slice var properties
is_slice(bool|None): slice or not, slice_var=True/False and its block size > 8192 are the judgement standard.
block_id(int|None): the number about the slice var.
offset(int|None): if the slice var is sliced, offset is the numel before the var.
vtype(str|None): a tag, such as Optimizer/Param/RemoteProfetch.
endpoint(str|None): which parameter the slice var on, such as "127.0.0.1:1001"
"""
if
isinstance
(
origin_var
,
Variable
):
self
.
origin
=
self
.
__create_var_struct
(
origin_var
)
else
:
self
.
origin
=
origin_var
if
isinstance
(
slice_var
,
Variable
):
self
.
slice
=
self
.
__create_var_struct
(
slice_var
)
else
:
self
.
slice
=
slice_var
if
self
.
equal
(
self
.
origin
,
self
.
slice
):
self
.
is_slice
=
False
self
.
block_id
=
0
self
.
offset
=
0
else
:
self
.
is_slice
=
True
self
.
block_id
=
0
self
.
offset
=
0
if
is_slice
is
not
None
:
self
.
is_slice
=
is_slice
if
block_id
is
not
None
:
self
.
block_id
=
block_id
if
offset
is
not
None
:
self
.
offset
=
offset
self
.
vtype
=
vtype
self
.
endpoint
=
endpoint
@
staticmethod
def
__create_var_struct
(
var
):
return
VarStruct
(
var
.
name
,
var
.
shape
,
var
.
dtype
,
var
.
type
,
var
.
lod_level
,
var
.
persistable
)
@
staticmethod
def
equal
(
var1
,
var2
):
"""
the two var is equal or not.
Returns:
bool: equal will return True else False
"""
assert
isinstance
(
var1
,
VarStruct
)
and
isinstance
(
var2
,
VarStruct
)
return
var1
.
name
==
var2
.
name
and
\
var1
.
type
==
var2
.
type
and
\
var1
.
shape
==
var2
.
shape
and
\
var1
.
dtype
==
var2
.
dtype
and
\
var1
.
lod_level
==
var2
.
lod_level
and
\
var1
.
persistable
==
var2
.
persistable
def
__str__
(
self
):
origin_var_str
=
"{name} : fluid.{type}.shape{shape}.astype({dtype})"
.
\
format
(
i
=
"{"
,
e
=
"}"
,
name
=
self
.
origin
.
name
,
type
=
self
.
origin
.
type
,
shape
=
self
.
origin
.
shape
,
dtype
=
self
.
origin
.
dtype
)
slice_var_str
=
"{name} : fluid.{type}.shape{shape}.astype({dtype})"
\
".slice({is_slice}).block({block_id}).offset({offset})"
.
\
format
(
i
=
"{"
,
e
=
"}"
,
name
=
self
.
slice
.
name
,
type
=
self
.
slice
.
type
,
shape
=
self
.
slice
.
shape
,
dtype
=
self
.
slice
.
dtype
,
is_slice
=
self
.
is_slice
,
block_id
=
self
.
block_id
,
offset
=
self
.
offset
)
return
"var owned: {}, origin var: ( {} ), slice var: ( {} ), endpoint: {} "
.
format
(
self
.
vtype
,
origin_var_str
,
slice_var_str
,
self
.
endpoint
)
class
VarsDistributed
(
object
):
"""
a gather about VarDistributed with many methods to find distributed vars.
through the class, we can get overview about the distributed parameters on parameter servers.
this class may centralized and convenient for developer to manage and get variable's distribute.
other module can also use this to find variables such io.py.
"""
def
__init__
(
self
):
self
.
distributed_vars
=
[]
def
add_distributed_var
(
self
,
origin_var
,
slice_var
,
is_slice
=
None
,
block_id
=
None
,
offset
=
None
,
vtype
=
None
,
endpoint
=
None
):
"""
add distributed var in this.
Args:
origin_var(Variable|VarStruct): origin var properties
slice_var(Variable|VarStruct): slice var properties
is_slice(bool|None): slice or not, slice_var=True/False and its block size > 8192 are the judgement standard.
block_id(int|None): the number about the slice var.
offset(int|None): if the slice var is sliced, offset is the numel before the var.
vtype(str|None): a tag, such as Optimizer/Param/RemoteProfetch.
endpoint(str|None): which parameter the slice var on, such as "127.0.0.1:1001"
Returns:
None
"""
self
.
distributed_vars
.
append
(
VarDistributed
(
origin_var
,
slice_var
,
is_slice
,
block_id
,
offset
,
vtype
,
endpoint
))
def
get_distributed_var_by_slice
(
self
,
var_name
):
"""
get distributed var by conditions.
Args:
var_name(str): slice var name, such as "w.traier0.block1"
Returns:
VarDistributed: distributed var.
"""
for
dist_var
in
self
.
distributed_vars
:
if
dist_var
.
slice
.
name
==
var_name
:
return
dist_var
return
None
@
staticmethod
def
equal
(
var1
,
var2
):
"""
the two var is equal or not.
Returns:
bool: equal will return True else False
"""
return
var1
.
name
==
var2
.
name
and
\
var1
.
type
==
var2
.
type
and
\
var1
.
shape
==
var2
.
shape
and
\
var1
.
dtype
==
var2
.
dtype
and
\
var1
.
lod_level
==
var2
.
lod_level
and
\
var1
.
persistable
==
var2
.
persistable
def
get_distributed_var_by_origin_and_ep
(
self
,
origin_var_name
,
endpoint
):
"""
get distributed var by conditions.
Args:
origin_var_name(str):
endpoint(str): the parameter endpoint, such as "127.0.0.1:1001"
Returns:
VarDistributed: distributed var.
"""
for
dist_var
in
self
.
distributed_vars
:
if
dist_var
.
origin
.
name
==
origin_var_name
and
dist_var
.
endpoint
==
endpoint
:
return
dist_var
return
None
def
get_distributed_vars_by_vtypes
(
self
,
vtypes
,
groupby
=
False
):
"""
get distributed vars by conditions.
Args:
vtype(str|None): distributed var's vtype, such as "Optimizer", "RemotePrefetch"
groupby(bool|False): group by origin var or not.
Returns:
list: distributed var list.
dict: distributed var map when groupby=True
"""
vtype_vars
=
[]
for
var
in
self
.
distributed_vars
:
if
var
.
vtype
in
vtypes
:
vtype_vars
.
append
(
var
)
if
not
groupby
:
return
vtype_vars
params_map
=
{}
for
var
in
vtype_vars
:
origin_var_name
=
var
.
origin
.
name
if
origin_var_name
in
params_map
.
keys
():
optimizers
=
params_map
.
get
(
origin_var_name
)
else
:
optimizers
=
[]
optimizers
.
append
(
var
)
params_map
[
origin_var_name
]
=
optimizers
return
params_map
def
get_distributed_vars_by_ep
(
self
,
endpoint
,
vtype
=
None
):
"""
get distributed vars by conditions.
Args:
endpoint(str): the parameter server endpoint, such as "127.0.0.1:2001"
vtype(str|None): distributed var's vtype, such as "Optimizer", "RemotePrefetch"
Returns:
list: distributed var list.
"""
endpoint_vars
=
[]
for
var
in
self
.
distributed_vars
:
if
var
.
endpoint
==
endpoint
:
endpoint_vars
.
append
(
var
)
if
not
vtype
:
return
endpoint_vars
vtype_vars
=
[]
for
var
in
endpoint_vars
:
if
var
.
vtype
==
vtype
:
vtype_vars
.
append
(
var
)
return
vtype_vars
def
overview
(
self
):
"""
get the overview string about all params on all parameter servers.
Returns:
Str: overview string.
"""
vars_str
=
[]
for
var
in
self
.
distributed_vars
:
vars_str
.
append
(
str
(
var
))
return
"
\n
"
.
join
(
vars_str
)
class
VarBlock
:
def
__init__
(
self
,
varname
,
offset
,
size
):
self
.
varname
=
varname
...
...
@@ -223,16 +477,13 @@ class DistributeTranspiler(object):
trainer_id
,
trainers
,
current_endpoint
,
startup_program
=
None
,
wait_port
=
True
):
startup_program
=
None
):
if
not
startup_program
:
startup_program
=
default_startup_program
()
if
trainer_id
>=
0
:
worker_endpoints
=
trainers
.
split
(
","
)
# send NCCL_ID to others or recv from trainer 0
worker_endpoints
.
remove
(
current_endpoint
)
if
trainer_id
==
0
and
wait_port
:
wait_server_ready
(
worker_endpoints
)
nccl_id_var
=
startup_program
.
global_block
().
create_var
(
name
=
"NCCLID"
,
persistable
=
True
,
type
=
core
.
VarDesc
.
VarType
.
RAW
)
...
...
@@ -313,13 +564,11 @@ class DistributeTranspiler(object):
if
self
.
config
.
mode
==
"nccl2"
:
assert
(
isinstance
(
trainers
,
str
))
self
.
origin_program
.
_trainers_endpoints
=
trainers
.
split
(
","
)
self
.
_transpile_nccl2
(
trainer_id
,
trainers
,
current_endpoint
,
startup_program
=
startup_program
,
wait_port
=
self
.
config
.
wait_port
)
startup_program
=
startup_program
)
return
self
.
trainer_num
=
trainers
...
...
@@ -327,6 +576,7 @@ class DistributeTranspiler(object):
self
.
trainer_id
=
trainer_id
pserver_endpoints
=
pservers
.
split
(
","
)
self
.
pserver_endpoints
=
pserver_endpoints
self
.
vars_overview
=
VarsDistributed
()
self
.
optimize_ops
,
self
.
params_grads
=
self
.
_get_optimize_pass
()
ps_dispatcher
=
self
.
config
.
split_method
(
self
.
pserver_endpoints
)
...
...
@@ -347,6 +597,7 @@ class DistributeTranspiler(object):
# add distributed attrs to program
self
.
origin_program
.
_is_distributed
=
True
self
.
origin_program
.
_endpoints
=
self
.
pserver_endpoints
self
.
origin_program
.
_ps_endpoint
=
current_endpoint
self
.
origin_program
.
_is_chief
=
self
.
trainer_id
==
0
self
.
origin_program
.
_distributed_lookup_table
=
self
.
table_name
if
self
.
table_name
else
None
...
...
@@ -454,6 +705,10 @@ class DistributeTranspiler(object):
self
.
param_grad_ep_mapping
[
ep
][
"params"
].
append
(
recv_vars
[
i
])
self
.
param_grad_ep_mapping
[
ep
][
"grads"
].
append
(
send_vars
[
i
])
distributed_var
=
self
.
vars_overview
.
get_distributed_var_by_slice
(
recv_vars
[
i
].
name
)
distributed_var
.
endpoint
=
ep
# step4: Concat the parameters splits together after recv.
all_recv_outputs
=
[]
for
param_varname
,
splited_var
in
six
.
iteritems
(
self
.
param_var_mapping
):
...
...
@@ -480,6 +735,12 @@ class DistributeTranspiler(object):
recv_op_role_var_name
=
splited_trainer_grad
[
0
].
name
if
param_varname
in
self
.
sparse_param_to_height_sections
:
for
table_name
in
table_names
:
distributed_var
=
self
.
vars_overview
.
get_distributed_var_by_slice
(
table_name
)
distributed_var
.
vtype
=
"RemotePrefetch"
height_sections
=
self
.
sparse_param_to_height_sections
[
param_varname
]
self
.
_update_remote_sparse_update_op
(
...
...
@@ -532,6 +793,9 @@ class DistributeTranspiler(object):
pserver_endpoints
)
self
.
_split_table_grad_and_add_send_vars
(
program
,
pserver_endpoints
)
self
.
_get_distributed_optimizer_vars
()
self
.
origin_program
.
_parameters_on_pservers
=
self
.
vars_overview
def
get_trainer_program
(
self
,
wait_port
=
True
):
"""
Get transpiled trainer side program.
...
...
@@ -541,6 +805,7 @@ class DistributeTranspiler(object):
"""
# remove optimize ops and add a send op to main_program
# FIXME(typhoonzero): Also ops like clip_gradient, lrn_decay?
lr_ops
=
self
.
_get_lr_ops
()
delete_ops
(
self
.
origin_program
.
global_block
(),
self
.
optimize_ops
)
delete_ops
(
self
.
origin_program
.
global_block
(),
lr_ops
)
...
...
@@ -665,9 +930,14 @@ class DistributeTranspiler(object):
# NOTE: assume blocks of the same variable is not distributed
# on the same pserver, only change param/grad varnames for
# trainers to fetch.
sys
.
stderr
.
write
(
"get_pserver_program() is deprecated, call get_pserver_programs() to get pserver main and startup in a single call.
\n
"
)
# step1
pserver_program
=
Program
()
pserver_program
.
random_seed
=
self
.
origin_program
.
random_seed
pserver_program
.
_copy_dist_param_info_from
(
self
.
origin_program
)
# step2: Create vars to receive vars at parameter servers.
recv_inputs
=
[]
for
v
in
self
.
param_grad_ep_mapping
[
endpoint
][
"params"
]:
...
...
@@ -703,9 +973,6 @@ class DistributeTranspiler(object):
else
:
recv_inputs
.
append
(
single_trainer_var
)
self
.
_slice_params_and_optimizes
=
self
.
_get_slice_vars_and_attrs
(
endpoint
)
# step 3
# Create a union-find data structure from optimize ops,
# If two ops are connected, we could add these two ops
...
...
@@ -882,10 +1149,6 @@ class DistributeTranspiler(object):
outputs
=
{},
attrs
=
attrs
)
# add distributed attrs
pserver_program
.
_slice_vars_and_attrs
=
list
(
self
.
_slice_params_and_optimizes
.
values
())
pserver_program
.
_sync_with_cpp
()
# save pserver program to generate pserver side startup relatively.
self
.
pserver_program
=
pserver_program
...
...
@@ -984,30 +1247,88 @@ class DistributeTranspiler(object):
inputs
=
{
"X"
:
startup_param_var
},
outputs
=
{
"Out"
:
startup_tmpvar
})
# add slice vars
s_prog
.
_slice_vars_and_attrs
=
pserver_program
.
_slice_vars_and_attrs
return
s_prog
def
_get_slice_vars_and_attrs
(
self
,
endpoint
):
slice_vars_and_attrs
=
{}
# ====================== private transpiler functions =====================
def
_get_slice_var_info
(
self
,
slice_var
):
block_suffix
=
"block"
for
param
in
self
.
param_grad_ep_mapping
[
endpoint
][
"params"
]:
orig_var_name
,
block_name
,
_
=
self
.
_get_varname_parts
(
param
.
name
)
block_idx
=
0
offset
=
0
is_slice
=
False
orig_var_name
,
block_name
,
_
=
self
.
_get_varname_parts
(
slice_var
.
name
)
if
not
block_name
:
continue
return
is_slice
,
block_idx
,
offset
block_idx
=
int
(
block_name
.
split
(
block_suffix
)[
1
])
orig_var
=
self
.
origin_program
.
global_block
().
vars
[
orig_var_name
]
skip_dim0
=
0
slice_vars
=
self
.
param_var_mapping
[
orig_var_name
]
orig_dim1_flatten
=
reduce
(
lambda
x
,
y
:
x
*
y
,
slice_vars
[
0
].
shape
[
1
:])
for
slice_var
in
slice_vars
[:
block_idx
]:
skip_dim0
+=
slice_var
.
shape
[
0
]
slice_vars_and_attrs
[
param
.
name
]
=
[
orig_var
,
skip_dim0
,
param
]
return
slice_vars_and_attrs
# ====================== private transpiler functions =====================
offset
=
skip_dim0
*
orig_dim1_flatten
is_slice
=
True
return
is_slice
,
block_idx
,
offset
def
_get_distributed_optimizer_vars
(
self
):
def
_get_distributed_optimizer_var
(
endpoint
):
opt_op_on_pserver
=
[]
for
_
,
op
in
enumerate
(
self
.
optimize_ops
):
if
self
.
_is_optimizer_op
(
op
)
and
self
.
_is_opt_op_on_pserver
(
endpoint
,
op
):
opt_op_on_pserver
.
append
(
op
)
for
opt_op
in
opt_op_on_pserver
:
dist_var
=
None
for
key
in
opt_op
.
input_names
:
if
key
==
"Param"
:
param_name
=
opt_op
.
input
(
key
)[
0
]
dist_var
=
self
.
vars_overview
.
get_distributed_var_by_origin_and_ep
(
param_name
,
endpoint
)
break
for
key
in
opt_op
.
input_names
:
if
key
in
[
"Param"
,
"Grad"
,
"LearningRate"
]:
continue
origin_var
=
self
.
origin_program
.
global_block
().
vars
[
opt_op
.
input
(
key
)[
0
]]
# update accumulator variable shape
new_shape
=
self
.
_get_optimizer_input_shape
(
opt_op
.
type
,
key
,
origin_var
.
shape
,
dist_var
.
slice
.
shape
)
if
new_shape
==
dist_var
.
slice
.
shape
:
splited_var
=
VarStruct
(
name
=
origin_var
.
name
,
shape
=
new_shape
,
dtype
=
origin_var
.
dtype
,
type
=
origin_var
.
type
,
lod_level
=
origin_var
.
lod_level
,
persistable
=
origin_var
.
persistable
)
self
.
vars_overview
.
add_distributed_var
(
origin_var
=
origin_var
,
slice_var
=
splited_var
,
is_slice
=
dist_var
.
is_slice
,
block_id
=
dist_var
.
block_id
,
offset
=
dist_var
.
offset
,
vtype
=
"Optimizer"
,
endpoint
=
endpoint
)
else
:
self
.
vars_overview
.
add_distributed_var
(
origin_var
=
origin_var
,
slice_var
=
origin_var
,
is_slice
=
False
,
block_id
=
0
,
offset
=
0
,
vtype
=
"Optimizer"
,
endpoint
=
endpoint
)
for
ep
in
self
.
pserver_endpoints
:
_get_distributed_optimizer_var
(
ep
)
def
_update_dist_lookup_table_vars
(
self
,
param_list
,
grad_list
,
params_grads
):
...
...
@@ -1093,6 +1414,22 @@ class DistributeTranspiler(object):
# origin_param_name -> [splited_param_vars]
self
.
param_var_mapping
=
self
.
_create_vars_from_blocklist
(
self
.
origin_program
,
param_blocks
)
for
orig_name
,
splited_vars
in
self
.
param_var_mapping
.
items
():
orig_var
=
self
.
origin_program
.
global_block
().
var
(
orig_name
)
for
splited_var
in
splited_vars
:
is_slice
,
block_id
,
offset
=
self
.
_get_slice_var_info
(
splited_var
)
self
.
vars_overview
.
add_distributed_var
(
origin_var
=
orig_var
,
slice_var
=
splited_var
,
block_id
=
block_id
,
offset
=
offset
,
is_slice
=
is_slice
,
vtype
=
"Param"
)
# origin_grad_name -> [splited_grad_vars]
self
.
grad_var_mapping
=
self
.
_create_vars_from_blocklist
(
self
.
origin_program
,
...
...
@@ -1729,13 +2066,6 @@ class DistributeTranspiler(object):
shape
=
new_shape
)
new_inputs
[
key
]
=
tmpvar
# var shape been changed
if
new_shape
!=
var
.
shape
:
slice_var_args
=
self
.
_slice_params_and_optimizes
[
param_var
.
name
]
self
.
_slice_params_and_optimizes
[
var
.
name
]
=
[
var
,
slice_var_args
[
1
],
tmpvar
]
# change output's ParamOut variable
outputs
=
self
.
_get_output_map_from_op
(
self
.
origin_program
.
global_block
().
vars
,
opt_op
)
...
...
@@ -1763,7 +2093,7 @@ class DistributeTranspiler(object):
# skip per trainer vars
if
g
.
name
.
find
(
".trainer_"
)
==
-
1
:
# only param or grads have splited blocks
if
self
.
_orig_varname
(
g
.
name
)
in
self
.
grad_name_to_param_name
or
\
if
self
.
_orig_varname
(
g
.
name
)
in
self
.
grad_name_to_param_name
or
\
self
.
_orig_varname
(
g
.
name
)
in
self
.
param_name_to_grad_name
:
grad_block
=
g
break
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录