Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
8b50ad80
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
8b50ad80
编写于
1月 23, 2019
作者:
T
tangwei12
提交者:
GitHub
1月 23, 2019
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
checkpoint at distributed training (#14854)
checkpoint for distributed training.
上级
07dc5a15
变更
21
展开全部
隐藏空白更改
内联
并排
Showing
21 changed file
with
1122 addition
and
280 deletion
+1122
-280
paddle/fluid/operators/distributed/grpc/grpc_client.cc
paddle/fluid/operators/distributed/grpc/grpc_client.cc
+52
-37
paddle/fluid/operators/distributed/grpc/grpc_client.h
paddle/fluid/operators/distributed/grpc/grpc_client.h
+12
-5
paddle/fluid/operators/distributed/grpc/grpc_server.cc
paddle/fluid/operators/distributed/grpc/grpc_server.cc
+55
-4
paddle/fluid/operators/distributed/grpc/grpc_service.h
paddle/fluid/operators/distributed/grpc/grpc_service.h
+3
-0
paddle/fluid/operators/distributed/request_handler.h
paddle/fluid/operators/distributed/request_handler.h
+13
-0
paddle/fluid/operators/distributed/request_handler_impl.cc
paddle/fluid/operators/distributed/request_handler_impl.cc
+29
-1
paddle/fluid/operators/distributed/request_handler_impl.h
paddle/fluid/operators/distributed/request_handler_impl.h
+10
-0
paddle/fluid/operators/distributed/rpc_client.h
paddle/fluid/operators/distributed/rpc_client.h
+7
-0
paddle/fluid/operators/distributed/send_recv.proto.in
paddle/fluid/operators/distributed/send_recv.proto.in
+18
-0
paddle/fluid/operators/distributed_ops/listen_and_serv_op.cc
paddle/fluid/operators/distributed_ops/listen_and_serv_op.cc
+5
-0
paddle/fluid/operators/distributed_ops/listen_and_serv_op.h
paddle/fluid/operators/distributed_ops/listen_and_serv_op.h
+2
-1
paddle/fluid/operators/distributed_ops/recv_op.cc
paddle/fluid/operators/distributed_ops/recv_op.cc
+47
-16
paddle/fluid/platform/mkldnn_reuse.h
paddle/fluid/platform/mkldnn_reuse.h
+2
-2
python/paddle/fluid/framework.py
python/paddle/fluid/framework.py
+12
-3
python/paddle/fluid/io.py
python/paddle/fluid/io.py
+322
-132
python/paddle/fluid/tests/unittests/dist_save_load.py
python/paddle/fluid/tests/unittests/dist_save_load.py
+43
-14
python/paddle/fluid/tests/unittests/dist_simnet_bow.py
python/paddle/fluid/tests/unittests/dist_simnet_bow.py
+9
-4
python/paddle/fluid/tests/unittests/test_dist_base.py
python/paddle/fluid/tests/unittests/test_dist_base.py
+4
-2
python/paddle/fluid/tests/unittests/test_dist_save_load.py
python/paddle/fluid/tests/unittests/test_dist_save_load.py
+71
-2
python/paddle/fluid/tests/unittests/test_dist_transpiler.py
python/paddle/fluid/tests/unittests/test_dist_transpiler.py
+34
-15
python/paddle/fluid/transpiler/distribute_transpiler.py
python/paddle/fluid/transpiler/distribute_transpiler.py
+372
-42
未找到文件。
paddle/fluid/operators/distributed/grpc/grpc_client.cc
浏览文件 @
8b50ad80
...
...
@@ -74,7 +74,7 @@ VarHandlePtr GRPCClient::AsyncSendVar(const std::string& ep,
const
framework
::
Scope
*
p_scope
=
&
scope
;
const
auto
ch
=
GetChannel
(
ep_val
);
SendProcessor
*
s
=
new
SendProcessor
(
ch
);
const
std
::
string
method
=
"SendRPC"
;
const
std
::
string
method
=
kSendRPC
;
VarHandlePtr
h
(
new
VarHandle
(
ep
,
method
,
var_name_val
,
p_ctx
,
p_scope
));
s
->
Prepare
(
h
,
time_out
);
...
...
@@ -107,7 +107,7 @@ VarHandlePtr GRPCClient::AsyncSendVar(const std::string& ep,
void
ProcGetResponse
(
const
VarHandle
&
var_h
,
const
::
grpc
::
ByteBuffer
&
ret_msg
)
{
VLOG
(
100
)
<<
"ProcGetResponse"
;
VLOG
(
4
)
<<
"ProcGetResponse"
;
framework
::
Variable
*
outvar
=
nullptr
;
// get response's trainer_id is not used
int
trainer_id
;
...
...
@@ -127,59 +127,74 @@ VarHandlePtr GRPCClient::AsyncGetVar(const std::string& ep,
const
platform
::
DeviceContext
&
ctx
,
const
framework
::
Scope
&
scope
,
const
std
::
string
&
var_name
,
const
std
::
string
&
out_varname
,
int64_t
time_out
)
{
return
_AsyncGetVar
(
ep
,
ctx
,
scope
,
var_
name
,
return
_AsyncGetVar
(
ep
,
ctx
,
scope
,
kGetRPC
,
var_name
,
out_var
name
,
"/sendrecv.SendRecvService/GetVariable"
,
time_out
);
}
VarHandlePtr
GRPCClient
::
AsyncGetVarNoBarrier
(
const
std
::
string
&
ep
,
const
platform
::
DeviceContext
&
ctx
,
const
framework
::
Scope
&
scope
,
const
std
::
string
&
var_name
,
const
std
::
string
&
out_varname
,
int64_t
time_out
)
{
std
::
string
var_name_no_barrier
=
string
::
Sprintf
(
"%s%s"
,
var_name
,
WITHOUT_BARRIER_MESSAGE
);
return
_AsyncGetVar
(
ep
,
ctx
,
scope
,
kGetNoBarrierRPC
,
var_name_no_barrier
,
out_varname
,
"/sendrecv.SendRecvService/GetVariableNoBarrier"
,
time_out
);
}
VarHandlePtr
GRPCClient
::
AsyncGetMonomerVariable
(
const
std
::
string
&
ep
,
const
platform
::
DeviceContext
&
ctx
,
const
framework
::
Scope
&
scope
,
const
std
::
string
&
var_name
,
int64_t
time_out
)
{
return
_AsyncGetVar
(
ep
,
ctx
,
scope
,
var_name
,
return
_AsyncGetVar
(
ep
,
ctx
,
scope
,
kGetMonomerRPC
,
var_name
,
var_name
,
"/sendrecv.SendRecvService/GetMonomerVariable"
,
time_out
);
}
VarHandlePtr
GRPCClient
::
_AsyncGetVar
(
const
std
::
string
&
ep
,
const
platform
::
DeviceContext
&
ctx
,
const
framework
::
Scope
&
scope
,
const
std
::
string
&
var_name
,
const
std
::
string
&
rpc_path
,
int64_t
time_out
)
{
VarHandlePtr
GRPCClient
::
_AsyncGetVar
(
const
std
::
string
&
ep
,
const
platform
::
DeviceContext
&
ctx
,
const
framework
::
Scope
&
scope
,
const
std
::
string
&
method
,
const
std
::
string
&
var_name
,
const
std
::
string
&
out_varname
,
const
std
::
string
&
rpc_path
,
int64_t
time_out
)
{
const
platform
::
DeviceContext
*
p_ctx
=
&
ctx
;
const
std
::
string
ep_val
=
ep
;
const
std
::
string
var_name_val
=
var_name
;
const
std
::
string
out_varname_val
=
out_varname
;
const
framework
::
Scope
*
p_scope
=
&
scope
;
const
auto
ch
=
GetChannel
(
ep_val
);
GetProcessor
*
s
=
new
GetProcessor
(
ch
);
const
std
::
string
method
=
"GetRPC"
;
VarHandlePtr
h
(
new
VarHandle
(
ep
,
method
,
var_
name_val
,
p_ctx
,
p_scope
));
VarHandlePtr
h
(
new
VarHandle
(
ep
,
method
,
out_var
name_val
,
p_ctx
,
p_scope
));
s
->
Prepare
(
h
,
time_out
);
framework
::
AsyncIO
([
var_name_val
,
s
,
method
,
p_ctx
,
h
,
rpc_path
,
this
]
{
// prepare input
sendrecv
::
VariableMessage
req
;
req
.
set_varname
(
var_name_val
);
req
.
set_trainer_id
(
trainer_id_
);
::
grpc
::
ByteBuffer
buf
;
RequestToByteBuffer
<
sendrecv
::
VariableMessage
>
(
req
,
&
buf
);
framework
::
AsyncIO
(
[
var_name_val
,
out_varname_val
,
s
,
method
,
p_ctx
,
h
,
rpc_path
,
this
]
{
// prepare input
sendrecv
::
VariableMessage
req
;
req
.
set_varname
(
var_name_val
);
req
.
set_out_varname
(
out_varname_val
);
req
.
set_trainer_id
(
trainer_id_
);
::
grpc
::
ByteBuffer
buf
;
RequestToByteBuffer
<
sendrecv
::
VariableMessage
>
(
req
,
&
buf
);
VLOG
(
3
)
<<
s
->
GetVarHandlePtr
()
->
String
()
<<
" begin"
;
VLOG
(
3
)
<<
s
->
GetVarHandlePtr
()
->
String
()
<<
" begin"
;
// stub context
s
->
response_call_back_
=
ProcGetResponse
;
// stub context
s
->
response_call_back_
=
ProcGetResponse
;
platform
::
RecordRPCEvent
record_event
(
method
,
p_ctx
);
platform
::
RecordRPCEvent
record_event
(
method
,
p_ctx
);
auto
call
=
s
->
stub_g_
.
PrepareUnaryCall
(
s
->
context_
.
get
(),
rpc_path
,
buf
,
&
cq_
);
call
->
StartCall
();
call
->
Finish
(
&
s
->
reply_
,
&
s
->
status_
,
reinterpret_cast
<
void
*>
(
s
));
auto
call
=
s
->
stub_g_
.
PrepareUnaryCall
(
s
->
context_
.
get
(),
rpc_path
,
buf
,
&
cq_
);
call
->
StartCall
();
call
->
Finish
(
&
s
->
reply_
,
&
s
->
status_
,
reinterpret_cast
<
void
*>
(
s
));
if
(
UNLIKELY
(
platform
::
IsProfileEnabled
()))
{
h
->
Wait
();
}
});
if
(
UNLIKELY
(
platform
::
IsProfileEnabled
()))
{
h
->
Wait
();
}
});
req_count_
++
;
...
...
@@ -202,7 +217,7 @@ VarHandlePtr GRPCClient::AsyncPrefetchVar(const std::string& ep,
const
auto
ch
=
GetChannel
(
ep_val
);
GetProcessor
*
s
=
new
GetProcessor
(
ch
);
const
std
::
string
method
=
"PrefetchRPC"
;
const
std
::
string
method
=
kPrefetchRPC
;
VarHandlePtr
h
(
new
VarHandle
(
ep
,
method
,
out_var_name_val
,
p_ctx
,
p_scope
));
s
->
Prepare
(
h
,
time_out
);
...
...
@@ -242,7 +257,7 @@ VarHandlePtr GRPCClient::AsyncSendBatchBarrier(const std::string& ep,
const
auto
ch
=
GetChannel
(
ep
);
BatchBarrierProcessor
*
s
=
new
BatchBarrierProcessor
(
ch
);
const
std
::
string
method
=
"BatchBarrierRPC"
;
const
std
::
string
method
=
kBatchBarrierRPC
;
VarHandlePtr
h
(
new
VarHandle
(
ep
,
method
,
BATCH_BARRIER_MESSAGE
,
nullptr
,
nullptr
));
s
->
Prepare
(
h
,
time_out
);
...
...
@@ -267,7 +282,7 @@ VarHandlePtr GRPCClient::AsyncSendFetchBarrier(const std::string& ep,
int64_t
time_out
)
{
const
auto
ch
=
GetChannel
(
ep
);
FetchBarrierProcessor
*
s
=
new
FetchBarrierProcessor
(
ch
);
const
std
::
string
method
=
"FetchBarrierRPC"
;
const
std
::
string
method
=
kFetchBarrierRPC
;
VarHandlePtr
h
(
new
VarHandle
(
ep
,
method
,
FETCH_BARRIER_MESSAGE
,
nullptr
,
nullptr
));
s
->
Prepare
(
h
,
time_out
);
...
...
@@ -293,7 +308,7 @@ VarHandlePtr GRPCClient::AsyncGetMonomerBarrier(const std::string& ep,
int64_t
time_out
)
{
const
auto
ch
=
GetChannel
(
ep
);
BatchBarrierProcessor
*
s
=
new
BatchBarrierProcessor
(
ch
);
const
std
::
string
method
=
"SendMonomerFetchBarrierRPC"
;
const
std
::
string
method
=
kSendMonomerFetchBarrierRPC
;
VarHandlePtr
h
(
new
VarHandle
(
ep
,
method
,
var_name
,
nullptr
,
nullptr
));
s
->
Prepare
(
h
,
time_out
);
...
...
@@ -320,7 +335,7 @@ VarHandlePtr GRPCClient::AsyncSendComplete(const std::string& ep,
const
auto
ch
=
GetChannel
(
ep
);
BatchBarrierProcessor
*
s
=
new
BatchBarrierProcessor
(
ch
);
const
std
::
string
method
=
"SendCompleteRPC"
;
const
std
::
string
method
=
kSendCompleteRPC
;
VarHandlePtr
h
(
new
VarHandle
(
ep
,
method
,
COMPLETE_MESSAGE
,
nullptr
,
nullptr
));
s
->
Prepare
(
h
,
time_out
);
...
...
@@ -347,7 +362,7 @@ VarHandlePtr GRPCClient::AsyncCheckpointNotify(const std::string& ep,
CheckpointNotifyProcessor
*
s
=
new
CheckpointNotifyProcessor
(
ch
);
const
std
::
string
method
=
"CheckPointNotifyRPC"
;
const
std
::
string
method
=
kCheckPointNotifyRPC
;
VarHandlePtr
h
(
new
VarHandle
(
ep
,
method
,
CHECKPOINT_SAVE_MESSAGE
,
nullptr
,
nullptr
));
...
...
paddle/fluid/operators/distributed/grpc/grpc_client.h
浏览文件 @
8b50ad80
...
...
@@ -186,8 +186,15 @@ class GRPCClient : public RPCClient {
const
platform
::
DeviceContext
&
ctx
,
const
framework
::
Scope
&
scope
,
const
std
::
string
&
var_name
,
const
std
::
string
&
out_varname
,
int64_t
time_out
=
FLAGS_rpc_deadline
)
override
;
VarHandlePtr
AsyncGetVarNoBarrier
(
const
std
::
string
&
ep
,
const
platform
::
DeviceContext
&
ctx
,
const
framework
::
Scope
&
scope
,
const
std
::
string
&
var_name
,
const
std
::
string
&
out_varname
,
int64_t
time_out
=
FLAGS_rpc_deadline
)
override
;
VarHandlePtr
AsyncGetMonomerVariable
(
const
std
::
string
&
ep
,
const
platform
::
DeviceContext
&
ctx
,
const
framework
::
Scope
&
scope
,
const
std
::
string
&
var_name
,
...
...
@@ -228,11 +235,11 @@ class GRPCClient : public RPCClient {
void
Proceed
();
std
::
shared_ptr
<
grpc
::
Channel
>
GetChannel
(
const
std
::
string
&
ep
);
VarHandlePtr
_AsyncGetVar
(
const
std
::
string
&
ep
,
const
platform
::
DeviceContext
&
ctx
,
const
framework
::
Scope
&
scope
,
const
std
::
string
&
var_name
,
const
std
::
string
&
rpc
,
int64_t
time_out
);
VarHandlePtr
_AsyncGetVar
(
const
std
::
string
&
ep
,
const
platform
::
DeviceContext
&
ctx
,
const
framework
::
Scope
&
scope
,
const
std
::
string
&
method
,
const
std
::
string
&
var_name
,
const
std
::
string
&
out_varname
,
const
std
::
string
&
rpc_path
,
int64_t
time_out
=
FLAGS_rpc_deadline
);
private:
grpc
::
CompletionQueue
cq_
;
...
...
paddle/fluid/operators/distributed/grpc/grpc_server.cc
浏览文件 @
8b50ad80
...
...
@@ -136,17 +136,65 @@ class RequestGet final : public RequestBase {
void
Process
()
override
{
// proc request.
std
::
string
varname
=
request_
.
varname
();
std
::
string
out_varname
=
request_
.
out_varname
();
int
trainer_id
=
request_
.
trainer_id
();
VLOG
(
4
)
<<
"RequestGet "
<<
varname
;
VLOG
(
4
)
<<
"RequestGet "
<<
out_varname
<<
" from "
<<
varname
;
auto
scope
=
request_handler_
->
scope
();
auto
invar
=
scope
->
FindVar
(
varname
)
;
framework
::
Variable
*
invar
=
nullptr
;
framework
::
Variable
*
outvar
=
nullptr
;
request_handler_
->
Handle
(
varname
,
scope
,
invar
,
&
outvar
,
trainer_id
);
request_handler_
->
Handle
(
varname
,
scope
,
invar
,
&
outvar
,
trainer_id
,
out_varname
);
if
(
outvar
)
{
SerializeToByteBuffer
(
varname
,
outvar
,
*
request_handler_
->
dev_ctx
(),
SerializeToByteBuffer
(
out_varname
,
outvar
,
*
request_handler_
->
dev_ctx
(),
&
reply_
);
}
Finish
(
reply_
,
&
responder_
);
}
protected:
sendrecv
::
VariableMessage
request_
;
::
grpc
::
ByteBuffer
reply_
;
ServerAsyncResponseWriter
<::
grpc
::
ByteBuffer
>
responder_
;
};
class
RequestGetNoBarrier
final
:
public
RequestBase
{
public:
explicit
RequestGetNoBarrier
(
GrpcService
::
AsyncService
*
service
,
::
grpc
::
ServerCompletionQueue
*
cq
,
RequestHandler
*
request_handler
,
int
req_id
)
:
RequestBase
(
service
,
cq
,
request_handler
,
req_id
),
responder_
(
&
ctx_
)
{
auto
method_id
=
static_cast
<
int
>
(
distributed
::
GrpcMethod
::
kGetVariableNoBarrier
);
service_
->
RequestAsyncUnary
(
method_id
,
&
ctx_
,
&
request_
,
&
responder_
,
cq_
,
cq_
,
reinterpret_cast
<
void
*>
(
static_cast
<
intptr_t
>
(
req_id
)));
}
virtual
~
RequestGetNoBarrier
()
{}
std
::
string
GetReqName
()
override
{
return
request_
.
varname
();
}
void
Process
()
override
{
// proc request.
std
::
string
varname
=
request_
.
varname
();
std
::
string
out_varname
=
request_
.
out_varname
();
int
trainer_id
=
request_
.
trainer_id
();
VLOG
(
4
)
<<
"RequestGetNoBarrier "
<<
out_varname
<<
" from "
<<
varname
;
auto
scope
=
request_handler_
->
scope
();
framework
::
Variable
*
invar
=
nullptr
;
framework
::
Variable
*
outvar
=
nullptr
;
request_handler_
->
Handle
(
varname
,
scope
,
invar
,
&
outvar
,
trainer_id
,
out_varname
);
if
(
outvar
)
{
SerializeToByteBuffer
(
out_varname
,
outvar
,
*
request_handler_
->
dev_ctx
(),
&
reply_
);
}
Finish
(
reply_
,
&
responder_
);
...
...
@@ -460,6 +508,9 @@ void AsyncGRPCServer::TryToRegisterNewOne(const std::string& rpc_name,
b
=
new
RequestSend
(
&
service_
,
cq
.
get
(),
handler
,
req_id
);
}
else
if
(
rpc_name
==
kRequestGet
)
{
b
=
new
RequestGet
(
&
service_
,
cq
.
get
(),
handler
,
req_id
);
}
else
if
(
rpc_name
==
kRequestGetNoBarrier
)
{
b
=
new
RequestGetNoBarrier
(
&
service_
,
cq
.
get
(),
handler
,
req_id
);
}
else
if
(
rpc_name
==
kRequestGetMonomerVariable
)
{
b
=
new
RequestGetMonomerVariable
(
&
service_
,
cq
.
get
(),
handler
,
req_id
,
this
);
...
...
paddle/fluid/operators/distributed/grpc/grpc_service.h
浏览文件 @
8b50ad80
...
...
@@ -81,6 +81,7 @@ enum class GrpcMethod {
kGetVariable
,
kPrefetchVariable
,
kCheckpointNotify
,
kGetVariableNoBarrier
,
kGetMonomerVariable
,
kGetMonomerBarrier
,
};
...
...
@@ -94,6 +95,8 @@ inline const char* GrpcMethodName(GrpcMethod id) {
return
"/sendrecv.SendRecvService/SendVariable"
;
case
GrpcMethod
::
kGetVariable
:
return
"/sendrecv.SendRecvService/GetVariable"
;
case
GrpcMethod
::
kGetVariableNoBarrier
:
return
"/sendrecv.SendRecvService/GetVariableNoBarrier"
;
case
GrpcMethod
::
kGetMonomerVariable
:
return
"/sendrecv.SendRecvService/GetMonomerVariable"
;
case
GrpcMethod
::
kGetMonomerBarrier
:
...
...
paddle/fluid/operators/distributed/request_handler.h
浏览文件 @
8b50ad80
...
...
@@ -42,11 +42,24 @@ constexpr char kRequestGetMonomerBarrier[] = "RequestGetMonomerBarrier";
constexpr
char
kRequestPrefetch
[]
=
"RequestPrefetch"
;
constexpr
char
kRequestCheckpoint
[]
=
"RequestCheckpoint"
;
constexpr
char
kRequestPassBarrier
[]
=
"RequestPassBarrier"
;
constexpr
char
kRequestGetNoBarrier
[]
=
"GetVariableNoBarrier"
;
constexpr
char
kSendRPC
[]
=
"SendRPC"
;
constexpr
char
kGetRPC
[]
=
"GetRPC"
;
constexpr
char
kGetNoBarrierRPC
[]
=
"GetNoBarrierRPC"
;
constexpr
char
kGetMonomerRPC
[]
=
"GetMonomerRPC"
;
constexpr
char
kPrefetchRPC
[]
=
"PrefetchRPC"
;
constexpr
char
kBatchBarrierRPC
[]
=
"BatchBarrierRPC"
;
constexpr
char
kFetchBarrierRPC
[]
=
"FetchBarrierRPC"
;
constexpr
char
kSendMonomerFetchBarrierRPC
[]
=
"SendMonomerFetchBarrierRPC"
;
constexpr
char
kSendCompleteRPC
[]
=
"SendCompleteRPC"
;
constexpr
char
kCheckPointNotifyRPC
[]
=
"CheckPointNotifyRPC"
;
#define LISTEN_TERMINATE_MESSAGE "TERMINATE@RECV"
#define BATCH_BARRIER_MESSAGE "BATCH_BARRIER@RECV"
#define FETCH_BARRIER_MESSAGE "FETCH_BARRIER@RECV"
#define COMPLETE_MESSAGE "COMPLETE@RECV"
#define WITHOUT_BARRIER_MESSAGE "@WITHOUT_BARRIER@RECV"
#define CHECKPOINT_SAVE_MESSAGE "SAVE@CHECKPOINTNOTIFY"
#define CHECKPOINT_LOAD_MESSAGE "LOAD@CHECKPOINTNOTIFY"
...
...
paddle/fluid/operators/distributed/request_handler_impl.cc
浏览文件 @
8b50ad80
...
...
@@ -23,6 +23,7 @@
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/framework/variable_helper.h"
#include "paddle/fluid/operators/distributed/rpc_server.h"
#include "paddle/fluid/string/piece.h"
#include "paddle/fluid/string/printf.h"
namespace
paddle
{
...
...
@@ -81,7 +82,8 @@ bool RequestGetHandler::Handle(const std::string& varname,
const
int
trainer_id
,
const
std
::
string
&
out_var_name
,
const
std
::
string
&
table_name
)
{
VLOG
(
4
)
<<
"RequestGetHandler:"
<<
varname
;
VLOG
(
4
)
<<
"RequestGetHandler:"
<<
varname
<<
" out_var_name: "
<<
out_var_name
;
if
(
sync_mode_
)
{
if
(
varname
==
FETCH_BARRIER_MESSAGE
)
{
...
...
@@ -112,6 +114,32 @@ bool RequestGetHandler::Handle(const std::string& varname,
return
true
;
}
bool
RequestGetNoBarrierHandler
::
Handle
(
const
std
::
string
&
varname
,
framework
::
Scope
*
scope
,
framework
::
Variable
*
invar
,
framework
::
Variable
**
outvar
,
const
int
trainer_id
,
const
std
::
string
&
out_var_name
,
const
std
::
string
&
table_name
)
{
VLOG
(
4
)
<<
"RequestGetNoBarrierHandler:"
<<
varname
<<
" out_var_name: "
<<
out_var_name
;
// get var from pserver immediately without barriers
string
::
Piece
without_barrier_piece
(
WITHOUT_BARRIER_MESSAGE
);
string
::
Piece
var_name_piece
=
string
::
Piece
(
varname
);
if
(
string
::
Contains
(
var_name_piece
,
without_barrier_piece
))
{
var_name_piece
=
string
::
TrimSuffix
(
var_name_piece
,
without_barrier_piece
);
VLOG
(
4
)
<<
"Get var "
<<
var_name_piece
<<
" with "
<<
WITHOUT_BARRIER_MESSAGE
;
*
outvar
=
scope_
->
FindVar
(
var_name_piece
.
ToString
());
return
true
;
}
else
{
PADDLE_THROW
(
"GetNoBarrier must contain %s"
,
WITHOUT_BARRIER_MESSAGE
);
}
return
true
;
}
bool
RequestPrefetchHandler
::
Handle
(
const
std
::
string
&
varname
,
framework
::
Scope
*
scope
,
framework
::
Variable
*
invar
,
...
...
paddle/fluid/operators/distributed/request_handler_impl.h
浏览文件 @
8b50ad80
...
...
@@ -67,6 +67,16 @@ class RequestGetHandler final : public RequestHandler {
bool
enable_dc_asgd_
;
};
class
RequestGetNoBarrierHandler
final
:
public
RequestHandler
{
public:
RequestGetNoBarrierHandler
()
:
RequestHandler
(
false
)
{}
virtual
~
RequestGetNoBarrierHandler
()
{}
bool
Handle
(
const
std
::
string
&
varname
,
framework
::
Scope
*
scope
,
framework
::
Variable
*
var
,
framework
::
Variable
**
outvar
,
const
int
trainer_id
,
const
std
::
string
&
out_var_name
=
""
,
const
std
::
string
&
table_name
=
""
)
override
;
};
static
inline
void
BuildVar
(
const
std
::
string
&
param_name
,
std
::
initializer_list
<
const
char
*>
arguments
,
paddle
::
framework
::
proto
::
OpDesc
::
Var
*
var
)
{
...
...
paddle/fluid/operators/distributed/rpc_client.h
浏览文件 @
8b50ad80
...
...
@@ -43,8 +43,15 @@ class RPCClient {
const
platform
::
DeviceContext
&
ctx
,
const
framework
::
Scope
&
scope
,
const
std
::
string
&
var_name
,
const
std
::
string
&
out_varname
,
int64_t
time_out
=
FLAGS_rpc_deadline
)
=
0
;
virtual
VarHandlePtr
AsyncGetVarNoBarrier
(
const
std
::
string
&
ep
,
const
platform
::
DeviceContext
&
ctx
,
const
framework
::
Scope
&
scope
,
const
std
::
string
&
var_name
,
const
std
::
string
&
out_varname
,
int64_t
time_out
=
FLAGS_rpc_deadline
)
=
0
;
virtual
VarHandlePtr
AsyncGetMonomerVariable
(
const
std
::
string
&
ep
,
const
platform
::
DeviceContext
&
ctx
,
const
framework
::
Scope
&
scope
,
const
std
::
string
&
var_name
,
...
...
paddle/fluid/operators/distributed/send_recv.proto.in
浏览文件 @
8b50ad80
...
...
@@ -17,8 +17,14 @@ package sendrecv;
option
cc_generic_services
=
@
cc_generic_services
@;
service
SendRecvService
{
//
For
parameter
server
round
-
robin
like
hashing
,
do
not
split
tensors
.
//
Send
and
recv
only
one
tensor
//
TODO
(
typhoonzero
):
add
streaming
API
rpc
SendVariable
(
VariableMessage
)
returns
(
VoidMessage
)
{}
//
Argument
VariableMessage
for
GetVariable
should
only
contain
varname
.
rpc
GetVariable
(
VariableMessage
)
returns
(
VariableMessage
)
{}
rpc
GetVariableNoBarrier
(
VariableMessage
)
returns
(
VariableMessage
)
{}
//
pre
-
fetch
variable
by
given
variable
name
and
Ids
rpc
PrefetchVariable
(
VariableMessage
)
returns
(
VariableMessage
)
{}
rpc
CheckpointNotify
(
VariableMessage
)
returns
(
VoidMessage
)
{}
...
...
@@ -27,12 +33,17 @@ service SendRecvService {
rpc
GetMonomerBarrier
(
VariableMessage
)
returns
(
VoidMessage
)
{}
}
//
It
can
be
:
LoDTensor
、
SelectedRows
or
NCCL_ID
enum
VarType
{
LOD_TENSOR
=
0
;
SELECTED_ROWS
=
1
;
NCCL_ID
=
2
;
}
//
VariableMessage
is
serialized
paddle
variable
message
.
//
NOTICE
(
gongwb
):
don
't modify this proto if you are not
// not familar with how we serialize in sendrecvop_utils.h
// and deserilize it in variable_response.h.
message VariableMessage {
enum Type {
// Pod Types
...
...
@@ -49,14 +60,21 @@ message VariableMessage {
string varname = 1;
// TODO(Yancey1989): reference framework::proto::VarDesc::VarType
VarType type = 2;
// bool persistable is not needed for sending.
// tensor info:
Type data_type = 3;
repeated int64 dims = 4;
// lod details:
int64 lod_level = 5;
repeated LodData lod = 6;
// selected_rows height, aka. original dim0
int64 slr_height = 7;
// tensor data
bytes serialized = 8;
// selected_rows data
bytes rows = 9;
// Look up table block execution output variable name.
string out_varname = 10;
// If 1, the ps server will start profiling, the ps
// server stops profiling and generates a profile to /tmp/profile_ps_*
...
...
paddle/fluid/operators/distributed_ops/listen_and_serv_op.cc
浏览文件 @
8b50ad80
...
...
@@ -347,6 +347,8 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
new
distributed
::
RequestPrefetchHandler
(
sync_mode
));
request_checkpoint_handler_
.
reset
(
new
distributed
::
RequestCheckpointHandler
(
sync_mode
,
checkpoint_block_id
));
request_get_no_barrier_handler_
.
reset
(
new
distributed
::
RequestGetNoBarrierHandler
());
rpc_service_
->
RegisterRPC
(
distributed
::
kRequestSend
,
request_send_handler_
.
get
(),
...
...
@@ -359,6 +361,8 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
FLAGS_rpc_prefetch_thread_num
);
rpc_service_
->
RegisterRPC
(
distributed
::
kRequestCheckpoint
,
request_checkpoint_handler_
.
get
());
rpc_service_
->
RegisterRPC
(
distributed
::
kRequestGetNoBarrier
,
request_get_no_barrier_handler_
.
get
());
auto
optimize_blocks
=
Attr
<
std
::
vector
<
framework
::
BlockDesc
*>>
(
kOptimizeBlocks
);
...
...
@@ -413,6 +417,7 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
f
(
request_get_handler_
.
get
());
f
(
request_prefetch_handler_
.
get
());
f
(
request_checkpoint_handler_
.
get
());
f
(
request_get_no_barrier_handler_
.
get
());
// start the server listening after all member initialized.
server_thread_
.
reset
(
new
std
::
thread
(
RunServer
,
rpc_service_
));
...
...
paddle/fluid/operators/distributed_ops/listen_and_serv_op.h
浏览文件 @
8b50ad80
...
...
@@ -55,7 +55,6 @@ class ListenAndServOp : public framework::OperatorBase {
const
framework
::
VariableNameMap
&
inputs
,
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
AttributeMap
&
attrs
);
virtual
~
ListenAndServOp
();
void
RunSyncLoop
(
framework
::
Executor
*
executor
,
...
...
@@ -89,6 +88,8 @@ class ListenAndServOp : public framework::OperatorBase {
mutable
std
::
shared_ptr
<
distributed
::
RPCServer
>
rpc_service_
;
mutable
std
::
shared_ptr
<
distributed
::
RequestHandler
>
request_send_handler_
;
mutable
std
::
shared_ptr
<
distributed
::
RequestHandler
>
request_get_handler_
;
mutable
std
::
shared_ptr
<
distributed
::
RequestHandler
>
request_get_no_barrier_handler_
;
mutable
std
::
shared_ptr
<
distributed
::
RequestHandler
>
request_prefetch_handler_
;
mutable
std
::
shared_ptr
<
distributed
::
RequestHandler
>
...
...
paddle/fluid/operators/distributed_ops/recv_op.cc
浏览文件 @
8b50ad80
...
...
@@ -27,30 +27,50 @@ namespace operators {
class
RecvOp
:
public
framework
::
OperatorBase
{
public:
RecvOp
(
const
std
::
string
&
type
,
const
framework
::
VariableNameMap
&
inputs
,
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
AttributeMap
&
attrs
)
RecvOp
(
const
std
::
string
&
type
,
const
framework
::
VariableNameMap
&
inputs
,
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
AttributeMap
&
attrs
)
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{
auto
outs
=
Outputs
(
"Out"
);
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{
std
::
vector
<
std
::
string
>
epmap
=
Attr
<
std
::
vector
<
std
::
string
>>
(
"epmap"
);
std
::
vector
<
std
::
string
>
varnames
=
Attr
<
std
::
vector
<
std
::
string
>>
(
"varnames"
);
int
sync_mode
=
Attr
<
int
>
(
"sync_mode"
);
auto
outs
=
Outputs
(
"Out"
);
bool
with_barrier
=
Attr
<
bool
>
(
"with_barrier"
);
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
auto
&
ctx
=
*
pool
.
Get
(
place
);
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
auto
&
ctx
=
*
pool
.
Get
(
place
);
distributed
::
RPCClient
*
rpc_client
=
distributed
::
RPCClient
*
rpc_client
=
distributed
::
RPCClient
::
GetInstance
<
RPCCLIENT_T
>
(
Attr
<
int
>
(
"trainer_id"
));
std
::
vector
<
distributed
::
VarHandlePtr
>
rets
;
for
(
size_t
i
=
0
;
i
<
outs
.
size
();
i
++
)
{
VLOG
(
3
)
<<
"getting "
<<
outs
[
i
]
<<
" from "
<<
epmap
[
i
];
rets
.
push_back
(
rpc_client
->
AsyncGetVar
(
epmap
[
i
],
ctx
,
scope
,
outs
[
i
]));
}
if
(
sync_mode
)
{
if
(
with_barrier
)
{
std
::
vector
<
distributed
::
VarHandlePtr
>
rets
;
for
(
size_t
i
=
0
;
i
<
outs
.
size
();
i
++
)
{
std
::
string
varname
=
varnames
.
size
()
==
0
?
outs
[
i
]
:
varnames
[
i
];
VLOG
(
4
)
<<
"recv "
<<
outs
[
i
]
<<
" from "
<<
epmap
[
i
]
<<
" with "
<<
varname
<<
" and with AsyncGetVar"
;
rets
.
push_back
(
rpc_client
->
AsyncGetVar
(
epmap
[
i
],
ctx
,
scope
,
varname
,
outs
[
i
]));
}
if
(
sync_mode
)
{
for
(
size_t
i
=
0
;
i
<
rets
.
size
();
i
++
)
{
PADDLE_ENFORCE
(
rets
[
i
]
->
Wait
(),
"internal error in RPCClient"
);
}
}
}
else
{
std
::
vector
<
distributed
::
VarHandlePtr
>
rets
;
for
(
size_t
i
=
0
;
i
<
outs
.
size
();
i
++
)
{
std
::
string
varname
=
varnames
.
size
()
==
0
?
outs
[
i
]
:
varnames
[
i
];
VLOG
(
4
)
<<
"recv "
<<
outs
[
i
]
<<
" from "
<<
epmap
[
i
]
<<
" with "
<<
varname
<<
" and with AsyncGetVarNoBarrier"
;
rets
.
push_back
(
rpc_client
->
AsyncGetVarNoBarrier
(
epmap
[
i
],
ctx
,
scope
,
varname
,
outs
[
i
]));
}
for
(
size_t
i
=
0
;
i
<
rets
.
size
();
i
++
)
{
PADDLE_ENFORCE
(
rets
[
i
]
->
Wait
(),
"internal error in RPCClient"
);
}
...
...
@@ -79,12 +99,23 @@ This operator can get variables from server side.
"(int, default 0)"
"sync recv or async recv."
)
.
SetDefault
(
0
);
AddAttr
<
bool
>
(
"with_barrier"
,
"(bool, default True) if with_barrier=False, will use "
"AsyncGetVarNoBarrier get variable from pserver immediately"
)
.
SetDefault
(
true
);
AddAttr
<
std
::
vector
<
std
::
string
>>
(
"varnames"
,
"(string vector, default {}) "
"sometimes we need to put received var in another name "
"for example: we need var named 'moment_1@127.0.0.1:1001', "
"and it real name on parameter server is 'moment_1'. "
)
.
SetDefault
({});
}
};
class
RecvOpShapeInference
:
public
framework
::
InferShapeBase
{
public:
void
operator
()(
framework
::
InferShapeContext
*
ctx
)
const
override
{}
void
operator
()(
framework
::
InferShapeContext
*
ctx
)
const
override
{}
};
}
// namespace operators
...
...
paddle/fluid/platform/mkldnn_reuse.h
浏览文件 @
8b50ad80
...
...
@@ -365,7 +365,7 @@ class TransposeMKLDNNHandler : public MKLDNNHandler {
mem_fmt
.
ndims
=
axis
.
size
();
for
(
unsigned
int
i
=
0
;
i
<
nchw_tz
.
size
();
++
i
)
{
mem_fmt
.
dims
[
i
]
=
nchw_tz
[
i
];
// logical dimensions (nchw format,
// regardless physical layout)
// regardless physical layout)
}
mem_fmt
.
data_type
=
mkldnn_f32
;
mem_fmt
.
format
=
mkldnn_blocked
;
...
...
@@ -374,7 +374,7 @@ class TransposeMKLDNNHandler : public MKLDNNHandler {
for
(
int
i
=
nchw_tz
.
size
()
-
1
;
i
>=
0
;
--
i
)
{
mem_fmt
.
layout_desc
.
blocking
.
padding_dims
[
i
]
=
nchw_tz
[
i
];
// logical dimensions (nchw format, regardless physical
// layout)
// layout)
mem_fmt
.
layout_desc
.
blocking
.
block_dims
[
i
]
=
1
;
mem_fmt
.
layout_desc
.
blocking
.
offset_padding_to_data
[
i
]
=
0
;
// no offset
mem_fmt
.
layout_desc
.
blocking
.
strides
[
0
][
axis
[
i
]]
=
total_stride
;
...
...
python/paddle/fluid/framework.py
浏览文件 @
8b50ad80
...
...
@@ -1696,12 +1696,20 @@ class Program(object):
self
.
_current_role
=
core
.
op_proto_and_checker_maker
.
OpRole
.
Forward
self
.
_op_role_var
=
[]
# for distribute
# for distribute training
# _is_distributed = True if under distributed training
self
.
_is_distributed
=
False
# _is_chief = True if the trainer is the first one, usually No.0
self
.
_is_chief
=
False
self
.
_slice_vars_and_attrs
=
[]
# _parameters_on_pservers records all the parameters distributed on parameter servers.
self
.
_parameters_on_pservers
=
None
# _endpoints is a list about parameter servers ip:port, such as ["ip:port","ip:port"]
self
.
_endpoints
=
[]
# if current role is parameter server, the _ps_endpoint is its "ip:port"
self
.
_ps_endpoint
=
None
# trainers_endpoints, it is used for distribution.
self
.
_trainers_endpoints
=
[]
# the distributed lookup table names
self
.
_distributed_lookup_table
=
None
@
property
...
...
@@ -2232,8 +2240,9 @@ class Program(object):
"Program"
)
self
.
_is_distributed
=
other
.
_is_distributed
self
.
_is_chief
=
other
.
_is_chief
self
.
_
slice_vars_and_attrs
=
other
.
_slice_vars_and_att
rs
self
.
_
parameters_on_pservers
=
other
.
_parameters_on_pserve
rs
self
.
_endpoints
=
other
.
_endpoints
self
.
_ps_endpoint
=
other
.
_ps_endpoint
self
.
_distributed_lookup_table
=
other
.
_distributed_lookup_table
def
_copy_data_info_from
(
self
,
other
):
...
...
python/paddle/fluid/io.py
浏览文件 @
8b50ad80
...
...
@@ -19,6 +19,7 @@ import errno
import
time
import
shutil
import
six
from
functools
import
reduce
from
paddle.fluid.executor
import
Executor
from
paddle.fluid.evaluator
import
Evaluator
...
...
@@ -183,8 +184,6 @@ def save_vars(executor,
# NOTE: don't save the variable which type is RAW
if
each_var
.
type
==
core
.
VarDesc
.
VarType
.
RAW
:
continue
if
each_var
.
name
==
main_program
.
_distributed_lookup_table
:
continue
new_var
=
_clone_var_in_block_
(
save_block
,
each_var
)
if
filename
is
None
:
save_block
.
append_op
(
...
...
@@ -206,16 +205,6 @@ def save_vars(executor,
outputs
=
{},
attrs
=
{
'file_path'
:
os
.
path
.
join
(
dirname
,
filename
)})
# if there is lookup table, the trainer 0 will notify all pserver to save.
if
main_program
.
_is_distributed
and
main_program
.
_is_chief
and
main_program
.
_distributed_lookup_table
:
lookup_table_filename
=
os
.
path
.
join
(
dirname
,
"__lookup_table__"
)
attrs
=
{}
attrs
[
'epmap'
]
=
main_program
.
_endpoints
attrs
[
'dir'
]
=
lookup_table_filename
attrs
[
'lookup_table'
]
=
main_program
.
_distributed_lookup_table
save_block
.
append_op
(
type
=
'checkpoint_notify'
,
inputs
=
{},
outputs
=
{},
attrs
=
attrs
)
executor
.
run
(
save_program
)
...
...
@@ -267,6 +256,186 @@ def save_params(executor, dirname, main_program=None, filename=None):
filename
=
filename
)
def
_save_distributed_persistables
(
executor
,
dirname
,
main_program
):
"""
save_persistables for distributed training.
the method will do things listed below:
1.save part of persistable variables on trainer.
2.receive "remote prefetch variables" from parameter servers and merge them.
3.save "distributed lookup table" on parameter servers.
4.receive "optimizer variables" from parameter servers and merge them.
Args:
executor(Executor): The executor to run for saving parameters.
dirname(str): The saving directory path.
main_program(Program): The program whose parameters will be
saved. the main_program must be the trainer_program
get after transpiler.
Returns:
None
Examples:
.. code-block:: python
exe = fluid.Executor(fluid.CPUPlace())
param_path = "./my_paddle_model"
t = distribute_transpiler.DistributeTranspiler()
t.transpile(...)
train_program = t.get_trainer_program()
_save_distributed_persistables(executor=exe, dirname=param_path, main_program=train_program)
"""
def
__save_remote_params
(
executor
,
dirname
,
remote_params_map
):
"""
recive params on pserver through rpc.
if the params are be sliced, will concat them to one, then save it.
"""
if
not
remote_params_map
:
return
prog
=
Program
()
block
=
prog
.
global_block
()
# recv optimize vars from pserver
for
name
,
remote_params
in
remote_params_map
.
items
():
origin_var
=
None
is_slice
=
False
slice_vars
=
[
0
]
*
len
(
remote_params
)
slice_var_names
=
[
""
]
*
len
(
remote_params
)
endpoints
=
[
""
]
*
len
(
remote_params
)
for
idx
,
optimizer
in
enumerate
(
remote_params
):
origin
=
optimizer
.
origin
slice
=
optimizer
.
slice
is_slice
=
optimizer
.
is_slice
block_id
=
optimizer
.
block_id
endpoint
=
optimizer
.
endpoint
if
idx
==
0
:
origin_var
=
block
.
create_var
(
name
=
origin
.
name
,
type
=
origin
.
type
,
shape
=
origin
.
shape
,
dtype
=
origin
.
dtype
,
persistable
=
True
)
slice_var
=
block
.
create_var
(
name
=
"{}.slice.{}"
.
format
(
slice
.
name
,
idx
),
type
=
slice
.
type
,
shape
=
slice
.
shape
,
dtype
=
slice
.
dtype
,
persistable
=
True
)
index
=
block_id
if
is_slice
else
idx
slice_vars
[
index
]
=
slice_var
slice_var_names
[
index
]
=
slice
.
name
endpoints
[
index
]
=
endpoint
if
is_slice
:
block
.
append_op
(
type
=
'recv'
,
inputs
=
{
"X"
:
[]},
outputs
=
{
"Out"
:
slice_vars
},
attrs
=
{
"epmap"
:
endpoints
,
"with_barrier"
:
False
,
"varnames"
:
slice_var_names
,
"sync_mode"
:
True
})
block
.
append_op
(
type
=
'concat'
,
inputs
=
{
'X'
:
slice_vars
},
outputs
=
{
'Out'
:
origin_var
},
attrs
=
{})
else
:
block
.
append_op
(
type
=
'recv'
,
inputs
=
{
"X"
:
[]},
outputs
=
{
"Out"
:
[
origin_var
]},
attrs
=
{
"epmap"
:
endpoints
[:
1
],
"with_barrier"
:
False
,
"varnames"
:
slice_var_names
,
"sync_mode"
:
True
})
block
.
append_op
(
type
=
'save'
,
inputs
=
{
'X'
:
[
origin_var
]},
outputs
=
{},
attrs
=
{
'file_path'
:
os
.
path
.
join
(
dirname
,
origin_var
.
name
)})
block
.
append_op
(
type
=
'delete_var'
,
inputs
=
{
'X'
:
slice_vars
})
executor
.
run
(
prog
)
def
__save_distributed_lookup_tables
(
executor
,
dirname
,
distributed_lookup_table
,
endpoints
):
"""
because the distributed lookup table may too huge to merge and save at one place,
it will be saved at parameter server independent respectively.
the save directory is dirname/"__lookup_table__".
"""
prog
=
Program
()
block
=
prog
.
global_block
()
# if there is lookup table, the trainer 0 will notify all pserver to save.
lookup_table_filename
=
os
.
path
.
join
(
dirname
,
"__lookup_table__"
)
attrs
=
{}
attrs
[
'epmap'
]
=
endpoints
attrs
[
'dir'
]
=
lookup_table_filename
attrs
[
'lookup_table'
]
=
distributed_lookup_table
block
.
append_op
(
type
=
'checkpoint_notify'
,
inputs
=
{},
outputs
=
{},
attrs
=
attrs
)
executor
.
run
(
prog
)
def
__exclude_vars
(
exclude_var_names
=
[]):
def
is_valid
(
var
):
if
var
.
name
in
exclude_var_names
:
return
False
if
var
.
desc
.
type
()
==
core
.
VarDesc
.
VarType
.
FEED_MINIBATCH
or
\
var
.
desc
.
type
()
==
core
.
VarDesc
.
VarType
.
FETCH_LIST
or
\
var
.
desc
.
type
()
==
core
.
VarDesc
.
VarType
.
READER
:
return
False
return
var
.
persistable
return
is_valid
if
not
isinstance
(
main_program
,
Program
):
raise
ValueError
(
"'main_program' should be an instance of Program."
)
if
not
main_program
.
_is_distributed
:
raise
ValueError
(
"'_save_distributed_persistables' just be designed for distributed training."
)
remote_params_map
=
main_program
.
_parameters_on_pservers
.
get_distributed_vars_by_vtypes
(
[
"Optimizer"
,
"RemotePrefetch"
],
groupby
=
True
)
exclude_var_names
=
[]
if
remote_params_map
:
exclude_var_names
.
extend
(
remote_params_map
.
keys
())
if
main_program
.
_distributed_lookup_table
:
if
isinstance
(
main_program
.
_distributed_lookup_table
,
list
):
exclude_var_names
.
extend
(
main_program
.
_distributed_lookup_table
)
else
:
exclude_var_names
.
append
(
main_program
.
_distributed_lookup_table
)
local_vars
=
list
(
filter
(
__exclude_vars
(
exclude_var_names
),
main_program
.
list_vars
()))
save_vars
(
executor
,
main_program
=
main_program
,
dirname
=
dirname
,
vars
=
local_vars
)
if
main_program
.
_is_chief
:
if
remote_params_map
:
__save_remote_params
(
executor
,
dirname
,
remote_params_map
)
if
main_program
.
_distributed_lookup_table
:
__save_distributed_lookup_tables
(
executor
,
dirname
,
main_program
.
_distributed_lookup_table
,
main_program
.
_endpoints
)
def
save_persistables
(
executor
,
dirname
,
main_program
=
None
,
filename
=
None
):
"""
This function filters out all variables with `persistable==True` from the
...
...
@@ -301,13 +470,19 @@ def save_persistables(executor, dirname, main_program=None, filename=None):
fluid.io.save_persistables(executor=exe, dirname=param_path,
main_program=None)
"""
save_vars
(
executor
,
dirname
=
dirname
,
main_program
=
main_program
,
vars
=
None
,
predicate
=
is_persistable
,
filename
=
filename
)
if
main_program
and
main_program
.
_is_distributed
:
_save_distributed_persistables
(
executor
,
dirname
=
dirname
,
main_program
=
main_program
)
else
:
save_vars
(
executor
,
dirname
=
dirname
,
main_program
=
main_program
,
vars
=
None
,
predicate
=
is_persistable
,
filename
=
filename
)
def
load_vars
(
executor
,
...
...
@@ -402,17 +577,11 @@ def load_vars(executor,
if
not
isinstance
(
main_program
,
Program
):
raise
TypeError
(
"program should be as Program type or None"
)
load_slice_vars
=
[]
for
each_var
in
main_program
.
_slice_vars_and_attrs
:
load_slice_vars
.
append
(
each_var
[
2
].
name
)
load_var_map
=
{}
for
each_var
in
vars
:
assert
isinstance
(
each_var
,
Variable
)
if
each_var
.
type
==
core
.
VarDesc
.
VarType
.
RAW
:
continue
if
each_var
.
name
in
load_slice_vars
:
continue
new_var
=
_clone_var_in_block_
(
load_block
,
each_var
)
if
filename
is
None
:
load_block
.
append_op
(
...
...
@@ -435,10 +604,6 @@ def load_vars(executor,
attrs
=
{
'file_path'
:
os
.
path
.
join
(
dirname
,
filename
)})
executor
.
run
(
load_prog
)
# load slice vars on pserver, if have it.
_load_slice_up_vars
(
executor
,
dirname
,
main_program
.
_slice_vars_and_attrs
)
def
load_params
(
executor
,
dirname
,
main_program
=
None
,
filename
=
None
):
"""
...
...
@@ -521,12 +686,134 @@ def load_persistables(executor, dirname, main_program=None, filename=None):
fluid.io.load_persistables(executor=exe, dirname=param_path,
main_program=None)
"""
load_vars
(
executor
,
dirname
=
dirname
,
main_program
=
main_program
,
predicate
=
is_persistable
,
filename
=
filename
)
if
main_program
and
main_program
.
_is_distributed
:
_load_distributed_persistables
(
executor
,
dirname
=
dirname
,
main_program
=
main_program
)
else
:
load_vars
(
executor
,
dirname
=
dirname
,
main_program
=
main_program
,
predicate
=
is_persistable
,
filename
=
filename
)
def
_load_distributed_persistables
(
executor
,
dirname
,
main_program
=
None
):
"""
customized load_persistables for distributed training.
it should be used on parameter server,
Args:
executor(Executor): The executor to run for saving parameters.
dirname(str): The load directory path.
main_program(Program): The program whose parameters will be
loaded. the main_program must be the pserver_program
get after transpiler.
Returns:
None
Examples:
.. code-block:: python
exe = fluid.Executor(fluid.CPUPlace())
param_path = "./my_paddle_model"
t = distribute_transpiler.DistributeTranspiler()
t.transpile(...)
pserver_prog = t.get_pserver_program(...)
_load_distributed_persistables(executor=exe, dirname=param_path, main_program=pserver_prog)
"""
def
__is_distributed_part_var
(
varname
):
trainer_idx
=
varname
.
find
(
".trainer_"
)
block_idx
=
varname
.
find
(
".block"
)
return
trainer_idx
or
block_idx
def
__load_persistable_vars
(
executor
,
dirname
,
need_load_vars
):
load_prog
=
Program
()
load_block
=
load_prog
.
global_block
()
need_delete_vars
=
[]
for
param
in
need_load_vars
:
origin_var
=
param
.
origin
slice_var
=
param
.
slice
is_slice
=
param
.
is_slice
offset
=
param
.
offset
if
is_slice
:
origin
=
load_block
.
create_var
(
name
=
"{}.load"
.
format
(
origin_var
.
name
),
type
=
origin_var
.
type
,
shape
=
origin_var
.
shape
,
dtype
=
origin_var
.
dtype
,
persistable
=
True
)
load_block
.
append_op
(
type
=
'load'
,
inputs
=
{},
outputs
=
{
'Out'
:
[
origin
]},
attrs
=
{
'file_path'
:
os
.
path
.
join
(
dirname
,
origin_var
.
name
)
})
slice
=
load_block
.
create_var
(
name
=
slice_var
.
name
,
type
=
slice_var
.
type
,
shape
=
slice_var
.
shape
,
dtype
=
slice_var
.
dtype
,
persistable
=
True
)
dim1_flatten
=
reduce
(
lambda
x
,
y
:
x
*
y
,
slice
.
shape
[
1
:])
start
=
int
(
offset
/
dim1_flatten
)
end
=
int
(
offset
/
dim1_flatten
+
slice
.
shape
[
0
])
load_block
.
append_op
(
type
=
"slice"
,
inputs
=
{
'Input'
:
origin
},
outputs
=
{
'Out'
:
slice
},
attrs
=
{
'axes'
:
[
0
],
'starts'
:
[
start
],
'ends'
:
[
end
]})
need_delete_vars
.
append
(
origin
)
else
:
origin
=
load_block
.
create_var
(
name
=
"{}"
.
format
(
origin_var
.
name
),
type
=
origin_var
.
type
,
shape
=
origin_var
.
shape
,
dtype
=
origin_var
.
dtype
,
persistable
=
True
)
load_block
.
append_op
(
type
=
'load'
,
inputs
=
{},
outputs
=
{
'Out'
:
[
origin
]},
attrs
=
{
'file_path'
:
os
.
path
.
join
(
dirname
,
origin_var
.
name
)
})
load_block
.
append_op
(
type
=
'delete_var'
,
inputs
=
{
'X'
:
need_delete_vars
},
)
executor
.
run
(
load_prog
)
if
not
isinstance
(
main_program
,
Program
):
raise
ValueError
(
"'main_program' should be an instance of Program."
)
if
not
main_program
.
_is_distributed
:
raise
ValueError
(
"'_load_distributed_persistables' just be designed for distributed training."
)
if
not
main_program
.
_ps_endpoint
:
raise
ValueError
(
"'_load_distributed_persistables' need current_endpoint set in DistributeTranspiler.transpile"
)
need_load_vars
=
main_program
.
_parameters_on_pservers
.
get_distributed_vars_by_ep
(
main_program
.
_ps_endpoint
)
__load_persistable_vars
(
executor
,
dirname
,
need_load_vars
)
def
prepend_feed_ops
(
inference_program
,
...
...
@@ -795,52 +1082,6 @@ def load_inference_model(dirname,
return
[
program
,
feed_target_names
,
fetch_targets
]
def
_save_lookup_tables_by_notify
(
executor
,
dirname
,
lookup_table
,
pserver_endpoints
):
"""
This function will send checkpoint notify message from Trainer 0
to all the pservers.
The checkpoint notify message contains lookup table name,
the absolute path on pserver to save lookup_table.
Args:
executor(Executor): The executor to run for send checkpoint notify.
dirname(str): The folder where to save.
lookup_table(string): the lookup table name, when use distribute
lookup table, we can get lookup table name by DistributeTranspiler.
table_name
ps_endpoint_list(list): the parameter server ip:port list.
when use distribute lookup table, we can get ps_endpoint_list by
distribute arguments.
Return:
None
Examples:
.. code-block:: python
exe = fluid.Executor(fluid.CPUPlace())
param_path = "./my_paddle_model"
table_name = "share_w"
ps_endpoints = ["127.0.0.1:6000","127.0.0.1:6001"]
_save_pserver_vars_by_notify(executor=exe,
dirname=param_path, lookup_table=table_name,
pserver_endpoints=ps_endpoints)
"""
pserver_notify_program
=
Program
()
pserver_notify_block
=
pserver_notify_program
.
global_block
()
attrs
=
{}
attrs
[
'epmap'
]
=
pserver_endpoints
attrs
[
'dir'
]
=
dirname
attrs
[
'lookup_table'
]
=
lookup_table
pserver_notify_block
.
append_op
(
type
=
'checkpoint_notify'
,
inputs
=
{},
outputs
=
{},
attrs
=
attrs
)
executor
.
run
(
pserver_notify_program
)
def
_endpoints_replacement
(
program
,
endpoints
):
ENDPOINT_MAP
=
"epmap"
for
op
in
program
.
global_block
().
ops
:
...
...
@@ -911,54 +1152,3 @@ def get_parameter_value_by_name(name, executor, program=None):
program
=
default_main_program
()
var
=
program
.
global_block
().
var
(
name
)
return
get_parameter_value
(
var
,
executor
)
def
_load_slice_up_vars
(
executor
,
dirname
,
slice_vars_and_attrs
):
if
not
slice_vars_and_attrs
:
return
load_prog
=
Program
()
load_block
=
load_prog
.
global_block
()
need_delete_vars
=
[]
for
var_tuple
in
slice_vars_and_attrs
:
orig_var
=
var_tuple
[
0
]
start
=
var_tuple
[
1
]
slice_var
=
var_tuple
[
2
]
end
=
start
+
slice_var
.
shape
[
0
]
orig_var_name
=
orig_var
.
name
orig_var
.
name
=
"{}.origin"
.
format
(
orig_var_name
)
clone_orig_var
=
load_block
.
create_var
(
name
=
orig_var
.
name
,
type
=
orig_var
.
type
,
shape
=
orig_var
.
shape
,
dtype
=
orig_var
.
dtype
,
persistable
=
True
)
clone_slice_var
=
load_block
.
create_var
(
name
=
slice_var
.
name
,
type
=
slice_var
.
type
,
shape
=
slice_var
.
shape
,
dtype
=
slice_var
.
dtype
,
persistable
=
True
)
load_block
.
append_op
(
type
=
'load'
,
inputs
=
{},
outputs
=
{
'Out'
:
[
clone_orig_var
]},
attrs
=
{
'file_path'
:
os
.
path
.
join
(
dirname
,
orig_var_name
)})
load_block
.
append_op
(
type
=
"slice"
,
inputs
=
{
'Input'
:
clone_orig_var
},
outputs
=
{
'Out'
:
clone_slice_var
},
attrs
=
{
'axes'
:
[
0
],
'starts'
:
[
start
],
'ends'
:
[
end
]})
need_delete_vars
.
append
(
clone_orig_var
)
load_block
.
append_op
(
type
=
'delete_var'
,
inputs
=
{
'X'
:
need_delete_vars
},
)
executor
.
run
(
load_prog
)
python/paddle/fluid/tests/unittests/dist_save_load.py
浏览文件 @
8b50ad80
...
...
@@ -80,7 +80,8 @@ class TestDistSaveLoad2x2(TestDistSimnetBow2x2):
# NOTE: pserver should not call memory optimize
t
=
self
.
get_transpiler
(
args
.
trainer_id
,
fluid
.
default_main_program
(),
args
.
endpoints
,
args
.
trainers
,
args
.
sync_mode
)
args
.
trainers
,
args
.
sync_mode
,
False
,
args
.
current_endpoint
)
pserver_prog
=
t
.
get_pserver_program
(
args
.
current_endpoint
)
startup_prog
=
t
.
get_startup_program
(
args
.
current_endpoint
,
pserver_prog
)
...
...
@@ -93,7 +94,8 @@ class TestDistSaveLoad2x2(TestDistSimnetBow2x2):
exe
.
run
(
startup_prog
)
if
need_load
and
model_dir
:
self
.
_load_persistable_vars
(
exe
,
model_dir
,
startup_prog
)
fluid
.
io
.
load_persistables
(
exe
,
model_dir
,
pserver_prog
)
exe
.
run
(
pserver_prog
)
def
run_trainer
(
self
,
args
):
...
...
@@ -158,19 +160,46 @@ class TestDistSaveLoad2x2(TestDistSimnetBow2x2):
need_save
=
bool
(
int
(
os
.
getenv
(
"SAVE"
,
"0"
)))
model_dir
=
os
.
getenv
(
"MODEL_DIR"
,
""
)
if
need_save
:
for
_
in
six
.
moves
.
xrange
(
RUN_STEP
):
loss
,
=
exe
.
run
(
fetch_list
=
[
avg_cost
.
name
],
feed
=
feeder
.
feed
(
get_data
()))
if
need_save
and
model_dir
:
io
.
save_persistables
(
startup_exe
,
model_dir
,
trainer_prog
)
var
=
np
.
array
(
fluid
.
global_scope
().
find_var
(
'__fc_b__'
).
get_tensor
())
if
six
.
PY2
:
print
(
pickle
.
dumps
(
np
.
ravel
(
var
).
tolist
()))
save_mode
=
os
.
getenv
(
"SAVE_MODE"
,
""
)
if
save_mode
==
"LOCAL"
:
if
need_save
:
for
_
in
six
.
moves
.
xrange
(
RUN_STEP
):
loss
,
=
exe
.
run
(
fetch_list
=
[
avg_cost
.
name
],
feed
=
feeder
.
feed
(
get_data
()))
if
need_save
and
model_dir
:
io
.
save_persistables
(
startup_exe
,
model_dir
,
trainer_prog
)
var
=
np
.
array
(
fluid
.
global_scope
().
find_var
(
'__fc_b__'
).
get_tensor
(
))
if
six
.
PY2
:
print
(
pickle
.
dumps
(
np
.
ravel
(
var
).
tolist
()))
else
:
sys
.
stdout
.
buffer
.
write
(
pickle
.
dumps
(
np
.
ravel
(
var
).
tolist
()))
elif
save_mode
==
"DIST"
:
skip_steps
=
int
(
os
.
getenv
(
"SKIP_STEPS"
))
loss
=
None
if
need_save
:
for
idx
in
six
.
moves
.
xrange
(
8
):
loss
,
=
exe
.
run
(
fetch_list
=
[
avg_cost
.
name
],
feed
=
feeder
.
feed
(
get_data
()))
if
need_save
and
model_dir
and
idx
==
skip_steps
and
args
.
trainer_id
==
0
:
io
.
save_persistables
(
startup_exe
,
model_dir
,
trainer_prog
)
else
:
for
idx
in
six
.
moves
.
xrange
(
8
):
data
=
get_data
()
if
idx
<=
skip_steps
:
continue
loss
,
=
exe
.
run
(
fetch_list
=
[
avg_cost
.
name
],
feed
=
feeder
.
feed
(
data
))
if
six
.
PY2
:
print
(
pickle
.
dumps
(
loss
.
tolist
()))
else
:
sys
.
stdout
.
buffer
.
write
(
pickle
.
dumps
(
loss
.
tolist
()))
else
:
sys
.
stdout
.
buffer
.
write
(
pickle
.
dumps
(
np
.
ravel
(
var
).
tolist
())
)
raise
Exception
(
"save_mode must be LOCAL or DIST"
)
if
__name__
==
"__main__"
:
...
...
python/paddle/fluid/tests/unittests/dist_simnet_bow.py
浏览文件 @
8b50ad80
...
...
@@ -75,9 +75,13 @@ def get_loss(cos_q_pt, cos_q_nt):
return
avg_cost
def
get_optimizer
():
# SGD optimizer
optimizer
=
fluid
.
optimizer
.
SGD
(
learning_rate
=
base_lr
)
def
get_optimizer
(
op
=
"sgd"
):
if
op
.
upper
()
==
"sgd"
.
upper
():
optimizer
=
fluid
.
optimizer
.
SGD
(
learning_rate
=
base_lr
)
elif
op
.
upper
()
==
"adam"
.
upper
():
optimizer
=
fluid
.
optimizer
.
Adam
(
learning_rate
=
base_lr
)
else
:
optimizer
=
fluid
.
optimizer
.
SGD
(
learning_rate
=
base_lr
)
return
optimizer
...
...
@@ -237,7 +241,8 @@ class TestDistSimnetBow2x2(TestDistRunnerBase):
inference_program
=
fluid
.
default_main_program
().
clone
()
# Optimization
opt
=
get_optimizer
()
opt
=
os
.
getenv
(
'OPTIMIZER'
,
'sgd'
)
opt
=
get_optimizer
(
opt
)
opt
.
minimize
(
avg_cost
)
# Reader
...
...
python/paddle/fluid/tests/unittests/test_dist_base.py
浏览文件 @
8b50ad80
...
...
@@ -43,7 +43,8 @@ class TestDistRunnerBase(object):
pserver_endpoints
,
trainers
,
sync_mode
,
dc_asgd
=
False
):
dc_asgd
=
False
,
current_endpoint
=
None
):
# NOTE: import fluid until runtime, or else forking processes will cause error.
config
=
fluid
.
DistributeTranspilerConfig
()
config
.
enable_dc_asgd
=
dc_asgd
...
...
@@ -53,7 +54,8 @@ class TestDistRunnerBase(object):
program
=
main_program
,
pservers
=
pserver_endpoints
,
trainers
=
trainers
,
sync_mode
=
sync_mode
)
sync_mode
=
sync_mode
,
current_endpoint
=
current_endpoint
)
return
t
def
run_pserver
(
self
,
args
):
...
...
python/paddle/fluid/tests/unittests/test_dist_save_load.py
浏览文件 @
8b50ad80
...
...
@@ -33,7 +33,6 @@ class TestDistSaveLoadDense2x2(TestDistBase):
delta
=
1e-3
,
check_error_log
=
False
,
need_envs
=
{}):
required_envs
=
{
"PATH"
:
os
.
getenv
(
"PATH"
,
""
),
"PYTHONPATH"
:
os
.
getenv
(
"PYTHONPATH"
,
""
),
...
...
@@ -77,7 +76,77 @@ class TestDistSaveLoadDense2x2(TestDistBase):
need_envs
=
{
"IS_DISTRIBUTED"
:
'0'
,
"IS_SPARSE"
:
'0'
,
'IS_SELF_CONTAINED_LR'
:
'1'
'IS_SELF_CONTAINED_LR'
:
'1'
,
'SAVE_MODE'
:
'LOCAL'
,
}
self
.
check_with_place
(
"dist_save_load.py"
,
delta
=
0
,
check_error_log
=
False
,
need_envs
=
need_envs
)
class
TestDistSaveLoadWithPServerStateDense2x2
(
TestDistBase
):
def
_setup_config
(
self
):
self
.
_sync_mode
=
True
self
.
_enforce_place
=
"CPU"
def
check_with_place
(
self
,
model_file
,
delta
=
1e-3
,
check_error_log
=
False
,
need_envs
=
{}):
required_envs
=
{
"PATH"
:
os
.
getenv
(
"PATH"
,
""
),
"PYTHONPATH"
:
os
.
getenv
(
"PYTHONPATH"
,
""
),
"LD_LIBRARY_PATH"
:
os
.
getenv
(
"LD_LIBRARY_PATH"
,
""
),
"http_proxy"
:
""
}
required_envs
.
update
(
need_envs
)
if
check_error_log
:
required_envs
[
"GLOG_v"
]
=
"3"
required_envs
[
"GLOG_logtostderr"
]
=
"1"
model_dir
=
tempfile
.
mkdtemp
()
save_env
=
{}
save_env
[
"SAVE_MODE"
]
=
"DIST"
save_env
[
"SAVE"
]
=
"1"
save_env
[
"MODEL_DIR"
]
=
model_dir
save_env
.
update
(
required_envs
)
tr0_var_1
,
tr1_var_1
=
self
.
_run_cluster
(
model_file
,
save_env
,
check_error_log
)
load_env
=
{}
load_env
[
"LOAD"
]
=
"1"
load_env
[
"MODEL_DIR"
]
=
model_dir
load_env
.
update
(
required_envs
)
tr0_var_2
,
tr1_var_2
=
self
.
_run_cluster
(
model_file
,
load_env
,
check_error_log
)
shutil
.
rmtree
(
model_dir
)
train0_1_np
=
np
.
array
(
tr0_var_1
)
train1_1_np
=
np
.
array
(
tr1_var_1
)
train0_2_np
=
np
.
array
(
tr0_var_2
)
train1_2_np
=
np
.
array
(
tr1_var_2
)
self
.
assertAlmostEqual
(
train0_1_np
.
all
(),
train0_2_np
.
all
(),
delta
=
delta
)
self
.
assertAlmostEqual
(
train1_1_np
.
all
(),
train1_2_np
.
all
(),
delta
=
delta
)
def
test_dist
(
self
):
need_envs
=
{
"IS_DISTRIBUTED"
:
'0'
,
"IS_SPARSE"
:
'0'
,
'IS_SELF_CONTAINED_LR'
:
'1'
,
'SAVE_MODE'
:
'DIST'
,
'OPTIMIZER'
:
'ADAM'
,
'SKIP_STEPS'
:
str
(
np
.
random
.
randint
(
2
,
6
))
}
self
.
check_with_place
(
"dist_save_load.py"
,
...
...
python/paddle/fluid/tests/unittests/test_dist_transpiler.py
浏览文件 @
8b50ad80
...
...
@@ -741,21 +741,40 @@ class TestLoadSliceVar(TranspilerTest):
pserver
,
_
=
self
.
get_pserver
(
self
.
pserver1_ep
)
pserver2
,
_
=
self
.
get_pserver
(
self
.
pserver2_ep
)
self
.
assertTrue
(
pserver
.
_slice_vars_and_attrs
)
self
.
assertTrue
(
pserver2
.
_slice_vars_and_attrs
)
for
idx
in
six
.
moves
.
xrange
(
len
(
pserver
.
_slice_vars_and_attrs
)):
self
.
assertEqual
(
pserver
.
_slice_vars_and_attrs
[
idx
][
0
],
pserver2
.
_slice_vars_and_attrs
[
idx
][
0
])
total_numel
=
six
.
moves
.
reduce
(
lambda
x
,
y
:
x
*
y
,
pserver
.
_slice_vars_and_attrs
[
idx
][
0
].
shape
)
self
.
assertEqual
(
total_numel
,
six
.
moves
.
reduce
(
lambda
x
,
y
:
x
*
y
,
pserver
.
_slice_vars_and_attrs
[
idx
][
2
].
shape
)
+
six
.
moves
.
reduce
(
lambda
x
,
y
:
x
*
y
,
pserver2
.
_slice_vars_and_attrs
[
idx
][
2
].
shape
))
vars_ps1
=
pserver
.
_parameters_on_pservers
.
get_distributed_vars_by_ep
(
self
.
pserver1_ep
)
vars_ps2
=
pserver
.
_parameters_on_pservers
.
get_distributed_vars_by_ep
(
self
.
pserver2_ep
)
self
.
assertTrue
(
vars_ps1
)
self
.
assertTrue
(
vars_ps2
)
for
idx
in
six
.
moves
.
xrange
(
len
(
vars_ps1
)):
total_numel
=
0
ps1_numel
,
ps2_numel
=
0
,
0
ps1_var
=
vars_ps1
[
idx
]
if
not
ps1_var
.
is_slice
:
total_numel
=
six
.
moves
.
reduce
(
lambda
x
,
y
:
x
*
y
,
vars_ps1
[
idx
].
origin
.
shape
)
ps1_numel
=
six
.
moves
.
reduce
(
lambda
x
,
y
:
x
*
y
,
vars_ps1
[
idx
].
slice
.
shape
)
else
:
ps2_var
=
None
for
var
in
vars_ps2
:
if
var
.
origin
.
name
==
ps1_var
.
origin
.
name
:
ps2_var
=
var
break
total_numel
=
six
.
moves
.
reduce
(
lambda
x
,
y
:
x
*
y
,
ps1_var
.
origin
.
shape
)
ps1_numel
=
six
.
moves
.
reduce
(
lambda
x
,
y
:
x
*
y
,
ps1_var
.
slice
.
shape
)
ps2_numel
=
six
.
moves
.
reduce
(
lambda
x
,
y
:
x
*
y
,
ps2_var
.
slice
.
shape
)
self
.
assertEqual
(
total_numel
,
ps1_numel
+
ps2_numel
)
class
TestNCCL2Transpile
(
TranspilerTest
):
...
...
python/paddle/fluid/transpiler/distribute_transpiler.py
浏览文件 @
8b50ad80
此差异已折叠。
点击以展开。
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录