Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
2e834eab
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
2e834eab
编写于
1月 14, 2020
作者:
1
123malin
提交者:
GitHub
1月 14, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Bug fix for sparse recorder (#21969) (#22245)
* test=develop, bug fix for sparse recorder
上级
681d908e
变更
22
隐藏空白更改
内联
并排
Showing
22 changed file
with
291 addition
and
62 deletion
+291
-62
paddle/fluid/operators/collective/c_gen_nccl_id_op.cc
paddle/fluid/operators/collective/c_gen_nccl_id_op.cc
+1
-1
paddle/fluid/operators/distributed/brpc/brpc_server.cc
paddle/fluid/operators/distributed/brpc/brpc_server.cc
+3
-1
paddle/fluid/operators/distributed/grpc/grpc_server.cc
paddle/fluid/operators/distributed/grpc/grpc_server.cc
+6
-6
paddle/fluid/operators/distributed/request_handler.h
paddle/fluid/operators/distributed/request_handler.h
+6
-4
paddle/fluid/operators/distributed/request_handler_impl.cc
paddle/fluid/operators/distributed/request_handler_impl.cc
+11
-7
paddle/fluid/operators/distributed/request_handler_impl.h
paddle/fluid/operators/distributed/request_handler_impl.h
+11
-9
paddle/fluid/operators/distributed/rpc_server_test.cc
paddle/fluid/operators/distributed/rpc_server_test.cc
+4
-2
paddle/fluid/operators/distributed_ops/fl_listen_and_serv_op.cc
.../fluid/operators/distributed_ops/fl_listen_and_serv_op.cc
+2
-2
paddle/fluid/operators/distributed_ops/gen_nccl_id_op.cc
paddle/fluid/operators/distributed_ops/gen_nccl_id_op.cc
+1
-1
paddle/fluid/operators/distributed_ops/listen_and_serv_op.cc
paddle/fluid/operators/distributed_ops/listen_and_serv_op.cc
+19
-13
paddle/fluid/operators/distributed_ops/send_recv_op_test.cc
paddle/fluid/operators/distributed_ops/send_recv_op_test.cc
+5
-3
paddle/fluid/operators/distributed_ops/test_send_nccl_id.cc
paddle/fluid/operators/distributed_ops/test_send_nccl_id.cc
+2
-1
python/paddle/fluid/layers/io.py
python/paddle/fluid/layers/io.py
+3
-1
python/paddle/fluid/tests/unittests/test_dist_transpiler.py
python/paddle/fluid/tests/unittests/test_dist_transpiler.py
+2
-2
python/paddle/fluid/tests/unittests/test_dist_transpiler_config.py
...ddle/fluid/tests/unittests/test_dist_transpiler_config.py
+181
-0
python/paddle/fluid/tests/unittests/test_listen_and_serv.sh
python/paddle/fluid/tests/unittests/test_listen_and_serv.sh
+1
-1
python/paddle/fluid/tests/unittests/test_listen_and_serv_op.py
...n/paddle/fluid/tests/unittests/test_listen_and_serv_op.py
+10
-2
python/paddle/fluid/tests/unittests/test_lookup_remote_table_op.py
...ddle/fluid/tests/unittests/test_lookup_remote_table_op.py
+2
-1
python/paddle/fluid/tests/unittests/test_nce_remote_table_op.py
.../paddle/fluid/tests/unittests/test_nce_remote_table_op.py
+2
-1
python/paddle/fluid/tests/unittests/test_recv_save_op.py
python/paddle/fluid/tests/unittests/test_recv_save_op.py
+2
-1
python/paddle/fluid/transpiler/distribute_transpiler.py
python/paddle/fluid/transpiler/distribute_transpiler.py
+15
-1
python/paddle/fluid/transpiler/geo_sgd_transpiler.py
python/paddle/fluid/transpiler/geo_sgd_transpiler.py
+2
-2
未找到文件。
paddle/fluid/operators/collective/c_gen_nccl_id_op.cc
浏览文件 @
2e834eab
...
@@ -90,7 +90,7 @@ class CGenNCCLIdOp : public framework::OperatorBase {
...
@@ -90,7 +90,7 @@ class CGenNCCLIdOp : public framework::OperatorBase {
// NOTE: Can not use unique_ptr here because the default
// NOTE: Can not use unique_ptr here because the default
// deleter will call GRPC Server's base class's dtor and
// deleter will call GRPC Server's base class's dtor and
// that will cause a wired crash.
// that will cause a wired crash.
distributed
::
RequestSendHandler
rpc_h
(
true
);
distributed
::
RequestSendHandler
rpc_h
(
distributed
::
DistributedMode
::
kSync
);
std
::
unique_ptr
<
distributed
::
RPCServer
>
rpc_service
(
std
::
unique_ptr
<
distributed
::
RPCServer
>
rpc_service
(
new
RPCSERVER_T
(
endpoint
,
1
));
new
RPCSERVER_T
(
endpoint
,
1
));
...
...
paddle/fluid/operators/distributed/brpc/brpc_server.cc
浏览文件 @
2e834eab
...
@@ -13,6 +13,8 @@
...
@@ -13,6 +13,8 @@
// limitations under the License.
// limitations under the License.
#include "paddle/fluid/operators/distributed/brpc/brpc_server.h"
#include "paddle/fluid/operators/distributed/brpc/brpc_server.h"
#include <memory>
#include <unordered_map>
#include "paddle/fluid/framework/threadpool.h"
#include "paddle/fluid/framework/threadpool.h"
#include "paddle/fluid/operators/distributed/brpc/brpc_sendrecvop_utils.h"
#include "paddle/fluid/operators/distributed/brpc/brpc_sendrecvop_utils.h"
#include "paddle/fluid/operators/distributed/brpc/brpc_variable_response.h"
#include "paddle/fluid/operators/distributed/brpc/brpc_variable_response.h"
...
@@ -100,7 +102,7 @@ class BRPCServiceImpl : public SendRecvService {
...
@@ -100,7 +102,7 @@ class BRPCServiceImpl : public SendRecvService {
distributed
::
BRPCVariableResponse
resp
(
request_send_h_
->
scope
(),
distributed
::
BRPCVariableResponse
resp
(
request_send_h_
->
scope
(),
request_send_h_
->
dev_ctx
(),
request_send_h_
->
dev_ctx
(),
!
request_send_h_
->
sync
_mode
());
request_send_h_
->
distributed
_mode
());
PADDLE_ENFORCE
(
resp
.
Parse
(
cntl
->
request_attachment
(),
*
request
)
==
0
,
PADDLE_ENFORCE
(
resp
.
Parse
(
cntl
->
request_attachment
(),
*
request
)
==
0
,
"parse iobuf to tensor error!"
);
"parse iobuf to tensor error!"
);
...
...
paddle/fluid/operators/distributed/grpc/grpc_server.cc
浏览文件 @
2e834eab
...
@@ -90,9 +90,9 @@ class RequestSend final : public RequestBase {
...
@@ -90,9 +90,9 @@ class RequestSend final : public RequestBase {
::
grpc
::
ServerCompletionQueue
*
cq
,
::
grpc
::
ServerCompletionQueue
*
cq
,
RequestHandler
*
request_handler
,
int
req_id
)
RequestHandler
*
request_handler
,
int
req_id
)
:
RequestBase
(
service
,
cq
,
request_handler
,
req_id
),
responder_
(
&
ctx_
)
{
:
RequestBase
(
service
,
cq
,
request_handler
,
req_id
),
responder_
(
&
ctx_
)
{
request_
.
reset
(
new
GRPCVariableResponse
(
request_handler
->
scope
(),
request_
.
reset
(
new
GRPCVariableResponse
(
request_handler
->
dev_ctx
(),
request_handler
->
scope
(),
request_handler
->
dev_ctx
(),
!
request_handler
->
sync
_mode
()));
request_handler
->
distributed
_mode
()));
int
method_id
=
static_cast
<
int
>
(
distributed
::
GrpcMethod
::
kSendVariable
);
int
method_id
=
static_cast
<
int
>
(
distributed
::
GrpcMethod
::
kSendVariable
);
service_
->
RequestAsyncUnary
(
service_
->
RequestAsyncUnary
(
method_id
,
&
ctx_
,
request_
.
get
(),
&
responder_
,
cq_
,
cq_
,
method_id
,
&
ctx_
,
request_
.
get
(),
&
responder_
,
cq_
,
cq_
,
...
@@ -401,9 +401,9 @@ class RequestNotify final : public RequestBase {
...
@@ -401,9 +401,9 @@ class RequestNotify final : public RequestBase {
::
grpc
::
ServerCompletionQueue
*
cq
,
::
grpc
::
ServerCompletionQueue
*
cq
,
RequestHandler
*
request_handler
,
int
req_id
)
RequestHandler
*
request_handler
,
int
req_id
)
:
RequestBase
(
service
,
cq
,
request_handler
,
req_id
),
responder_
(
&
ctx_
)
{
:
RequestBase
(
service
,
cq
,
request_handler
,
req_id
),
responder_
(
&
ctx_
)
{
request_
.
reset
(
new
GRPCVariableResponse
(
request_handler
->
scope
(),
request_
.
reset
(
new
GRPCVariableResponse
(
request_handler
->
dev_ctx
(),
request_handler
->
scope
(),
request_handler
->
dev_ctx
(),
!
request_handler
->
sync
_mode
()));
request_handler
->
distributed
_mode
()));
int
method_id
=
static_cast
<
int
>
(
distributed
::
GrpcMethod
::
kRequestNotify
);
int
method_id
=
static_cast
<
int
>
(
distributed
::
GrpcMethod
::
kRequestNotify
);
service_
->
RequestAsyncUnary
(
service_
->
RequestAsyncUnary
(
method_id
,
&
ctx_
,
request_
.
get
(),
&
responder_
,
cq_
,
cq_
,
method_id
,
&
ctx_
,
request_
.
get
(),
&
responder_
,
cq_
,
cq_
,
...
...
paddle/fluid/operators/distributed/request_handler.h
浏览文件 @
2e834eab
...
@@ -68,6 +68,8 @@ constexpr char kCheckPointNotifyRPC[] = "CheckPointNotifyRPC";
...
@@ -68,6 +68,8 @@ constexpr char kCheckPointNotifyRPC[] = "CheckPointNotifyRPC";
#define CHECKPOINT_SAVE_MESSAGE "SAVE@CHECKPOINTNOTIFY"
#define CHECKPOINT_SAVE_MESSAGE "SAVE@CHECKPOINTNOTIFY"
#define CHECKPOINT_LOAD_MESSAGE "LOAD@CHECKPOINTNOTIFY"
#define CHECKPOINT_LOAD_MESSAGE "LOAD@CHECKPOINTNOTIFY"
enum
DistributedMode
{
kSync
=
0
,
kAsync
=
1
,
kHalfAsync
=
2
,
kGeo
=
3
};
class
RPCServer
;
class
RPCServer
;
class
VarHandle
{
class
VarHandle
{
...
@@ -151,8 +153,8 @@ typedef std::shared_ptr<VarHandle> VarHandlePtr;
...
@@ -151,8 +153,8 @@ typedef std::shared_ptr<VarHandle> VarHandlePtr;
class
RequestHandler
{
class
RequestHandler
{
public:
public:
explicit
RequestHandler
(
bool
sync
_mode
)
explicit
RequestHandler
(
int
distributed
_mode
)
:
sync_mode_
(
sync
_mode
),
:
distributed_mode_
(
distributed
_mode
),
dev_ctx_
(
nullptr
),
dev_ctx_
(
nullptr
),
executor_
(
nullptr
),
executor_
(
nullptr
),
scope_
(
nullptr
),
scope_
(
nullptr
),
...
@@ -198,7 +200,7 @@ class RequestHandler {
...
@@ -198,7 +200,7 @@ class RequestHandler {
void
SetRPCServer
(
RPCServer
*
rpc_server
)
{
rpc_server_
=
rpc_server
;
}
void
SetRPCServer
(
RPCServer
*
rpc_server
)
{
rpc_server_
=
rpc_server
;
}
// Get attributes.
// Get attributes.
bool
sync_mode
()
{
return
sync
_mode_
;
}
int
distributed_mode
()
{
return
distributed
_mode_
;
}
framework
::
Scope
*
scope
()
{
return
scope_
;
}
framework
::
Scope
*
scope
()
{
return
scope_
;
}
const
platform
::
DeviceContext
*
dev_ctx
()
{
return
dev_ctx_
;
}
const
platform
::
DeviceContext
*
dev_ctx
()
{
return
dev_ctx_
;
}
framework
::
ProgramDesc
*
program
()
{
return
program_
;
}
framework
::
ProgramDesc
*
program
()
{
return
program_
;
}
...
@@ -225,7 +227,7 @@ class RequestHandler {
...
@@ -225,7 +227,7 @@ class RequestHandler {
const
std
::
string
&
table_name
=
""
)
=
0
;
const
std
::
string
&
table_name
=
""
)
=
0
;
protected:
protected:
const
bool
sync
_mode_
;
const
int
distributed
_mode_
;
const
platform
::
DeviceContext
*
dev_ctx_
;
const
platform
::
DeviceContext
*
dev_ctx_
;
framework
::
Executor
*
executor_
;
framework
::
Executor
*
executor_
;
...
...
paddle/fluid/operators/distributed/request_handler_impl.cc
浏览文件 @
2e834eab
...
@@ -61,7 +61,7 @@ bool RequestSendHandler::Handle(const std::string& varname,
...
@@ -61,7 +61,7 @@ bool RequestSendHandler::Handle(const std::string& varname,
rpc_server_
->
Complete
();
rpc_server_
->
Complete
();
}
else
{
}
else
{
// Async
// Async
if
(
!
sync_mode_
)
{
if
(
distributed_mode_
!=
DistributedMode
::
kSync
)
{
VLOG
(
3
)
<<
"async process var: "
<<
varname
;
VLOG
(
3
)
<<
"async process var: "
<<
varname
;
if
(
varname
==
BATCH_BARRIER_MESSAGE
)
{
if
(
varname
==
BATCH_BARRIER_MESSAGE
)
{
PADDLE_THROW
(
PADDLE_THROW
(
...
@@ -82,7 +82,8 @@ bool RequestSendHandler::Handle(const std::string& varname,
...
@@ -82,7 +82,8 @@ bool RequestSendHandler::Handle(const std::string& varname,
scope
->
Rename
(
varname
,
run_varname
);
scope
->
Rename
(
varname
,
run_varname
);
}
}
if
(
AsyncSparseParamUpdateRecorder
::
GetInstance
()
->
HasGrad
(
run_varname
))
{
if
(
distributed_mode_
==
DistributedMode
::
kGeo
&&
AsyncSparseParamUpdateRecorder
::
GetInstance
()
->
HasGrad
(
run_varname
))
{
auto
&
grad_slr
=
auto
&
grad_slr
=
scope
->
FindVar
(
run_varname
)
->
Get
<
framework
::
SelectedRows
>
();
scope
->
FindVar
(
run_varname
)
->
Get
<
framework
::
SelectedRows
>
();
AsyncSparseParamUpdateRecorder
::
GetInstance
()
->
Update
(
run_varname
,
AsyncSparseParamUpdateRecorder
::
GetInstance
()
->
Update
(
run_varname
,
...
@@ -116,7 +117,7 @@ bool RequestGetHandler::Handle(const std::string& varname,
...
@@ -116,7 +117,7 @@ bool RequestGetHandler::Handle(const std::string& varname,
<<
" out_var_name: "
<<
out_var_name
<<
" trainer_id: "
<<
trainer_id
<<
" out_var_name: "
<<
out_var_name
<<
" trainer_id: "
<<
trainer_id
<<
" table_name: "
<<
table_name
;
<<
" table_name: "
<<
table_name
;
if
(
sync_mode_
)
{
if
(
distributed_mode_
==
DistributedMode
::
kSync
)
{
if
(
varname
==
FETCH_BARRIER_MESSAGE
)
{
if
(
varname
==
FETCH_BARRIER_MESSAGE
)
{
VLOG
(
3
)
<<
"sync: recv fetch barrier message"
;
VLOG
(
3
)
<<
"sync: recv fetch barrier message"
;
rpc_server_
->
IncreaseBatchBarrier
(
kRequestGet
);
rpc_server_
->
IncreaseBatchBarrier
(
kRequestGet
);
...
@@ -140,10 +141,13 @@ bool RequestGetHandler::Handle(const std::string& varname,
...
@@ -140,10 +141,13 @@ bool RequestGetHandler::Handle(const std::string& varname,
framework
::
TensorCopy
(
t_orig
,
dev_ctx_
->
GetPlace
(),
t
);
framework
::
TensorCopy
(
t_orig
,
dev_ctx_
->
GetPlace
(),
t
);
}
}
VLOG
(
1
)
<<
"Table name empty? "
<<
table_name
.
empty
();
VLOG
(
1
)
<<
"Table name empty? "
<<
table_name
.
empty
();
VLOG
(
1
)
<<
"AsyncSparseParamUpdateRecorder "
<<
varname
<<
" exist "
if
(
distributed_mode_
==
DistributedMode
::
kGeo
)
{
<<
AsyncSparseParamUpdateRecorder
::
GetInstance
()
->
HasParam
(
VLOG
(
1
)
<<
"AsyncSparseParamUpdateRecorder "
<<
varname
<<
" exist "
varname
);
<<
AsyncSparseParamUpdateRecorder
::
GetInstance
()
->
HasParam
(
if
(
AsyncSparseParamUpdateRecorder
::
GetInstance
()
->
HasParam
(
varname
)
&&
varname
);
}
if
(
distributed_mode_
==
DistributedMode
::
kGeo
&&
AsyncSparseParamUpdateRecorder
::
GetInstance
()
->
HasParam
(
varname
)
&&
!
table_name
.
empty
())
{
!
table_name
.
empty
())
{
std
::
vector
<
int64_t
>
updated_rows
;
std
::
vector
<
int64_t
>
updated_rows
;
AsyncSparseParamUpdateRecorder
::
GetInstance
()
->
GetAndClear
(
AsyncSparseParamUpdateRecorder
::
GetInstance
()
->
GetAndClear
(
...
...
paddle/fluid/operators/distributed/request_handler_impl.h
浏览文件 @
2e834eab
...
@@ -38,8 +38,8 @@ namespace distributed {
...
@@ -38,8 +38,8 @@ namespace distributed {
class
RequestSendHandler
final
:
public
RequestHandler
{
class
RequestSendHandler
final
:
public
RequestHandler
{
public:
public:
explicit
RequestSendHandler
(
bool
sync
_mode
,
bool
enable_dc_asgd
=
false
)
explicit
RequestSendHandler
(
int
distributed
_mode
,
bool
enable_dc_asgd
=
false
)
:
RequestHandler
(
sync
_mode
)
{
:
RequestHandler
(
distributed
_mode
)
{
enable_dc_asgd_
=
enable_dc_asgd
;
enable_dc_asgd_
=
enable_dc_asgd
;
}
}
virtual
~
RequestSendHandler
()
{}
virtual
~
RequestSendHandler
()
{}
...
@@ -54,8 +54,8 @@ class RequestSendHandler final : public RequestHandler {
...
@@ -54,8 +54,8 @@ class RequestSendHandler final : public RequestHandler {
class
RequestGetHandler
final
:
public
RequestHandler
{
class
RequestGetHandler
final
:
public
RequestHandler
{
public:
public:
explicit
RequestGetHandler
(
bool
sync
_mode
,
bool
enable_dc_asgd
=
false
)
explicit
RequestGetHandler
(
int
distributed
_mode
,
bool
enable_dc_asgd
=
false
)
:
RequestHandler
(
sync
_mode
)
{
:
RequestHandler
(
distributed
_mode
)
{
enable_dc_asgd_
=
enable_dc_asgd
;
enable_dc_asgd_
=
enable_dc_asgd
;
}
}
virtual
~
RequestGetHandler
()
{}
virtual
~
RequestGetHandler
()
{}
...
@@ -89,7 +89,8 @@ static inline void BuildVar(const std::string& param_name,
...
@@ -89,7 +89,8 @@ static inline void BuildVar(const std::string& param_name,
class
RequestPrefetchHandler
final
:
public
RequestHandler
{
class
RequestPrefetchHandler
final
:
public
RequestHandler
{
public:
public:
explicit
RequestPrefetchHandler
(
bool
sync_mode
)
:
RequestHandler
(
sync_mode
)
{}
explicit
RequestPrefetchHandler
(
int
distributed_mode
)
:
RequestHandler
(
distributed_mode
)
{}
virtual
~
RequestPrefetchHandler
()
{}
virtual
~
RequestPrefetchHandler
()
{}
bool
Handle
(
const
std
::
string
&
varname
,
framework
::
Scope
*
scope
,
bool
Handle
(
const
std
::
string
&
varname
,
framework
::
Scope
*
scope
,
framework
::
Variable
*
var
,
framework
::
Variable
**
outvar
,
framework
::
Variable
*
var
,
framework
::
Variable
**
outvar
,
...
@@ -113,8 +114,9 @@ class RequestPrefetchHandler final : public RequestHandler {
...
@@ -113,8 +114,9 @@ class RequestPrefetchHandler final : public RequestHandler {
class
RequestCheckpointHandler
final
:
public
RequestHandler
{
class
RequestCheckpointHandler
final
:
public
RequestHandler
{
public:
public:
explicit
RequestCheckpointHandler
(
bool
sync_mode
,
int
checkpoint_notify_id
)
explicit
RequestCheckpointHandler
(
int
distributed_mode
,
:
RequestHandler
(
sync_mode
)
{
int
checkpoint_notify_id
)
:
RequestHandler
(
distributed_mode
)
{
this
->
checkpoint_notify_id
=
checkpoint_notify_id
;
this
->
checkpoint_notify_id
=
checkpoint_notify_id
;
}
}
virtual
~
RequestCheckpointHandler
()
{}
virtual
~
RequestCheckpointHandler
()
{}
...
@@ -129,8 +131,8 @@ class RequestCheckpointHandler final : public RequestHandler {
...
@@ -129,8 +131,8 @@ class RequestCheckpointHandler final : public RequestHandler {
class
RequestNotifyHandler
final
:
public
RequestHandler
{
class
RequestNotifyHandler
final
:
public
RequestHandler
{
public:
public:
explicit
RequestNotifyHandler
(
bool
sync
_mode
,
int
lr_decay_block_id
)
explicit
RequestNotifyHandler
(
int
distributed
_mode
,
int
lr_decay_block_id
)
:
RequestHandler
(
sync
_mode
)
{
:
RequestHandler
(
distributed
_mode
)
{
this
->
lr_decay_block_id
=
lr_decay_block_id
;
this
->
lr_decay_block_id
=
lr_decay_block_id
;
}
}
virtual
~
RequestNotifyHandler
()
{}
virtual
~
RequestNotifyHandler
()
{}
...
...
paddle/fluid/operators/distributed/rpc_server_test.cc
浏览文件 @
2e834eab
...
@@ -131,7 +131,8 @@ void StartServer(const std::string& rpc_name) {
...
@@ -131,7 +131,8 @@ void StartServer(const std::string& rpc_name) {
TEST
(
PREFETCH
,
CPU
)
{
TEST
(
PREFETCH
,
CPU
)
{
setenv
(
"http_proxy"
,
""
,
1
);
setenv
(
"http_proxy"
,
""
,
1
);
setenv
(
"https_proxy"
,
""
,
1
);
setenv
(
"https_proxy"
,
""
,
1
);
g_req_handler
.
reset
(
new
distributed
::
RequestPrefetchHandler
(
true
));
g_req_handler
.
reset
(
new
distributed
::
RequestPrefetchHandler
(
distributed
::
DistributedMode
::
kSync
));
g_rpc_service
.
reset
(
new
RPCSERVER_T
(
"127.0.0.1:0"
,
1
));
g_rpc_service
.
reset
(
new
RPCSERVER_T
(
"127.0.0.1:0"
,
1
));
distributed
::
RPCClient
*
client
=
distributed
::
RPCClient
*
client
=
distributed
::
RPCClient
::
GetInstance
<
RPCCLIENT_T
>
(
0
);
distributed
::
RPCClient
::
GetInstance
<
RPCCLIENT_T
>
(
0
);
...
@@ -173,7 +174,8 @@ TEST(PREFETCH, CPU) {
...
@@ -173,7 +174,8 @@ TEST(PREFETCH, CPU) {
TEST
(
COMPLETE
,
CPU
)
{
TEST
(
COMPLETE
,
CPU
)
{
setenv
(
"http_proxy"
,
""
,
1
);
setenv
(
"http_proxy"
,
""
,
1
);
setenv
(
"https_proxy"
,
""
,
1
);
setenv
(
"https_proxy"
,
""
,
1
);
g_req_handler
.
reset
(
new
distributed
::
RequestSendHandler
(
true
));
g_req_handler
.
reset
(
new
distributed
::
RequestSendHandler
(
distributed
::
DistributedMode
::
kSync
));
g_rpc_service
.
reset
(
new
RPCSERVER_T
(
"127.0.0.1:0"
,
2
));
g_rpc_service
.
reset
(
new
RPCSERVER_T
(
"127.0.0.1:0"
,
2
));
distributed
::
RPCClient
*
client
=
distributed
::
RPCClient
*
client
=
distributed
::
RPCClient
::
GetInstance
<
RPCCLIENT_T
>
(
0
);
distributed
::
RPCClient
::
GetInstance
<
RPCCLIENT_T
>
(
0
);
...
...
paddle/fluid/operators/distributed_ops/fl_listen_and_serv_op.cc
浏览文件 @
2e834eab
...
@@ -199,9 +199,9 @@ void FlListenAndServOp::RunImpl(const framework::Scope &scope,
...
@@ -199,9 +199,9 @@ void FlListenAndServOp::RunImpl(const framework::Scope &scope,
rpc_service_
.
reset
(
new
RPCSERVER_T
(
endpoint
,
fan_in
));
rpc_service_
.
reset
(
new
RPCSERVER_T
(
endpoint
,
fan_in
));
request_send_handler_
.
reset
(
request_send_handler_
.
reset
(
new
distributed
::
RequestSendHandler
(
sync_mode
,
false
));
new
distributed
::
RequestSendHandler
(
!
sync_mode
,
false
));
request_get_handler_
.
reset
(
request_get_handler_
.
reset
(
new
distributed
::
RequestGetHandler
(
sync_mode
,
false
));
new
distributed
::
RequestGetHandler
(
!
sync_mode
,
false
));
rpc_service_
->
RegisterRPC
(
distributed
::
kRequestSend
,
rpc_service_
->
RegisterRPC
(
distributed
::
kRequestSend
,
request_send_handler_
.
get
(),
request_send_handler_
.
get
(),
...
...
paddle/fluid/operators/distributed_ops/gen_nccl_id_op.cc
浏览文件 @
2e834eab
...
@@ -184,7 +184,7 @@ class GenNCCLIdOp : public framework::OperatorBase {
...
@@ -184,7 +184,7 @@ class GenNCCLIdOp : public framework::OperatorBase {
// NOTE: Can not use unique_ptr here because the default
// NOTE: Can not use unique_ptr here because the default
// deleter will call GRPC Server's base class's dtor and
// deleter will call GRPC Server's base class's dtor and
// that will cause a wired crash.
// that will cause a wired crash.
distributed
::
RequestSendHandler
rpc_h
(
true
);
distributed
::
RequestSendHandler
rpc_h
(
distributed
::
DistributedMode
::
kSync
);
std
::
unique_ptr
<
distributed
::
RPCServer
>
rpc_service
(
std
::
unique_ptr
<
distributed
::
RPCServer
>
rpc_service
(
new
RPCSERVER_T
(
endpoint
,
1
));
new
RPCSERVER_T
(
endpoint
,
1
));
...
...
paddle/fluid/operators/distributed_ops/listen_and_serv_op.cc
浏览文件 @
2e834eab
...
@@ -338,7 +338,7 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
...
@@ -338,7 +338,7 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
auto
&
dev_ctx
=
*
pool
.
Get
(
dev_place
);
auto
&
dev_ctx
=
*
pool
.
Get
(
dev_place
);
framework
::
Scope
&
recv_scope
=
scope
.
NewScope
();
framework
::
Scope
&
recv_scope
=
scope
.
NewScope
();
bool
sync_mode
=
Attr
<
bool
>
(
"sync
_mode"
);
int
distributed_mode
=
Attr
<
int
>
(
"distributed
_mode"
);
bool
dc_sgd
=
Attr
<
bool
>
(
"dc_asgd"
);
bool
dc_sgd
=
Attr
<
bool
>
(
"dc_asgd"
);
auto
fan_in
=
Attr
<
int
>
(
"Fanin"
);
auto
fan_in
=
Attr
<
int
>
(
"Fanin"
);
auto
pserver_id
=
Attr
<
int
>
(
"pserver_id"
);
auto
pserver_id
=
Attr
<
int
>
(
"pserver_id"
);
...
@@ -349,8 +349,9 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
...
@@ -349,8 +349,9 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
int
checkpoint_block_id
=
Attr
<
int
>
(
kCheckpointBlockId
);
int
checkpoint_block_id
=
Attr
<
int
>
(
kCheckpointBlockId
);
int
lr_decay_block_id
=
Attr
<
int
>
(
kLRDecayBlockId
);
int
lr_decay_block_id
=
Attr
<
int
>
(
kLRDecayBlockId
);
VLOG
(
4
)
<<
"pserver_id: "
<<
pserver_id
<<
", sync_mode:"
<<
sync_mode
VLOG
(
4
)
<<
"pserver_id: "
<<
pserver_id
<<
", fan_in:"
<<
fan_in
<<
", end_point:"
<<
endpoint
<<
", distributed_mode:"
<<
distributed_mode
<<
", fan_in:"
<<
fan_in
<<
", end_point:"
<<
endpoint
<<
", checkpoint_block_id: "
<<
checkpoint_block_id
<<
", checkpoint_block_id: "
<<
checkpoint_block_id
<<
", lr_decay_block_id: "
<<
lr_decay_block_id
;
<<
", lr_decay_block_id: "
<<
lr_decay_block_id
;
...
@@ -361,17 +362,17 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
...
@@ -361,17 +362,17 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
auto
rpc_prefetch_thread_num
=
Attr
<
int
>
(
"rpc_prefetch_thread_num"
);
auto
rpc_prefetch_thread_num
=
Attr
<
int
>
(
"rpc_prefetch_thread_num"
);
request_send_handler_
.
reset
(
request_send_handler_
.
reset
(
new
distributed
::
RequestSendHandler
(
sync
_mode
,
dc_sgd
));
new
distributed
::
RequestSendHandler
(
distributed
_mode
,
dc_sgd
));
request_get_handler_
.
reset
(
request_get_handler_
.
reset
(
new
distributed
::
RequestGetHandler
(
sync
_mode
,
dc_sgd
));
new
distributed
::
RequestGetHandler
(
distributed
_mode
,
dc_sgd
));
request_prefetch_handler_
.
reset
(
request_prefetch_handler_
.
reset
(
new
distributed
::
RequestPrefetchHandler
(
sync
_mode
));
new
distributed
::
RequestPrefetchHandler
(
distributed
_mode
));
request_checkpoint_handler_
.
reset
(
new
distributed
::
RequestCheckpointHandler
(
request_checkpoint_handler_
.
reset
(
new
distributed
::
RequestCheckpointHandler
(
sync
_mode
,
checkpoint_block_id
));
distributed
_mode
,
checkpoint_block_id
));
request_get_no_barrier_handler_
.
reset
(
request_get_no_barrier_handler_
.
reset
(
new
distributed
::
RequestGetNoBarrierHandler
());
new
distributed
::
RequestGetNoBarrierHandler
());
request_notify_handler_
.
reset
(
request_notify_handler_
.
reset
(
new
distributed
::
RequestNotifyHandler
(
new
distributed
::
RequestNotifyHandler
(
sync
_mode
,
lr_decay_block_id
));
distributed
_mode
,
lr_decay_block_id
));
rpc_service_
->
RegisterRPC
(
distributed
::
kRequestSend
,
rpc_service_
->
RegisterRPC
(
distributed
::
kRequestSend
,
request_send_handler_
.
get
(),
rpc_send_thread_num
);
request_send_handler_
.
get
(),
rpc_send_thread_num
);
...
@@ -469,7 +470,7 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
...
@@ -469,7 +470,7 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
signal
(
SIGINT
,
SignalHandler
::
StopAndExit
);
signal
(
SIGINT
,
SignalHandler
::
StopAndExit
);
signal
(
SIGTERM
,
SignalHandler
::
StopAndExit
);
signal
(
SIGTERM
,
SignalHandler
::
StopAndExit
);
if
(
sync_mode
)
{
if
(
distributed_mode
==
distributed
::
DistributedMode
::
kSync
)
{
// start the server listening after all member initialized.
// start the server listening after all member initialized.
server_thread_
.
reset
(
new
std
::
thread
(
RunServer
,
rpc_service_
));
server_thread_
.
reset
(
new
std
::
thread
(
RunServer
,
rpc_service_
));
VLOG
(
3
)
<<
"wait server thread to become ready..."
;
VLOG
(
3
)
<<
"wait server thread to become ready..."
;
...
@@ -483,8 +484,10 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
...
@@ -483,8 +484,10 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
RunSyncLoop
(
&
executor
,
program
,
&
recv_scope
,
&
dev_ctx
,
RunSyncLoop
(
&
executor
,
program
,
&
recv_scope
,
&
dev_ctx
,
prefetch_block_id_list
,
checkpoint_block_id
);
prefetch_block_id_list
,
checkpoint_block_id
);
}
else
{
}
else
{
distributed
::
AsyncSparseParamUpdateRecorder
::
Init
(
if
(
distributed_mode
==
distributed
::
DistributedMode
::
kGeo
)
{
fan_in
,
sparse_grad_name_to_param_name
);
distributed
::
AsyncSparseParamUpdateRecorder
::
Init
(
fan_in
,
sparse_grad_name_to_param_name
);
}
VLOG
(
2
)
<<
"RunAsyncLoop"
;
VLOG
(
2
)
<<
"RunAsyncLoop"
;
auto
grad_to_block_id_str
=
auto
grad_to_block_id_str
=
...
@@ -530,7 +533,10 @@ class ListenAndServOpMaker : public framework::OpProtoAndCheckerMaker {
...
@@ -530,7 +533,10 @@ class ListenAndServOpMaker : public framework::OpProtoAndCheckerMaker {
"['param1@GRAD.block0:1', 'param2@GRAD.blockn:2'] "
"['param1@GRAD.block0:1', 'param2@GRAD.blockn:2'] "
"a map from grad name to it's optimize block id"
)
"a map from grad name to it's optimize block id"
)
.
SetDefault
({});
.
SetDefault
({});
AddAttr
<
bool
>
(
"sync_mode"
,
"if works at sync_mode or not"
).
SetDefault
(
true
);
AddAttr
<
int
>
(
"distributed_mode"
,
"indicate distriubte training mode, 0 is sync, 1 is "
"fully-async, 2 is half-async, 3 is geo"
)
.
SetDefault
(
0
);
AddAttr
<
bool
>
(
"dc_asgd"
,
"set to true will enable DC-ASGD training."
)
AddAttr
<
bool
>
(
"dc_asgd"
,
"set to true will enable DC-ASGD training."
)
.
SetDefault
(
false
);
.
SetDefault
(
false
);
AddAttr
<
std
::
vector
<
framework
::
BlockDesc
*>>
(
AddAttr
<
std
::
vector
<
framework
::
BlockDesc
*>>
(
...
...
paddle/fluid/operators/distributed_ops/send_recv_op_test.cc
浏览文件 @
2e834eab
...
@@ -32,9 +32,11 @@ USE_OP(sum);
...
@@ -32,9 +32,11 @@ USE_OP(sum);
namespace
f
=
paddle
::
framework
;
namespace
f
=
paddle
::
framework
;
namespace
p
=
paddle
::
platform
;
namespace
p
=
paddle
::
platform
;
namespace
m
=
paddle
::
operators
::
math
;
namespace
m
=
paddle
::
operators
::
math
;
namespace
d
=
paddle
::
operators
::
distributed
// global for simplicity.
// global for simplicity.
std
::
unique_ptr
<
f
::
OperatorBase
>
listen_and_serv_op
;
std
::
unique_ptr
<
f
::
OperatorBase
>
listen_and_serv_op
;
int
selected_port
;
int
selected_port
;
void
InitTensorsInScope
(
const
p
::
CPUPlace
&
place
,
f
::
Scope
*
scope
)
{
void
InitTensorsInScope
(
const
p
::
CPUPlace
&
place
,
f
::
Scope
*
scope
)
{
...
@@ -145,7 +147,7 @@ void StartServerNet(bool is_sparse, std::atomic<bool> *initialized) {
...
@@ -145,7 +147,7 @@ void StartServerNet(bool is_sparse, std::atomic<bool> *initialized) {
attrs
.
insert
({
"optimize_blocks"
,
optimize_blocks
});
attrs
.
insert
({
"optimize_blocks"
,
optimize_blocks
});
attrs
.
insert
({
"PrefetchBlock"
,
prefetch_block
});
attrs
.
insert
({
"PrefetchBlock"
,
prefetch_block
});
attrs
.
insert
({
"grad_to_block_id"
,
std
::
vector
<
std
::
string
>
({
""
})});
attrs
.
insert
({
"grad_to_block_id"
,
std
::
vector
<
std
::
string
>
({
""
})});
attrs
.
insert
({
"
sync_mode"
,
true
});
attrs
.
insert
({
"
distributed_mode"
,
d
::
DistributedMode
::
kSync
});
VLOG
(
4
)
<<
"before init op"
;
VLOG
(
4
)
<<
"before init op"
;
listen_and_serv_op
=
listen_and_serv_op
=
f
::
OpRegistry
::
CreateOp
(
"listen_and_serv"
,
{{
"X"
,
{
"x1"
}}},
{},
attrs
);
f
::
OpRegistry
::
CreateOp
(
"listen_and_serv"
,
{{
"X"
,
{
"x1"
}}},
{},
attrs
);
...
...
paddle/fluid/operators/distributed_ops/test_send_nccl_id.cc
浏览文件 @
2e834eab
...
@@ -72,7 +72,8 @@ void StartServer() {
...
@@ -72,7 +72,8 @@ void StartServer() {
}
}
TEST
(
SendNcclId
,
RPCServer
)
{
TEST
(
SendNcclId
,
RPCServer
)
{
g_req_handler
.
reset
(
new
distributed
::
RequestSendHandler
(
true
));
g_req_handler
.
reset
(
new
distributed
::
RequestSendHandler
(
distributed
::
DistributedMode
::
kSync
));
g_rpc_service
.
reset
(
new
RPCSERVER_T
(
"127.0.0.1:0"
,
1
));
g_rpc_service
.
reset
(
new
RPCSERVER_T
(
"127.0.0.1:0"
,
1
));
std
::
thread
server_thread
(
StartServer
);
std
::
thread
server_thread
(
StartServer
);
...
...
python/paddle/fluid/layers/io.py
浏览文件 @
2e834eab
...
@@ -29,6 +29,7 @@ from ..framework import convert_np_dtype_to_dtype_, default_main_program, \
...
@@ -29,6 +29,7 @@ from ..framework import convert_np_dtype_to_dtype_, default_main_program, \
default_startup_program
,
program_guard
,
Program
,
Variable
default_startup_program
,
program_guard
,
Program
,
Variable
from
..layer_helper
import
LayerHelper
from
..layer_helper
import
LayerHelper
from
..unique_name
import
generate
as
unique_name
from
..unique_name
import
generate
as
unique_name
from
..transpiler.distribute_transpiler
import
DistributedMode
import
logging
import
logging
__all__
=
[
__all__
=
[
...
@@ -240,7 +241,8 @@ class ListenAndServ(object):
...
@@ -240,7 +241,8 @@ class ListenAndServ(object):
'optimize_blocks'
:
[
'optimize_blocks'
:
[
current_block
current_block
],
# did not support multiple optimize blocks in layers
],
# did not support multiple optimize blocks in layers
'sync_mode'
:
True
,
# did not support async now in layers
'distributed_mode'
:
DistributedMode
.
SYNC
,
# did not support async now in layers
'grad_to_block_id'
:
[
""
]
'grad_to_block_id'
:
[
""
]
})
})
...
...
python/paddle/fluid/tests/unittests/test_dist_transpiler.py
浏览文件 @
2e834eab
...
@@ -62,10 +62,10 @@ class TranspilerTest(unittest.TestCase):
...
@@ -62,10 +62,10 @@ class TranspilerTest(unittest.TestCase):
self
.
origin_prog
=
main
.
clone
()
self
.
origin_prog
=
main
.
clone
()
return
main
return
main
def
get_trainer
(
self
,
config
=
None
):
def
get_trainer
(
self
,
config
=
None
,
sync_mode
=
True
):
src
=
fluid
.
default_startup_program
().
clone
()
src
=
fluid
.
default_startup_program
().
clone
()
t
=
self
.
_transpiler_instance
(
config
)
t
=
self
.
_transpiler_instance
(
config
,
sync_mode
=
True
)
trainer_main
=
t
.
get_trainer_program
(
wait_port
=
False
)
trainer_main
=
t
.
get_trainer_program
(
wait_port
=
False
)
trainer_startup
=
fluid
.
default_startup_program
()
trainer_startup
=
fluid
.
default_startup_program
()
...
...
python/paddle/fluid/tests/unittests/test_dist_transpiler_config.py
0 → 100644
浏览文件 @
2e834eab
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
unittest
import
paddle.fluid
as
fluid
import
gc
gc
.
set_debug
(
gc
.
DEBUG_COLLECTABLE
)
class
TranspilerTest
(
unittest
.
TestCase
):
def
setUp
(
self
):
self
.
trainer_id
=
0
self
.
trainers
=
2
self
.
pservers
=
2
# NOTE: we do not actually bind this port
self
.
pserver_eps
=
"127.0.0.1:6174,127.0.0.1:6175"
self
.
pserver1_ep
=
"127.0.0.1:6174"
self
.
pserver2_ep
=
"127.0.0.1:6175"
self
.
sync_mode
=
True
self
.
transpiler
=
None
def
net_conf
(
self
):
x
=
fluid
.
layers
.
data
(
name
=
'x'
,
shape
=
[
1000
],
dtype
=
'float32'
)
y_predict
=
fluid
.
layers
.
fc
(
input
=
x
,
size
=
1000
,
act
=
None
,
param_attr
=
fluid
.
ParamAttr
(
name
=
'fc_w'
),
bias_attr
=
fluid
.
ParamAttr
(
name
=
'fc_b'
))
y
=
fluid
.
layers
.
data
(
name
=
'y'
,
shape
=
[
1
],
dtype
=
'float32'
)
cost
=
fluid
.
layers
.
square_error_cost
(
input
=
y_predict
,
label
=
y
)
avg_cost
=
fluid
.
layers
.
mean
(
cost
)
sgd_optimizer
=
fluid
.
optimizer
.
SGD
(
learning_rate
=
0.1
)
sgd_optimizer
.
minimize
(
avg_cost
)
def
get_main_program
(
self
):
main
=
fluid
.
Program
()
main
.
random_seed
=
1
with
fluid
.
program_guard
(
main
):
self
.
net_conf
()
self
.
origin_prog
=
main
.
clone
()
return
main
def
get_trainer
(
self
,
config
=
None
,
sync_mode
=
True
):
src
=
fluid
.
default_startup_program
().
clone
()
t
=
self
.
_transpiler_instance
(
config
,
sync_mode
=
True
)
trainer_main
=
t
.
get_trainer_program
(
wait_port
=
False
)
trainer_startup
=
fluid
.
default_startup_program
()
assert
(
src
.
num_blocks
==
1
)
assert
(
trainer_startup
.
num_blocks
==
src
.
num_blocks
)
return
trainer_main
,
trainer_startup
def
get_pserver
(
self
,
ep
,
config
=
None
,
sync_mode
=
True
):
t
=
self
.
_transpiler_instance
(
config
,
sync_mode
)
pserver
=
t
.
get_pserver_program
(
ep
)
startup
=
t
.
get_startup_program
(
ep
,
pserver
)
return
pserver
,
startup
def
_transpiler_instance
(
self
,
config
=
None
,
sync_mode
=
True
):
if
not
self
.
transpiler
:
main
=
self
.
get_main_program
()
self
.
transpiler
=
fluid
.
DistributeTranspiler
(
config
=
config
)
self
.
transpiler
.
transpile
(
self
.
trainer_id
,
program
=
main
,
pservers
=
self
.
pserver_eps
,
trainers
=
self
.
trainers
,
sync_mode
=
sync_mode
)
return
self
.
transpiler
def
transpiler_test_impl
(
self
):
pass
def
test_transpiler
(
self
):
main
=
fluid
.
Program
()
startup
=
fluid
.
Program
()
with
fluid
.
unique_name
.
guard
():
with
fluid
.
program_guard
(
main
,
startup
):
self
.
transpiler_test_impl
()
# NOTE: run gc.collect to eliminate pybind side objects to
# prevent random double-deallocate when inherited in python.
del
self
.
transpiler
del
main
del
startup
gc
.
collect
()
class
TestBasicModelAsync
(
TranspilerTest
):
def
transpiler_test_impl
(
self
):
config
=
fluid
.
DistributeTranspilerConfig
()
config
.
sync_mode
=
False
config
.
runtime_split_send_recv
=
True
pserver
,
startup
=
self
.
get_pserver
(
self
.
pserver1_ep
,
config
,
False
)
pserver2
,
startup2
=
self
.
get_pserver
(
self
.
pserver2_ep
,
config
,
False
)
trainer
,
_
=
self
.
get_trainer
(
config
,
False
)
self
.
assertEqual
([
op
.
type
for
op
in
trainer
.
global_block
().
ops
],
[
'mul'
,
'elementwise_add'
,
'elementwise_sub'
,
'square'
,
'mean'
,
'fill_constant'
,
'mean_grad'
,
'square_grad'
,
'elementwise_sub_grad'
,
'elementwise_add_grad'
,
'send'
,
'mul_grad'
,
'send'
,
'recv'
,
'recv'
])
self
.
assertEqual
(
len
(
pserver
.
blocks
),
3
)
# block0: listen_and_serv
self
.
assertEqual
([
op
.
type
for
op
in
pserver
.
blocks
[
0
].
ops
],
[
"listen_and_serv"
])
self
.
assertEqual
(
pserver
.
blocks
[
0
].
ops
[
0
].
attr
(
"distributed_mode"
),
1
)
# block1~2: optimize pass
self
.
assertEqual
([
op
.
type
for
op
in
pserver
.
blocks
[
2
].
ops
],
[
"sgd"
])
class
TestBasicModelHalfAsync
(
TranspilerTest
):
def
transpiler_test_impl
(
self
):
config
=
fluid
.
DistributeTranspilerConfig
()
config
.
sync_mode
=
False
config
.
runtime_split_send_recv
=
False
pserver
,
startup
=
self
.
get_pserver
(
self
.
pserver1_ep
,
config
,
False
)
pserver2
,
startup2
=
self
.
get_pserver
(
self
.
pserver2_ep
,
config
,
False
)
trainer
,
_
=
self
.
get_trainer
(
config
,
False
)
self
.
assertEqual
([
op
.
type
for
op
in
trainer
.
global_block
().
ops
],
[
'mul'
,
'elementwise_add'
,
'elementwise_sub'
,
'square'
,
'mean'
,
'fill_constant'
,
'mean_grad'
,
'square_grad'
,
'elementwise_sub_grad'
,
'elementwise_add_grad'
,
'send'
,
'mul_grad'
,
'split_byref'
,
'send'
,
'recv'
,
'recv'
,
'concat'
])
self
.
assertEqual
(
len
(
pserver
.
blocks
),
3
)
# block0: listen_and_serv
self
.
assertEqual
([
op
.
type
for
op
in
pserver
.
blocks
[
0
].
ops
],
[
"listen_and_serv"
])
self
.
assertEqual
(
pserver
.
blocks
[
0
].
ops
[
0
].
attr
(
"distributed_mode"
),
2
)
# block1~2: optimize pass
self
.
assertEqual
([
op
.
type
for
op
in
pserver
.
blocks
[
2
].
ops
],
[
"sgd"
])
class
TestBasicModelSync
(
TranspilerTest
):
def
transpiler_test_impl
(
self
):
config
=
fluid
.
DistributeTranspilerConfig
()
config
.
sync_mode
=
True
config
.
runtime_split_send_recv
=
False
pserver
,
startup
=
self
.
get_pserver
(
self
.
pserver1_ep
,
config
,
True
)
pserver2
,
startup2
=
self
.
get_pserver
(
self
.
pserver2_ep
,
config
,
True
)
trainer
,
_
=
self
.
get_trainer
(
config
,
True
)
self
.
assertEqual
([
op
.
type
for
op
in
trainer
.
global_block
().
ops
],
[
'mul'
,
'elementwise_add'
,
'elementwise_sub'
,
'square'
,
'mean'
,
'fill_constant'
,
'mean_grad'
,
'square_grad'
,
'elementwise_sub_grad'
,
'elementwise_add_grad'
,
'send'
,
'mul_grad'
,
'split_byref'
,
'send'
,
'send_barrier'
,
'recv'
,
'recv'
,
'fetch_barrier'
,
'concat'
])
self
.
assertEqual
(
len
(
pserver
.
blocks
),
3
)
# block0: listen_and_serv
self
.
assertEqual
([
op
.
type
for
op
in
pserver
.
blocks
[
0
].
ops
],
[
"listen_and_serv"
])
self
.
assertEqual
(
pserver
.
blocks
[
0
].
ops
[
0
].
attr
(
"distributed_mode"
),
0
)
# block1~2: optimize pass
self
.
assertEqual
([
op
.
type
for
op
in
pserver
.
blocks
[
2
].
ops
],
[
"sum"
,
"scale"
,
"sgd"
])
if
__name__
==
"__main__"
:
unittest
.
main
()
python/paddle/fluid/tests/unittests/test_listen_and_serv.sh
浏览文件 @
2e834eab
...
@@ -8,7 +8,7 @@ flag1=test_handle_signal_in_serv_op.flag
...
@@ -8,7 +8,7 @@ flag1=test_handle_signal_in_serv_op.flag
flag2
=
test_list_and_serv_run_empty_optimize_block.flag
flag2
=
test_list_and_serv_run_empty_optimize_block.flag
for
i
in
{
1..10
}
;
do
for
i
in
{
1..10
}
;
do
sleep
3
s
sleep
6
s
if
[[
-f
"
${
flag1
}
"
&&
-f
"
${
flag2
}
"
]]
;
then
if
[[
-f
"
${
flag1
}
"
&&
-f
"
${
flag2
}
"
]]
;
then
echo
"test_listen_and_serv_op exit"
echo
"test_listen_and_serv_op exit"
exit
0
exit
0
...
...
python/paddle/fluid/tests/unittests/test_listen_and_serv_op.py
浏览文件 @
2e834eab
...
@@ -52,7 +52,11 @@ def run_pserver(use_cuda, sync_mode, ip, port, trainers, trainer_id):
...
@@ -52,7 +52,11 @@ def run_pserver(use_cuda, sync_mode, ip, port, trainers, trainer_id):
config
=
fluid
.
DistributeTranspilerConfig
()
config
=
fluid
.
DistributeTranspilerConfig
()
config
.
sync_mode
=
sync_mode
config
.
sync_mode
=
sync_mode
t
=
fluid
.
DistributeTranspiler
(
config
=
config
)
t
=
fluid
.
DistributeTranspiler
(
config
=
config
)
t
.
transpile
(
trainer_id
,
pservers
=
pserver_endpoints
,
trainers
=
trainers
)
t
.
transpile
(
trainer_id
,
pservers
=
pserver_endpoints
,
trainers
=
trainers
,
sync_mode
=
sync_mode
)
pserver_prog
=
t
.
get_pserver_program
(
current_endpoint
)
pserver_prog
=
t
.
get_pserver_program
(
current_endpoint
)
pserver_startup
=
t
.
get_startup_program
(
current_endpoint
,
pserver_prog
)
pserver_startup
=
t
.
get_startup_program
(
current_endpoint
,
pserver_prog
)
exe
.
run
(
pserver_startup
)
exe
.
run
(
pserver_startup
)
...
@@ -86,7 +90,11 @@ def run_pserver_with_empty_block(use_cuda, sync_mode, ip, port, trainers,
...
@@ -86,7 +90,11 @@ def run_pserver_with_empty_block(use_cuda, sync_mode, ip, port, trainers,
config
.
slice_var_up
=
False
config
.
slice_var_up
=
False
t
=
fluid
.
DistributeTranspiler
(
config
=
config
)
t
=
fluid
.
DistributeTranspiler
(
config
=
config
)
t
.
transpile
(
trainer_id
,
pservers
=
pserver_endpoints
,
trainers
=
trainers
)
t
.
transpile
(
trainer_id
,
pservers
=
pserver_endpoints
,
trainers
=
trainers
,
sync_mode
=
sync_mode
)
pserver_prog
=
t
.
get_pserver_program
(
ps2
)
pserver_prog
=
t
.
get_pserver_program
(
ps2
)
# pserver2 have no parameter
# pserver2 have no parameter
...
...
python/paddle/fluid/tests/unittests/test_lookup_remote_table_op.py
浏览文件 @
2e834eab
...
@@ -25,6 +25,7 @@ import paddle.fluid as fluid
...
@@ -25,6 +25,7 @@ import paddle.fluid as fluid
import
paddle.fluid.core
as
core
import
paddle.fluid.core
as
core
from
paddle.fluid.op
import
Operator
from
paddle.fluid.op
import
Operator
from
paddle.fluid.framework
import
Program
,
program_guard
from
paddle.fluid.framework
import
Program
,
program_guard
from
paddle.fluid.transpiler.distribute_transpiler
import
DistributedMode
from
dist_test_utils
import
*
from
dist_test_utils
import
*
...
@@ -53,7 +54,7 @@ def run_pserver(pserver_id, use_cuda, sync_mode):
...
@@ -53,7 +54,7 @@ def run_pserver(pserver_id, use_cuda, sync_mode):
"optimize_blocks"
:
[
optimize_block
],
"optimize_blocks"
:
[
optimize_block
],
"endpoint"
:
'127.0.0.1:0'
,
"endpoint"
:
'127.0.0.1:0'
,
"Fanin"
:
1
,
"Fanin"
:
1
,
"
sync_mode"
:
True
,
"
distributed_mode"
:
DistributedMode
.
SYNC
,
"grad_to_block_id"
:
[]
"grad_to_block_id"
:
[]
})
})
...
...
python/paddle/fluid/tests/unittests/test_nce_remote_table_op.py
浏览文件 @
2e834eab
...
@@ -26,6 +26,7 @@ import paddle.fluid.core as core
...
@@ -26,6 +26,7 @@ import paddle.fluid.core as core
from
paddle.fluid.op
import
Operator
from
paddle.fluid.op
import
Operator
from
paddle.fluid.framework
import
Program
,
program_guard
from
paddle.fluid.framework
import
Program
,
program_guard
from
dist_test_utils
import
*
from
dist_test_utils
import
*
from
paddle.fluid.transpiler.distribute_transpiler
import
DistributedMode
def
nce
(
input
,
weight
,
bias
,
sample_weight
,
labels
,
num_classes
,
def
nce
(
input
,
weight
,
bias
,
sample_weight
,
labels
,
num_classes
,
...
@@ -92,7 +93,7 @@ def run_pserver(pserver_id, use_cuda, sync_mode):
...
@@ -92,7 +93,7 @@ def run_pserver(pserver_id, use_cuda, sync_mode):
"optimize_blocks"
:
[
optimize_block
],
"optimize_blocks"
:
[
optimize_block
],
"endpoint"
:
'127.0.0.1:0'
,
"endpoint"
:
'127.0.0.1:0'
,
"Fanin"
:
1
,
"Fanin"
:
1
,
"
sync_mode"
:
True
,
"
distributed_mode"
:
DistributedMode
.
SYNC
,
"grad_to_block_id"
:
[]
"grad_to_block_id"
:
[]
})
})
...
...
python/paddle/fluid/tests/unittests/test_recv_save_op.py
浏览文件 @
2e834eab
...
@@ -29,6 +29,7 @@ from paddle.fluid.op import Operator
...
@@ -29,6 +29,7 @@ from paddle.fluid.op import Operator
from
paddle.fluid.framework
import
Program
,
program_guard
from
paddle.fluid.framework
import
Program
,
program_guard
from
paddle.fluid.transpiler.details
import
VarStruct
,
VarsDistributed
from
paddle.fluid.transpiler.details
import
VarStruct
,
VarsDistributed
from
dist_test_utils
import
*
from
dist_test_utils
import
*
from
paddle.fluid.transpiler.distribute_transpiler
import
DistributedMode
def
run_pserver
(
pserver_id
):
def
run_pserver
(
pserver_id
):
...
@@ -56,7 +57,7 @@ def run_pserver(pserver_id):
...
@@ -56,7 +57,7 @@ def run_pserver(pserver_id):
"optimize_blocks"
:
[
optimize_block
],
"optimize_blocks"
:
[
optimize_block
],
"endpoint"
:
'127.0.0.1:0'
,
"endpoint"
:
'127.0.0.1:0'
,
"Fanin"
:
1
,
"Fanin"
:
1
,
"
sync_mode"
:
True
,
"
distributed_mode"
:
DistributedMode
.
SYNC
,
"grad_to_block_id"
:
[]
"grad_to_block_id"
:
[]
})
})
...
...
python/paddle/fluid/transpiler/distribute_transpiler.py
浏览文件 @
2e834eab
...
@@ -65,6 +65,13 @@ LR_SCHED_OP_ROLE_ATTR_VALUE = core.op_proto_and_checker_maker.OpRole.LRSched
...
@@ -65,6 +65,13 @@ LR_SCHED_OP_ROLE_ATTR_VALUE = core.op_proto_and_checker_maker.OpRole.LRSched
PRINT_LOG
=
False
PRINT_LOG
=
False
class
DistributedMode
:
SYNC
=
0
ASYNC
=
1
HALF_ASYNC
=
2
GEO
=
3
def
log
(
*
args
):
def
log
(
*
args
):
if
PRINT_LOG
:
if
PRINT_LOG
:
print
(
args
)
print
(
args
)
...
@@ -313,6 +320,13 @@ class DistributeTranspiler(object):
...
@@ -313,6 +320,13 @@ class DistributeTranspiler(object):
if
self
.
config
.
split_method
is
None
:
if
self
.
config
.
split_method
is
None
:
self
.
config
.
split_method
=
RoundRobin
self
.
config
.
split_method
=
RoundRobin
if
self
.
config
.
sync_mode
:
self
.
distributed_mode
=
DistributedMode
.
SYNC
elif
self
.
config
.
runtime_split_send_recv
:
self
.
distributed_mode
=
DistributedMode
.
ASYNC
else
:
self
.
distributed_mode
=
DistributedMode
.
HALF_ASYNC
global
PRINT_LOG
global
PRINT_LOG
if
self
.
config
.
print_log
:
if
self
.
config
.
print_log
:
PRINT_LOG
=
True
PRINT_LOG
=
True
...
@@ -1333,7 +1347,7 @@ class DistributeTranspiler(object):
...
@@ -1333,7 +1347,7 @@ class DistributeTranspiler(object):
"endpoint"
:
endpoint
,
"endpoint"
:
endpoint
,
"pserver_id"
:
self
.
pserver_endpoints
.
index
(
endpoint
),
"pserver_id"
:
self
.
pserver_endpoints
.
index
(
endpoint
),
"Fanin"
:
self
.
trainer_num
,
"Fanin"
:
self
.
trainer_num
,
"
sync_mode"
:
self
.
sync
_mode
,
"
distributed_mode"
:
self
.
distributed
_mode
,
"grad_to_block_id"
:
grad_to_block_id
,
"grad_to_block_id"
:
grad_to_block_id
,
"sparse_grad_to_param"
:
sparse_grad_to_param
,
"sparse_grad_to_param"
:
sparse_grad_to_param
,
"lr_decay_block_id"
:
lr_decay_block_id
,
"lr_decay_block_id"
:
lr_decay_block_id
,
...
...
python/paddle/fluid/transpiler/geo_sgd_transpiler.py
浏览文件 @
2e834eab
...
@@ -38,7 +38,7 @@ from ..framework import Program, default_main_program, \
...
@@ -38,7 +38,7 @@ from ..framework import Program, default_main_program, \
from
.details
import
wait_server_ready
,
VarsDistributed
from
.details
import
wait_server_ready
,
VarsDistributed
from
.details
import
delete_ops
from
.details
import
delete_ops
from
..distribute_lookup_table
import
find_distributed_lookup_table
from
..distribute_lookup_table
import
find_distributed_lookup_table
from
.distribute_transpiler
import
DistributeTranspiler
,
DistributeTranspilerConfig
,
slice_variable
,
same_or_split_var
,
ServerRuntimeConfig
from
.distribute_transpiler
import
DistributeTranspiler
,
DistributeTranspilerConfig
,
slice_variable
,
same_or_split_var
,
ServerRuntimeConfig
,
DistributedMode
RPC_OP_ROLE_ATTR_NAME
=
op_role_attr_name
=
core
.
op_proto_and_checker_maker
.
kOpRoleAttrName
(
RPC_OP_ROLE_ATTR_NAME
=
op_role_attr_name
=
core
.
op_proto_and_checker_maker
.
kOpRoleAttrName
(
)
)
...
@@ -247,7 +247,7 @@ class GeoSgdTranspiler(DistributeTranspiler):
...
@@ -247,7 +247,7 @@ class GeoSgdTranspiler(DistributeTranspiler):
"optimize_blocks"
:
optimize_block
,
"optimize_blocks"
:
optimize_block
,
"endpoint"
:
endpoint
,
"endpoint"
:
endpoint
,
"Fanin"
:
self
.
trainer_num
,
"Fanin"
:
self
.
trainer_num
,
"
sync_mode"
:
self
.
sync_mode
,
"
distributed_mode"
:
DistributedMode
.
GEO
,
"grad_to_block_id"
:
param_to_block_id
,
"grad_to_block_id"
:
param_to_block_id
,
"sparse_grad_to_param"
:
sparse_grad_to_param
,
"sparse_grad_to_param"
:
sparse_grad_to_param
,
"rpc_get_thread_num"
:
self
.
server_config
.
_rpc_get_thread_num
,
"rpc_get_thread_num"
:
self
.
server_config
.
_rpc_get_thread_num
,
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录