Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
985bceac
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2299
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
985bceac
编写于
1月 13, 2020
作者:
1
123malin
提交者:
GitHub
1月 13, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Bug fix for sparse recorder (#21969)
* test=develop, bug fix for sparse recorder
上级
7e2af4c9
变更
22
隐藏空白更改
内联
并排
Showing
22 changed file
with
291 addition
and
62 deletion
+291
-62
paddle/fluid/operators/collective/c_gen_nccl_id_op.cc
paddle/fluid/operators/collective/c_gen_nccl_id_op.cc
+1
-1
paddle/fluid/operators/distributed/brpc/brpc_server.cc
paddle/fluid/operators/distributed/brpc/brpc_server.cc
+3
-1
paddle/fluid/operators/distributed/grpc/grpc_server.cc
paddle/fluid/operators/distributed/grpc/grpc_server.cc
+6
-6
paddle/fluid/operators/distributed/request_handler.h
paddle/fluid/operators/distributed/request_handler.h
+6
-4
paddle/fluid/operators/distributed/request_handler_impl.cc
paddle/fluid/operators/distributed/request_handler_impl.cc
+11
-7
paddle/fluid/operators/distributed/request_handler_impl.h
paddle/fluid/operators/distributed/request_handler_impl.h
+11
-9
paddle/fluid/operators/distributed/rpc_server_test.cc
paddle/fluid/operators/distributed/rpc_server_test.cc
+4
-2
paddle/fluid/operators/distributed_ops/fl_listen_and_serv_op.cc
.../fluid/operators/distributed_ops/fl_listen_and_serv_op.cc
+2
-2
paddle/fluid/operators/distributed_ops/gen_nccl_id_op.cc
paddle/fluid/operators/distributed_ops/gen_nccl_id_op.cc
+1
-1
paddle/fluid/operators/distributed_ops/listen_and_serv_op.cc
paddle/fluid/operators/distributed_ops/listen_and_serv_op.cc
+19
-13
paddle/fluid/operators/distributed_ops/send_recv_op_test.cc
paddle/fluid/operators/distributed_ops/send_recv_op_test.cc
+5
-3
paddle/fluid/operators/distributed_ops/test_send_nccl_id.cc
paddle/fluid/operators/distributed_ops/test_send_nccl_id.cc
+2
-1
python/paddle/fluid/layers/io.py
python/paddle/fluid/layers/io.py
+3
-1
python/paddle/fluid/tests/unittests/test_dist_transpiler.py
python/paddle/fluid/tests/unittests/test_dist_transpiler.py
+2
-2
python/paddle/fluid/tests/unittests/test_dist_transpiler_config.py
...ddle/fluid/tests/unittests/test_dist_transpiler_config.py
+181
-0
python/paddle/fluid/tests/unittests/test_listen_and_serv.sh
python/paddle/fluid/tests/unittests/test_listen_and_serv.sh
+1
-1
python/paddle/fluid/tests/unittests/test_listen_and_serv_op.py
...n/paddle/fluid/tests/unittests/test_listen_and_serv_op.py
+10
-2
python/paddle/fluid/tests/unittests/test_lookup_remote_table_op.py
...ddle/fluid/tests/unittests/test_lookup_remote_table_op.py
+2
-1
python/paddle/fluid/tests/unittests/test_nce_remote_table_op.py
.../paddle/fluid/tests/unittests/test_nce_remote_table_op.py
+2
-1
python/paddle/fluid/tests/unittests/test_recv_save_op.py
python/paddle/fluid/tests/unittests/test_recv_save_op.py
+2
-1
python/paddle/fluid/transpiler/distribute_transpiler.py
python/paddle/fluid/transpiler/distribute_transpiler.py
+15
-1
python/paddle/fluid/transpiler/geo_sgd_transpiler.py
python/paddle/fluid/transpiler/geo_sgd_transpiler.py
+2
-2
未找到文件。
paddle/fluid/operators/collective/c_gen_nccl_id_op.cc
浏览文件 @
985bceac
...
@@ -90,7 +90,7 @@ class CGenNCCLIdOp : public framework::OperatorBase {
...
@@ -90,7 +90,7 @@ class CGenNCCLIdOp : public framework::OperatorBase {
// NOTE: Can not use unique_ptr here because the default
// NOTE: Can not use unique_ptr here because the default
// deleter will call GRPC Server's base class's dtor and
// deleter will call GRPC Server's base class's dtor and
// that will cause a wired crash.
// that will cause a wired crash.
distributed
::
RequestSendHandler
rpc_h
(
true
);
distributed
::
RequestSendHandler
rpc_h
(
distributed
::
DistributedMode
::
kSync
);
std
::
unique_ptr
<
distributed
::
RPCServer
>
rpc_service
(
std
::
unique_ptr
<
distributed
::
RPCServer
>
rpc_service
(
new
RPCSERVER_T
(
endpoint
,
1
));
new
RPCSERVER_T
(
endpoint
,
1
));
...
...
paddle/fluid/operators/distributed/brpc/brpc_server.cc
浏览文件 @
985bceac
...
@@ -13,6 +13,8 @@
...
@@ -13,6 +13,8 @@
// limitations under the License.
// limitations under the License.
#include "paddle/fluid/operators/distributed/brpc/brpc_server.h"
#include "paddle/fluid/operators/distributed/brpc/brpc_server.h"
#include <memory>
#include <unordered_map>
#include "paddle/fluid/framework/threadpool.h"
#include "paddle/fluid/framework/threadpool.h"
#include "paddle/fluid/operators/distributed/brpc/brpc_sendrecvop_utils.h"
#include "paddle/fluid/operators/distributed/brpc/brpc_sendrecvop_utils.h"
#include "paddle/fluid/operators/distributed/brpc/brpc_variable_response.h"
#include "paddle/fluid/operators/distributed/brpc/brpc_variable_response.h"
...
@@ -100,7 +102,7 @@ class BRPCServiceImpl : public SendRecvService {
...
@@ -100,7 +102,7 @@ class BRPCServiceImpl : public SendRecvService {
distributed
::
BRPCVariableResponse
resp
(
request_send_h_
->
scope
(),
distributed
::
BRPCVariableResponse
resp
(
request_send_h_
->
scope
(),
request_send_h_
->
dev_ctx
(),
request_send_h_
->
dev_ctx
(),
!
request_send_h_
->
sync
_mode
());
request_send_h_
->
distributed
_mode
());
PADDLE_ENFORCE
(
resp
.
Parse
(
cntl
->
request_attachment
(),
*
request
)
==
0
,
PADDLE_ENFORCE
(
resp
.
Parse
(
cntl
->
request_attachment
(),
*
request
)
==
0
,
"parse iobuf to tensor error!"
);
"parse iobuf to tensor error!"
);
...
...
paddle/fluid/operators/distributed/grpc/grpc_server.cc
浏览文件 @
985bceac
...
@@ -90,9 +90,9 @@ class RequestSend final : public RequestBase {
...
@@ -90,9 +90,9 @@ class RequestSend final : public RequestBase {
::
grpc
::
ServerCompletionQueue
*
cq
,
::
grpc
::
ServerCompletionQueue
*
cq
,
RequestHandler
*
request_handler
,
int
req_id
)
RequestHandler
*
request_handler
,
int
req_id
)
:
RequestBase
(
service
,
cq
,
request_handler
,
req_id
),
responder_
(
&
ctx_
)
{
:
RequestBase
(
service
,
cq
,
request_handler
,
req_id
),
responder_
(
&
ctx_
)
{
request_
.
reset
(
new
GRPCVariableResponse
(
request_handler
->
scope
(),
request_
.
reset
(
new
GRPCVariableResponse
(
request_handler
->
dev_ctx
(),
request_handler
->
scope
(),
request_handler
->
dev_ctx
(),
!
request_handler
->
sync
_mode
()));
request_handler
->
distributed
_mode
()));
int
method_id
=
static_cast
<
int
>
(
distributed
::
GrpcMethod
::
kSendVariable
);
int
method_id
=
static_cast
<
int
>
(
distributed
::
GrpcMethod
::
kSendVariable
);
service_
->
RequestAsyncUnary
(
service_
->
RequestAsyncUnary
(
method_id
,
&
ctx_
,
request_
.
get
(),
&
responder_
,
cq_
,
cq_
,
method_id
,
&
ctx_
,
request_
.
get
(),
&
responder_
,
cq_
,
cq_
,
...
@@ -401,9 +401,9 @@ class RequestNotify final : public RequestBase {
...
@@ -401,9 +401,9 @@ class RequestNotify final : public RequestBase {
::
grpc
::
ServerCompletionQueue
*
cq
,
::
grpc
::
ServerCompletionQueue
*
cq
,
RequestHandler
*
request_handler
,
int
req_id
)
RequestHandler
*
request_handler
,
int
req_id
)
:
RequestBase
(
service
,
cq
,
request_handler
,
req_id
),
responder_
(
&
ctx_
)
{
:
RequestBase
(
service
,
cq
,
request_handler
,
req_id
),
responder_
(
&
ctx_
)
{
request_
.
reset
(
new
GRPCVariableResponse
(
request_handler
->
scope
(),
request_
.
reset
(
new
GRPCVariableResponse
(
request_handler
->
dev_ctx
(),
request_handler
->
scope
(),
request_handler
->
dev_ctx
(),
!
request_handler
->
sync
_mode
()));
request_handler
->
distributed
_mode
()));
int
method_id
=
static_cast
<
int
>
(
distributed
::
GrpcMethod
::
kRequestNotify
);
int
method_id
=
static_cast
<
int
>
(
distributed
::
GrpcMethod
::
kRequestNotify
);
service_
->
RequestAsyncUnary
(
service_
->
RequestAsyncUnary
(
method_id
,
&
ctx_
,
request_
.
get
(),
&
responder_
,
cq_
,
cq_
,
method_id
,
&
ctx_
,
request_
.
get
(),
&
responder_
,
cq_
,
cq_
,
...
...
paddle/fluid/operators/distributed/request_handler.h
浏览文件 @
985bceac
...
@@ -68,6 +68,8 @@ constexpr char kCheckPointNotifyRPC[] = "CheckPointNotifyRPC";
...
@@ -68,6 +68,8 @@ constexpr char kCheckPointNotifyRPC[] = "CheckPointNotifyRPC";
#define CHECKPOINT_SAVE_MESSAGE "SAVE@CHECKPOINTNOTIFY"
#define CHECKPOINT_SAVE_MESSAGE "SAVE@CHECKPOINTNOTIFY"
#define CHECKPOINT_LOAD_MESSAGE "LOAD@CHECKPOINTNOTIFY"
#define CHECKPOINT_LOAD_MESSAGE "LOAD@CHECKPOINTNOTIFY"
enum
DistributedMode
{
kSync
=
0
,
kAsync
=
1
,
kHalfAsync
=
2
,
kGeo
=
3
};
class
RPCServer
;
class
RPCServer
;
class
VarHandle
{
class
VarHandle
{
...
@@ -151,8 +153,8 @@ typedef std::shared_ptr<VarHandle> VarHandlePtr;
...
@@ -151,8 +153,8 @@ typedef std::shared_ptr<VarHandle> VarHandlePtr;
class
RequestHandler
{
class
RequestHandler
{
public:
public:
explicit
RequestHandler
(
bool
sync
_mode
)
explicit
RequestHandler
(
int
distributed
_mode
)
:
sync_mode_
(
sync
_mode
),
:
distributed_mode_
(
distributed
_mode
),
dev_ctx_
(
nullptr
),
dev_ctx_
(
nullptr
),
executor_
(
nullptr
),
executor_
(
nullptr
),
scope_
(
nullptr
),
scope_
(
nullptr
),
...
@@ -198,7 +200,7 @@ class RequestHandler {
...
@@ -198,7 +200,7 @@ class RequestHandler {
void
SetRPCServer
(
RPCServer
*
rpc_server
)
{
rpc_server_
=
rpc_server
;
}
void
SetRPCServer
(
RPCServer
*
rpc_server
)
{
rpc_server_
=
rpc_server
;
}
// Get attributes.
// Get attributes.
bool
sync_mode
()
{
return
sync
_mode_
;
}
int
distributed_mode
()
{
return
distributed
_mode_
;
}
framework
::
Scope
*
scope
()
{
return
scope_
;
}
framework
::
Scope
*
scope
()
{
return
scope_
;
}
const
platform
::
DeviceContext
*
dev_ctx
()
{
return
dev_ctx_
;
}
const
platform
::
DeviceContext
*
dev_ctx
()
{
return
dev_ctx_
;
}
framework
::
ProgramDesc
*
program
()
{
return
program_
;
}
framework
::
ProgramDesc
*
program
()
{
return
program_
;
}
...
@@ -225,7 +227,7 @@ class RequestHandler {
...
@@ -225,7 +227,7 @@ class RequestHandler {
const
std
::
string
&
table_name
=
""
)
=
0
;
const
std
::
string
&
table_name
=
""
)
=
0
;
protected:
protected:
const
bool
sync
_mode_
;
const
int
distributed
_mode_
;
const
platform
::
DeviceContext
*
dev_ctx_
;
const
platform
::
DeviceContext
*
dev_ctx_
;
framework
::
Executor
*
executor_
;
framework
::
Executor
*
executor_
;
...
...
paddle/fluid/operators/distributed/request_handler_impl.cc
浏览文件 @
985bceac
...
@@ -61,7 +61,7 @@ bool RequestSendHandler::Handle(const std::string& varname,
...
@@ -61,7 +61,7 @@ bool RequestSendHandler::Handle(const std::string& varname,
rpc_server_
->
Complete
();
rpc_server_
->
Complete
();
}
else
{
}
else
{
// Async
// Async
if
(
!
sync_mode_
)
{
if
(
distributed_mode_
!=
DistributedMode
::
kSync
)
{
VLOG
(
3
)
<<
"async process var: "
<<
varname
;
VLOG
(
3
)
<<
"async process var: "
<<
varname
;
if
(
varname
==
BATCH_BARRIER_MESSAGE
)
{
if
(
varname
==
BATCH_BARRIER_MESSAGE
)
{
PADDLE_THROW
(
PADDLE_THROW
(
...
@@ -82,7 +82,8 @@ bool RequestSendHandler::Handle(const std::string& varname,
...
@@ -82,7 +82,8 @@ bool RequestSendHandler::Handle(const std::string& varname,
scope
->
Rename
(
varname
,
run_varname
);
scope
->
Rename
(
varname
,
run_varname
);
}
}
if
(
AsyncSparseParamUpdateRecorder
::
GetInstance
()
->
HasGrad
(
run_varname
))
{
if
(
distributed_mode_
==
DistributedMode
::
kGeo
&&
AsyncSparseParamUpdateRecorder
::
GetInstance
()
->
HasGrad
(
run_varname
))
{
auto
&
grad_slr
=
auto
&
grad_slr
=
scope
->
FindVar
(
run_varname
)
->
Get
<
framework
::
SelectedRows
>
();
scope
->
FindVar
(
run_varname
)
->
Get
<
framework
::
SelectedRows
>
();
AsyncSparseParamUpdateRecorder
::
GetInstance
()
->
Update
(
run_varname
,
AsyncSparseParamUpdateRecorder
::
GetInstance
()
->
Update
(
run_varname
,
...
@@ -116,7 +117,7 @@ bool RequestGetHandler::Handle(const std::string& varname,
...
@@ -116,7 +117,7 @@ bool RequestGetHandler::Handle(const std::string& varname,
<<
" out_var_name: "
<<
out_var_name
<<
" trainer_id: "
<<
trainer_id
<<
" out_var_name: "
<<
out_var_name
<<
" trainer_id: "
<<
trainer_id
<<
" table_name: "
<<
table_name
;
<<
" table_name: "
<<
table_name
;
if
(
sync_mode_
)
{
if
(
distributed_mode_
==
DistributedMode
::
kSync
)
{
if
(
varname
==
FETCH_BARRIER_MESSAGE
)
{
if
(
varname
==
FETCH_BARRIER_MESSAGE
)
{
VLOG
(
3
)
<<
"sync: recv fetch barrier message"
;
VLOG
(
3
)
<<
"sync: recv fetch barrier message"
;
rpc_server_
->
IncreaseBatchBarrier
(
kRequestGet
);
rpc_server_
->
IncreaseBatchBarrier
(
kRequestGet
);
...
@@ -140,10 +141,13 @@ bool RequestGetHandler::Handle(const std::string& varname,
...
@@ -140,10 +141,13 @@ bool RequestGetHandler::Handle(const std::string& varname,
framework
::
TensorCopy
(
t_orig
,
dev_ctx_
->
GetPlace
(),
t
);
framework
::
TensorCopy
(
t_orig
,
dev_ctx_
->
GetPlace
(),
t
);
}
}
VLOG
(
1
)
<<
"Table name empty? "
<<
table_name
.
empty
();
VLOG
(
1
)
<<
"Table name empty? "
<<
table_name
.
empty
();
VLOG
(
1
)
<<
"AsyncSparseParamUpdateRecorder "
<<
varname
<<
" exist "
if
(
distributed_mode_
==
DistributedMode
::
kGeo
)
{
<<
AsyncSparseParamUpdateRecorder
::
GetInstance
()
->
HasParam
(
VLOG
(
1
)
<<
"AsyncSparseParamUpdateRecorder "
<<
varname
<<
" exist "
varname
);
<<
AsyncSparseParamUpdateRecorder
::
GetInstance
()
->
HasParam
(
if
(
AsyncSparseParamUpdateRecorder
::
GetInstance
()
->
HasParam
(
varname
)
&&
varname
);
}
if
(
distributed_mode_
==
DistributedMode
::
kGeo
&&
AsyncSparseParamUpdateRecorder
::
GetInstance
()
->
HasParam
(
varname
)
&&
!
table_name
.
empty
())
{
!
table_name
.
empty
())
{
std
::
vector
<
int64_t
>
updated_rows
;
std
::
vector
<
int64_t
>
updated_rows
;
AsyncSparseParamUpdateRecorder
::
GetInstance
()
->
GetAndClear
(
AsyncSparseParamUpdateRecorder
::
GetInstance
()
->
GetAndClear
(
...
...
paddle/fluid/operators/distributed/request_handler_impl.h
浏览文件 @
985bceac
...
@@ -38,8 +38,8 @@ namespace distributed {
...
@@ -38,8 +38,8 @@ namespace distributed {
class
RequestSendHandler
final
:
public
RequestHandler
{
class
RequestSendHandler
final
:
public
RequestHandler
{
public:
public:
explicit
RequestSendHandler
(
bool
sync
_mode
,
bool
enable_dc_asgd
=
false
)
explicit
RequestSendHandler
(
int
distributed
_mode
,
bool
enable_dc_asgd
=
false
)
:
RequestHandler
(
sync
_mode
)
{
:
RequestHandler
(
distributed
_mode
)
{
enable_dc_asgd_
=
enable_dc_asgd
;
enable_dc_asgd_
=
enable_dc_asgd
;
}
}
virtual
~
RequestSendHandler
()
{}
virtual
~
RequestSendHandler
()
{}
...
@@ -54,8 +54,8 @@ class RequestSendHandler final : public RequestHandler {
...
@@ -54,8 +54,8 @@ class RequestSendHandler final : public RequestHandler {
class
RequestGetHandler
final
:
public
RequestHandler
{
class
RequestGetHandler
final
:
public
RequestHandler
{
public:
public:
explicit
RequestGetHandler
(
bool
sync
_mode
,
bool
enable_dc_asgd
=
false
)
explicit
RequestGetHandler
(
int
distributed
_mode
,
bool
enable_dc_asgd
=
false
)
:
RequestHandler
(
sync
_mode
)
{
:
RequestHandler
(
distributed
_mode
)
{
enable_dc_asgd_
=
enable_dc_asgd
;
enable_dc_asgd_
=
enable_dc_asgd
;
}
}
virtual
~
RequestGetHandler
()
{}
virtual
~
RequestGetHandler
()
{}
...
@@ -89,7 +89,8 @@ static inline void BuildVar(const std::string& param_name,
...
@@ -89,7 +89,8 @@ static inline void BuildVar(const std::string& param_name,
class
RequestPrefetchHandler
final
:
public
RequestHandler
{
class
RequestPrefetchHandler
final
:
public
RequestHandler
{
public:
public:
explicit
RequestPrefetchHandler
(
bool
sync_mode
)
:
RequestHandler
(
sync_mode
)
{}
explicit
RequestPrefetchHandler
(
int
distributed_mode
)
:
RequestHandler
(
distributed_mode
)
{}
virtual
~
RequestPrefetchHandler
()
{}
virtual
~
RequestPrefetchHandler
()
{}
bool
Handle
(
const
std
::
string
&
varname
,
framework
::
Scope
*
scope
,
bool
Handle
(
const
std
::
string
&
varname
,
framework
::
Scope
*
scope
,
framework
::
Variable
*
var
,
framework
::
Variable
**
outvar
,
framework
::
Variable
*
var
,
framework
::
Variable
**
outvar
,
...
@@ -113,8 +114,9 @@ class RequestPrefetchHandler final : public RequestHandler {
...
@@ -113,8 +114,9 @@ class RequestPrefetchHandler final : public RequestHandler {
class
RequestCheckpointHandler
final
:
public
RequestHandler
{
class
RequestCheckpointHandler
final
:
public
RequestHandler
{
public:
public:
explicit
RequestCheckpointHandler
(
bool
sync_mode
,
int
checkpoint_notify_id
)
explicit
RequestCheckpointHandler
(
int
distributed_mode
,
:
RequestHandler
(
sync_mode
)
{
int
checkpoint_notify_id
)
:
RequestHandler
(
distributed_mode
)
{
this
->
checkpoint_notify_id
=
checkpoint_notify_id
;
this
->
checkpoint_notify_id
=
checkpoint_notify_id
;
}
}
virtual
~
RequestCheckpointHandler
()
{}
virtual
~
RequestCheckpointHandler
()
{}
...
@@ -129,8 +131,8 @@ class RequestCheckpointHandler final : public RequestHandler {
...
@@ -129,8 +131,8 @@ class RequestCheckpointHandler final : public RequestHandler {
class
RequestNotifyHandler
final
:
public
RequestHandler
{
class
RequestNotifyHandler
final
:
public
RequestHandler
{
public:
public:
explicit
RequestNotifyHandler
(
bool
sync
_mode
,
int
lr_decay_block_id
)
explicit
RequestNotifyHandler
(
int
distributed
_mode
,
int
lr_decay_block_id
)
:
RequestHandler
(
sync
_mode
)
{
:
RequestHandler
(
distributed
_mode
)
{
this
->
lr_decay_block_id
=
lr_decay_block_id
;
this
->
lr_decay_block_id
=
lr_decay_block_id
;
}
}
virtual
~
RequestNotifyHandler
()
{}
virtual
~
RequestNotifyHandler
()
{}
...
...
paddle/fluid/operators/distributed/rpc_server_test.cc
浏览文件 @
985bceac
...
@@ -131,7 +131,8 @@ void StartServer(const std::string& rpc_name) {
...
@@ -131,7 +131,8 @@ void StartServer(const std::string& rpc_name) {
TEST
(
PREFETCH
,
CPU
)
{
TEST
(
PREFETCH
,
CPU
)
{
setenv
(
"http_proxy"
,
""
,
1
);
setenv
(
"http_proxy"
,
""
,
1
);
setenv
(
"https_proxy"
,
""
,
1
);
setenv
(
"https_proxy"
,
""
,
1
);
g_req_handler
.
reset
(
new
distributed
::
RequestPrefetchHandler
(
true
));
g_req_handler
.
reset
(
new
distributed
::
RequestPrefetchHandler
(
distributed
::
DistributedMode
::
kSync
));
g_rpc_service
.
reset
(
new
RPCSERVER_T
(
"127.0.0.1:0"
,
1
));
g_rpc_service
.
reset
(
new
RPCSERVER_T
(
"127.0.0.1:0"
,
1
));
distributed
::
RPCClient
*
client
=
distributed
::
RPCClient
*
client
=
distributed
::
RPCClient
::
GetInstance
<
RPCCLIENT_T
>
(
0
);
distributed
::
RPCClient
::
GetInstance
<
RPCCLIENT_T
>
(
0
);
...
@@ -173,7 +174,8 @@ TEST(PREFETCH, CPU) {
...
@@ -173,7 +174,8 @@ TEST(PREFETCH, CPU) {
TEST
(
COMPLETE
,
CPU
)
{
TEST
(
COMPLETE
,
CPU
)
{
setenv
(
"http_proxy"
,
""
,
1
);
setenv
(
"http_proxy"
,
""
,
1
);
setenv
(
"https_proxy"
,
""
,
1
);
setenv
(
"https_proxy"
,
""
,
1
);
g_req_handler
.
reset
(
new
distributed
::
RequestSendHandler
(
true
));
g_req_handler
.
reset
(
new
distributed
::
RequestSendHandler
(
distributed
::
DistributedMode
::
kSync
));
g_rpc_service
.
reset
(
new
RPCSERVER_T
(
"127.0.0.1:0"
,
2
));
g_rpc_service
.
reset
(
new
RPCSERVER_T
(
"127.0.0.1:0"
,
2
));
distributed
::
RPCClient
*
client
=
distributed
::
RPCClient
*
client
=
distributed
::
RPCClient
::
GetInstance
<
RPCCLIENT_T
>
(
0
);
distributed
::
RPCClient
::
GetInstance
<
RPCCLIENT_T
>
(
0
);
...
...
paddle/fluid/operators/distributed_ops/fl_listen_and_serv_op.cc
浏览文件 @
985bceac
...
@@ -199,9 +199,9 @@ void FlListenAndServOp::RunImpl(const framework::Scope &scope,
...
@@ -199,9 +199,9 @@ void FlListenAndServOp::RunImpl(const framework::Scope &scope,
rpc_service_
.
reset
(
new
RPCSERVER_T
(
endpoint
,
fan_in
));
rpc_service_
.
reset
(
new
RPCSERVER_T
(
endpoint
,
fan_in
));
request_send_handler_
.
reset
(
request_send_handler_
.
reset
(
new
distributed
::
RequestSendHandler
(
sync_mode
,
false
));
new
distributed
::
RequestSendHandler
(
!
sync_mode
,
false
));
request_get_handler_
.
reset
(
request_get_handler_
.
reset
(
new
distributed
::
RequestGetHandler
(
sync_mode
,
false
));
new
distributed
::
RequestGetHandler
(
!
sync_mode
,
false
));
rpc_service_
->
RegisterRPC
(
distributed
::
kRequestSend
,
rpc_service_
->
RegisterRPC
(
distributed
::
kRequestSend
,
request_send_handler_
.
get
(),
request_send_handler_
.
get
(),
...
...
paddle/fluid/operators/distributed_ops/gen_nccl_id_op.cc
浏览文件 @
985bceac
...
@@ -184,7 +184,7 @@ class GenNCCLIdOp : public framework::OperatorBase {
...
@@ -184,7 +184,7 @@ class GenNCCLIdOp : public framework::OperatorBase {
// NOTE: Can not use unique_ptr here because the default
// NOTE: Can not use unique_ptr here because the default
// deleter will call GRPC Server's base class's dtor and
// deleter will call GRPC Server's base class's dtor and
// that will cause a wired crash.
// that will cause a wired crash.
distributed
::
RequestSendHandler
rpc_h
(
true
);
distributed
::
RequestSendHandler
rpc_h
(
distributed
::
DistributedMode
::
kSync
);
std
::
unique_ptr
<
distributed
::
RPCServer
>
rpc_service
(
std
::
unique_ptr
<
distributed
::
RPCServer
>
rpc_service
(
new
RPCSERVER_T
(
endpoint
,
1
));
new
RPCSERVER_T
(
endpoint
,
1
));
...
...
paddle/fluid/operators/distributed_ops/listen_and_serv_op.cc
浏览文件 @
985bceac
...
@@ -338,7 +338,7 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
...
@@ -338,7 +338,7 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
auto
&
dev_ctx
=
*
pool
.
Get
(
dev_place
);
auto
&
dev_ctx
=
*
pool
.
Get
(
dev_place
);
framework
::
Scope
&
recv_scope
=
scope
.
NewScope
();
framework
::
Scope
&
recv_scope
=
scope
.
NewScope
();
bool
sync_mode
=
Attr
<
bool
>
(
"sync
_mode"
);
int
distributed_mode
=
Attr
<
int
>
(
"distributed
_mode"
);
bool
dc_sgd
=
Attr
<
bool
>
(
"dc_asgd"
);
bool
dc_sgd
=
Attr
<
bool
>
(
"dc_asgd"
);
auto
fan_in
=
Attr
<
int
>
(
"Fanin"
);
auto
fan_in
=
Attr
<
int
>
(
"Fanin"
);
auto
pserver_id
=
Attr
<
int
>
(
"pserver_id"
);
auto
pserver_id
=
Attr
<
int
>
(
"pserver_id"
);
...
@@ -349,8 +349,9 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
...
@@ -349,8 +349,9 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
int
checkpoint_block_id
=
Attr
<
int
>
(
kCheckpointBlockId
);
int
checkpoint_block_id
=
Attr
<
int
>
(
kCheckpointBlockId
);
int
lr_decay_block_id
=
Attr
<
int
>
(
kLRDecayBlockId
);
int
lr_decay_block_id
=
Attr
<
int
>
(
kLRDecayBlockId
);
VLOG
(
4
)
<<
"pserver_id: "
<<
pserver_id
<<
", sync_mode:"
<<
sync_mode
VLOG
(
4
)
<<
"pserver_id: "
<<
pserver_id
<<
", fan_in:"
<<
fan_in
<<
", end_point:"
<<
endpoint
<<
", distributed_mode:"
<<
distributed_mode
<<
", fan_in:"
<<
fan_in
<<
", end_point:"
<<
endpoint
<<
", checkpoint_block_id: "
<<
checkpoint_block_id
<<
", checkpoint_block_id: "
<<
checkpoint_block_id
<<
", lr_decay_block_id: "
<<
lr_decay_block_id
;
<<
", lr_decay_block_id: "
<<
lr_decay_block_id
;
...
@@ -361,17 +362,17 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
...
@@ -361,17 +362,17 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
auto
rpc_prefetch_thread_num
=
Attr
<
int
>
(
"rpc_prefetch_thread_num"
);
auto
rpc_prefetch_thread_num
=
Attr
<
int
>
(
"rpc_prefetch_thread_num"
);
request_send_handler_
.
reset
(
request_send_handler_
.
reset
(
new
distributed
::
RequestSendHandler
(
sync
_mode
,
dc_sgd
));
new
distributed
::
RequestSendHandler
(
distributed
_mode
,
dc_sgd
));
request_get_handler_
.
reset
(
request_get_handler_
.
reset
(
new
distributed
::
RequestGetHandler
(
sync
_mode
,
dc_sgd
));
new
distributed
::
RequestGetHandler
(
distributed
_mode
,
dc_sgd
));
request_prefetch_handler_
.
reset
(
request_prefetch_handler_
.
reset
(
new
distributed
::
RequestPrefetchHandler
(
sync
_mode
));
new
distributed
::
RequestPrefetchHandler
(
distributed
_mode
));
request_checkpoint_handler_
.
reset
(
new
distributed
::
RequestCheckpointHandler
(
request_checkpoint_handler_
.
reset
(
new
distributed
::
RequestCheckpointHandler
(
sync
_mode
,
checkpoint_block_id
));
distributed
_mode
,
checkpoint_block_id
));
request_get_no_barrier_handler_
.
reset
(
request_get_no_barrier_handler_
.
reset
(
new
distributed
::
RequestGetNoBarrierHandler
());
new
distributed
::
RequestGetNoBarrierHandler
());
request_notify_handler_
.
reset
(
request_notify_handler_
.
reset
(
new
distributed
::
RequestNotifyHandler
(
new
distributed
::
RequestNotifyHandler
(
sync
_mode
,
lr_decay_block_id
));
distributed
_mode
,
lr_decay_block_id
));
rpc_service_
->
RegisterRPC
(
distributed
::
kRequestSend
,
rpc_service_
->
RegisterRPC
(
distributed
::
kRequestSend
,
request_send_handler_
.
get
(),
rpc_send_thread_num
);
request_send_handler_
.
get
(),
rpc_send_thread_num
);
...
@@ -469,7 +470,7 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
...
@@ -469,7 +470,7 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
signal
(
SIGINT
,
SignalHandler
::
StopAndExit
);
signal
(
SIGINT
,
SignalHandler
::
StopAndExit
);
signal
(
SIGTERM
,
SignalHandler
::
StopAndExit
);
signal
(
SIGTERM
,
SignalHandler
::
StopAndExit
);
if
(
sync_mode
)
{
if
(
distributed_mode
==
distributed
::
DistributedMode
::
kSync
)
{
// start the server listening after all member initialized.
// start the server listening after all member initialized.
server_thread_
.
reset
(
new
std
::
thread
(
RunServer
,
rpc_service_
));
server_thread_
.
reset
(
new
std
::
thread
(
RunServer
,
rpc_service_
));
VLOG
(
3
)
<<
"wait server thread to become ready..."
;
VLOG
(
3
)
<<
"wait server thread to become ready..."
;
...
@@ -483,8 +484,10 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
...
@@ -483,8 +484,10 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
RunSyncLoop
(
&
executor
,
program
,
&
recv_scope
,
&
dev_ctx
,
RunSyncLoop
(
&
executor
,
program
,
&
recv_scope
,
&
dev_ctx
,
prefetch_block_id_list
,
checkpoint_block_id
);
prefetch_block_id_list
,
checkpoint_block_id
);
}
else
{
}
else
{
distributed
::
AsyncSparseParamUpdateRecorder
::
Init
(
if
(
distributed_mode
==
distributed
::
DistributedMode
::
kGeo
)
{
fan_in
,
sparse_grad_name_to_param_name
);
distributed
::
AsyncSparseParamUpdateRecorder
::
Init
(
fan_in
,
sparse_grad_name_to_param_name
);
}
VLOG
(
2
)
<<
"RunAsyncLoop"
;
VLOG
(
2
)
<<
"RunAsyncLoop"
;
auto
grad_to_block_id_str
=
auto
grad_to_block_id_str
=
...
@@ -530,7 +533,10 @@ class ListenAndServOpMaker : public framework::OpProtoAndCheckerMaker {
...
@@ -530,7 +533,10 @@ class ListenAndServOpMaker : public framework::OpProtoAndCheckerMaker {
"['param1@GRAD.block0:1', 'param2@GRAD.blockn:2'] "
"['param1@GRAD.block0:1', 'param2@GRAD.blockn:2'] "
"a map from grad name to it's optimize block id"
)
"a map from grad name to it's optimize block id"
)
.
SetDefault
({});
.
SetDefault
({});
AddAttr
<
bool
>
(
"sync_mode"
,
"if works at sync_mode or not"
).
SetDefault
(
true
);
AddAttr
<
int
>
(
"distributed_mode"
,
"indicate distriubte training mode, 0 is sync, 1 is "
"fully-async, 2 is half-async, 3 is geo"
)
.
SetDefault
(
0
);
AddAttr
<
bool
>
(
"dc_asgd"
,
"set to true will enable DC-ASGD training."
)
AddAttr
<
bool
>
(
"dc_asgd"
,
"set to true will enable DC-ASGD training."
)
.
SetDefault
(
false
);
.
SetDefault
(
false
);
AddAttr
<
std
::
vector
<
framework
::
BlockDesc
*>>
(
AddAttr
<
std
::
vector
<
framework
::
BlockDesc
*>>
(
...
...
paddle/fluid/operators/distributed_ops/send_recv_op_test.cc
浏览文件 @
985bceac
...
@@ -32,9 +32,11 @@ USE_OP(sum);
...
@@ -32,9 +32,11 @@ USE_OP(sum);
namespace
f
=
paddle
::
framework
;
namespace
f
=
paddle
::
framework
;
namespace
p
=
paddle
::
platform
;
namespace
p
=
paddle
::
platform
;
namespace
m
=
paddle
::
operators
::
math
;
namespace
m
=
paddle
::
operators
::
math
;
namespace
d
=
paddle
::
operators
::
distributed
// global for simplicity.
// global for simplicity.
std
::
unique_ptr
<
f
::
OperatorBase
>
listen_and_serv_op
;
std
::
unique_ptr
<
f
::
OperatorBase
>
listen_and_serv_op
;
int
selected_port
;
int
selected_port
;
void
InitTensorsInScope
(
const
p
::
CPUPlace
&
place
,
f
::
Scope
*
scope
)
{
void
InitTensorsInScope
(
const
p
::
CPUPlace
&
place
,
f
::
Scope
*
scope
)
{
...
@@ -145,7 +147,7 @@ void StartServerNet(bool is_sparse, std::atomic<bool> *initialized) {
...
@@ -145,7 +147,7 @@ void StartServerNet(bool is_sparse, std::atomic<bool> *initialized) {
attrs
.
insert
({
"optimize_blocks"
,
optimize_blocks
});
attrs
.
insert
({
"optimize_blocks"
,
optimize_blocks
});
attrs
.
insert
({
"PrefetchBlock"
,
prefetch_block
});
attrs
.
insert
({
"PrefetchBlock"
,
prefetch_block
});
attrs
.
insert
({
"grad_to_block_id"
,
std
::
vector
<
std
::
string
>
({
""
})});
attrs
.
insert
({
"grad_to_block_id"
,
std
::
vector
<
std
::
string
>
({
""
})});
attrs
.
insert
({
"
sync_mode"
,
true
});
attrs
.
insert
({
"
distributed_mode"
,
d
::
DistributedMode
::
kSync
});
VLOG
(
4
)
<<
"before init op"
;
VLOG
(
4
)
<<
"before init op"
;
listen_and_serv_op
=
listen_and_serv_op
=
f
::
OpRegistry
::
CreateOp
(
"listen_and_serv"
,
{{
"X"
,
{
"x1"
}}},
{},
attrs
);
f
::
OpRegistry
::
CreateOp
(
"listen_and_serv"
,
{{
"X"
,
{
"x1"
}}},
{},
attrs
);
...
...
paddle/fluid/operators/distributed_ops/test_send_nccl_id.cc
浏览文件 @
985bceac
...
@@ -72,7 +72,8 @@ void StartServer() {
...
@@ -72,7 +72,8 @@ void StartServer() {
}
}
TEST
(
SendNcclId
,
RPCServer
)
{
TEST
(
SendNcclId
,
RPCServer
)
{
g_req_handler
.
reset
(
new
distributed
::
RequestSendHandler
(
true
));
g_req_handler
.
reset
(
new
distributed
::
RequestSendHandler
(
distributed
::
DistributedMode
::
kSync
));
g_rpc_service
.
reset
(
new
RPCSERVER_T
(
"127.0.0.1:0"
,
1
));
g_rpc_service
.
reset
(
new
RPCSERVER_T
(
"127.0.0.1:0"
,
1
));
std
::
thread
server_thread
(
StartServer
);
std
::
thread
server_thread
(
StartServer
);
...
...
python/paddle/fluid/layers/io.py
浏览文件 @
985bceac
...
@@ -29,6 +29,7 @@ from ..framework import convert_np_dtype_to_dtype_, default_main_program, \
...
@@ -29,6 +29,7 @@ from ..framework import convert_np_dtype_to_dtype_, default_main_program, \
default_startup_program
,
program_guard
,
Program
,
Variable
default_startup_program
,
program_guard
,
Program
,
Variable
from
..layer_helper
import
LayerHelper
from
..layer_helper
import
LayerHelper
from
..unique_name
import
generate
as
unique_name
from
..unique_name
import
generate
as
unique_name
from
..transpiler.distribute_transpiler
import
DistributedMode
import
logging
import
logging
__all__
=
[
__all__
=
[
...
@@ -240,7 +241,8 @@ class ListenAndServ(object):
...
@@ -240,7 +241,8 @@ class ListenAndServ(object):
'optimize_blocks'
:
[
'optimize_blocks'
:
[
current_block
current_block
],
# did not support multiple optimize blocks in layers
],
# did not support multiple optimize blocks in layers
'sync_mode'
:
True
,
# did not support async now in layers
'distributed_mode'
:
DistributedMode
.
SYNC
,
# did not support async now in layers
'grad_to_block_id'
:
[
""
]
'grad_to_block_id'
:
[
""
]
})
})
...
...
python/paddle/fluid/tests/unittests/test_dist_transpiler.py
浏览文件 @
985bceac
...
@@ -62,10 +62,10 @@ class TranspilerTest(unittest.TestCase):
...
@@ -62,10 +62,10 @@ class TranspilerTest(unittest.TestCase):
self
.
origin_prog
=
main
.
clone
()
self
.
origin_prog
=
main
.
clone
()
return
main
return
main
def
get_trainer
(
self
,
config
=
None
):
def
get_trainer
(
self
,
config
=
None
,
sync_mode
=
True
):
src
=
fluid
.
default_startup_program
().
clone
()
src
=
fluid
.
default_startup_program
().
clone
()
t
=
self
.
_transpiler_instance
(
config
)
t
=
self
.
_transpiler_instance
(
config
,
sync_mode
=
True
)
trainer_main
=
t
.
get_trainer_program
(
wait_port
=
False
)
trainer_main
=
t
.
get_trainer_program
(
wait_port
=
False
)
trainer_startup
=
fluid
.
default_startup_program
()
trainer_startup
=
fluid
.
default_startup_program
()
...
...
python/paddle/fluid/tests/unittests/test_dist_transpiler_config.py
0 → 100644
浏览文件 @
985bceac
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
unittest
import
paddle.fluid
as
fluid
import
gc
gc
.
set_debug
(
gc
.
DEBUG_COLLECTABLE
)
class
TranspilerTest
(
unittest
.
TestCase
):
def
setUp
(
self
):
self
.
trainer_id
=
0
self
.
trainers
=
2
self
.
pservers
=
2
# NOTE: we do not actually bind this port
self
.
pserver_eps
=
"127.0.0.1:6174,127.0.0.1:6175"
self
.
pserver1_ep
=
"127.0.0.1:6174"
self
.
pserver2_ep
=
"127.0.0.1:6175"
self
.
sync_mode
=
True
self
.
transpiler
=
None
def
net_conf
(
self
):
x
=
fluid
.
layers
.
data
(
name
=
'x'
,
shape
=
[
1000
],
dtype
=
'float32'
)
y_predict
=
fluid
.
layers
.
fc
(
input
=
x
,
size
=
1000
,
act
=
None
,
param_attr
=
fluid
.
ParamAttr
(
name
=
'fc_w'
),
bias_attr
=
fluid
.
ParamAttr
(
name
=
'fc_b'
))
y
=
fluid
.
layers
.
data
(
name
=
'y'
,
shape
=
[
1
],
dtype
=
'float32'
)
cost
=
fluid
.
layers
.
square_error_cost
(
input
=
y_predict
,
label
=
y
)
avg_cost
=
fluid
.
layers
.
mean
(
cost
)
sgd_optimizer
=
fluid
.
optimizer
.
SGD
(
learning_rate
=
0.1
)
sgd_optimizer
.
minimize
(
avg_cost
)
def
get_main_program
(
self
):
main
=
fluid
.
Program
()
main
.
random_seed
=
1
with
fluid
.
program_guard
(
main
):
self
.
net_conf
()
self
.
origin_prog
=
main
.
clone
()
return
main
def
get_trainer
(
self
,
config
=
None
,
sync_mode
=
True
):
src
=
fluid
.
default_startup_program
().
clone
()
t
=
self
.
_transpiler_instance
(
config
,
sync_mode
=
True
)
trainer_main
=
t
.
get_trainer_program
(
wait_port
=
False
)
trainer_startup
=
fluid
.
default_startup_program
()
assert
(
src
.
num_blocks
==
1
)
assert
(
trainer_startup
.
num_blocks
==
src
.
num_blocks
)
return
trainer_main
,
trainer_startup
def
get_pserver
(
self
,
ep
,
config
=
None
,
sync_mode
=
True
):
t
=
self
.
_transpiler_instance
(
config
,
sync_mode
)
pserver
=
t
.
get_pserver_program
(
ep
)
startup
=
t
.
get_startup_program
(
ep
,
pserver
)
return
pserver
,
startup
def
_transpiler_instance
(
self
,
config
=
None
,
sync_mode
=
True
):
if
not
self
.
transpiler
:
main
=
self
.
get_main_program
()
self
.
transpiler
=
fluid
.
DistributeTranspiler
(
config
=
config
)
self
.
transpiler
.
transpile
(
self
.
trainer_id
,
program
=
main
,
pservers
=
self
.
pserver_eps
,
trainers
=
self
.
trainers
,
sync_mode
=
sync_mode
)
return
self
.
transpiler
def
transpiler_test_impl
(
self
):
pass
def
test_transpiler
(
self
):
main
=
fluid
.
Program
()
startup
=
fluid
.
Program
()
with
fluid
.
unique_name
.
guard
():
with
fluid
.
program_guard
(
main
,
startup
):
self
.
transpiler_test_impl
()
# NOTE: run gc.collect to eliminate pybind side objects to
# prevent random double-deallocate when inherited in python.
del
self
.
transpiler
del
main
del
startup
gc
.
collect
()
class
TestBasicModelAsync
(
TranspilerTest
):
def
transpiler_test_impl
(
self
):
config
=
fluid
.
DistributeTranspilerConfig
()
config
.
sync_mode
=
False
config
.
runtime_split_send_recv
=
True
pserver
,
startup
=
self
.
get_pserver
(
self
.
pserver1_ep
,
config
,
False
)
pserver2
,
startup2
=
self
.
get_pserver
(
self
.
pserver2_ep
,
config
,
False
)
trainer
,
_
=
self
.
get_trainer
(
config
,
False
)
self
.
assertEqual
([
op
.
type
for
op
in
trainer
.
global_block
().
ops
],
[
'mul'
,
'elementwise_add'
,
'elementwise_sub'
,
'square'
,
'mean'
,
'fill_constant'
,
'mean_grad'
,
'square_grad'
,
'elementwise_sub_grad'
,
'elementwise_add_grad'
,
'send'
,
'mul_grad'
,
'send'
,
'recv'
,
'recv'
])
self
.
assertEqual
(
len
(
pserver
.
blocks
),
3
)
# block0: listen_and_serv
self
.
assertEqual
([
op
.
type
for
op
in
pserver
.
blocks
[
0
].
ops
],
[
"listen_and_serv"
])
self
.
assertEqual
(
pserver
.
blocks
[
0
].
ops
[
0
].
attr
(
"distributed_mode"
),
1
)
# block1~2: optimize pass
self
.
assertEqual
([
op
.
type
for
op
in
pserver
.
blocks
[
2
].
ops
],
[
"sgd"
])
class
TestBasicModelHalfAsync
(
TranspilerTest
):
def
transpiler_test_impl
(
self
):
config
=
fluid
.
DistributeTranspilerConfig
()
config
.
sync_mode
=
False
config
.
runtime_split_send_recv
=
False
pserver
,
startup
=
self
.
get_pserver
(
self
.
pserver1_ep
,
config
,
False
)
pserver2
,
startup2
=
self
.
get_pserver
(
self
.
pserver2_ep
,
config
,
False
)
trainer
,
_
=
self
.
get_trainer
(
config
,
False
)
self
.
assertEqual
([
op
.
type
for
op
in
trainer
.
global_block
().
ops
],
[
'mul'
,
'elementwise_add'
,
'elementwise_sub'
,
'square'
,
'mean'
,
'fill_constant'
,
'mean_grad'
,
'square_grad'
,
'elementwise_sub_grad'
,
'elementwise_add_grad'
,
'send'
,
'mul_grad'
,
'split_byref'
,
'send'
,
'recv'
,
'recv'
,
'concat'
])
self
.
assertEqual
(
len
(
pserver
.
blocks
),
3
)
# block0: listen_and_serv
self
.
assertEqual
([
op
.
type
for
op
in
pserver
.
blocks
[
0
].
ops
],
[
"listen_and_serv"
])
self
.
assertEqual
(
pserver
.
blocks
[
0
].
ops
[
0
].
attr
(
"distributed_mode"
),
2
)
# block1~2: optimize pass
self
.
assertEqual
([
op
.
type
for
op
in
pserver
.
blocks
[
2
].
ops
],
[
"sgd"
])
class
TestBasicModelSync
(
TranspilerTest
):
def
transpiler_test_impl
(
self
):
config
=
fluid
.
DistributeTranspilerConfig
()
config
.
sync_mode
=
True
config
.
runtime_split_send_recv
=
False
pserver
,
startup
=
self
.
get_pserver
(
self
.
pserver1_ep
,
config
,
True
)
pserver2
,
startup2
=
self
.
get_pserver
(
self
.
pserver2_ep
,
config
,
True
)
trainer
,
_
=
self
.
get_trainer
(
config
,
True
)
self
.
assertEqual
([
op
.
type
for
op
in
trainer
.
global_block
().
ops
],
[
'mul'
,
'elementwise_add'
,
'elementwise_sub'
,
'square'
,
'mean'
,
'fill_constant'
,
'mean_grad'
,
'square_grad'
,
'elementwise_sub_grad'
,
'elementwise_add_grad'
,
'send'
,
'mul_grad'
,
'split_byref'
,
'send'
,
'send_barrier'
,
'recv'
,
'recv'
,
'fetch_barrier'
,
'concat'
])
self
.
assertEqual
(
len
(
pserver
.
blocks
),
3
)
# block0: listen_and_serv
self
.
assertEqual
([
op
.
type
for
op
in
pserver
.
blocks
[
0
].
ops
],
[
"listen_and_serv"
])
self
.
assertEqual
(
pserver
.
blocks
[
0
].
ops
[
0
].
attr
(
"distributed_mode"
),
0
)
# block1~2: optimize pass
self
.
assertEqual
([
op
.
type
for
op
in
pserver
.
blocks
[
2
].
ops
],
[
"sum"
,
"scale"
,
"sgd"
])
if
__name__
==
"__main__"
:
unittest
.
main
()
python/paddle/fluid/tests/unittests/test_listen_and_serv.sh
浏览文件 @
985bceac
...
@@ -8,7 +8,7 @@ flag1=test_handle_signal_in_serv_op.flag
...
@@ -8,7 +8,7 @@ flag1=test_handle_signal_in_serv_op.flag
flag2
=
test_list_and_serv_run_empty_optimize_block.flag
flag2
=
test_list_and_serv_run_empty_optimize_block.flag
for
i
in
{
1..10
}
;
do
for
i
in
{
1..10
}
;
do
sleep
3
s
sleep
6
s
if
[[
-f
"
${
flag1
}
"
&&
-f
"
${
flag2
}
"
]]
;
then
if
[[
-f
"
${
flag1
}
"
&&
-f
"
${
flag2
}
"
]]
;
then
echo
"test_listen_and_serv_op exit"
echo
"test_listen_and_serv_op exit"
exit
0
exit
0
...
...
python/paddle/fluid/tests/unittests/test_listen_and_serv_op.py
浏览文件 @
985bceac
...
@@ -52,7 +52,11 @@ def run_pserver(use_cuda, sync_mode, ip, port, trainers, trainer_id):
...
@@ -52,7 +52,11 @@ def run_pserver(use_cuda, sync_mode, ip, port, trainers, trainer_id):
config
=
fluid
.
DistributeTranspilerConfig
()
config
=
fluid
.
DistributeTranspilerConfig
()
config
.
sync_mode
=
sync_mode
config
.
sync_mode
=
sync_mode
t
=
fluid
.
DistributeTranspiler
(
config
=
config
)
t
=
fluid
.
DistributeTranspiler
(
config
=
config
)
t
.
transpile
(
trainer_id
,
pservers
=
pserver_endpoints
,
trainers
=
trainers
)
t
.
transpile
(
trainer_id
,
pservers
=
pserver_endpoints
,
trainers
=
trainers
,
sync_mode
=
sync_mode
)
pserver_prog
=
t
.
get_pserver_program
(
current_endpoint
)
pserver_prog
=
t
.
get_pserver_program
(
current_endpoint
)
pserver_startup
=
t
.
get_startup_program
(
current_endpoint
,
pserver_prog
)
pserver_startup
=
t
.
get_startup_program
(
current_endpoint
,
pserver_prog
)
exe
.
run
(
pserver_startup
)
exe
.
run
(
pserver_startup
)
...
@@ -86,7 +90,11 @@ def run_pserver_with_empty_block(use_cuda, sync_mode, ip, port, trainers,
...
@@ -86,7 +90,11 @@ def run_pserver_with_empty_block(use_cuda, sync_mode, ip, port, trainers,
config
.
slice_var_up
=
False
config
.
slice_var_up
=
False
t
=
fluid
.
DistributeTranspiler
(
config
=
config
)
t
=
fluid
.
DistributeTranspiler
(
config
=
config
)
t
.
transpile
(
trainer_id
,
pservers
=
pserver_endpoints
,
trainers
=
trainers
)
t
.
transpile
(
trainer_id
,
pservers
=
pserver_endpoints
,
trainers
=
trainers
,
sync_mode
=
sync_mode
)
pserver_prog
=
t
.
get_pserver_program
(
ps2
)
pserver_prog
=
t
.
get_pserver_program
(
ps2
)
# pserver2 have no parameter
# pserver2 have no parameter
...
...
python/paddle/fluid/tests/unittests/test_lookup_remote_table_op.py
浏览文件 @
985bceac
...
@@ -25,6 +25,7 @@ import paddle.fluid as fluid
...
@@ -25,6 +25,7 @@ import paddle.fluid as fluid
import
paddle.fluid.core
as
core
import
paddle.fluid.core
as
core
from
paddle.fluid.op
import
Operator
from
paddle.fluid.op
import
Operator
from
paddle.fluid.framework
import
Program
,
program_guard
from
paddle.fluid.framework
import
Program
,
program_guard
from
paddle.fluid.transpiler.distribute_transpiler
import
DistributedMode
from
dist_test_utils
import
*
from
dist_test_utils
import
*
...
@@ -53,7 +54,7 @@ def run_pserver(pserver_id, use_cuda, sync_mode):
...
@@ -53,7 +54,7 @@ def run_pserver(pserver_id, use_cuda, sync_mode):
"optimize_blocks"
:
[
optimize_block
],
"optimize_blocks"
:
[
optimize_block
],
"endpoint"
:
'127.0.0.1:0'
,
"endpoint"
:
'127.0.0.1:0'
,
"Fanin"
:
1
,
"Fanin"
:
1
,
"
sync_mode"
:
True
,
"
distributed_mode"
:
DistributedMode
.
SYNC
,
"grad_to_block_id"
:
[]
"grad_to_block_id"
:
[]
})
})
...
...
python/paddle/fluid/tests/unittests/test_nce_remote_table_op.py
浏览文件 @
985bceac
...
@@ -26,6 +26,7 @@ import paddle.fluid.core as core
...
@@ -26,6 +26,7 @@ import paddle.fluid.core as core
from
paddle.fluid.op
import
Operator
from
paddle.fluid.op
import
Operator
from
paddle.fluid.framework
import
Program
,
program_guard
from
paddle.fluid.framework
import
Program
,
program_guard
from
dist_test_utils
import
*
from
dist_test_utils
import
*
from
paddle.fluid.transpiler.distribute_transpiler
import
DistributedMode
def
nce
(
input
,
weight
,
bias
,
sample_weight
,
labels
,
num_classes
,
def
nce
(
input
,
weight
,
bias
,
sample_weight
,
labels
,
num_classes
,
...
@@ -92,7 +93,7 @@ def run_pserver(pserver_id, use_cuda, sync_mode):
...
@@ -92,7 +93,7 @@ def run_pserver(pserver_id, use_cuda, sync_mode):
"optimize_blocks"
:
[
optimize_block
],
"optimize_blocks"
:
[
optimize_block
],
"endpoint"
:
'127.0.0.1:0'
,
"endpoint"
:
'127.0.0.1:0'
,
"Fanin"
:
1
,
"Fanin"
:
1
,
"
sync_mode"
:
True
,
"
distributed_mode"
:
DistributedMode
.
SYNC
,
"grad_to_block_id"
:
[]
"grad_to_block_id"
:
[]
})
})
...
...
python/paddle/fluid/tests/unittests/test_recv_save_op.py
浏览文件 @
985bceac
...
@@ -29,6 +29,7 @@ from paddle.fluid.op import Operator
...
@@ -29,6 +29,7 @@ from paddle.fluid.op import Operator
from
paddle.fluid.framework
import
Program
,
program_guard
from
paddle.fluid.framework
import
Program
,
program_guard
from
paddle.fluid.transpiler.details
import
VarStruct
,
VarsDistributed
from
paddle.fluid.transpiler.details
import
VarStruct
,
VarsDistributed
from
dist_test_utils
import
*
from
dist_test_utils
import
*
from
paddle.fluid.transpiler.distribute_transpiler
import
DistributedMode
def
run_pserver
(
pserver_id
):
def
run_pserver
(
pserver_id
):
...
@@ -56,7 +57,7 @@ def run_pserver(pserver_id):
...
@@ -56,7 +57,7 @@ def run_pserver(pserver_id):
"optimize_blocks"
:
[
optimize_block
],
"optimize_blocks"
:
[
optimize_block
],
"endpoint"
:
'127.0.0.1:0'
,
"endpoint"
:
'127.0.0.1:0'
,
"Fanin"
:
1
,
"Fanin"
:
1
,
"
sync_mode"
:
True
,
"
distributed_mode"
:
DistributedMode
.
SYNC
,
"grad_to_block_id"
:
[]
"grad_to_block_id"
:
[]
})
})
...
...
python/paddle/fluid/transpiler/distribute_transpiler.py
浏览文件 @
985bceac
...
@@ -65,6 +65,13 @@ LR_SCHED_OP_ROLE_ATTR_VALUE = core.op_proto_and_checker_maker.OpRole.LRSched
...
@@ -65,6 +65,13 @@ LR_SCHED_OP_ROLE_ATTR_VALUE = core.op_proto_and_checker_maker.OpRole.LRSched
PRINT_LOG
=
False
PRINT_LOG
=
False
class
DistributedMode
:
SYNC
=
0
ASYNC
=
1
HALF_ASYNC
=
2
GEO
=
3
def
log
(
*
args
):
def
log
(
*
args
):
if
PRINT_LOG
:
if
PRINT_LOG
:
print
(
args
)
print
(
args
)
...
@@ -313,6 +320,13 @@ class DistributeTranspiler(object):
...
@@ -313,6 +320,13 @@ class DistributeTranspiler(object):
if
self
.
config
.
split_method
is
None
:
if
self
.
config
.
split_method
is
None
:
self
.
config
.
split_method
=
RoundRobin
self
.
config
.
split_method
=
RoundRobin
if
self
.
config
.
sync_mode
:
self
.
distributed_mode
=
DistributedMode
.
SYNC
elif
self
.
config
.
runtime_split_send_recv
:
self
.
distributed_mode
=
DistributedMode
.
ASYNC
else
:
self
.
distributed_mode
=
DistributedMode
.
HALF_ASYNC
global
PRINT_LOG
global
PRINT_LOG
if
self
.
config
.
print_log
:
if
self
.
config
.
print_log
:
PRINT_LOG
=
True
PRINT_LOG
=
True
...
@@ -1333,7 +1347,7 @@ class DistributeTranspiler(object):
...
@@ -1333,7 +1347,7 @@ class DistributeTranspiler(object):
"endpoint"
:
endpoint
,
"endpoint"
:
endpoint
,
"pserver_id"
:
self
.
pserver_endpoints
.
index
(
endpoint
),
"pserver_id"
:
self
.
pserver_endpoints
.
index
(
endpoint
),
"Fanin"
:
self
.
trainer_num
,
"Fanin"
:
self
.
trainer_num
,
"
sync_mode"
:
self
.
sync
_mode
,
"
distributed_mode"
:
self
.
distributed
_mode
,
"grad_to_block_id"
:
grad_to_block_id
,
"grad_to_block_id"
:
grad_to_block_id
,
"sparse_grad_to_param"
:
sparse_grad_to_param
,
"sparse_grad_to_param"
:
sparse_grad_to_param
,
"lr_decay_block_id"
:
lr_decay_block_id
,
"lr_decay_block_id"
:
lr_decay_block_id
,
...
...
python/paddle/fluid/transpiler/geo_sgd_transpiler.py
浏览文件 @
985bceac
...
@@ -38,7 +38,7 @@ from ..framework import Program, default_main_program, \
...
@@ -38,7 +38,7 @@ from ..framework import Program, default_main_program, \
from
.details
import
wait_server_ready
,
VarsDistributed
from
.details
import
wait_server_ready
,
VarsDistributed
from
.details
import
delete_ops
from
.details
import
delete_ops
from
..distribute_lookup_table
import
find_distributed_lookup_table
from
..distribute_lookup_table
import
find_distributed_lookup_table
from
.distribute_transpiler
import
DistributeTranspiler
,
DistributeTranspilerConfig
,
slice_variable
,
same_or_split_var
,
ServerRuntimeConfig
from
.distribute_transpiler
import
DistributeTranspiler
,
DistributeTranspilerConfig
,
slice_variable
,
same_or_split_var
,
ServerRuntimeConfig
,
DistributedMode
RPC_OP_ROLE_ATTR_NAME
=
op_role_attr_name
=
core
.
op_proto_and_checker_maker
.
kOpRoleAttrName
(
RPC_OP_ROLE_ATTR_NAME
=
op_role_attr_name
=
core
.
op_proto_and_checker_maker
.
kOpRoleAttrName
(
)
)
...
@@ -247,7 +247,7 @@ class GeoSgdTranspiler(DistributeTranspiler):
...
@@ -247,7 +247,7 @@ class GeoSgdTranspiler(DistributeTranspiler):
"optimize_blocks"
:
optimize_block
,
"optimize_blocks"
:
optimize_block
,
"endpoint"
:
endpoint
,
"endpoint"
:
endpoint
,
"Fanin"
:
self
.
trainer_num
,
"Fanin"
:
self
.
trainer_num
,
"
sync_mode"
:
self
.
sync_mode
,
"
distributed_mode"
:
DistributedMode
.
GEO
,
"grad_to_block_id"
:
param_to_block_id
,
"grad_to_block_id"
:
param_to_block_id
,
"sparse_grad_to_param"
:
sparse_grad_to_param
,
"sparse_grad_to_param"
:
sparse_grad_to_param
,
"rpc_get_thread_num"
:
self
.
server_config
.
_rpc_get_thread_num
,
"rpc_get_thread_num"
:
self
.
server_config
.
_rpc_get_thread_num
,
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录