Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
7f2aa2db
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2299
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
7f2aa2db
编写于
8月 30, 2020
作者:
C
Chengmo
提交者:
GitHub
8月 30, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
【paddle.fleet】Support Heter Parameter Server (#25998)
* Support Heter Parameter Server
上级
ac63c7cd
变更
35
展开全部
显示空白变更内容
内联
并排
Showing
35 changed file
with
2506 addition
and
83 deletion
+2506
-83
paddle/fluid/operators/distributed/CMakeLists.txt
paddle/fluid/operators/distributed/CMakeLists.txt
+1
-1
paddle/fluid/operators/distributed/grpc/grpc_client.cc
paddle/fluid/operators/distributed/grpc/grpc_client.cc
+82
-0
paddle/fluid/operators/distributed/grpc/grpc_client.h
paddle/fluid/operators/distributed/grpc/grpc_client.h
+32
-0
paddle/fluid/operators/distributed/grpc/grpc_serde.cc
paddle/fluid/operators/distributed/grpc/grpc_serde.cc
+13
-3
paddle/fluid/operators/distributed/grpc/grpc_serde.h
paddle/fluid/operators/distributed/grpc/grpc_serde.h
+5
-0
paddle/fluid/operators/distributed/grpc/grpc_server.cc
paddle/fluid/operators/distributed/grpc/grpc_server.cc
+48
-0
paddle/fluid/operators/distributed/grpc/grpc_service.h
paddle/fluid/operators/distributed/grpc/grpc_service.h
+5
-1
paddle/fluid/operators/distributed/request_handler.h
paddle/fluid/operators/distributed/request_handler.h
+2
-0
paddle/fluid/operators/distributed/request_handler_impl.cc
paddle/fluid/operators/distributed/request_handler_impl.cc
+16
-0
paddle/fluid/operators/distributed/request_handler_impl.h
paddle/fluid/operators/distributed/request_handler_impl.h
+11
-0
paddle/fluid/operators/distributed/rpc_client.h
paddle/fluid/operators/distributed/rpc_client.h
+6
-0
paddle/fluid/operators/distributed/rpc_server_test.cc
paddle/fluid/operators/distributed/rpc_server_test.cc
+96
-13
paddle/fluid/operators/distributed/send_recv.proto.in
paddle/fluid/operators/distributed/send_recv.proto.in
+1
-1
paddle/fluid/operators/distributed/variable_response.h
paddle/fluid/operators/distributed/variable_response.h
+7
-0
paddle/fluid/operators/distributed_ops/listen_and_serv_op.cc
paddle/fluid/operators/distributed_ops/listen_and_serv_op.cc
+8
-1
paddle/fluid/operators/distributed_ops/listen_and_serv_op.h
paddle/fluid/operators/distributed_ops/listen_and_serv_op.h
+2
-0
paddle/fluid/operators/distributed_ops/send_and_recv_op.cc
paddle/fluid/operators/distributed_ops/send_and_recv_op.cc
+98
-0
python/paddle/distributed/fleet/base/fleet_base.py
python/paddle/distributed/fleet/base/fleet_base.py
+8
-7
python/paddle/distributed/fleet/base/role_maker.py
python/paddle/distributed/fleet/base/role_maker.py
+118
-13
python/paddle/distributed/fleet/meta_optimizers/__init__.py
python/paddle/distributed/fleet/meta_optimizers/__init__.py
+2
-2
python/paddle/distributed/fleet/meta_optimizers/parameter_server_graph_optimizer.py
...fleet/meta_optimizers/parameter_server_graph_optimizer.py
+6
-3
python/paddle/distributed/fleet/meta_optimizers/parameter_server_optimizer.py
...buted/fleet/meta_optimizers/parameter_server_optimizer.py
+23
-5
python/paddle/distributed/fleet/runtime/parameter_server_runtime.py
...dle/distributed/fleet/runtime/parameter_server_runtime.py
+21
-3
python/paddle/fluid/incubate/fleet/base/fleet_base.py
python/paddle/fluid/incubate/fleet/base/fleet_base.py
+1
-1
python/paddle/fluid/incubate/fleet/base/role_maker.py
python/paddle/fluid/incubate/fleet/base/role_maker.py
+0
-1
python/paddle/fluid/incubate/fleet/parameter_server/distribute_transpiler/__init__.py
.../fleet/parameter_server/distribute_transpiler/__init__.py
+4
-4
python/paddle/fluid/incubate/fleet/parameter_server/ir/heter_trainer_pass.py
.../incubate/fleet/parameter_server/ir/heter_trainer_pass.py
+100
-0
python/paddle/fluid/incubate/fleet/parameter_server/ir/pserver_pass.py
.../fluid/incubate/fleet/parameter_server/ir/pserver_pass.py
+3
-3
python/paddle/fluid/incubate/fleet/parameter_server/ir/public.py
...paddle/fluid/incubate/fleet/parameter_server/ir/public.py
+49
-18
python/paddle/fluid/incubate/fleet/parameter_server/ir/trainer_pass.py
.../fluid/incubate/fleet/parameter_server/ir/trainer_pass.py
+879
-1
python/paddle/fluid/tests/unittests/ctr_dataset_reader.py
python/paddle/fluid/tests/unittests/ctr_dataset_reader.py
+56
-2
python/paddle/fluid/tests/unittests/dist_fleet_heter_ctr.py
python/paddle/fluid/tests/unittests/dist_fleet_heter_ctr.py
+220
-0
python/paddle/fluid/tests/unittests/test_dist_fleet_heter_base.py
...addle/fluid/tests/unittests/test_dist_fleet_heter_base.py
+388
-0
python/paddle/fluid/tests/unittests/test_dist_fleet_heter_ctr.py
...paddle/fluid/tests/unittests/test_dist_fleet_heter_ctr.py
+56
-0
python/paddle/fluid/tests/unittests/test_dist_fleet_heter_program.py
...le/fluid/tests/unittests/test_dist_fleet_heter_program.py
+139
-0
未找到文件。
paddle/fluid/operators/distributed/CMakeLists.txt
浏览文件 @
7f2aa2db
...
@@ -56,7 +56,7 @@ endif()
...
@@ -56,7 +56,7 @@ endif()
cc_test
(
rpc_server_test SRCS rpc_server_test.cc
cc_test
(
rpc_server_test SRCS rpc_server_test.cc
DEPS
${
RPC_DEPS
}
executor scope proto_desc lookup_sparse_table_read_op
)
DEPS
${
RPC_DEPS
}
executor scope proto_desc lookup_sparse_table_read_op
scale_op
)
cc_test
(
varhandle_test SRCS varhandle_test.cc DEPS profiler scope
)
cc_test
(
varhandle_test SRCS varhandle_test.cc DEPS profiler scope
)
cc_library
(
parameter_prefetch SRCS parameter_prefetch.cc DEPS sendrecvop_rpc memory
)
cc_library
(
parameter_prefetch SRCS parameter_prefetch.cc DEPS sendrecvop_rpc memory
)
cc_library
(
parameter_send SRCS parameter_send.cc DEPS sendrecvop_rpc memory
)
cc_library
(
parameter_send SRCS parameter_send.cc DEPS sendrecvop_rpc memory
)
...
...
paddle/fluid/operators/distributed/grpc/grpc_client.cc
浏览文件 @
7f2aa2db
...
@@ -132,6 +132,15 @@ void ProcGetResponse(const VarHandle& var_h,
...
@@ -132,6 +132,15 @@ void ProcGetResponse(const VarHandle& var_h,
&
trainer_id
);
&
trainer_id
);
}
}
void
ProcGetRecvResponse
(
const
VarHandle
&
var_h
,
const
::
grpc
::
ByteBuffer
&
ret_msg
)
{
VLOG
(
4
)
<<
"ProcGetRecvResponse"
;
framework
::
Variable
*
outvar
=
nullptr
;
int
trainer_id
;
DeserializeRecvFromByteBuffer
(
ret_msg
,
*
var_h
.
ctx
(),
var_h
.
scope
(),
&
outvar
,
&
trainer_id
);
}
template
<
typename
T
>
template
<
typename
T
>
void
RequestToByteBuffer
(
const
T
&
proto
,
::
grpc
::
ByteBuffer
*
result
)
{
void
RequestToByteBuffer
(
const
T
&
proto
,
::
grpc
::
ByteBuffer
*
result
)
{
::
grpc
::
Slice
slice
(
proto
.
ByteSizeLong
());
::
grpc
::
Slice
slice
(
proto
.
ByteSizeLong
());
...
@@ -482,6 +491,79 @@ VarHandlePtr GRPCClient::AsyncDistributeNotify(
...
@@ -482,6 +491,79 @@ VarHandlePtr GRPCClient::AsyncDistributeNotify(
return
h
;
return
h
;
}
}
VarHandlePtr
GRPCClient
::
AsyncSendAndRecv
(
const
std
::
string
&
ep
,
const
platform
::
DeviceContext
&
ctx
,
const
framework
::
Scope
&
scope
,
const
std
::
string
&
send_var_name
,
const
std
::
string
&
recv_var_name
,
const
std
::
string
&
table_name
,
int64_t
time_out
)
{
const
platform
::
DeviceContext
*
p_ctx
=
&
ctx
;
const
std
::
string
ep_val
=
ep
;
const
std
::
string
send_var_name_val
=
send_var_name
;
const
std
::
string
recv_var_name_val
=
recv_var_name
;
const
std
::
string
table_name_val
=
table_name
;
const
framework
::
Scope
*
p_scope
=
&
scope
;
const
auto
ch
=
GetChannel
(
ep_val
);
const
std
::
string
method
=
kSendAndRecvRPC
;
VLOG
(
4
)
<<
"GRPCClient::SendAndRecv Begin ,Send_var_name: "
<<
send_var_name_val
<<
" Recv_var_name: "
<<
recv_var_name_val
;
int
retry_times_
=
0
;
while
(
true
)
{
SendAndRecvProcessor
*
s
=
new
SendAndRecvProcessor
(
ch
);
VarHandlePtr
h
(
new
VarHandle
(
ep
,
method
,
send_var_name_val
,
p_ctx
,
p_scope
));
VarHandlePtr
h_recv
(
new
VarHandle
(
ep
,
method
,
recv_var_name_val
,
p_ctx
,
p_scope
));
s
->
Prepare
(
h
,
time_out
);
s
->
RecvPrepare
(
h_recv
);
framework
::
AsyncIO
([
send_var_name_val
,
recv_var_name_val
,
table_name_val
,
p_scope
,
p_ctx
,
s
,
method
,
h
,
this
]
{
auto
*
send_var
=
p_scope
->
FindVar
(
send_var_name_val
);
send_var
->
GetMutable
<
framework
::
LoDTensor
>
()
->
set_lod
({});
::
grpc
::
ByteBuffer
buf
;
VLOG
(
4
)
<<
"SerializeToByteBuffer: send_var_name_val: "
<<
send_var_name_val
<<
" recv_var_name_val: "
<<
recv_var_name_val
;
SerializeToByteBuffer
(
send_var_name_val
,
send_var
,
*
p_ctx
,
&
buf
,
recv_var_name_val
,
trainer_id_
,
table_name_val
);
VLOG
(
3
)
<<
s
->
GetVarHandlePtr
()
->
String
()
<<
" begin"
;
// stub context
s
->
response_call_back_
=
ProcGetRecvResponse
;
platform
::
RecordRPCEvent
record_event
(
method
);
auto
call
=
s
->
stub_g_
.
PrepareUnaryCall
(
s
->
context_
.
get
(),
"/sendrecv.SendRecvService/SendAndRecvVariable"
,
buf
,
&
cq_
);
call
->
StartCall
();
call
->
Finish
(
&
s
->
reply_
,
&
s
->
status_
,
reinterpret_cast
<
void
*>
(
s
));
if
(
UNLIKELY
(
platform
::
IsProfileEnabled
()))
{
h
->
Wait
();
}
});
req_count_
++
;
if
(
FLAGS_rpc_retry_times
>
0
&&
retry_times_
<
FLAGS_rpc_retry_times
)
{
h
->
Wait
();
if
(
h
->
should_retry
)
{
VLOG
(
3
)
<<
"rpc call failed, retry times "
<<
retry_times_
;
retry_times_
++
;
std
::
random_device
rd
;
std
::
this_thread
::
sleep_for
(
std
::
chrono
::
milliseconds
(
rd
()
%
5
));
continue
;
}
}
return
h
;
}
}
bool
GRPCClient
::
Wait
()
{
bool
GRPCClient
::
Wait
()
{
std
::
unique_lock
<
std
::
mutex
>
lk
(
sync_mutex_
);
std
::
unique_lock
<
std
::
mutex
>
lk
(
sync_mutex_
);
sync_cond_
.
wait
(
lk
,
[
this
]
{
return
(
req_count_
==
0
||
ok_
==
false
);
});
sync_cond_
.
wait
(
lk
,
[
this
]
{
return
(
req_count_
==
0
||
ok_
==
false
);
});
...
...
paddle/fluid/operators/distributed/grpc/grpc_client.h
浏览文件 @
7f2aa2db
...
@@ -53,6 +53,8 @@ namespace distributed {
...
@@ -53,6 +53,8 @@ namespace distributed {
void
ProcGetResponse
(
const
VarHandle
&
var_h
,
const
grpc
::
ByteBuffer
&
msg
);
void
ProcGetResponse
(
const
VarHandle
&
var_h
,
const
grpc
::
ByteBuffer
&
msg
);
void
ProcGetRecvResponse
(
const
VarHandle
&
var_h
,
const
grpc
::
ByteBuffer
&
msg
);
class
BaseProcessor
{
class
BaseProcessor
{
public:
public:
BaseProcessor
()
{
context_
=
nullptr
;
}
BaseProcessor
()
{
context_
=
nullptr
;
}
...
@@ -131,6 +133,28 @@ class GetProcessor : public BaseProcessor {
...
@@ -131,6 +133,28 @@ class GetProcessor : public BaseProcessor {
RequestGetCallBack
response_call_back_
=
ProcGetResponse
;
RequestGetCallBack
response_call_back_
=
ProcGetResponse
;
};
};
class
SendAndRecvProcessor
:
public
BaseProcessor
{
public:
explicit
SendAndRecvProcessor
(
std
::
shared_ptr
<
grpc
::
Channel
>
ch
)
:
BaseProcessor
(),
stub_g_
(
ch
)
{}
virtual
~
SendAndRecvProcessor
()
{}
void
ProcessImpl
()
override
{
if
(
response_call_back_
)
{
response_call_back_
(
*
var_h_recv_
.
get
(),
reply_
);
var_h_recv_
->
Finish
(
true
);
}
}
void
RecvPrepare
(
VarHandlePtr
h_recv
)
{
var_h_recv_
=
h_recv
;
}
::
grpc
::
ByteBuffer
reply_
;
::
grpc
::
GenericStub
stub_g_
;
RequestGetCallBack
response_call_back_
=
ProcGetResponse
;
VarHandlePtr
var_h_recv_
;
};
class
BatchBarrierProcessor
:
public
BaseProcessor
{
class
BatchBarrierProcessor
:
public
BaseProcessor
{
public:
public:
explicit
BatchBarrierProcessor
(
std
::
shared_ptr
<
grpc
::
Channel
>
ch
)
explicit
BatchBarrierProcessor
(
std
::
shared_ptr
<
grpc
::
Channel
>
ch
)
...
@@ -231,6 +255,14 @@ class GRPCClient : public RPCClient {
...
@@ -231,6 +255,14 @@ class GRPCClient : public RPCClient {
const
framework
::
Scope
&
scope
,
const
std
::
string
&
var_name
,
const
framework
::
Scope
&
scope
,
const
std
::
string
&
var_name
,
int64_t
time_out
=
FLAGS_rpc_deadline
)
override
;
int64_t
time_out
=
FLAGS_rpc_deadline
)
override
;
VarHandlePtr
AsyncSendAndRecv
(
const
std
::
string
&
ep
,
const
platform
::
DeviceContext
&
ctx
,
const
framework
::
Scope
&
scope
,
const
std
::
string
&
send_var_name
,
const
std
::
string
&
recv_var_name
,
const
std
::
string
&
table_name
=
""
,
int64_t
time_out
=
FLAGS_rpc_deadline
)
override
;
VarHandlePtr
AsyncSendComplete
(
VarHandlePtr
AsyncSendComplete
(
const
std
::
string
&
ep
,
int64_t
time_out
=
FLAGS_rpc_deadline
)
override
;
const
std
::
string
&
ep
,
int64_t
time_out
=
FLAGS_rpc_deadline
)
override
;
...
...
paddle/fluid/operators/distributed/grpc/grpc_serde.cc
浏览文件 @
7f2aa2db
...
@@ -76,7 +76,6 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
...
@@ -76,7 +76,6 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
PADDLE_THROW
(
"Serialize does not support type: %s"
,
PADDLE_THROW
(
"Serialize does not support type: %s"
,
typeid
(
var
->
Type
()).
name
());
typeid
(
var
->
Type
()).
name
());
}
}
std
::
string
header
;
std
::
string
header
;
request
.
AppendToString
(
&
header
);
request
.
AppendToString
(
&
header
);
auto
buffer
=
std
::
unique_ptr
<
char
[]
>
(
new
char
[
1024
]);
auto
buffer
=
std
::
unique_ptr
<
char
[]
>
(
new
char
[
1024
]);
...
@@ -101,7 +100,6 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
...
@@ -101,7 +100,6 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
}
}
#endif
#endif
PADDLE_ENFORCE_NOT_NULL
(
payload
);
PADDLE_ENFORCE_NOT_NULL
(
payload
);
e
.
WriteVarlengthBeginning
(
VarMsg
::
kSerializedFieldNumber
,
e
.
WriteVarlengthBeginning
(
VarMsg
::
kSerializedFieldNumber
,
payload
->
memory_size
());
payload
->
memory_size
());
if
(
payload
->
memory_size
()
>=
std
::
numeric_limits
<
int
>::
max
())
{
if
(
payload
->
memory_size
()
>=
std
::
numeric_limits
<
int
>::
max
())
{
...
@@ -140,7 +138,6 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
...
@@ -140,7 +138,6 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
::
grpc
::
Slice
::
STEAL_REF
);
::
grpc
::
Slice
::
STEAL_REF
);
num_slices
=
4
;
num_slices
=
4
;
}
}
::
grpc
::
ByteBuffer
tmp
(
&
slices
[
0
],
num_slices
);
::
grpc
::
ByteBuffer
tmp
(
&
slices
[
0
],
num_slices
);
msg
->
Swap
(
&
tmp
);
msg
->
Swap
(
&
tmp
);
}
}
...
@@ -156,6 +153,19 @@ void DeserializeFromByteBuffer(const ::grpc::ByteBuffer& msg,
...
@@ -156,6 +153,19 @@ void DeserializeFromByteBuffer(const ::grpc::ByteBuffer& msg,
*
trainer_id
=
resp
.
GetTrainerId
();
*
trainer_id
=
resp
.
GetTrainerId
();
}
}
void
DeserializeRecvFromByteBuffer
(
const
::
grpc
::
ByteBuffer
&
msg
,
const
platform
::
DeviceContext
&
ctx
,
const
framework
::
Scope
*
scope
,
framework
::
Variable
**
var
,
int
*
trainer_id
)
{
platform
::
RecordRPCEvent
record_event
(
"deserial"
);
operators
::
distributed
::
GRPCVariableResponse
resp
(
scope
,
&
ctx
);
PADDLE_ENFORCE_EQ
(
resp
.
Parse
(
msg
),
0
,
platform
::
errors
::
InvalidArgument
(
"parse bytebuffer to tensor error!"
));
*
var
=
resp
.
GetRecvVar
();
*
trainer_id
=
resp
.
GetTrainerId
();
}
}
// namespace distributed
}
// namespace distributed
}
// namespace operators
}
// namespace operators
}
// namespace paddle
}
// namespace paddle
paddle/fluid/operators/distributed/grpc/grpc_serde.h
浏览文件 @
7f2aa2db
...
@@ -47,6 +47,11 @@ void DeserializeFromByteBuffer(const ::grpc::ByteBuffer& msg,
...
@@ -47,6 +47,11 @@ void DeserializeFromByteBuffer(const ::grpc::ByteBuffer& msg,
const
framework
::
Scope
*
scope
,
const
framework
::
Scope
*
scope
,
framework
::
Variable
**
var
,
int
*
trainer_id
);
framework
::
Variable
**
var
,
int
*
trainer_id
);
void
DeserializeRecvFromByteBuffer
(
const
::
grpc
::
ByteBuffer
&
msg
,
const
platform
::
DeviceContext
&
ctx
,
const
framework
::
Scope
*
scope
,
framework
::
Variable
**
var
,
int
*
trainer_id
);
}
// namespace distributed
}
// namespace distributed
}
// namespace operators
}
// namespace operators
}
// namespace paddle
}
// namespace paddle
paddle/fluid/operators/distributed/grpc/grpc_server.cc
浏览文件 @
7f2aa2db
...
@@ -28,6 +28,7 @@ DECLARE_int32(rpc_retry_bind_port);
...
@@ -28,6 +28,7 @@ DECLARE_int32(rpc_retry_bind_port);
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
namespace
distributed
{
namespace
distributed
{
enum
CallStatus
{
PROCESS
=
0
,
FINISH
};
enum
CallStatus
{
PROCESS
=
0
,
FINISH
};
// reference:
// reference:
...
@@ -433,6 +434,51 @@ class RequestNotify final : public RequestBase {
...
@@ -433,6 +434,51 @@ class RequestNotify final : public RequestBase {
ServerAsyncResponseWriter
<
sendrecv
::
VoidMessage
>
responder_
;
ServerAsyncResponseWriter
<
sendrecv
::
VoidMessage
>
responder_
;
};
};
class
RequestSendAndRecv
final
:
public
RequestBase
{
public:
explicit
RequestSendAndRecv
(
GrpcService
::
AsyncService
*
service
,
::
grpc
::
ServerCompletionQueue
*
cq
,
RequestHandler
*
request_handler
,
int
req_id
)
:
RequestBase
(
service
,
cq
,
request_handler
,
req_id
),
responder_
(
&
ctx_
)
{
request_
.
reset
(
new
GRPCVariableResponse
(
request_handler
->
scope
(),
request_handler
->
dev_ctx
(),
request_handler
->
distributed_mode
()));
int
method_id
=
static_cast
<
int
>
(
distributed
::
GrpcMethod
::
kRequestSendAndRecv
);
service_
->
RequestAsyncUnary
(
method_id
,
&
ctx_
,
request_
.
get
(),
&
responder_
,
cq_
,
cq_
,
reinterpret_cast
<
void
*>
(
static_cast
<
intptr_t
>
(
req_id
)));
}
virtual
~
RequestSendAndRecv
()
{}
std
::
string
GetReqName
()
override
{
return
request_
->
Varname
();
}
void
Process
()
override
{
std
::
string
in_var_name
=
request_
->
Varname
();
std
::
string
out_var_name
=
request_
->
OutVarname
();
std
::
string
table_name
=
request_
->
TableName
();
int
trainer_id
=
request_
->
GetTrainerId
();
VLOG
(
4
)
<<
"RequestSendAndRecv, in_var_name: "
<<
in_var_name
<<
" out_var_name: "
<<
out_var_name
<<
" trainer: "
<<
trainer_id
;
auto
scope
=
request_
->
GetMutableLocalScope
();
auto
invar
=
scope
->
FindVar
(
in_var_name
);
framework
::
Variable
*
outvar
=
nullptr
;
request_handler_
->
Handle
(
in_var_name
,
scope
,
invar
,
&
outvar
,
trainer_id
,
out_var_name
,
table_name
);
SerializeToByteBuffer
(
out_var_name
,
outvar
,
*
request_handler_
->
dev_ctx
(),
&
reply_
);
Finish
(
reply_
,
&
responder_
);
}
protected:
std
::
shared_ptr
<
GRPCVariableResponse
>
request_
;
::
grpc
::
ByteBuffer
reply_
;
ServerAsyncResponseWriter
<::
grpc
::
ByteBuffer
>
responder_
;
};
void
AsyncGRPCServer
::
WaitServerReady
()
{
void
AsyncGRPCServer
::
WaitServerReady
()
{
VLOG
(
4
)
<<
"AsyncGRPCServer is waiting server ready"
;
VLOG
(
4
)
<<
"AsyncGRPCServer is waiting server ready"
;
std
::
unique_lock
<
std
::
mutex
>
lock
(
this
->
mutex_ready_
);
std
::
unique_lock
<
std
::
mutex
>
lock
(
this
->
mutex_ready_
);
...
@@ -586,6 +632,8 @@ void AsyncGRPCServer::TryToRegisterNewOne(const std::string& rpc_name,
...
@@ -586,6 +632,8 @@ void AsyncGRPCServer::TryToRegisterNewOne(const std::string& rpc_name,
b
=
new
RequestCheckpointNotify
(
service_
.
get
(),
cq
.
get
(),
handler
,
req_id
);
b
=
new
RequestCheckpointNotify
(
service_
.
get
(),
cq
.
get
(),
handler
,
req_id
);
}
else
if
(
rpc_name
==
kRequestNotify
)
{
}
else
if
(
rpc_name
==
kRequestNotify
)
{
b
=
new
RequestNotify
(
service_
.
get
(),
cq
.
get
(),
handler
,
req_id
);
b
=
new
RequestNotify
(
service_
.
get
(),
cq
.
get
(),
handler
,
req_id
);
}
else
if
(
rpc_name
==
kRequestSendAndRecv
)
{
b
=
new
RequestSendAndRecv
(
service_
.
get
(),
cq
.
get
(),
handler
,
req_id
);
}
else
{
}
else
{
PADDLE_ENFORCE
(
false
,
"not supported rpc"
);
PADDLE_ENFORCE
(
false
,
"not supported rpc"
);
}
}
...
...
paddle/fluid/operators/distributed/grpc/grpc_service.h
浏览文件 @
7f2aa2db
...
@@ -85,10 +85,12 @@ enum class GrpcMethod {
...
@@ -85,10 +85,12 @@ enum class GrpcMethod {
kGetMonomerVariable
,
kGetMonomerVariable
,
kGetMonomerBarrier
,
kGetMonomerBarrier
,
kRequestNotify
,
kRequestNotify
,
kRequestSendAndRecv
,
// when you add new handler, change kGrpcNumMethods at the same time!
};
};
static
const
int
kGrpcNumMethods
=
static
const
int
kGrpcNumMethods
=
static_cast
<
int
>
(
GrpcMethod
::
kRequest
Notify
)
+
1
;
static_cast
<
int
>
(
GrpcMethod
::
kRequest
SendAndRecv
)
+
1
;
inline
const
char
*
GrpcMethodName
(
GrpcMethod
id
)
{
inline
const
char
*
GrpcMethodName
(
GrpcMethod
id
)
{
switch
(
id
)
{
switch
(
id
)
{
...
@@ -108,6 +110,8 @@ inline const char* GrpcMethodName(GrpcMethod id) {
...
@@ -108,6 +110,8 @@ inline const char* GrpcMethodName(GrpcMethod id) {
return
"/sendrecv.SendRecvService/CheckpointNotify"
;
return
"/sendrecv.SendRecvService/CheckpointNotify"
;
case
GrpcMethod
::
kRequestNotify
:
case
GrpcMethod
::
kRequestNotify
:
return
"/sendrecv.SendRecvService/DistributeNotify"
;
return
"/sendrecv.SendRecvService/DistributeNotify"
;
case
GrpcMethod
::
kRequestSendAndRecv
:
return
"/sendrecv.SendRecvService/SendAndRecvVariable"
;
}
}
// Shouldn't be reached.
// Shouldn't be reached.
...
...
paddle/fluid/operators/distributed/request_handler.h
浏览文件 @
7f2aa2db
...
@@ -46,6 +46,7 @@ constexpr char kRequestCheckpoint[] = "RequestCheckpoint";
...
@@ -46,6 +46,7 @@ constexpr char kRequestCheckpoint[] = "RequestCheckpoint";
constexpr
char
kRequestPassBarrier
[]
=
"RequestPassBarrier"
;
constexpr
char
kRequestPassBarrier
[]
=
"RequestPassBarrier"
;
constexpr
char
kRequestGetNoBarrier
[]
=
"GetVariableNoBarrier"
;
constexpr
char
kRequestGetNoBarrier
[]
=
"GetVariableNoBarrier"
;
constexpr
char
kRequestNotify
[]
=
"RequestNotify"
;
constexpr
char
kRequestNotify
[]
=
"RequestNotify"
;
constexpr
char
kRequestSendAndRecv
[]
=
"RequestSendAndRecv"
;
constexpr
char
kSendRPC
[]
=
"SendRPC"
;
constexpr
char
kSendRPC
[]
=
"SendRPC"
;
constexpr
char
kGetRPC
[]
=
"GetRPC"
;
constexpr
char
kGetRPC
[]
=
"GetRPC"
;
...
@@ -57,6 +58,7 @@ constexpr char kFetchBarrierRPC[] = "FetchBarrierRPC";
...
@@ -57,6 +58,7 @@ constexpr char kFetchBarrierRPC[] = "FetchBarrierRPC";
constexpr
char
kSendMonomerFetchBarrierRPC
[]
=
"SendMonomerFetchBarrierRPC"
;
constexpr
char
kSendMonomerFetchBarrierRPC
[]
=
"SendMonomerFetchBarrierRPC"
;
constexpr
char
kSendCompleteRPC
[]
=
"SendCompleteRPC"
;
constexpr
char
kSendCompleteRPC
[]
=
"SendCompleteRPC"
;
constexpr
char
kCheckPointNotifyRPC
[]
=
"CheckPointNotifyRPC"
;
constexpr
char
kCheckPointNotifyRPC
[]
=
"CheckPointNotifyRPC"
;
constexpr
char
kSendAndRecvRPC
[]
=
"SendAndRecvRPC"
;
constexpr
int64_t
kPrefetchTimeout
=
60000
;
constexpr
int64_t
kPrefetchTimeout
=
60000
;
#define LISTEN_TERMINATE_MESSAGE "TERMINATE@RECV"
#define LISTEN_TERMINATE_MESSAGE "TERMINATE@RECV"
...
...
paddle/fluid/operators/distributed/request_handler_impl.cc
浏览文件 @
7f2aa2db
...
@@ -325,6 +325,22 @@ bool RequestNotifyHandler::Handle(const std::string &varname,
...
@@ -325,6 +325,22 @@ bool RequestNotifyHandler::Handle(const std::string &varname,
return
true
;
return
true
;
}
}
bool
RequestSendAndRecvHandler
::
Handle
(
const
std
::
string
&
varname
,
framework
::
Scope
*
Scope
,
framework
::
Variable
*
var
,
framework
::
Variable
**
outvar
,
const
int
trainer_id
,
const
std
::
string
&
out_var_name
,
const
std
::
string
&
table_name
)
{
VLOG
(
3
)
<<
"SendAndRecvHandle: "
<<
varname
<<
" out_var_name: "
<<
out_var_name
<<
" , trainer_id: "
<<
trainer_id
;
executor_
->
RunPreparedContext
((
*
grad_to_prepared_ctx_
)[
varname
].
get
(),
Scope
);
*
outvar
=
Scope
->
FindVar
(
out_var_name
);
return
true
;
}
}
// namespace distributed
}
// namespace distributed
}
// namespace operators
}
// namespace operators
}
// namespace paddle
}
// namespace paddle
paddle/fluid/operators/distributed/request_handler_impl.h
浏览文件 @
7f2aa2db
...
@@ -176,6 +176,17 @@ class RequestNotifyHandler final : public RequestHandler {
...
@@ -176,6 +176,17 @@ class RequestNotifyHandler final : public RequestHandler {
std
::
unordered_map
<
int
,
int64_t
>
decay_counters
;
std
::
unordered_map
<
int
,
int64_t
>
decay_counters
;
};
};
class
RequestSendAndRecvHandler
final
:
public
RequestHandler
{
public:
explicit
RequestSendAndRecvHandler
(
int
distributed_mode
)
:
RequestHandler
(
distributed_mode
)
{}
virtual
~
RequestSendAndRecvHandler
()
{}
bool
Handle
(
const
std
::
string
&
varname
,
framework
::
Scope
*
Scope
,
framework
::
Variable
*
var
,
framework
::
Variable
**
outvar
,
const
int
trainer_id
,
const
std
::
string
&
out_var_name
=
""
,
const
std
::
string
&
table_name
=
""
)
override
;
};
}
// namespace distributed
}
// namespace distributed
}
// namespace operators
}
// namespace operators
}
// namespace paddle
}
// namespace paddle
paddle/fluid/operators/distributed/rpc_client.h
浏览文件 @
7f2aa2db
...
@@ -85,6 +85,12 @@ class RPCClient {
...
@@ -85,6 +85,12 @@ class RPCClient {
const
framework
::
Scope
&
scope
,
const
std
::
string
&
var_name
,
const
framework
::
Scope
&
scope
,
const
std
::
string
&
var_name
,
int64_t
time_out
=
FLAGS_rpc_deadline
)
=
0
;
int64_t
time_out
=
FLAGS_rpc_deadline
)
=
0
;
virtual
VarHandlePtr
AsyncSendAndRecv
(
const
std
::
string
&
ep
,
const
platform
::
DeviceContext
&
ctx
,
const
framework
::
Scope
&
scope
,
const
std
::
string
&
send_var_name
,
const
std
::
string
&
recv_var_name
,
const
std
::
string
&
table_name
=
""
,
int64_t
time_out
=
FLAGS_rpc_deadline
)
=
0
;
virtual
VarHandlePtr
AsyncSendComplete
(
virtual
VarHandlePtr
AsyncSendComplete
(
const
std
::
string
&
ep
,
int64_t
time_out
=
FLAGS_rpc_deadline
)
=
0
;
const
std
::
string
&
ep
,
int64_t
time_out
=
FLAGS_rpc_deadline
)
=
0
;
...
...
paddle/fluid/operators/distributed/rpc_server_test.cc
浏览文件 @
7f2aa2db
...
@@ -35,27 +35,24 @@ namespace platform = paddle::platform;
...
@@ -35,27 +35,24 @@ namespace platform = paddle::platform;
namespace
distributed
=
paddle
::
operators
::
distributed
;
namespace
distributed
=
paddle
::
operators
::
distributed
;
USE_NO_KERNEL_OP
(
lookup_sparse_table_read
);
USE_NO_KERNEL_OP
(
lookup_sparse_table_read
);
USE_OP
(
scale
);
std
::
unique_ptr
<
distributed
::
RPCServer
>
g_rpc_service
;
std
::
unique_ptr
<
distributed
::
RPCServer
>
g_rpc_service
;
std
::
unique_ptr
<
distributed
::
RequestHandler
>
g_req_handler
;
std
::
unique_ptr
<
distributed
::
RequestHandler
>
g_req_handler
;
framework
::
BlockDesc
*
Append
PrefetchBlco
k
(
framework
::
ProgramDesc
*
program
)
{
framework
::
BlockDesc
*
Append
SendAndRecvBloc
k
(
framework
::
ProgramDesc
*
program
)
{
auto
root_block
=
program
->
MutableBlock
(
0
);
auto
root_block
=
program
->
MutableBlock
(
0
);
auto
*
block
=
program
->
AppendBlock
(
*
root_block
);
auto
*
block
=
program
->
AppendBlock
(
*
root_block
);
framework
::
VariableNameMap
input
({{
"W"
,
{
"w"
}},
{
"Ids"
,
{
"ids"
}}});
framework
::
OpDesc
*
op
=
block
->
AppendOp
();
framework
::
VariableNameMap
output
({{
"Output"
,
{
"out"
}}});
op
->
SetType
(
"scale"
);
auto
op
=
block
->
AppendOp
();
op
->
SetInput
(
"X"
,
{
"x"
});
op
->
SetType
(
"lookup_sparse_table_read"
);
op
->
SetOutput
(
"Out"
,
{
"res"
});
op
->
SetInput
(
"W"
,
{
"w"
});
op
->
SetAttr
(
"scale"
,
0.5
f
);
op
->
SetInput
(
"Ids"
,
{
"ids"
});
op
->
SetOutput
(
"Out"
,
{
"out"
});
auto
&
out
=
*
root_block
->
Var
(
"res"
);
op
->
SetAttr
(
"tablename"
,
{
"w"
});
op
->
SetAttr
(
"value_names"
,
{
"Param"
});
auto
&
out
=
*
root_block
->
Var
(
"out"
);
out
.
SetType
(
framework
::
proto
::
VarType
::
LOD_TENSOR
);
out
.
SetType
(
framework
::
proto
::
VarType
::
LOD_TENSOR
);
out
.
SetShape
({
1
0
,
10
});
out
.
SetShape
({
1
,
10
});
return
block
;
return
block
;
}
}
...
@@ -69,6 +66,12 @@ void CreateVarsOnScope(framework::Scope* scope, platform::CPUPlace* place) {
...
@@ -69,6 +66,12 @@ void CreateVarsOnScope(framework::Scope* scope, platform::CPUPlace* place) {
auto
ids_var
=
scope
->
Var
(
"ids"
);
auto
ids_var
=
scope
->
Var
(
"ids"
);
ids_var
->
GetMutable
<
framework
::
LoDTensor
>
();
ids_var
->
GetMutable
<
framework
::
LoDTensor
>
();
auto
x_var
=
scope
->
Var
(
"x"
);
x_var
->
GetMutable
<
framework
::
LoDTensor
>
();
auto
res_var
=
scope
->
Var
(
"res"
);
res_var
->
GetMutable
<
framework
::
LoDTensor
>
();
}
}
void
InitTensorsOnClient
(
framework
::
Scope
*
scope
,
platform
::
CPUPlace
*
place
,
void
InitTensorsOnClient
(
framework
::
Scope
*
scope
,
platform
::
CPUPlace
*
place
,
...
@@ -78,6 +81,11 @@ void InitTensorsOnClient(framework::Scope* scope, platform::CPUPlace* place,
...
@@ -78,6 +81,11 @@ void InitTensorsOnClient(framework::Scope* scope, platform::CPUPlace* place,
int64_t
*
ids_ptr
=
int64_t
*
ids_ptr
=
ids_var
->
mutable_data
<
int64_t
>
(
framework
::
DDim
({
rows_numel
,
1
}),
*
place
);
ids_var
->
mutable_data
<
int64_t
>
(
framework
::
DDim
({
rows_numel
,
1
}),
*
place
);
for
(
int64_t
i
=
0
;
i
<
rows_numel
;
++
i
)
ids_ptr
[
i
]
=
i
*
2
;
for
(
int64_t
i
=
0
;
i
<
rows_numel
;
++
i
)
ids_ptr
[
i
]
=
i
*
2
;
auto
x_var
=
scope
->
Var
(
"x"
)
->
GetMutable
<
framework
::
LoDTensor
>
();
float
*
x_ptr
=
x_var
->
mutable_data
<
float
>
(
framework
::
DDim
({
1
,
rows_numel
}),
*
place
);
for
(
int64_t
i
=
0
;
i
<
rows_numel
;
++
i
)
x_ptr
[
i
]
=
1.0
;
}
}
void
InitTensorsOnServer
(
framework
::
Scope
*
scope
,
platform
::
CPUPlace
*
place
,
void
InitTensorsOnServer
(
framework
::
Scope
*
scope
,
platform
::
CPUPlace
*
place
,
...
@@ -124,6 +132,38 @@ void StartServer(const std::string& rpc_name) {
...
@@ -124,6 +132,38 @@ void StartServer(const std::string& rpc_name) {
server_thread
.
join
();
server_thread
.
join
();
}
}
void
StartSendAndRecvServer
(
const
std
::
string
&
rpc_name
)
{
framework
::
ProgramDesc
program
;
framework
::
Scope
scope
;
platform
::
CPUPlace
place
;
framework
::
Executor
exe
(
place
);
platform
::
CPUDeviceContext
ctx
(
place
);
auto
block
=
AppendSendAndRecvBlock
(
&
program
);
std
::
string
in_var_name
(
"x"
);
std
::
vector
<
int
>
prefetch_block_ids
{
block
->
ID
()};
auto
prepared
=
exe
.
Prepare
(
program
,
prefetch_block_ids
);
InitTensorsOnServer
(
&
scope
,
&
place
,
10
);
std
::
unordered_map
<
std
::
string
,
std
::
shared_ptr
<
framework
::
ExecutorPrepareContext
>>
grad_to_prepared_ctx
;
grad_to_prepared_ctx
[
in_var_name
]
=
prepared
[
0
];
g_req_handler
->
SetProgram
(
&
program
);
g_req_handler
->
SetGradToPreparedCtx
(
&
grad_to_prepared_ctx
);
g_req_handler
->
SetDevCtx
(
&
ctx
);
g_req_handler
->
SetScope
(
&
scope
);
g_req_handler
->
SetExecutor
(
&
exe
);
g_rpc_service
->
RegisterRPC
(
rpc_name
,
g_req_handler
.
get
());
g_req_handler
->
SetRPCServer
(
g_rpc_service
.
get
());
std
::
thread
server_thread
(
std
::
bind
(
&
distributed
::
RPCServer
::
StartServer
,
g_rpc_service
.
get
()));
server_thread
.
join
();
}
TEST
(
COMPLETE
,
CPU
)
{
TEST
(
COMPLETE
,
CPU
)
{
setenv
(
"http_proxy"
,
""
,
1
);
setenv
(
"http_proxy"
,
""
,
1
);
setenv
(
"https_proxy"
,
""
,
1
);
setenv
(
"https_proxy"
,
""
,
1
);
...
@@ -147,3 +187,46 @@ TEST(COMPLETE, CPU) {
...
@@ -147,3 +187,46 @@ TEST(COMPLETE, CPU) {
g_rpc_service
.
reset
(
nullptr
);
g_rpc_service
.
reset
(
nullptr
);
g_req_handler
.
reset
(
nullptr
);
g_req_handler
.
reset
(
nullptr
);
}
}
TEST
(
SENDANDRECV
,
CPU
)
{
setenv
(
"http_proxy"
,
""
,
1
);
setenv
(
"https_proxy"
,
""
,
1
);
g_req_handler
.
reset
(
new
distributed
::
RequestSendAndRecvHandler
(
distributed
::
DistributedMode
::
kAsync
));
g_rpc_service
.
reset
(
new
RPCSERVER_T
(
"127.0.0.1:0"
,
1
));
distributed
::
RPCClient
*
client
=
distributed
::
RPCClient
::
GetInstance
<
RPCCLIENT_T
>
(
0
);
PADDLE_ENFORCE_NE
(
client
,
nullptr
,
platform
::
errors
::
InvalidArgument
(
"Client Start Fail, Check Your Code & Env"
));
std
::
thread
server_thread
(
StartSendAndRecvServer
,
distributed
::
kRequestSendAndRecv
);
g_rpc_service
->
WaitServerReady
();
int
port
=
g_rpc_service
->
GetSelectedPort
();
std
::
string
ep
=
paddle
::
string
::
Sprintf
(
"127.0.0.1:%d"
,
port
);
framework
::
Scope
scope
;
platform
::
CPUPlace
place
;
platform
::
CPUDeviceContext
ctx
(
place
);
// create var on local scope
int64_t
rows_numel
=
10
;
InitTensorsOnClient
(
&
scope
,
&
place
,
rows_numel
);
std
::
string
in_var_name
(
"x"
);
std
::
string
out_var_name
(
"res"
);
client
->
AsyncSendAndRecv
(
ep
,
ctx
,
scope
,
in_var_name
,
out_var_name
);
client
->
Wait
();
auto
var
=
scope
.
Var
(
out_var_name
);
auto
value
=
var
->
GetMutable
<
framework
::
LoDTensor
>
();
auto
ptr
=
value
->
mutable_data
<
float
>
(
place
);
for
(
int64_t
i
=
0
;
i
<
rows_numel
;
++
i
)
{
EXPECT_EQ
(
ptr
[
i
],
0.5
);
}
g_rpc_service
->
ShutDown
();
server_thread
.
join
();
LOG
(
INFO
)
<<
"begin reset"
;
g_rpc_service
.
reset
(
nullptr
);
g_req_handler
.
reset
(
nullptr
);
}
paddle/fluid/operators/distributed/send_recv.proto.in
浏览文件 @
7f2aa2db
...
@@ -29,7 +29,7 @@ service SendRecvService {
...
@@ -29,7 +29,7 @@ service SendRecvService {
rpc
CheckpointNotify
(
VariableMessage
)
returns
(
VoidMessage
)
{}
rpc
CheckpointNotify
(
VariableMessage
)
returns
(
VoidMessage
)
{}
rpc
DistributeNotify
(
VariableMessage
)
returns
(
VoidMessage
)
{}
rpc
DistributeNotify
(
VariableMessage
)
returns
(
VoidMessage
)
{}
rpc
SendAndRecvVariable
(
VariableMessage
)
returns
(
VariableMessage
)
{}
rpc
GetMonomerVariable
(
VariableMessage
)
returns
(
VariableMessage
)
{}
rpc
GetMonomerVariable
(
VariableMessage
)
returns
(
VariableMessage
)
{}
rpc
GetMonomerBarrier
(
VariableMessage
)
returns
(
VoidMessage
)
{}
rpc
GetMonomerBarrier
(
VariableMessage
)
returns
(
VoidMessage
)
{}
}
}
...
...
paddle/fluid/operators/distributed/variable_response.h
浏览文件 @
7f2aa2db
...
@@ -96,6 +96,13 @@ class VariableResponse {
...
@@ -96,6 +96,13 @@ class VariableResponse {
return
scope_
->
FindVar
(
meta_
.
varname
());
return
scope_
->
FindVar
(
meta_
.
varname
());
}
}
framework
::
Variable
*
GetRecvVar
()
{
if
(
create_scope_
)
{
return
local_scope_
->
Var
(
meta_
.
out_varname
());
}
return
scope_
->
FindVar
(
meta_
.
out_varname
());
}
int
GetTrainerId
()
{
return
static_cast
<
int
>
(
meta_
.
trainer_id
());
}
int
GetTrainerId
()
{
return
static_cast
<
int
>
(
meta_
.
trainer_id
());
}
protected:
protected:
...
...
paddle/fluid/operators/distributed_ops/listen_and_serv_op.cc
浏览文件 @
7f2aa2db
...
@@ -268,7 +268,6 @@ void ListenAndServOp::RunAsyncLoop(framework::Executor *executor,
...
@@ -268,7 +268,6 @@ void ListenAndServOp::RunAsyncLoop(framework::Executor *executor,
size_t
num_blocks
=
program
->
Size
();
size_t
num_blocks
=
program
->
Size
();
PADDLE_ENFORCE_GE
(
num_blocks
,
2
,
PADDLE_ENFORCE_GE
(
num_blocks
,
2
,
"server program should have at least 2 blocks"
);
"server program should have at least 2 blocks"
);
std
::
vector
<
int
>
block_list
;
std
::
vector
<
int
>
block_list
;
for
(
size_t
blkid
=
1
;
blkid
<
num_blocks
;
++
blkid
)
{
for
(
size_t
blkid
=
1
;
blkid
<
num_blocks
;
++
blkid
)
{
block_list
.
push_back
(
blkid
);
block_list
.
push_back
(
blkid
);
...
@@ -295,6 +294,7 @@ void ListenAndServOp::RunAsyncLoop(framework::Executor *executor,
...
@@ -295,6 +294,7 @@ void ListenAndServOp::RunAsyncLoop(framework::Executor *executor,
request_send_handler_
->
SetGradToPreparedCtx
(
&
grad_to_prepared_ctx
);
request_send_handler_
->
SetGradToPreparedCtx
(
&
grad_to_prepared_ctx
);
request_get_handler_
->
SetGradToPreparedCtx
(
&
grad_to_prepared_ctx
);
request_get_handler_
->
SetGradToPreparedCtx
(
&
grad_to_prepared_ctx
);
request_prefetch_handler_
->
SetGradToPreparedCtx
(
&
grad_to_prepared_ctx
);
request_prefetch_handler_
->
SetGradToPreparedCtx
(
&
grad_to_prepared_ctx
);
request_send_and_recv_handler_
->
SetGradToPreparedCtx
(
&
grad_to_prepared_ctx
);
while
(
true
)
{
while
(
true
)
{
if
(
rpc_service_
->
IsExit
())
{
if
(
rpc_service_
->
IsExit
())
{
...
@@ -394,6 +394,8 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
...
@@ -394,6 +394,8 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
new
distributed
::
RequestGetNoBarrierHandler
());
new
distributed
::
RequestGetNoBarrierHandler
());
request_notify_handler_
.
reset
(
request_notify_handler_
.
reset
(
new
distributed
::
RequestNotifyHandler
(
distributed_mode
,
fan_in
));
new
distributed
::
RequestNotifyHandler
(
distributed_mode
,
fan_in
));
request_send_and_recv_handler_
.
reset
(
new
distributed
::
RequestSendAndRecvHandler
(
distributed_mode
));
rpc_service_
->
RegisterRPC
(
distributed
::
kRequestSend
,
rpc_service_
->
RegisterRPC
(
distributed
::
kRequestSend
,
request_send_handler_
.
get
(),
rpc_send_thread_num
);
request_send_handler_
.
get
(),
rpc_send_thread_num
);
...
@@ -408,6 +410,9 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
...
@@ -408,6 +410,9 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
request_get_no_barrier_handler_
.
get
());
request_get_no_barrier_handler_
.
get
());
rpc_service_
->
RegisterRPC
(
distributed
::
kRequestNotify
,
rpc_service_
->
RegisterRPC
(
distributed
::
kRequestNotify
,
request_notify_handler_
.
get
(),
rpc_send_thread_num
);
request_notify_handler_
.
get
(),
rpc_send_thread_num
);
rpc_service_
->
RegisterRPC
(
distributed
::
kRequestSendAndRecv
,
request_send_and_recv_handler_
.
get
(),
rpc_get_thread_num
);
auto
optimize_blocks
=
auto
optimize_blocks
=
Attr
<
std
::
vector
<
framework
::
BlockDesc
*>>
(
kOptimizeBlocks
);
Attr
<
std
::
vector
<
framework
::
BlockDesc
*>>
(
kOptimizeBlocks
);
...
@@ -416,6 +421,7 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
...
@@ -416,6 +421,7 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
"optimize blocks is less than 1. Optimize blocks "
"optimize blocks is less than 1. Optimize blocks "
"should be 1 at least on the pserver side."
));
"should be 1 at least on the pserver side."
));
auto
*
program
=
optimize_blocks
[
0
]
->
Program
();
auto
*
program
=
optimize_blocks
[
0
]
->
Program
();
framework
::
Executor
executor
(
dev_place
);
framework
::
Executor
executor
(
dev_place
);
std
::
shared_ptr
<
framework
::
ExecutorPrepareContext
>
ckpt_pre_context
=
nullptr
;
std
::
shared_ptr
<
framework
::
ExecutorPrepareContext
>
ckpt_pre_context
=
nullptr
;
...
@@ -488,6 +494,7 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
...
@@ -488,6 +494,7 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
f
(
request_checkpoint_handler_
.
get
());
f
(
request_checkpoint_handler_
.
get
());
f
(
request_get_no_barrier_handler_
.
get
());
f
(
request_get_no_barrier_handler_
.
get
());
f
(
request_notify_handler_
.
get
());
f
(
request_notify_handler_
.
get
());
f
(
request_send_and_recv_handler_
.
get
());
// register SIGINT(from ctrl+C) and SIGTERM(from kill) signal handlers
// register SIGINT(from ctrl+C) and SIGTERM(from kill) signal handlers
signal
(
SIGINT
,
SignalHandler
::
StopAndExit
);
signal
(
SIGINT
,
SignalHandler
::
StopAndExit
);
...
...
paddle/fluid/operators/distributed_ops/listen_and_serv_op.h
浏览文件 @
7f2aa2db
...
@@ -99,6 +99,8 @@ class ListenAndServOp : public framework::OperatorBase {
...
@@ -99,6 +99,8 @@ class ListenAndServOp : public framework::OperatorBase {
mutable
std
::
shared_ptr
<
distributed
::
RequestHandler
>
mutable
std
::
shared_ptr
<
distributed
::
RequestHandler
>
request_checkpoint_handler_
;
request_checkpoint_handler_
;
mutable
std
::
shared_ptr
<
distributed
::
RequestHandler
>
request_notify_handler_
;
mutable
std
::
shared_ptr
<
distributed
::
RequestHandler
>
request_notify_handler_
;
mutable
std
::
shared_ptr
<
distributed
::
RequestHandler
>
request_send_and_recv_handler_
;
mutable
std
::
shared_ptr
<
std
::
thread
>
server_thread_
;
mutable
std
::
shared_ptr
<
std
::
thread
>
server_thread_
;
mutable
std
::
vector
<
std
::
string
>
sparse_vars_
;
mutable
std
::
vector
<
std
::
string
>
sparse_vars_
;
...
...
paddle/fluid/operators/distributed_ops/send_and_recv_op.cc
0 → 100644
浏览文件 @
7f2aa2db
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <future> // NOLINT
#include <ostream>
#include "paddle/fluid/framework/blocking_queue.h"
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/distributed/communicator.h"
#include "paddle/fluid/operators/distributed/communicator_common.h"
#include "paddle/fluid/operators/distributed/distributed.h"
#include "paddle/fluid/operators/distributed/parameter_send.h"
#include "paddle/fluid/operators/distributed_ops/send_recv_util.h"
#include "paddle/fluid/platform/profiler.h"
namespace
paddle
{
namespace
operators
{
template
<
typename
DeviceContext
,
typename
T
>
class
SendAndRecvKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
auto
&
scope
=
ctx
.
scope
();
const
auto
&
place
=
ctx
.
GetPlace
();
auto
send_var_name
=
ctx
.
Attr
<
std
::
string
>
(
"send_var_name"
);
auto
recv_var_name
=
ctx
.
Attr
<
std
::
string
>
(
"recv_var_name"
);
auto
epmap
=
ctx
.
Attr
<
std
::
string
>
(
"endpoint"
);
auto
trainer_id
=
ctx
.
Attr
<
int
>
(
"trainer_id"
);
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
auto
&
context
=
*
pool
.
Get
(
place
);
distributed
::
RPCClient
*
rpc_client
=
distributed
::
RPCClient
::
GetInstance
<
RPCCLIENT_T
>
(
trainer_id
);
VLOG
(
3
)
<<
"SendAndRecvOp Send_var_name: "
<<
send_var_name
<<
" Recv_var_name: "
<<
recv_var_name
;
distributed
::
VarHandlePtr
rets
=
rpc_client
->
AsyncSendAndRecv
(
epmap
,
context
,
scope
,
send_var_name
,
recv_var_name
);
rets
->
Wait
();
}
};
class
SendAndRecvOp
:
public
framework
::
OperatorWithKernel
{
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{}
protected:
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
auto
data_type
=
OperatorWithKernel
::
IndicateVarDataType
(
ctx
,
"X"
);
return
framework
::
OpKernelType
(
data_type
,
platform
::
CPUPlace
());
}
};
class
SendAndRecvOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
public:
void
Make
()
{
AddInput
(
"X"
,
"Tensor Input variable to be sent"
).
AsDuplicable
();
AddOutput
(
"Out"
,
"Tensor Output varibale to be recv"
).
AsDuplicable
();
AddAttr
<
std
::
string
>
(
"send_var_name"
,
"Send Tensor's name"
)
.
SetDefault
(
std
::
string
(
""
));
AddAttr
<
std
::
string
>
(
"recv_var_name"
,
"Recv Tensor's name"
)
.
SetDefault
(
std
::
string
(
""
));
AddAttr
<
int
>
(
"trainer_id"
,
"trainer id from 0 ~ worker_num."
).
SetDefault
(
0
);
AddAttr
<
std
::
string
>
(
"endpoint"
,
"Server endpoint"
)
.
SetDefault
({
"127.0.0.1:6164"
});
AddComment
(
R"DOC(
SendAndRecv operator
This operator will send variables to listen_and_serve op at the parameter server.
And recv variable from parameter server of send variable's scope.
)DOC"
);
}
};
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
REGISTER_OPERATOR
(
send_and_recv
,
ops
::
SendAndRecvOp
,
ops
::
SendAndRecvOpMaker
);
REGISTER_OP_CPU_KERNEL
(
send_and_recv
,
ops
::
SendAndRecvKernel
<
paddle
::
platform
::
CPUDeviceContext
,
float
>
)
python/paddle/distributed/fleet/base/fleet_base.py
浏览文件 @
7f2aa2db
...
@@ -200,7 +200,8 @@ class Fleet(object):
...
@@ -200,7 +200,8 @@ class Fleet(object):
bool: True if this is a node of server,
bool: True if this is a node of server,
False if not.
False if not.
"""
"""
return
self
.
_role_maker
.
is_server
()
return
self
.
_role_maker
.
is_server
(
)
or
self
.
_role_maker
.
_is_heter_worker
()
@
property
@
property
def
util
(
self
):
def
util
(
self
):
...
...
python/paddle/distributed/fleet/base/role_maker.py
浏览文件 @
7f2aa2db
...
@@ -14,6 +14,7 @@
...
@@ -14,6 +14,7 @@
"""Defination of Role Makers."""
"""Defination of Role Makers."""
import
os
import
os
import
numpy
as
np
import
numpy
as
np
import
warnings
from
multiprocessing
import
Process
,
Manager
from
multiprocessing
import
Process
,
Manager
import
paddle.fluid
as
fluid
import
paddle.fluid
as
fluid
...
@@ -23,6 +24,7 @@ import paddle.fluid as fluid
...
@@ -23,6 +24,7 @@ import paddle.fluid as fluid
class
Role
:
class
Role
:
WORKER
=
1
WORKER
=
1
SERVER
=
2
SERVER
=
2
HETER_WORKER
=
3
class
RoleMakerBase
(
object
):
class
RoleMakerBase
(
object
):
...
@@ -40,6 +42,11 @@ class RoleMakerBase(object):
...
@@ -40,6 +42,11 @@ class RoleMakerBase(object):
self
.
_role
=
None
self
.
_role
=
None
self
.
_current_id
=
-
1
self
.
_current_id
=
-
1
# for heter parameter server mode
self
.
_heter_trainer_endpoints
=
[]
self
.
_heter_trainer_device
=
"CPU"
self
.
_is_heter_parameter_server_mode
=
False
self
.
_node_type
=
None
self
.
_node_type
=
None
self
.
_node_type_comm
=
None
self
.
_node_type_comm
=
None
self
.
_all_comm
=
None
self
.
_all_comm
=
None
...
@@ -163,12 +170,58 @@ class RoleMakerBase(object):
...
@@ -163,12 +170,58 @@ class RoleMakerBase(object):
"""
"""
print
(
"warning: RoleMakerBase does not have barrier worker."
)
print
(
"warning: RoleMakerBase does not have barrier worker."
)
def
_is_heter_worker
(
self
):
"""
Return is_heter_worker() of current process
"""
warnings
.
warn
(
"RoleMakerBase does not have function: _is_heter_worker."
)
return
False
def
_heter_worker_num
(
self
):
"""
Get current total heter-worker number.
Returns:
int: heter_worker number
"""
warnings
.
warn
(
"RoleMakerBase does not have function: _heter_worker_num."
)
return
0
def
_get_heter_worker_endpoints
(
self
):
"""
Returns:
string: all heter_trainers'endpoints
"""
assert
self
.
_heter_trainer_endpoints
!=
[]
return
self
.
_heter_trainer_endpoints
def
_get_heter_worker_endpoint
(
self
):
"""
Returns:
int: corresponding heter_trainer's endpoint
e.g: if we have 4 cpu-trainer(default), 2 gpu-trainer(heter)
then No.0 and No.2 cpu-trainer will work with No.0 gpu-trainer
and No.1 and No.3 cpu-trainer will work with No.1 gpu-trainerr
"""
assert
self
.
_heter_trainer_endpoints
!=
[]
return
self
.
_heter_trainer_endpoints
[(
self
.
_current_id
+
1
)
%
self
.
_heter_worker_num
()]
def
_get_heter_worker_device
(
self
):
"""
Returns:
string: heter_trainer's device of current node, e.g: CPU/GPU/XPU
"""
return
self
.
_heter_trainer_device
.
upper
()
class
PaddleCloudRoleMaker
(
RoleMakerBase
):
class
PaddleCloudRoleMaker
(
RoleMakerBase
):
def
__init__
(
self
,
is_collective
=
False
,
**
kwargs
):
def
__init__
(
self
,
is_collective
=
False
,
**
kwargs
):
super
(
PaddleCloudRoleMaker
,
self
).
__init__
()
super
(
PaddleCloudRoleMaker
,
self
).
__init__
()
self
.
_is_collective
=
is_collective
self
.
_is_collective
=
is_collective
self
.
_init_gloo
=
False
#default no init gloo
self
.
_init_gloo
=
False
#
default no init gloo
self
.
_kwargs
=
kwargs
self
.
_kwargs
=
kwargs
self
.
_role_is_generated
=
False
self
.
_role_is_generated
=
False
...
@@ -278,10 +331,7 @@ class PaddleCloudRoleMaker(RoleMakerBase):
...
@@ -278,10 +331,7 @@ class PaddleCloudRoleMaker(RoleMakerBase):
"""
"""
get index of current node
get index of current node
"""
"""
if
self
.
is_server
():
return
self
.
_current_id
return
self
.
server_index
()
elif
self
.
is_worker
():
return
self
.
worker_index
()
def
worker_num
(
self
):
def
worker_num
(
self
):
"""
"""
...
@@ -323,6 +373,22 @@ class PaddleCloudRoleMaker(RoleMakerBase):
...
@@ -323,6 +373,22 @@ class PaddleCloudRoleMaker(RoleMakerBase):
self
.
generate_role
()
self
.
generate_role
()
return
self
.
_server_endpoints
return
self
.
_server_endpoints
def
_heter_worker_num
(
self
):
"""
get heter worker nums
"""
if
not
self
.
_role_is_generated
:
self
.
generate_role
()
return
self
.
_heter_trainers_num
def
_is_heter_worker
(
self
):
"""
whether current process is heter worker
"""
if
not
self
.
_role_is_generated
:
self
.
generate_role
()
return
self
.
_role
==
Role
.
HETER_WORKER
def
_get_rank
(
self
):
def
_get_rank
(
self
):
"""
"""
get current rank in all workers and pservers
get current rank in all workers and pservers
...
@@ -342,17 +408,47 @@ class PaddleCloudRoleMaker(RoleMakerBase):
...
@@ -342,17 +408,47 @@ class PaddleCloudRoleMaker(RoleMakerBase):
def
_ps_env
(
self
):
def
_ps_env
(
self
):
try
:
try
:
# Environment variable PADDLE_PSERVERS_IP_PORT_LIST must be set
# Environment variable PADDLE_PSERVERS_IP_PORT_LIST must be set
# format: string(ip:port), eg. 127.0.0.1:6001
# format: string(ip:port,ip:port), eg. 127.0.0.1:6001,127.0.0.1:6002
self
.
_server_endpoints
=
os
.
environ
[
self
.
_server_endpoints
=
os
.
getenv
(
"PADDLE_PSERVERS_IP_PORT_LIST"
,
"PADDLE_PSERVERS_IP_PORT_LIST"
].
split
(
","
)
""
).
split
(
","
)
assert
self
.
_server_endpoints
!=
""
self
.
_worker_endpoints
=
os
.
getenv
(
"PADDLE_TRAINER_ENDPOINTS"
,
self
.
_worker_endpoints
=
os
.
getenv
(
"PADDLE_TRAINER_ENDPOINTS"
,
""
).
split
(
","
)
""
).
split
(
","
)
assert
self
.
_server_endpoints
!=
""
trainers_num
=
int
(
os
.
environ
[
"PADDLE_TRAINERS_NUM"
])
trainers_num
=
int
(
os
.
environ
[
"PADDLE_TRAINERS_NUM"
])
training_role
=
os
.
environ
[
"TRAINING_ROLE"
]
training_role
=
os
.
environ
[
"TRAINING_ROLE"
]
if
training_role
not
in
[
"TRAINER"
,
"PSERVER"
]:
if
training_role
not
in
[
"TRAINER"
,
"PSERVER"
,
"HETER_TRAINER"
]:
raise
ValueError
(
"TRAINING_ROLE must be PSERVER or TRAINER"
)
raise
ValueError
(
"TRAINING_ROLE must be PSERVER or TRAINER or HETER_TRAINER, but get {}, please check your environment."
.
format
(
training_role
))
# For heter parameter server env setting
heter_trainer_eplist
=
os
.
getenv
(
"PADDLE_HETER_TRAINER_IP_PORT_LIST"
,
None
)
heter_trainer_device
=
os
.
getenv
(
"PADDLE_HETER_TRAINER_DEVICE"
,
None
)
if
heter_trainer_eplist
and
heter_trainer_device
:
try
:
heter_trainer_eplist
=
os
.
environ
[
"PADDLE_HETER_TRAINER_IP_PORT_LIST"
].
split
(
","
)
except
:
raise
ValueError
(
"Can not Find PADDLE_HETER_TRAINER_IP_PORT_LIST in env or its format doesn't match the requirement: 'IP:PORT,IP:PORT' ."
)
self
.
_is_heter_parameter_server_mode
=
True
heter_trainers_num
=
len
(
heter_trainer_eplist
)
current_node_device
=
heter_trainer_device
.
upper
()
if
current_node_device
not
in
[
"CPU"
,
"GPU"
,
"XPU"
]:
raise
ValueError
(
"Heter Trainer doesn't support {} device now, please use CPU / GPU / XPU(KunLun)"
.
format
(
heter_trainer_device
))
self
.
_heter_trainer_device
=
current_node_device
else
:
self
.
_is_heter_parameter_server_mode
=
False
heter_trainers_num
=
0
if
training_role
==
"TRAINER"
:
if
training_role
==
"TRAINER"
:
role
=
Role
.
WORKER
role
=
Role
.
WORKER
...
@@ -365,17 +461,26 @@ class PaddleCloudRoleMaker(RoleMakerBase):
...
@@ -365,17 +461,26 @@ class PaddleCloudRoleMaker(RoleMakerBase):
ip
=
os
.
environ
[
"POD_IP"
]
ip
=
os
.
environ
[
"POD_IP"
]
self
.
_cur_endpoint
=
ip
+
":"
+
port
self
.
_cur_endpoint
=
ip
+
":"
+
port
current_id
=
self
.
_server_endpoints
.
index
(
self
.
_cur_endpoint
)
current_id
=
self
.
_server_endpoints
.
index
(
self
.
_cur_endpoint
)
elif
training_role
==
"HETER_TRAINER"
:
role
=
Role
.
HETER_WORKER
cur_ip
=
os
.
environ
[
"POD_IP"
]
cur_port
=
os
.
environ
[
"PADDLE_PORT"
]
curr_endpoint
=
":"
.
join
([
cur_ip
,
cur_port
])
current_id
=
heter_trainer_eplist
.
index
(
curr_endpoint
)
else
:
else
:
raise
ValueError
(
"TRAINING_ROLE must be PSERVER or TRAINER"
)
except
ValueError
as
ve
:
raise
ValueError
(
raise
ValueError
(
"something wrong with PaddleCloud, please check environment"
)
"TRAINING_ROLE must be PSERVER or TRAINER or HETER_TRAINER"
)
except
ValueError
as
e
:
raise
ValueError
(
"Something wrong with PaddleCloud, please check environment"
)
self
.
_trainers_num
=
trainers_num
self
.
_trainers_num
=
trainers_num
self
.
_role
=
role
self
.
_role
=
role
self
.
_current_id
=
current_id
self
.
_current_id
=
current_id
self
.
_node_num
=
len
(
self
.
_node_num
=
len
(
set
([
x
.
split
(
':'
)[
0
]
for
x
in
self
.
_worker_endpoints
]))
set
([
x
.
split
(
':'
)[
0
]
for
x
in
self
.
_worker_endpoints
]))
self
.
_heter_trainers_num
=
heter_trainers_num
self
.
_heter_trainer_endpoints
=
heter_trainer_eplist
def
_collective_env
(
self
):
def
_collective_env
(
self
):
self
.
_current_id
=
int
(
os
.
getenv
(
"PADDLE_TRAINER_ID"
,
"0"
))
self
.
_current_id
=
int
(
os
.
getenv
(
"PADDLE_TRAINER_ID"
,
"0"
))
...
...
python/paddle/distributed/fleet/meta_optimizers/__init__.py
浏览文件 @
7f2aa2db
...
@@ -15,10 +15,10 @@ from .amp_optimizer import AMPOptimizer
...
@@ -15,10 +15,10 @@ from .amp_optimizer import AMPOptimizer
from
.recompute_optimizer
import
RecomputeOptimizer
from
.recompute_optimizer
import
RecomputeOptimizer
from
.gradient_merge_optimizer
import
GradientMergeOptimizer
from
.gradient_merge_optimizer
import
GradientMergeOptimizer
from
.graph_execution_optimizer
import
GraphExecutionOptimizer
from
.graph_execution_optimizer
import
GraphExecutionOptimizer
from
.
async_optimizer
import
AsyncMeta
Optimizer
from
.
parameter_server_optimizer
import
ParameterServer
Optimizer
from
.pipeline_optimizer
import
PipelineOptimizer
from
.pipeline_optimizer
import
PipelineOptimizer
from
.localsgd_optimizer
import
LocalSGDOptimizer
from
.localsgd_optimizer
import
LocalSGDOptimizer
from
.lars_optimizer
import
LarsOptimizer
from
.lars_optimizer
import
LarsOptimizer
from
.
async_graph_execution_optimizer
import
AsyncGraphExecution
Optimizer
from
.
parameter_server_graph_optimizer
import
ParameterServerGraph
Optimizer
from
.dgc_optimizer
import
DGCOptimizer
from
.dgc_optimizer
import
DGCOptimizer
from
.lamb_optimizer
import
LambOptimizer
from
.lamb_optimizer
import
LambOptimizer
python/paddle/distributed/fleet/meta_optimizers/
async_graph_execution
_optimizer.py
→
python/paddle/distributed/fleet/meta_optimizers/
parameter_server_graph
_optimizer.py
浏览文件 @
7f2aa2db
...
@@ -13,12 +13,12 @@
...
@@ -13,12 +13,12 @@
from
paddle
import
fluid
from
paddle
import
fluid
from
paddle.fluid
import
compiler
from
paddle.fluid
import
compiler
from
.
async_optimizer
import
AsyncMeta
Optimizer
from
.
parameter_server_optimizer
import
ParameterServer
Optimizer
class
AsyncGraphExecutionOptimizer
(
AsyncMeta
Optimizer
):
class
ParameterServerGraphOptimizer
(
ParameterServer
Optimizer
):
def
__init__
(
self
,
optimizer
):
def
__init__
(
self
,
optimizer
):
super
(
AsyncGraphExecution
Optimizer
,
self
).
__init__
(
optimizer
)
super
(
ParameterServerGraph
Optimizer
,
self
).
__init__
(
optimizer
)
self
.
inner_opt
=
optimizer
self
.
inner_opt
=
optimizer
# we do not allow meta optimizer to be inner optimizer currently
# we do not allow meta optimizer to be inner optimizer currently
self
.
meta_optimizers_white_list
=
[]
self
.
meta_optimizers_white_list
=
[]
...
@@ -31,6 +31,9 @@ class AsyncGraphExecutionOptimizer(AsyncMetaOptimizer):
...
@@ -31,6 +31,9 @@ class AsyncGraphExecutionOptimizer(AsyncMetaOptimizer):
if
self
.
role_maker
.
is_server
():
if
self
.
role_maker
.
is_server
():
return
False
return
False
if
self
.
role_maker
.
_is_heter_parameter_server_mode
:
return
False
return
True
return
True
def
_disable_strategy
(
self
,
dist_strategy
):
def
_disable_strategy
(
self
,
dist_strategy
):
...
...
python/paddle/distributed/fleet/meta_optimizers/
async
_optimizer.py
→
python/paddle/distributed/fleet/meta_optimizers/
parameter_server
_optimizer.py
浏览文件 @
7f2aa2db
...
@@ -15,9 +15,9 @@ from paddle import fluid
...
@@ -15,9 +15,9 @@ from paddle import fluid
from
.meta_optimizer_base
import
MetaOptimizerBase
from
.meta_optimizer_base
import
MetaOptimizerBase
class
AsyncMeta
Optimizer
(
MetaOptimizerBase
):
class
ParameterServer
Optimizer
(
MetaOptimizerBase
):
def
__init__
(
self
,
optimizer
):
def
__init__
(
self
,
optimizer
):
super
(
AsyncMeta
Optimizer
,
self
).
__init__
(
optimizer
)
super
(
ParameterServer
Optimizer
,
self
).
__init__
(
optimizer
)
self
.
inner_opt
=
optimizer
self
.
inner_opt
=
optimizer
# we do not allow meta optimizer to be inner optimizer currently
# we do not allow meta optimizer to be inner optimizer currently
self
.
meta_optimizers_white_list
=
[]
self
.
meta_optimizers_white_list
=
[]
...
@@ -68,6 +68,21 @@ class AsyncMetaOptimizer(MetaOptimizerBase):
...
@@ -68,6 +68,21 @@ class AsyncMetaOptimizer(MetaOptimizerBase):
_startup
=
worker
.
init_from_server_pass
(
_startup
,
compiled_config
)
_startup
=
worker
.
init_from_server_pass
(
_startup
,
compiled_config
)
_startup
=
worker
.
delet_extra_optimizes_pass
(
_startup
,
_startup
=
worker
.
delet_extra_optimizes_pass
(
_startup
,
compiled_config
)
compiled_config
)
# for heter program
if
self
.
role_maker
.
_is_heter_parameter_server_mode
:
from
paddle.fluid.incubate.fleet.parameter_server.ir
import
heter_trainer_pass
as
heter_worker
if
self
.
role_maker
.
_is_heter_worker
():
# for heter worker
_main
=
heter_worker
.
split_heter_worker_ops_pass
(
_main
,
compiled_config
)
else
:
# for default worker
_main
=
heter_worker
.
split_trainer_ops_pass
(
_main
,
compiled_config
)
# for startup change
_startup
=
heter_worker
.
delete_startup_useless_ops_var_pass
(
_startup
,
_main
,
compiled_config
)
else
:
else
:
_main
=
worker
.
append_send_ops_pass
(
_main
,
compiled_config
)
_main
=
worker
.
append_send_ops_pass
(
_main
,
compiled_config
)
_startup
=
_startup
_startup
=
_startup
...
@@ -129,9 +144,12 @@ class AsyncMetaOptimizer(MetaOptimizerBase):
...
@@ -129,9 +144,12 @@ class AsyncMetaOptimizer(MetaOptimizerBase):
_origin_startup_program
,
_origin_startup_program
,
strategy
,
self
.
role_maker
)
strategy
,
self
.
role_maker
)
main_program
,
startup_program
=
\
if
self
.
role_maker
.
is_worker
()
or
self
.
role_maker
.
_is_heter_worker
():
self
.
_build_trainer_programs
(
compiled_config
)
if
self
.
role_maker
.
is_worker
()
\
main_program
,
startup_program
=
self
.
_build_trainer_programs
(
else
self
.
_build_pserver_programs
(
compiled_config
)
compiled_config
)
elif
self
.
role_maker
.
is_server
():
main_program
,
startup_program
=
self
.
_build_pserver_programs
(
compiled_config
)
loss
.
block
.
program
=
main_program
loss
.
block
.
program
=
main_program
fluid
.
framework
.
switch_startup_program
(
startup_program
)
fluid
.
framework
.
switch_startup_program
(
startup_program
)
...
...
python/paddle/distributed/fleet/runtime/parameter_server_runtime.py
浏览文件 @
7f2aa2db
...
@@ -196,6 +196,18 @@ class ParameterServerRuntime(RuntimeBase):
...
@@ -196,6 +196,18 @@ class ParameterServerRuntime(RuntimeBase):
else
:
else
:
warnings
.
warn
(
"communicator has been initialized, skip"
)
warnings
.
warn
(
"communicator has been initialized, skip"
)
def
_get_executor
(
self
):
if
self
.
role_maker
.
_is_heter_worker
():
if
self
.
role_maker
.
_get_heter_worker_device
()
==
"GPU"
:
gpu_id
=
int
(
os
.
getenv
(
"FLAGS_selected_gpus"
,
"0"
))
executor
=
Executor
(
fluid
.
CUDAPlace
(
gpu_id
))
else
:
raise
ValueError
(
"Not Support Device {}"
.
format
(
self
.
role_maker
.
_get_heter_worker_device
()))
else
:
executor
=
fluid
.
Executor
(
fluid
.
CPUPlace
())
return
executor
def
_init_server
(
self
,
*
args
,
**
kwargs
):
def
_init_server
(
self
,
*
args
,
**
kwargs
):
if
len
(
args
)
>
1
:
if
len
(
args
)
>
1
:
raise
ValueError
(
"init server can only accept 1 args: `dirname`"
)
raise
ValueError
(
"init server can only accept 1 args: `dirname`"
)
...
@@ -204,9 +216,15 @@ class ParameterServerRuntime(RuntimeBase):
...
@@ -204,9 +216,15 @@ class ParameterServerRuntime(RuntimeBase):
else
:
else
:
model_dirname
=
None
model_dirname
=
None
executor
=
fluid
.
Executor
(
fluid
.
CPUPlace
())
if
self
.
role_maker
.
_is_heter_worker
():
self
.
_init_worker
()
executor
=
self
.
_get_executor
()
executor
.
run
(
fluid
.
default_startup_program
())
executor
.
run
(
fluid
.
default_startup_program
())
if
self
.
role_maker
.
_is_heter_worker
():
return
if
not
model_dirname
:
if
not
model_dirname
:
return
return
...
@@ -237,12 +255,12 @@ class ParameterServerRuntime(RuntimeBase):
...
@@ -237,12 +255,12 @@ class ParameterServerRuntime(RuntimeBase):
# self._load_sparse_params(dirname=model_dir, varnames=distribtued_varnames)
# self._load_sparse_params(dirname=model_dir, varnames=distribtued_varnames)
def
_run_server
(
self
):
def
_run_server
(
self
):
executor
=
fluid
.
Executor
(
fluid
.
CPUPlace
()
)
executor
=
self
.
_get_executor
(
)
executor
.
run
(
fluid
.
default_main_program
())
executor
.
run
(
fluid
.
default_main_program
())
def
_stop_worker
(
self
):
def
_stop_worker
(
self
):
self
.
_communicator
.
stop
()
self
.
_communicator
.
stop
()
executor
=
fluid
.
Executor
(
fluid
.
CPUPlace
()
)
executor
=
self
.
_get_executor
(
)
executor
.
close
()
executor
.
close
()
def
_get_optimizer_status
(
self
,
op
,
param_name
):
def
_get_optimizer_status
(
self
,
op
,
param_name
):
...
...
python/paddle/fluid/incubate/fleet/base/fleet_base.py
浏览文件 @
7f2aa2db
...
@@ -145,7 +145,7 @@ class Fleet(object):
...
@@ -145,7 +145,7 @@ class Fleet(object):
Returns:
Returns:
bool: True if this is a node of server,
bool: True if this is a node of server,
False if not
.
False if not
"""
"""
return
self
.
_role_maker
.
is_server
()
return
self
.
_role_maker
.
is_server
()
...
...
python/paddle/fluid/incubate/fleet/base/role_maker.py
浏览文件 @
7f2aa2db
...
@@ -343,7 +343,6 @@ class MPISymetricRoleMaker(MPIRoleMaker):
...
@@ -343,7 +343,6 @@ class MPISymetricRoleMaker(MPIRoleMaker):
def
get_pserver_endpoints
(
self
):
def
get_pserver_endpoints
(
self
):
"""
"""
get pserver endpoints
get pserver endpoints
Returns:
Returns:
endpoints(list): pserver endpoints
endpoints(list): pserver endpoints
"""
"""
...
...
python/paddle/fluid/incubate/fleet/parameter_server/distribute_transpiler/__init__.py
浏览文件 @
7f2aa2db
python/paddle/fluid/incubate/fleet/parameter_server/ir/heter_trainer_pass.py
0 → 100644
浏览文件 @
7f2aa2db
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
print_function
import
warnings
import
paddle.fluid.core
as
core
import
paddle.fluid.framework
as
framework
from
paddle.fluid.transpiler.details.program_utils
import
delete_ops
from
paddle.fluid.incubate.fleet.parameter_server.ir.trainer_pass
import
find_heter_ops
from
paddle.fluid.incubate.fleet.parameter_server.ir.trainer_pass
import
create_heter_program
from
paddle.fluid.incubate.fleet.parameter_server.ir.trainer_pass
import
create_trainer_program
from
paddle.fluid.incubate.fleet.parameter_server.ir.trainer_pass
import
find_block_joints
from
paddle.fluid.incubate.fleet.parameter_server.ir.trainer_pass
import
find_op_input_output
from
paddle.fluid.incubate.fleet.parameter_server.ir.trainer_pass
import
get_vars_name_in_block
def
split_heter_worker_ops_pass
(
program
,
config
):
"""
split heter worker program from origin-program
1. find heter op (located on different device)
2. find input&output of every heter-block
3. create heter worker program, add listen&serv op
"""
default_deveice
=
"cpu"
program
,
heter_ops
,
_
,
program_block_ops
=
find_heter_ops
(
program
,
default_deveice
)
if
len
(
heter_ops
)
==
0
:
warnings
.
warn
(
"Currently running in Heter Parameter Server mode, but no OP running on heterogeneous devices, Please check your code."
)
return
program
current_device
=
"gpu"
if
current_device
not
in
heter_ops
:
raise
ValueError
(
"Op which run on device {} not exist."
.
format
(
current_device
))
block_vars_detail
=
find_block_joints
(
program
,
program_block_ops
,
heter_ops
)
heter_program
=
framework
.
Program
()
create_heter_program
(
program
,
config
,
heter_program
,
heter_ops
,
block_vars_detail
,
current_device
)
return
heter_program
def
split_trainer_ops_pass
(
program
,
config
):
"""
split cpu-trainer program from origin-program
1. find heter op (located on different device)
2. find input&output of every heter-block
3. create cpu-trainer program, add send&recv op
"""
# Todo: support user define default_device (MrChengmo)
default_deveice
=
"cpu"
program
,
heter_ops
,
_
,
program_block_ops
=
find_heter_ops
(
program
,
default_deveice
)
block_vars_detail
=
find_block_joints
(
program
,
program_block_ops
,
heter_ops
)
create_trainer_program
(
program
,
config
,
heter_ops
,
block_vars_detail
)
return
program
def
delete_startup_useless_ops_var_pass
(
startup_program
,
main_program
,
config
):
"""
delete variable which not used in current main_program
"""
# find all op and its var
vars_in_main_program
=
get_vars_name_in_block
(
main_program
.
global_block
())
block_nums
=
startup_program
.
num_blocks
for
block_index
in
range
(
1
,
block_nums
):
current_block
=
startup_program
.
block
(
block_index
)
# delete useless op
need_delete_op
=
[]
for
op
in
current_block
.
ops
:
inputs
,
outputs
=
find_op_input_output
(
startup_program
,
current_block
,
op
)
inputs
+=
outputs
# Todo: delete some concat op
if
list
(
set
(
inputs
)
&
set
(
vars_in_main_program
))
==
None
:
need_delete_op
.
append
(
op
)
delete_ops
(
current_block
,
need_delete_op
)
# delete useless var
for
var
in
current_block
.
vars
:
if
var
.
name
not
in
vars_in_main_program
:
startup_program
.
_remove_var
(
var
.
name
)
return
startup_program
python/paddle/fluid/incubate/fleet/parameter_server/ir/pserver_pass.py
浏览文件 @
7f2aa2db
python/paddle/fluid/incubate/fleet/parameter_server/ir/public.py
浏览文件 @
7f2aa2db
...
@@ -12,33 +12,23 @@
...
@@ -12,33 +12,23 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
# Copyright(c) 2020 PaddlePaddle Authors.All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0(the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http: // www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
print_function
from
__future__
import
print_function
from
functools
import
reduce
from
functools
import
reduce
import
collections
import
collections
import
math
import
math
import
os
import
os
import
warnings
import
six
import
six
import
paddle.fluid
as
fluid
from
paddle.fluid
import
core
from
paddle.fluid
import
core
from
paddle.fluid.core
import
CommContext
from
paddle.fluid.core
import
CommContext
import
paddle.fluid.framework
as
framework
from
paddle.fluid.incubate.fleet.parameter_server.mode
import
DistributedMode
from
paddle.fluid.incubate.fleet.parameter_server.mode
import
DistributedMode
from
paddle.fluid.incubate.fleet.parameter_server.ir
import
vars_metatools
from
paddle.fluid.incubate.fleet.parameter_server.ir
import
vars_metatools
from
paddle.fluid.incubate.fleet.parameter_server.ir.ps_dispatcher
import
RoundRobin
,
PSDispatcher
from
paddle.fluid.incubate.fleet.parameter_server.ir.ps_dispatcher
import
RoundRobin
,
PSDispatcher
from
paddle.fluid.transpiler.details.program_utils
import
delete_ops
OP_NAME_SCOPE
=
"op_namescope"
OP_NAME_SCOPE
=
"op_namescope"
CLIP_OP_NAME_SCOPE
=
"@CLIP"
CLIP_OP_NAME_SCOPE
=
"@CLIP"
...
@@ -122,9 +112,20 @@ class MergedVariable:
...
@@ -122,9 +112,20 @@ class MergedVariable:
self
.
offsets
=
offsets
self
.
offsets
=
offsets
def
Singleton
(
cls
):
_instance
=
{}
def
_singleton
(
*
args
,
**
kargs
):
if
cls
not
in
_instance
:
_instance
[
cls
]
=
cls
(
*
args
,
**
kargs
)
return
_instance
[
cls
]
return
_singleton
@
Singleton
class
CompileTimeStrategy
(
object
):
class
CompileTimeStrategy
(
object
):
def
__init__
(
self
,
main_program
,
startup_program
,
strategy
,
role_maker
):
def
__init__
(
self
,
main_program
,
startup_program
,
strategy
,
role_maker
):
self
.
min_block_size
=
8192
self
.
min_block_size
=
8192
self
.
origin_main_program
=
main_program
self
.
origin_main_program
=
main_program
...
@@ -177,6 +178,12 @@ class CompileTimeStrategy(object):
...
@@ -177,6 +178,12 @@ class CompileTimeStrategy(object):
def
get_ps_endpoints
(
self
):
def
get_ps_endpoints
(
self
):
return
self
.
role_maker
.
get_pserver_endpoints
()
return
self
.
role_maker
.
get_pserver_endpoints
()
def
get_heter_worker_endpoints
(
self
):
return
self
.
role_maker
.
_get_heter_worker_endpoints
()
def
get_heter_worker_endpoint
(
self
):
return
self
.
role_maker
.
_get_heter_worker_endpoint
()
def
get_origin_programs
(
self
):
def
get_origin_programs
(
self
):
return
self
.
origin_main_program
,
self
.
origin_startup_program
return
self
.
origin_main_program
,
self
.
origin_startup_program
...
@@ -810,6 +817,30 @@ class CompileTimeStrategy(object):
...
@@ -810,6 +817,30 @@ class CompileTimeStrategy(object):
return
sparse_param_grads
,
dense_param_grads
return
sparse_param_grads
,
dense_param_grads
def
remove_var_pair_by_grad
(
self
,
var_name
):
for
index
,
pair
in
enumerate
(
self
.
merged_variables_pairs
):
var
=
pair
[
0
]
var_grad
=
pair
[
1
]
if
var_grad
.
merged_var
.
name
==
var_name
:
del
self
.
merged_variables_pairs
[
index
]
for
index
,
pair
in
enumerate
(
self
.
merged_dense_pairs
):
var
=
pair
[
0
]
var_grad
=
pair
[
1
]
if
var_grad
.
merged_var
.
name
==
var_name
:
del
self
.
merged_dense_pairs
[
index
]
return
for
index
,
pair
in
enumerate
(
self
.
merged_sparse_pairs
):
var
=
pair
[
0
]
var_grad
=
pair
[
1
]
if
var_grad
.
merged_var
.
name
==
var_name
:
del
self
.
merged_sparse_pairs
[
index
]
return
print
(
"Not find {} in self.merge_pairs"
.
format
(
var_name
))
def
_is_opt_role_op
(
op
):
def
_is_opt_role_op
(
op
):
# NOTE : depend on oprole to find out whether this op is for
# NOTE : depend on oprole to find out whether this op is for
...
...
python/paddle/fluid/incubate/fleet/parameter_server/ir/trainer_pass.py
浏览文件 @
7f2aa2db
此差异已折叠。
点击以展开。
python/paddle/fluid/tests/unittests/ctr_dataset_reader.py
浏览文件 @
7f2aa2db
...
@@ -17,8 +17,9 @@ from __future__ import print_function
...
@@ -17,8 +17,9 @@ from __future__ import print_function
import
os
import
os
import
logging
import
logging
import
tarfile
import
tarfile
import
tempfile
import
random
import
random
import
warnings
import
paddle
import
paddle
import
paddle.fluid.incubate.data_generator
as
data_generator
import
paddle.fluid.incubate.data_generator
as
data_generator
...
@@ -57,7 +58,7 @@ def load_dnn_input_record(sent):
...
@@ -57,7 +58,7 @@ def load_dnn_input_record(sent):
def
load_lr_input_record
(
sent
):
def
load_lr_input_record
(
sent
):
res
=
[]
res
=
[]
for
_
in
[
x
.
split
(
':'
)
for
x
in
sent
.
split
()]:
for
_
in
[
x
.
split
(
':'
)
for
x
in
sent
.
split
()]:
res
.
append
(
int
(
_
[
0
]))
res
.
append
(
int
(
_
[
0
])
%
10000
)
return
res
return
res
...
@@ -120,9 +121,62 @@ def prepare_data():
...
@@ -120,9 +121,62 @@ def prepare_data():
lr_input_dim
=
res
[
1
]
lr_input_dim
=
res
[
1
]
logger
.
info
(
'dnn input dim: %d'
%
dnn_input_dim
)
logger
.
info
(
'dnn input dim: %d'
%
dnn_input_dim
)
logger
.
info
(
'lr input dim: %d'
%
lr_input_dim
)
logger
.
info
(
'lr input dim: %d'
%
lr_input_dim
)
return
dnn_input_dim
,
lr_input_dim
,
train_file_path
return
dnn_input_dim
,
lr_input_dim
,
train_file_path
def
gen_fake_line
(
dnn_data_num
=
7
,
dnn_data_range
=
1e5
,
lr_data_num
=
5
,
lr_data_range
=
1e5
):
line
=
""
# for deep data
for
index
in
range
(
dnn_data_num
):
data
=
str
(
random
.
randint
(
0
,
dnn_data_range
-
1
))
if
index
<
dnn_data_num
-
1
:
data
+=
" "
line
+=
data
line
+=
"
\t
"
# for wide data
for
index
in
range
(
lr_data_num
):
data
=
str
(
random
.
randint
(
0
,
lr_data_range
-
1
))
+
":"
+
str
(
1
)
if
index
<
lr_data_num
-
1
:
data
+=
" "
line
+=
data
line
+=
"
\t
"
# for label
line
+=
str
(
random
.
randint
(
0
,
1
))
line
+=
"
\n
"
return
line
def
prepare_fake_data
(
file_nums
=
8
,
file_lines
=
1000
):
"""
Create fake data with same type as avazu_ctr_data
"""
file_dir
=
tempfile
.
mkdtemp
()
warnings
.
warn
(
"Fake data write in {}"
.
format
(
file_dir
))
for
file_index
in
range
(
file_nums
):
with
open
(
os
.
path
.
join
(
file_dir
,
"ctr_train_data_part_{}"
.
format
(
file_index
)),
'w+'
)
as
fin
:
file_str
=
""
for
line_index
in
range
(
file_lines
):
file_str
+=
gen_fake_line
()
fin
.
write
(
file_str
)
warnings
.
warn
(
"Write done ctr_train_data_part_{}"
.
format
(
file_index
))
file_list
=
[
os
.
path
.
join
(
file_dir
,
x
)
for
x
in
os
.
listdir
(
file_dir
)]
assert
len
(
file_list
)
==
file_nums
return
file_list
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
pairwise_reader
=
DatasetCtrReader
()
pairwise_reader
=
DatasetCtrReader
()
pairwise_reader
.
run_from_stdin
()
pairwise_reader
.
run_from_stdin
()
python/paddle/fluid/tests/unittests/dist_fleet_heter_ctr.py
0 → 100644
浏览文件 @
7f2aa2db
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Distribute CTR model for test fleet api
"""
from
__future__
import
print_function
import
shutil
import
tempfile
import
time
import
paddle
import
paddle.fluid
as
fluid
import
os
import
numpy
as
np
import
ctr_dataset_reader
from
test_dist_fleet_heter_base
import
runtime_main
,
FleetDistHeterRunnerBase
from
dist_fleet_ctr
import
TestDistCTR2x2
,
fake_ctr_reader
from
paddle.distributed.fleet.base.util_factory
import
fleet_util
# Fix seed for test
fluid
.
default_startup_program
().
random_seed
=
1
fluid
.
default_main_program
().
random_seed
=
1
class
TestHeterPsCTR2x2
(
FleetDistHeterRunnerBase
):
"""
For test CTR model, using Fleet api
"""
def
net
(
self
,
args
,
batch_size
=
4
,
lr
=
0.01
):
"""
network definition
Args:
batch_size(int): the size of mini-batch for training
lr(float): learning rate of training
Returns:
avg_cost: LoDTensor of cost.
"""
dnn_input_dim
,
lr_input_dim
=
int
(
1e5
),
int
(
1e5
)
dnn_data
=
fluid
.
layers
.
data
(
name
=
"dnn_data"
,
shape
=
[
-
1
,
1
],
dtype
=
"int64"
,
lod_level
=
1
,
append_batch_size
=
False
)
lr_data
=
fluid
.
layers
.
data
(
name
=
"lr_data"
,
shape
=
[
-
1
,
1
],
dtype
=
"int64"
,
lod_level
=
1
,
append_batch_size
=
False
)
label
=
fluid
.
layers
.
data
(
name
=
"click"
,
shape
=
[
-
1
,
1
],
dtype
=
"float32"
,
lod_level
=
0
,
append_batch_size
=
False
)
datas
=
[
dnn_data
,
lr_data
,
label
]
if
args
.
reader
==
"pyreader"
:
self
.
reader
=
fluid
.
io
.
PyReader
(
feed_list
=
datas
,
capacity
=
64
,
iterable
=
False
,
use_double_buffer
=
False
)
# build dnn model
dnn_layer_dims
=
[
128
,
64
,
32
,
1
]
dnn_embedding
=
fluid
.
layers
.
embedding
(
is_distributed
=
False
,
input
=
dnn_data
,
size
=
[
dnn_input_dim
,
dnn_layer_dims
[
0
]],
param_attr
=
fluid
.
ParamAttr
(
name
=
"deep_embedding"
,
initializer
=
fluid
.
initializer
.
Constant
(
value
=
0.01
)),
is_sparse
=
True
)
dnn_pool
=
fluid
.
layers
.
sequence_pool
(
input
=
dnn_embedding
,
pool_type
=
"sum"
)
dnn_out
=
dnn_pool
# build lr model
lr_embbding
=
fluid
.
layers
.
embedding
(
is_distributed
=
False
,
input
=
lr_data
,
size
=
[
lr_input_dim
,
1
],
param_attr
=
fluid
.
ParamAttr
(
name
=
"wide_embedding"
,
initializer
=
fluid
.
initializer
.
Constant
(
value
=
0.01
)),
is_sparse
=
True
)
lr_pool
=
fluid
.
layers
.
sequence_pool
(
input
=
lr_embbding
,
pool_type
=
"sum"
)
with
fluid
.
device_guard
(
"gpu"
):
for
i
,
dim
in
enumerate
(
dnn_layer_dims
[
1
:]):
fc
=
fluid
.
layers
.
fc
(
input
=
dnn_out
,
size
=
dim
,
act
=
"relu"
,
param_attr
=
fluid
.
ParamAttr
(
initializer
=
fluid
.
initializer
.
Constant
(
value
=
0.01
)),
name
=
'dnn-fc-%d'
%
i
)
dnn_out
=
fc
merge_layer
=
fluid
.
layers
.
concat
(
input
=
[
dnn_out
,
lr_pool
],
axis
=
1
)
label
=
fluid
.
layers
.
cast
(
label
,
dtype
=
"int64"
)
predict
=
fluid
.
layers
.
fc
(
input
=
merge_layer
,
size
=
2
,
act
=
'softmax'
)
cost
=
fluid
.
layers
.
cross_entropy
(
input
=
predict
,
label
=
label
)
avg_cost
=
fluid
.
layers
.
mean
(
x
=
cost
)
fluid
.
layers
.
Print
(
avg_cost
,
message
=
"avg_cost"
)
self
.
feeds
=
datas
self
.
train_file_path
=
[
"fake1"
,
"fake2"
]
self
.
avg_cost
=
avg_cost
self
.
predict
=
predict
return
avg_cost
def
check_model_right
(
self
,
dirname
):
model_filename
=
os
.
path
.
join
(
dirname
,
"__model__"
)
with
open
(
model_filename
,
"rb"
)
as
f
:
program_desc_str
=
f
.
read
()
program
=
fluid
.
Program
.
parse_from_string
(
program_desc_str
)
with
open
(
os
.
path
.
join
(
dirname
,
"__model__.proto"
),
"w"
)
as
wn
:
wn
.
write
(
str
(
program
))
def
do_pyreader_training
(
self
,
fleet
):
"""
do training using dataset, using fetch handler to catch variable
Args:
fleet(Fleet api): the fleet object of Parameter Server, define distribute training role
"""
exe
=
fluid
.
Executor
(
fluid
.
CPUPlace
())
fleet
.
init_worker
()
exe
.
run
(
fluid
.
default_startup_program
())
batch_size
=
4
train_reader
=
paddle
.
batch
(
fake_ctr_reader
(),
batch_size
=
batch_size
)
self
.
reader
.
decorate_sample_list_generator
(
train_reader
)
for
epoch_id
in
range
(
1
):
self
.
reader
.
start
()
try
:
pass_start
=
time
.
time
()
while
True
:
exe
.
run
(
program
=
fluid
.
default_main_program
())
pass_time
=
time
.
time
()
-
pass_start
except
fluid
.
core
.
EOFException
:
self
.
reader
.
reset
()
fleet
.
stop_worker
()
def
do_dataset_training
(
self
,
fleet
):
train_file_list
=
ctr_dataset_reader
.
prepare_fake_data
()
exe
=
fluid
.
Executor
(
fluid
.
CPUPlace
())
fleet
.
init_worker
()
exe
.
run
(
fluid
.
default_startup_program
())
thread_num
=
1
batch_size
=
128
filelist
=
fleet_util
.
get_file_shard
(
train_file_list
)
print
(
"filelist: {}"
.
format
(
filelist
))
# config dataset
dataset
=
paddle
.
distributed
.
fleet
.
DatasetFactory
().
create_dataset
()
dataset
.
set_batch_size
(
batch_size
)
dataset
.
set_use_var
(
self
.
feeds
)
pipe_command
=
'python ctr_dataset_reader.py'
dataset
.
set_pipe_command
(
pipe_command
)
dataset
.
set_filelist
(
filelist
)
dataset
.
set_thread
(
thread_num
)
for
epoch_id
in
range
(
1
):
pass_start
=
time
.
time
()
dataset
.
set_filelist
(
filelist
)
exe
.
train_from_dataset
(
program
=
fluid
.
default_main_program
(),
dataset
=
dataset
,
fetch_list
=
[
self
.
avg_cost
],
fetch_info
=
[
"cost"
],
print_period
=
2
,
debug
=
int
(
os
.
getenv
(
"Debug"
,
"0"
)))
pass_time
=
time
.
time
()
-
pass_start
print
(
"do_dataset_training done. using time {}"
.
format
(
pass_time
))
if
os
.
getenv
(
"SAVE_MODEL"
)
==
"1"
:
model_dir
=
tempfile
.
mkdtemp
()
fleet
.
save_inference_model
(
exe
,
model_dir
,
[
feed
.
name
for
feed
in
self
.
feeds
],
self
.
avg_cost
)
self
.
check_model_right
(
model_dir
)
shutil
.
rmtree
(
model_dir
)
fleet
.
stop_worker
()
print
(
"do_dataset_training stop worker."
)
if
__name__
==
"__main__"
:
runtime_main
(
TestHeterPsCTR2x2
)
python/paddle/fluid/tests/unittests/test_dist_fleet_heter_base.py
0 → 100644
浏览文件 @
7f2aa2db
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
print_function
"""
high level unit test for distribute fleet.
"""
import
os
import
sys
import
subprocess
import
six
import
shutil
import
numpy
as
np
import
argparse
from
contextlib
import
closing
import
socket
import
time
import
tempfile
import
unittest
import
paddle
import
paddle.fluid
as
fluid
import
paddle.distributed.fleet.base.role_maker
as
role_maker
from
paddle.distributed.fleet.base.util_factory
import
fleet_util
from
paddle.distributed.fleet
import
fleet
__all__
=
[
'FleetDistHeterRunnerBase'
,
'TestFleetHeterBase'
,
'runtime_main'
]
RUN_STEP
=
5
LEARNING_RATE
=
0.01
DIST_UT_PORT
=
0
class
FleetDistHeterRunnerBase
(
object
):
"""
run_pserver,run_trainer : after init role, using transpiler split program
net : implment by child class, the network of model
do training : exe run program
"""
def
build_role
(
self
,
args
):
environs
=
{}
environs
[
"PADDLE_PSERVERS_IP_PORT_LIST"
]
=
args
.
endpoints
environs
[
"PADDLE_TRAINER_ENDPOINTS"
]
=
args
.
trainer_endpoints
environs
[
"PADDLE_HETER_TRAINER_IP_PORT_LIST"
]
=
args
.
heter_trainer_endpoints
environs
[
"PADDLE_HETER_TRAINER_DEVICE"
]
=
args
.
heter_trainer_device
environs
[
"TRAINING_ROLE"
]
=
args
.
role
.
upper
()
environs
[
"PADDLE_TRAINERS_NUM"
]
=
args
.
trainers
environs
[
"PADDLE_TRAINER_ID"
]
=
args
.
current_id
if
args
.
role
.
upper
()
==
"PSERVER"
:
environs
[
"POD_IP"
]
=
args
.
endpoints
.
split
(
","
)[
int
(
args
.
current_id
)].
split
(
":"
)[
0
]
environs
[
"PADDLE_PORT"
]
=
args
.
endpoints
.
split
(
","
)[
int
(
args
.
current_id
)].
split
(
":"
)[
1
]
elif
args
.
role
.
upper
()
==
"HETER_TRAINER"
:
environs
[
"POD_IP"
]
=
args
.
heter_trainer_endpoints
.
split
(
","
)[
int
(
args
.
current_id
)].
split
(
":"
)[
0
]
environs
[
"PADDLE_PORT"
]
=
args
.
heter_trainer_endpoints
.
split
(
","
)[
int
(
args
.
current_id
)].
split
(
":"
)[
1
]
environs
[
"FLAGS_selected_gpus"
]
=
args
.
current_id
for
k
,
v
in
environs
.
items
():
os
.
environ
[
k
]
=
str
(
v
)
self
.
role
=
role_maker
.
PaddleCloudRoleMaker
()
return
self
.
role
def
build_strategy
(
self
,
args
):
self
.
strategy
=
paddle
.
distributed
.
fleet
.
DistributedStrategy
()
self
.
strategy
.
a_sync
=
True
return
self
.
strategy
def
build_optimizer
(
self
,
avg_cost
,
strategy
):
optimizer
=
fluid
.
optimizer
.
SGD
(
LEARNING_RATE
)
optimizer
=
fleet
.
distributed_optimizer
(
optimizer
,
strategy
=
strategy
)
optimizer
.
minimize
(
avg_cost
)
def
run_pserver
(
self
,
args
):
fleet
.
init_server
()
fleet
.
run_server
()
def
run_dataset_trainer
(
self
,
args
):
out
=
self
.
do_dataset_training
(
fleet
)
def
run_pyreader_trainer
(
self
,
args
):
out
=
self
.
do_pyreader_training
(
fleet
)
def
net
(
self
,
args
,
batch_size
=
4
,
lr
=
0.01
):
raise
NotImplementedError
(
"get_model should be implemented by child classes."
)
def
do_dataset_training
(
self
,
fleet
):
raise
NotImplementedError
(
"do_dataset_training should be implemented by child classes."
)
def
do_pyreader_training
(
self
,
fleet
):
raise
NotImplementedError
(
"do_pyreader_training should be implemented by child classes."
)
class
TestFleetHeterBase
(
unittest
.
TestCase
):
"""
start_pserver,start_trainer : add start cmd to test
run_cluster : using multi process to test distribute program
"""
def
_setup_config
(
self
):
raise
NotImplementedError
(
"tests should have _setup_config implemented"
)
def
tearDown
(
self
):
t
=
time
.
time
()
-
self
.
startTime
print
(
'%s: %.3f'
%
(
self
.
__class__
.
__name__
,
t
))
def
setUp
(
self
):
self
.
startTime
=
time
.
time
()
self
.
_mode
=
"async"
self
.
_reader
=
"pyreader"
self
.
_trainers
=
2
self
.
_pservers
=
2
self
.
_port_set
=
set
()
self
.
_heter_device
=
"gpu"
global
DIST_UT_PORT
if
DIST_UT_PORT
==
0
and
os
.
getenv
(
"PADDLE_DIST_UT_PORT"
):
DIST_UT_PORT
=
int
(
os
.
getenv
(
"PADDLE_DIST_UT_PORT"
))
if
DIST_UT_PORT
:
print
(
"set begin_port:"
,
DIST_UT_PORT
)
self
.
_ps_endpoints
=
"127.0.0.1:%s,127.0.0.1:%s"
%
(
DIST_UT_PORT
,
DIST_UT_PORT
+
1
)
self
.
_tr_endpoints
=
"127.0.0.1:%s,127.0.0.1:%s"
%
(
DIST_UT_PORT
+
2
,
DIST_UT_PORT
+
3
)
self
.
_heter_endpoints
=
"127.0.0.1:%s,127.0.0.1:%s"
%
(
DIST_UT_PORT
+
4
,
DIST_UT_PORT
+
5
)
DIST_UT_PORT
+=
6
else
:
self
.
_ps_endpoints
=
"127.0.0.1:%s,127.0.0.1:%s"
%
(
self
.
_find_free_port
(),
self
.
_find_free_port
())
self
.
_tr_endpoints
=
"127.0.0.1:%s,127.0.0.1:%s"
%
(
self
.
_find_free_port
(),
self
.
_find_free_port
())
self
.
_heter_endpoints
=
"127.0.0.1:%s,127.0.0.1:%s"
%
(
self
.
_find_free_port
(),
self
.
_find_free_port
())
self
.
_python_interp
=
sys
.
executable
self
.
_geo_sgd_need_push_nums
=
5
self
.
_grad_clip_mode
=
0
self
.
_setup_config
()
def
_find_free_port
(
self
):
def
__free_port
():
with
closing
(
socket
.
socket
(
socket
.
AF_INET
,
socket
.
SOCK_STREAM
))
as
s
:
s
.
bind
((
''
,
0
))
return
s
.
getsockname
()[
1
]
while
True
:
port
=
__free_port
()
if
port
not
in
self
.
_port_set
:
self
.
_port_set
.
add
(
port
)
return
port
def
_start_pserver
(
self
,
cmd
,
required_envs
):
ps0_cmd
,
ps1_cmd
=
cmd
.
format
(
0
),
cmd
.
format
(
1
)
ps0_pipe
=
open
(
tempfile
.
gettempdir
()
+
"/ps0_err.log"
,
"wb+"
)
ps1_pipe
=
open
(
tempfile
.
gettempdir
()
+
"/ps1_err.log"
,
"wb+"
)
ps0_proc
=
subprocess
.
Popen
(
ps0_cmd
.
strip
().
split
(
" "
),
stdout
=
subprocess
.
PIPE
,
stderr
=
ps0_pipe
,
env
=
required_envs
)
ps1_proc
=
subprocess
.
Popen
(
ps1_cmd
.
strip
().
split
(
" "
),
stdout
=
subprocess
.
PIPE
,
stderr
=
ps1_pipe
,
env
=
required_envs
)
return
ps0_proc
,
ps1_proc
,
ps0_pipe
,
ps1_pipe
def
_start_trainer
(
self
,
cmd
,
required_envs
):
tr0_cmd
,
tr1_cmd
=
cmd
.
format
(
0
),
cmd
.
format
(
1
)
tr0_pipe
=
open
(
tempfile
.
gettempdir
()
+
"/tr0_err.log"
,
"wb+"
)
tr1_pipe
=
open
(
tempfile
.
gettempdir
()
+
"/tr1_err.log"
,
"wb+"
)
tr0_out
=
open
(
tempfile
.
gettempdir
()
+
"/tr0_out.log"
,
"wb+"
)
tr1_out
=
open
(
tempfile
.
gettempdir
()
+
"/tr1_out.log"
,
"wb+"
)
tr0_proc
=
subprocess
.
Popen
(
tr0_cmd
.
strip
().
split
(
" "
),
stdout
=
tr0_out
,
stderr
=
tr0_pipe
,
env
=
required_envs
)
tr1_proc
=
subprocess
.
Popen
(
tr1_cmd
.
strip
().
split
(
" "
),
stdout
=
tr1_out
,
stderr
=
tr1_pipe
,
env
=
required_envs
)
return
tr0_proc
,
tr1_proc
,
tr0_pipe
,
tr1_pipe
def
_start_heter_trainer
(
self
,
cmd
,
required_envs
):
heter0_cmd
,
heter1_cmd
=
cmd
.
format
(
0
),
cmd
.
format
(
1
)
heter0_pipe
=
open
(
tempfile
.
gettempdir
()
+
"/heter0_err.log"
,
"wb+"
)
heter1_pipe
=
open
(
tempfile
.
gettempdir
()
+
"/heter1_err.log"
,
"wb+"
)
heter0_out
=
open
(
tempfile
.
gettempdir
()
+
"/heter0_out.log"
,
"wb+"
)
heter1_out
=
open
(
tempfile
.
gettempdir
()
+
"/heter1_out.log"
,
"wb+"
)
heter0_proc
=
subprocess
.
Popen
(
heter0_cmd
.
strip
().
split
(
" "
),
stdout
=
heter0_out
,
stderr
=
heter0_pipe
,
env
=
required_envs
)
heter1_proc
=
subprocess
.
Popen
(
heter1_cmd
.
strip
().
split
(
" "
),
stdout
=
heter1_out
,
stderr
=
heter1_pipe
,
env
=
required_envs
)
return
heter0_proc
,
heter1_proc
,
heter0_pipe
,
heter1_pipe
def
_run_cluster
(
self
,
model
,
envs
):
env
=
{
'GRAD_CLIP'
:
str
(
self
.
_grad_clip_mode
)}
python_path
=
self
.
_python_interp
gloo_path
=
tempfile
.
mkdtemp
()
if
os
.
getenv
(
'WITH_COVERAGE'
,
'OFF'
)
==
'ON'
:
envs
[
'COVERAGE_FILE'
]
=
os
.
getenv
(
'COVERAGE_FILE'
,
''
)
python_path
+=
" -m coverage run --branch -p"
env
.
update
(
envs
)
tr_cmd
=
"{0} {1} --role trainer --endpoints {2} --trainer_endpoints {3} --current_id {{}} --trainers {4} --mode {5} --geo_sgd_need_push_nums {6} --reader {7} --gloo_path {8} --heter_trainer_endpoints {9} --heter_trainer_device {10}"
.
format
(
python_path
,
model
,
self
.
_ps_endpoints
,
self
.
_tr_endpoints
,
self
.
_trainers
,
self
.
_mode
,
self
.
_geo_sgd_need_push_nums
,
self
.
_reader
,
gloo_path
,
self
.
_heter_endpoints
,
self
.
_heter_device
)
ps_cmd
=
"{0} {1} --role pserver --endpoints {2} --trainer_endpoints {3} --current_id {{}} --trainers {4} --mode {5} --geo_sgd_need_push_nums {6} --reader {7} --gloo_path {8} --heter_trainer_endpoints {9} --heter_trainer_device {10}"
.
format
(
python_path
,
model
,
self
.
_ps_endpoints
,
self
.
_tr_endpoints
,
self
.
_trainers
,
self
.
_mode
,
self
.
_geo_sgd_need_push_nums
,
self
.
_reader
,
gloo_path
,
self
.
_heter_endpoints
,
self
.
_heter_device
)
heter_cmd
=
"{0} {1} --role heter_trainer --endpoints {2} --trainer_endpoints {3} --current_id {{}} --trainers {4} --mode {5} --geo_sgd_need_push_nums {6} --reader {7} --gloo_path {8} --heter_trainer_endpoints {9} --heter_trainer_device {10}"
.
format
(
python_path
,
model
,
self
.
_ps_endpoints
,
self
.
_tr_endpoints
,
self
.
_trainers
,
self
.
_mode
,
self
.
_geo_sgd_need_push_nums
,
self
.
_reader
,
gloo_path
,
self
.
_heter_endpoints
,
self
.
_heter_device
)
# Run dist train to compare with local results
ps0
,
ps1
,
ps0_pipe
,
ps1_pipe
=
self
.
_start_pserver
(
ps_cmd
,
env
)
tr0
,
tr1
,
tr0_pipe
,
tr1_pipe
=
self
.
_start_trainer
(
tr_cmd
,
env
)
heter0
,
heter1
,
heter0_pipe
,
heter1_pipe
=
self
.
_start_heter_trainer
(
heter_cmd
,
env
)
# Wait until trainer process terminate
while
True
:
stat0
=
tr0
.
poll
()
time
.
sleep
(
0.1
)
if
stat0
is
not
None
:
break
while
True
:
stat1
=
tr1
.
poll
()
time
.
sleep
(
0.1
)
if
stat1
is
not
None
:
break
tr0_out
,
tr0_err
=
tr0
.
communicate
()
tr1_out
,
tr1_err
=
tr1
.
communicate
()
print
(
"tr end communicate"
)
tr0_ret
=
tr0
.
returncode
tr1_ret
=
tr0
.
returncode
print
(
"tr get returncode: {}"
.
format
(
tr0_ret
))
if
tr0_ret
!=
0
:
print
(
"========================Error tr0_err begin==========================="
)
os
.
system
(
"cat {}"
.
format
(
tempfile
.
gettempdir
()
+
"/tr0_err.log"
))
print
(
"========================Error tr0_err end==========================="
)
if
tr1_ret
!=
0
:
print
(
"========================Error tr1_err begin==========================="
)
os
.
system
(
"cat {}"
.
format
(
tempfile
.
gettempdir
()
+
"/tr1_err.log"
))
print
(
"========================Error tr1_err end==========================="
)
self
.
assertEqual
(
tr0_ret
,
0
,
"something wrong in tr0, please check"
)
self
.
assertEqual
(
tr1_ret
,
0
,
"something wrong in tr1, please check"
)
# close trainer file
tr0_pipe
.
close
()
tr1_pipe
.
close
()
ps0_pipe
.
close
()
ps1_pipe
.
close
()
heter0_pipe
.
close
()
heter1_pipe
.
close
()
ps0
.
terminate
()
ps1
.
terminate
()
heter0
.
terminate
()
heter1
.
terminate
()
shutil
.
rmtree
(
gloo_path
)
return
0
,
0
def
check_with_place
(
self
,
model_file
,
delta
=
1e-3
,
check_error_log
=
False
,
need_envs
=
{}):
required_envs
=
{
"PATH"
:
os
.
getenv
(
"PATH"
,
""
),
"PYTHONPATH"
:
os
.
getenv
(
"PYTHONPATH"
,
""
),
"LD_LIBRARY_PATH"
:
os
.
getenv
(
"LD_LIBRARY_PATH"
,
""
),
"FLAGS_rpc_deadline"
:
"5000"
,
# 5sec to fail fast
"http_proxy"
:
""
}
required_envs
.
update
(
need_envs
)
if
check_error_log
:
required_envs
[
"GLOG_v"
]
=
"3"
required_envs
[
"GLOG_logtostderr"
]
=
"1"
tr0_losses
,
tr1_losses
=
self
.
_run_cluster
(
model_file
,
required_envs
)
def
runtime_main
(
test_class
):
parser
=
argparse
.
ArgumentParser
(
description
=
'Run Fleet test.'
)
parser
.
add_argument
(
'--role'
,
type
=
str
,
required
=
True
,
choices
=
[
'pserver'
,
'trainer'
,
'heter_trainer'
])
parser
.
add_argument
(
'--endpoints'
,
type
=
str
,
required
=
False
,
default
=
""
)
parser
.
add_argument
(
'--trainer_endpoints'
,
type
=
str
,
required
=
False
,
default
=
""
)
parser
.
add_argument
(
'--heter_trainer_endpoints'
,
type
=
str
,
required
=
False
,
default
=
""
)
parser
.
add_argument
(
'--heter_trainer_device'
,
type
=
str
,
required
=
False
,
default
=
"gpu"
)
parser
.
add_argument
(
'--gloo_path'
,
type
=
str
,
required
=
False
,
default
=
""
)
parser
.
add_argument
(
'--current_id'
,
type
=
int
,
required
=
False
,
default
=
0
)
parser
.
add_argument
(
'--trainers'
,
type
=
int
,
required
=
False
,
default
=
1
)
parser
.
add_argument
(
'--mode'
,
type
=
str
,
required
=
False
,
default
=
'async'
)
parser
.
add_argument
(
'--geo_sgd_need_push_nums'
,
type
=
int
,
required
=
False
,
default
=
2
)
parser
.
add_argument
(
'--reader'
,
type
=
str
,
required
=
False
,
default
=
'dataset'
)
args
=
parser
.
parse_args
()
model
=
test_class
()
role
=
model
.
build_role
(
args
)
fleet
.
init
(
role
)
strategy
=
model
.
build_strategy
(
args
)
avg_cost
=
model
.
net
(
args
)
model
.
build_optimizer
(
avg_cost
,
strategy
)
fleet_util
.
_set_strategy
(
strategy
)
fleet_util
.
_set_role_maker
(
role
)
if
args
.
role
==
"pserver"
or
args
.
role
==
"heter_trainer"
:
model
.
run_pserver
(
args
)
else
:
if
args
.
reader
==
"dataset"
:
model
.
run_dataset_trainer
(
args
)
else
:
model
.
run_pyreader_trainer
(
args
)
python/paddle/fluid/tests/unittests/test_dist_fleet_heter_ctr.py
0 → 100644
浏览文件 @
7f2aa2db
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
print_function
import
os
import
unittest
import
tempfile
from
test_dist_fleet_heter_base
import
TestFleetHeterBase
class
TestDistHeterDatasetAsync2x2
(
TestFleetHeterBase
):
def
_setup_config
(
self
):
self
.
_mode
=
"async"
self
.
_reader
=
"dataset"
def
check_with_place
(
self
,
model_file
,
delta
=
1e-3
,
check_error_log
=
False
,
need_envs
=
{}):
required_envs
=
{
"PATH"
:
os
.
getenv
(
"PATH"
,
""
),
"PYTHONPATH"
:
os
.
getenv
(
"PYTHONPATH"
,
""
),
"LD_LIBRARY_PATH"
:
os
.
getenv
(
"LD_LIBRARY_PATH"
,
""
),
"FLAGS_rpc_deadline"
:
"5000"
,
# 5sec to fail fast
"http_proxy"
:
""
,
"CPU_NUM"
:
"1"
}
required_envs
.
update
(
need_envs
)
if
check_error_log
:
required_envs
[
"GLOG_v"
]
=
"4"
required_envs
[
"GLOG_logtostderr"
]
=
"1"
tr0_losses
,
tr1_losses
=
self
.
_run_cluster
(
model_file
,
required_envs
)
def
test_dist_train
(
self
):
self
.
check_with_place
(
"dist_fleet_heter_ctr.py"
,
delta
=
1e-5
,
check_error_log
=
True
)
if
__name__
==
"__main__"
:
unittest
.
main
()
python/paddle/fluid/tests/unittests/test_dist_fleet_heter_program.py
0 → 100644
浏览文件 @
7f2aa2db
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
print_function
import
unittest
import
paddle
import
os
import
math
import
paddle.fluid
as
fluid
import
paddle.distributed.fleet.base.role_maker
as
role_maker
from
paddle.distributed.fleet.base.util_factory
import
fleet_util
from
paddle.distributed.fleet
import
fleet
class
TestDistFleetHeterProgram
(
unittest
.
TestCase
):
def
build_role
(
self
):
environs
=
{}
environs
[
"PADDLE_PSERVERS_IP_PORT_LIST"
]
=
"127.0.0.1:36012,127.0.0.1:36013"
environs
[
"PADDLE_TRAINER_ENDPOINTS"
]
=
"127.0.0.1:36014,127.0.0.1:36015"
environs
[
"PADDLE_HETER_TRAINER_IP_PORT_LIST"
]
=
"127.0.0.1:36016,127.0.0.1:36017"
environs
[
"PADDLE_HETER_TRAINER_DEVICE"
]
=
"gpu"
environs
[
"TRAINING_ROLE"
]
=
"HETER_TRAINER"
environs
[
"PADDLE_TRAINERS_NUM"
]
=
2
environs
[
"PADDLE_TRAINER_ID"
]
=
0
environs
[
"POD_IP"
]
=
"127.0.0.1"
environs
[
"PADDLE_PORT"
]
=
"36016"
environs
[
"FLAGS_selected_gpus"
]
=
0
for
k
,
v
in
environs
.
items
():
os
.
environ
[
k
]
=
str
(
v
)
self
.
role
=
role_maker
.
PaddleCloudRoleMaker
()
return
self
.
role
def
build_strategy
(
self
):
self
.
strategy
=
paddle
.
distributed
.
fleet
.
DistributedStrategy
()
self
.
strategy
.
a_sync
=
True
return
self
.
strategy
def
build_input
(
self
):
dense_input
=
fluid
.
layers
.
data
(
name
=
"dense_input"
,
shape
=
[
10
],
dtype
=
"float32"
)
sparse_input_ids
=
[
fluid
.
layers
.
data
(
name
=
"C"
+
str
(
i
),
shape
=
[
1
],
lod_level
=
1
,
dtype
=
"int64"
)
for
i
in
range
(
1
,
27
)
]
label
=
fluid
.
layers
.
data
(
name
=
"label"
,
shape
=
[
1
],
dtype
=
"float32"
)
inputs
=
[
dense_input
]
+
sparse_input_ids
+
[
label
]
return
inputs
def
build_net
(
self
,
inputs
):
def
embedding_layer
(
input
):
return
fluid
.
layers
.
embedding
(
input
=
input
,
is_sparse
=
True
,
size
=
[
100001
,
10
],
param_attr
=
fluid
.
ParamAttr
(
name
=
"SparseFeatFactors"
,
initializer
=
fluid
.
initializer
.
Uniform
()),
)
sparse_embed_seq
=
list
(
map
(
embedding_layer
,
inputs
[
1
:
-
1
]))
concated
=
fluid
.
layers
.
concat
(
sparse_embed_seq
+
inputs
[
0
:
1
],
axis
=
1
)
with
fluid
.
device_guard
(
"gpu"
):
fc1
=
fluid
.
layers
.
fc
(
input
=
concated
,
size
=
400
,
act
=
"relu"
,
param_attr
=
fluid
.
ParamAttr
(
initializer
=
fluid
.
initializer
.
Normal
(
scale
=
1
/
math
.
sqrt
(
concated
.
shape
[
1
]))),
name
=
"fc1"
)
with
fluid
.
device_guard
(
"cpu"
):
fc2
=
fluid
.
layers
.
fc
(
input
=
fc1
,
size
=
400
,
act
=
"relu"
,
param_attr
=
fluid
.
ParamAttr
(
initializer
=
fluid
.
initializer
.
Normal
(
scale
=
1
/
math
.
sqrt
(
fc1
.
shape
[
1
]))),
name
=
"fc2"
)
with
fluid
.
device_guard
(
"gpu"
):
fc3
=
fluid
.
layers
.
fc
(
input
=
fc2
,
size
=
400
,
act
=
"relu"
,
param_attr
=
fluid
.
ParamAttr
(
initializer
=
fluid
.
initializer
.
Normal
(
scale
=
1
/
math
.
sqrt
(
fc2
.
shape
[
1
]))),
name
=
"fc3"
)
with
fluid
.
device_guard
(
"cpu"
):
predict
=
fluid
.
layers
.
fc
(
input
=
fc3
,
size
=
2
,
act
=
"softmax"
,
param_attr
=
fluid
.
ParamAttr
(
initializer
=
fluid
.
initializer
.
Normal
(
scale
=
1
/
math
.
sqrt
(
fc3
.
shape
[
1
]))),
)
with
fluid
.
device_guard
(
"gpu"
):
labels
=
fluid
.
layers
.
cast
(
inputs
[
-
1
],
dtype
=
"int64"
)
cost
=
fluid
.
layers
.
cross_entropy
(
input
=
predict
,
label
=
labels
)
avg_cost
=
fluid
.
layers
.
reduce_sum
(
cost
)
return
avg_cost
def
build_optimizer
(
self
,
avg_cost
,
strategy
):
optimizer
=
fluid
.
optimizer
.
SGD
(
1e-2
)
optimizer
=
fleet
.
distributed_optimizer
(
optimizer
,
strategy
=
strategy
)
optimizer
.
minimize
(
avg_cost
)
def
test
(
self
):
role
=
self
.
build_role
()
fleet
.
init
(
role
)
strategy
=
self
.
build_strategy
()
inputs
=
self
.
build_input
()
avg_cost
=
self
.
build_net
(
inputs
)
self
.
build_optimizer
(
avg_cost
,
strategy
)
if
__name__
==
"__main__"
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录