Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
7fb817d4
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2297
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
未验证
提交
7fb817d4
编写于
1月 06, 2020
作者:
1
123malin
提交者:
GitHub
1月 06, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add distributed_strategy (#21710)
* add distributed_strategy
上级
ad8a9cb8
变更
15
隐藏空白更改
内联
并排
Showing
15 changed file
with
841 addition
and
160 deletion
+841
-160
paddle/fluid/operators/distributed/communicator.cc
paddle/fluid/operators/distributed/communicator.cc
+52
-51
paddle/fluid/operators/distributed/communicator.h
paddle/fluid/operators/distributed/communicator.h
+24
-11
paddle/fluid/operators/distributed_ops/listen_and_serv_op.cc
paddle/fluid/operators/distributed_ops/listen_and_serv_op.cc
+13
-7
paddle/fluid/pybind/communicator_py.cc
paddle/fluid/pybind/communicator_py.cc
+8
-4
python/paddle/fluid/communicator.py
python/paddle/fluid/communicator.py
+9
-3
python/paddle/fluid/incubate/fleet/parameter_server/distribute_transpiler/__init__.py
.../fleet/parameter_server/distribute_transpiler/__init__.py
+52
-30
python/paddle/fluid/incubate/fleet/parameter_server/distribute_transpiler/distributed_strategy.py
...eter_server/distribute_transpiler/distributed_strategy.py
+228
-0
python/paddle/fluid/tests/unittests/ctr_dataset_reader.py
python/paddle/fluid/tests/unittests/ctr_dataset_reader.py
+18
-0
python/paddle/fluid/tests/unittests/dist_fleet_ctr.py
python/paddle/fluid/tests/unittests/dist_fleet_ctr.py
+62
-7
python/paddle/fluid/tests/unittests/test_dist_fleet_base.py
python/paddle/fluid/tests/unittests/test_dist_fleet_base.py
+65
-32
python/paddle/fluid/tests/unittests/test_dist_fleet_ctr.py
python/paddle/fluid/tests/unittests/test_dist_fleet_ctr.py
+96
-2
python/paddle/fluid/tests/unittests/test_dist_fleet_geo.py
python/paddle/fluid/tests/unittests/test_dist_fleet_geo.py
+4
-3
python/paddle/fluid/tests/unittests/test_distributed_strategy.py
...paddle/fluid/tests/unittests/test_distributed_strategy.py
+169
-0
python/paddle/fluid/transpiler/distribute_transpiler.py
python/paddle/fluid/transpiler/distribute_transpiler.py
+34
-8
python/paddle/fluid/transpiler/geo_sgd_transpiler.py
python/paddle/fluid/transpiler/geo_sgd_transpiler.py
+7
-2
未找到文件。
paddle/fluid/operators/distributed/communicator.cc
浏览文件 @
7fb817d4
...
...
@@ -63,6 +63,43 @@ inline void VSUB(int n, const T *x, const T *y, T *z) {
}
}
void
Communicator
::
SetEnvFlagsDefault
()
{
env_flags_dict
.
clear
();
env_flags_dict
.
insert
(
std
::
pair
<
std
::
string
,
int
>
(
"independent_recv_thread"
,
FLAGS_communicator_independent_recv_thread
));
env_flags_dict
.
insert
(
std
::
pair
<
std
::
string
,
int
>
(
"send_queue_size"
,
FLAGS_communicator_send_queue_size
));
env_flags_dict
.
insert
(
std
::
pair
<
std
::
string
,
int
>
(
"min_send_grad_num_before_recv"
,
FLAGS_communicator_min_send_grad_num_before_recv
));
env_flags_dict
.
insert
(
std
::
pair
<
std
::
string
,
int
>
(
"thread_pool_size"
,
FLAGS_communicator_thread_pool_size
));
env_flags_dict
.
insert
(
std
::
pair
<
std
::
string
,
int
>
(
"send_wait_times"
,
FLAGS_communicator_send_wait_times
));
env_flags_dict
.
insert
(
std
::
pair
<
std
::
string
,
int
>
(
"max_merge_var_num"
,
FLAGS_communicator_max_merge_var_num
));
env_flags_dict
.
insert
(
std
::
pair
<
std
::
string
,
int
>
(
"fake_rpc"
,
FLAGS_communicator_fake_rpc
));
env_flags_dict
.
insert
(
std
::
pair
<
std
::
string
,
int
>
(
"merge_sparse_grad"
,
FLAGS_communicator_merge_sparse_grad
));
env_flags_dict
.
insert
(
std
::
pair
<
std
::
string
,
int
>
(
"is_sgd_optimizer"
,
FLAGS_communicator_is_sgd_optimizer
));
return
;
}
Communicator
::
Communicator
()
{
SetEnvFlagsDefault
();
}
Communicator
::
Communicator
(
const
std
::
map
<
std
::
string
,
int
>
&
env_flags
)
{
SetEnvFlagsDefault
();
for
(
auto
&
iter
:
env_flags
)
{
std
::
string
flag_name
=
iter
.
first
;
int
val_
=
iter
.
second
;
env_flags_dict
.
at
(
flag_name
)
=
val_
;
}
return
;
}
std
::
once_flag
Communicator
::
init_flag_
;
std
::
shared_ptr
<
Communicator
>
Communicator
::
communicator_
(
nullptr
);
...
...
@@ -73,25 +110,6 @@ void AsyncCommunicator::InitImpl(const RpcCtxMap &send_varname_to_ctx,
recv_varname_to_ctx_
=
std
::
move
(
recv_varname_to_ctx
);
recv_scope_
=
std
::
move
(
recv_scope
);
// get all send information from graph, build vars_to_send
VLOG
(
0
)
<<
"communicator_independent_recv_thread: "
<<
FLAGS_communicator_independent_recv_thread
;
VLOG
(
0
)
<<
"communicator_send_queue_size: "
<<
FLAGS_communicator_send_queue_size
;
VLOG
(
0
)
<<
"communicator_min_send_grad_num_before_recv: "
<<
FLAGS_communicator_min_send_grad_num_before_recv
;
VLOG
(
0
)
<<
"communicator_thread_pool_size: "
<<
FLAGS_communicator_thread_pool_size
;
VLOG
(
0
)
<<
"communicator_send_wait_times: "
<<
FLAGS_communicator_send_wait_times
;
VLOG
(
0
)
<<
"communicator_max_merge_var_num: "
<<
FLAGS_communicator_max_merge_var_num
;
VLOG
(
0
)
<<
"communicator_fake_rpc: "
<<
FLAGS_communicator_fake_rpc
;
VLOG
(
0
)
<<
"communicator_merge_sparse_grad: "
<<
FLAGS_communicator_merge_sparse_grad
;
VLOG
(
0
)
<<
"communicator_is_sgd_optimizer: "
<<
FLAGS_communicator_is_sgd_optimizer
;
if
(
send_varname_to_ctx
.
size
()
==
0
)
{
VLOG
(
0
)
<<
"nothing need to be send, will not start send_thread"
;
}
else
{
...
...
@@ -99,17 +117,17 @@ void AsyncCommunicator::InitImpl(const RpcCtxMap &send_varname_to_ctx,
for
(
auto
&
iter
:
send_varname_to_ctx_
)
{
send_varname_to_queue_
[
iter
.
first
]
=
std
::
make_shared
<
BlockingQueue
<
std
::
shared_ptr
<
Variable
>>>
(
FLAGS_communicator_send_queue_size
);
env_flags_dict
[
"send_queue_size"
]
);
}
send_threadpool_
.
reset
(
new
::
ThreadPool
(
FLAGS_communicator_thread_pool_size
));
new
::
ThreadPool
(
env_flags_dict
[
"thread_pool_size"
]
));
}
if
(
recv_varname_to_ctx
.
size
()
==
0
)
{
VLOG
(
0
)
<<
"nothing need to be received, will not start recv_thread"
;
}
else
{
recv_threadpool_
.
reset
(
new
::
ThreadPool
(
FLAGS_communicator_thread_pool_size
));
new
::
ThreadPool
(
env_flags_dict
[
"thread_pool_size"
]
));
}
}
...
...
@@ -132,7 +150,7 @@ void AsyncCommunicator::InitImpl(const paddle::framework::ProgramDesc &program,
auto
trainer_id
=
boost
::
get
<
int
>
(
op
->
GetNullableAttr
(
"trainer_id"
));
auto
merge_add
=
boost
::
get
<
bool
>
(
op
->
GetNullableAttr
(
"merge_add"
));
if
(
!
merge_add
)
{
merge_add
=
FLAGS_communicator_is_sgd_optimizer
;
merge_add
=
static_cast
<
bool
>
(
env_flags_dict
[
"is_sgd_optimizer"
])
;
}
auto
use_send_handler
=
boost
::
get
<
bool
>
(
op
->
GetNullableAttr
(
"use_send_handler"
));
...
...
@@ -194,10 +212,10 @@ void AsyncCommunicator::SendThread() {
std
::
vector
<
std
::
shared_ptr
<
Variable
>>
vars
;
int
merged_var_num
=
0
;
int
wait_times
=
0
;
while
(
merged_var_num
<
FLAGS_communicator_max_merge_var_num
)
{
while
(
merged_var_num
<
env_flags_dict
[
"max_merge_var_num"
]
)
{
if
(
var_queue
->
Size
()
==
0
)
{
VLOG
(
4
)
<<
"wait_times -> "
<<
wait_times
;
if
(
wait_times
>=
FLAGS_communicator_send_wait_times
)
{
if
(
wait_times
>=
env_flags_dict
[
"send_wait_times"
]
)
{
break
;
}
std
::
this_thread
::
sleep_for
(
std
::
chrono
::
milliseconds
(
10
));
...
...
@@ -226,7 +244,7 @@ void AsyncCommunicator::SendThread() {
VLOG
(
4
)
<<
"merge "
<<
merged_var_num
<<
" "
<<
var_name
<<
" use time "
<<
after_merge
-
before_merge
;
auto
send_functor
=
distributed
::
ParameterSend
<
float
>
();
if
(
!
FLAGS_communicator_fake_rpc
)
{
if
(
!
env_flags_dict
[
"fake_rpc"
]
)
{
send_functor
(
ctx
,
*
send_scope_
,
true
,
1
);
}
auto
after_send
=
GetCurrentUS
();
...
...
@@ -255,7 +273,7 @@ void AsyncCommunicator::RecvThread() {
VLOG
(
3
)
<<
"RecvThread start!"
;
while
(
running_
)
{
int
grad_num
=
grad_num_
.
load
();
if
(
grad_num
>
FLAGS_communicator_min_send_grad_num_before_recv
)
{
if
(
grad_num
>
env_flags_dict
[
"min_send_grad_num_before_recv"
]
)
{
VLOG
(
1
)
<<
"current grad num "
<<
grad_num
;
RecvAll
();
grad_num_
.
store
(
0
);
...
...
@@ -273,10 +291,10 @@ void AsyncCommunicator::Send(const std::string &var_name,
auto
*
grad_var
=
scope
.
FindVar
(
var_name
);
PADDLE_ENFORCE
(
grad_var
->
IsInitialized
(),
"grad var should be inited"
);
if
(
grad_var
->
IsType
<
framework
::
SelectedRows
>
()
&&
!
FLAGS_communicator_merge_sparse_grad
)
{
!
env_flags_dict
[
"merge_sparse_grad"
]
)
{
auto
send_functor
=
distributed
::
ParameterSend
<
float
>
();
auto
&
ctx
=
send_varname_to_ctx_
.
at
(
var_name
);
if
(
!
FLAGS_communicator_fake_rpc
)
{
if
(
!
env_flags_dict
[
"fake_rpc"
]
)
{
send_functor
(
ctx
,
scope
,
true
,
1
);
}
}
else
{
...
...
@@ -289,7 +307,7 @@ void AsyncCommunicator::Send(const std::string &var_name,
}
void
AsyncCommunicator
::
Recv
()
{
if
(
FLAGS_communicator_independent_recv_thread
)
{
if
(
env_flags_dict
[
"independent_recv_thread"
]
)
{
return
;
}
...
...
@@ -313,7 +331,7 @@ void AsyncCommunicator::RecvAll() {
auto
&
var_name
=
iter
.
first
;
VLOG
(
4
)
<<
"recv var "
<<
var_name
;
auto
recv_functor
=
distributed
::
ParameterRecv
<
float
>
();
if
(
!
FLAGS_communicator_fake_rpc
)
{
if
(
!
env_flags_dict
[
"fake_rpc"
]
)
{
recv_functor
(
iter
.
second
,
*
recv_scope_
);
}
};
...
...
@@ -336,7 +354,7 @@ void AsyncCommunicator::Start() {
// start send and recv thread
send_thread_
.
reset
(
new
std
::
thread
(
std
::
bind
(
&
AsyncCommunicator
::
SendThread
,
this
)));
if
(
FLAGS_communicator_independent_recv_thread
)
{
if
(
env_flags_dict
[
"independent_recv_thread"
]
)
{
recv_thread_
.
reset
(
new
std
::
thread
(
std
::
bind
(
&
AsyncCommunicator
::
RecvThread
,
this
)));
}
...
...
@@ -396,25 +414,8 @@ void GeoSgdCommunicator::InitImpl(
geo_need_push_nums_
=
std
::
move
(
geo_need_push_nums
);
// get all send information from graph, build vars_to_send
VLOG
(
0
)
<<
"communicator_independent_recv_thread: "
<<
FLAGS_communicator_independent_recv_thread
;
VLOG
(
0
)
<<
"communicator_send_queue_size: "
<<
FLAGS_communicator_send_queue_size
;
VLOG
(
0
)
<<
"communicator_min_send_grad_num_before_recv: "
<<
FLAGS_communicator_min_send_grad_num_before_recv
;
VLOG
(
0
)
<<
"communicator_thread_pool_size: "
<<
FLAGS_communicator_thread_pool_size
;
VLOG
(
0
)
<<
"communicator_send_wait_times: "
<<
FLAGS_communicator_send_wait_times
;
VLOG
(
0
)
<<
"communicator_max_merge_var_num: "
<<
FLAGS_communicator_max_merge_var_num
;
VLOG
(
0
)
<<
"communicator_fake_rpc: "
<<
FLAGS_communicator_fake_rpc
;
VLOG
(
0
)
<<
"communicator_merge_sparse_grad: "
<<
FLAGS_communicator_merge_sparse_grad
;
VLOG
(
0
)
<<
"Trainer nums: "
<<
trainer_nums_
;
VLOG
(
0
)
<<
"geo_sgd_push_before_local_train_nums: "
<<
geo_need_push_nums_
;
VLOG
(
0
)
<<
"communicator_merge_sparse_bucket "
<<
FLAGS_communicator_merge_sparse_bucket
;
// process var info from transpiler
for
(
auto
&
iter
:
vars_info
)
{
...
...
@@ -461,7 +462,7 @@ void GeoSgdCommunicator::InitImpl(
LOG
(
WARNING
)
<<
"no var need to send and recv!!"
;
}
send_threadpool_
.
reset
(
new
::
ThreadPool
(
FLAGS_communicator_thread_pool_size
));
send_threadpool_
.
reset
(
new
::
ThreadPool
(
env_flags_dict
[
"thread_pool_size"
]
));
need_push_queue_
=
std
::
make_shared
<
BlockingQueue
<
std
::
shared_ptr
<
SparseIdsMap
>>>
(
geo_need_push_nums
);
...
...
@@ -570,7 +571,7 @@ void GeoSgdCommunicator::SendThread() {
VLOG
(
4
)
<<
"ids_send_vec_ pushed"
;
}
else
if
(
need_push_queue_
->
Size
()
==
0
)
{
VLOG
(
4
)
<<
"wait_times -> "
<<
wait_times
;
if
(
wait_times
>=
FLAGS_communicator_send_wait_times
)
{
if
(
wait_times
>=
env_flags_dict
[
"send_wait_times"
]
)
{
break
;
}
std
::
this_thread
::
sleep_for
(
std
::
chrono
::
milliseconds
(
10
));
...
...
paddle/fluid/operators/distributed/communicator.h
浏览文件 @
7fb817d4
...
...
@@ -174,9 +174,12 @@ using RpcCtxMap = std::unordered_map<std::string, RpcContext>;
class
Communicator
{
public:
Communicator
()
{}
Communicator
();
explicit
Communicator
(
const
std
::
map
<
std
::
string
,
int
>&
env_flags
);
virtual
~
Communicator
()
{}
virtual
void
SetEnvFlagsDefault
();
virtual
void
Start
()
=
0
;
virtual
void
Stop
()
=
0
;
virtual
bool
IsRunning
()
{
return
running_
;
}
...
...
@@ -221,9 +224,10 @@ class Communicator {
template
<
typename
T
>
static
Communicator
*
InitInstance
(
const
paddle
::
framework
::
ProgramDesc
&
program
,
Scope
*
recv_scope
)
{
const
paddle
::
framework
::
ProgramDesc
&
program
,
Scope
*
recv_scope
,
const
std
::
map
<
std
::
string
,
int
>&
env_flags
)
{
std
::
call_once
(
init_flag_
,
&
Communicator
::
InitWithProgram
<
T
>
,
program
,
recv_scope
);
recv_scope
,
std
::
ref
(
env_flags
)
);
return
communicator_
.
get
();
}
...
...
@@ -232,10 +236,12 @@ class Communicator {
const
paddle
::
framework
::
ProgramDesc
&
program
,
Scope
*
training_scope
,
std
::
map
<
std
::
string
,
std
::
map
<
std
::
string
,
std
::
vector
<
std
::
string
>>>&
vars_info
,
const
int
&
trainers
,
const
int
&
geo_need_push_nums
)
{
const
int
&
trainers
,
const
int
&
geo_need_push_nums
,
const
std
::
map
<
std
::
string
,
int
>&
env_flags
)
{
std
::
call_once
(
init_flag_
,
&
Communicator
::
InitWithTranspilerInfo
<
T
>
,
program
,
training_scope
,
std
::
ref
(
vars_info
),
std
::
ref
(
trainers
),
std
::
ref
(
geo_need_push_nums
));
std
::
ref
(
trainers
),
std
::
ref
(
geo_need_push_nums
),
std
::
ref
(
env_flags
));
return
communicator_
.
get
();
}
...
...
@@ -253,9 +259,10 @@ class Communicator {
template
<
typename
T
>
static
void
InitWithProgram
(
const
paddle
::
framework
::
ProgramDesc
&
program
,
Scope
*
recv_scope
)
{
Scope
*
recv_scope
,
const
std
::
map
<
std
::
string
,
int
>&
env_flags
)
{
if
(
communicator_
.
get
()
==
nullptr
)
{
communicator_
.
reset
(
new
T
());
communicator_
.
reset
(
new
T
(
std
::
ref
(
env_flags
)
));
communicator_
->
InitImpl
(
program
,
recv_scope
);
}
}
...
...
@@ -265,9 +272,10 @@ class Communicator {
const
paddle
::
framework
::
ProgramDesc
&
program
,
Scope
*
training_scope
,
std
::
map
<
std
::
string
,
std
::
map
<
std
::
string
,
std
::
vector
<
std
::
string
>>>&
vars_info
,
const
int
&
trainers
,
const
int
&
geo_need_push_nums
)
{
const
int
&
trainers
,
const
int
&
geo_need_push_nums
,
const
std
::
map
<
std
::
string
,
int
>&
env_flags
)
{
if
(
communicator_
.
get
()
==
nullptr
)
{
communicator_
.
reset
(
new
T
());
communicator_
.
reset
(
new
T
(
std
::
ref
(
env_flags
)
));
communicator_
->
InitImpl
(
program
,
training_scope
,
std
::
ref
(
vars_info
),
std
::
ref
(
trainers
),
std
::
ref
(
geo_need_push_nums
));
}
...
...
@@ -277,6 +285,7 @@ class Communicator {
bool
running_
=
false
;
static
std
::
shared_ptr
<
Communicator
>
communicator_
;
static
std
::
once_flag
init_flag_
;
std
::
unordered_map
<
std
::
string
,
int
>
env_flags_dict
;
};
using
SparseIdsMap
=
...
...
@@ -284,7 +293,9 @@ using SparseIdsMap =
class
AsyncCommunicator
:
public
Communicator
{
public:
AsyncCommunicator
()
{}
AsyncCommunicator
()
:
Communicator
()
{}
explicit
AsyncCommunicator
(
const
std
::
map
<
std
::
string
,
int
>&
env_flags
)
:
Communicator
(
env_flags
)
{}
~
AsyncCommunicator
();
void
Start
()
override
;
void
Stop
()
override
;
...
...
@@ -331,7 +342,9 @@ class AsyncCommunicator : public Communicator {
class
GeoSgdCommunicator
:
public
Communicator
{
public:
GeoSgdCommunicator
()
{}
GeoSgdCommunicator
()
:
Communicator
()
{}
explicit
GeoSgdCommunicator
(
const
std
::
map
<
std
::
string
,
int
>&
env_flags
)
:
Communicator
(
env_flags
)
{}
~
GeoSgdCommunicator
();
void
InitImpl
(
const
paddle
::
framework
::
ProgramDesc
&
program
,
Scope
*
training_scope
,
...
...
paddle/fluid/operators/distributed_ops/listen_and_serv_op.cc
浏览文件 @
7fb817d4
...
...
@@ -356,6 +356,10 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
rpc_service_
.
reset
(
new
RPCSERVER_T
(
endpoint
,
fan_in
));
auto
rpc_get_thread_num
=
Attr
<
int
>
(
"rpc_get_thread_num"
);
auto
rpc_send_thread_num
=
Attr
<
int
>
(
"rpc_send_thread_num"
);
auto
rpc_prefetch_thread_num
=
Attr
<
int
>
(
"rpc_prefetch_thread_num"
);
request_send_handler_
.
reset
(
new
distributed
::
RequestSendHandler
(
sync_mode
,
dc_sgd
));
request_get_handler_
.
reset
(
...
...
@@ -370,21 +374,18 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
new
distributed
::
RequestNotifyHandler
(
sync_mode
,
lr_decay_block_id
));
rpc_service_
->
RegisterRPC
(
distributed
::
kRequestSend
,
request_send_handler_
.
get
(),
FLAGS_rpc_send_thread_num
);
request_send_handler_
.
get
(),
rpc_send_thread_num
);
rpc_service_
->
RegisterRPC
(
distributed
::
kRequestGet
,
request_get_handler_
.
get
(),
FLAGS_rpc_get_thread_num
);
request_get_handler_
.
get
(),
rpc_get_thread_num
);
rpc_service_
->
RegisterRPC
(
distributed
::
kRequestPrefetch
,
request_prefetch_handler_
.
get
(),
FLAGS_
rpc_prefetch_thread_num
);
rpc_prefetch_thread_num
);
rpc_service_
->
RegisterRPC
(
distributed
::
kRequestCheckpoint
,
request_checkpoint_handler_
.
get
());
rpc_service_
->
RegisterRPC
(
distributed
::
kRequestGetNoBarrier
,
request_get_no_barrier_handler_
.
get
());
rpc_service_
->
RegisterRPC
(
distributed
::
kRequestNotify
,
request_notify_handler_
.
get
(),
FLAGS_rpc_send_thread_num
);
request_notify_handler_
.
get
(),
rpc_send_thread_num
);
auto
optimize_blocks
=
Attr
<
std
::
vector
<
framework
::
BlockDesc
*>>
(
kOptimizeBlocks
);
...
...
@@ -549,6 +550,11 @@ class ListenAndServOpMaker : public framework::OpProtoAndCheckerMaker {
.
SetDefault
(
-
1
);
AddAttr
<
int
>
(
kLRDecayBlockId
,
"BolckID to run lr decay on pserer."
)
.
SetDefault
(
-
1
);
AddAttr
<
int
>
(
"rpc_get_thread_num"
,
"pserver get thread num."
).
SetDefault
(
1
);
AddAttr
<
int
>
(
"rpc_send_thread_num"
,
"pserver send thread num."
)
.
SetDefault
(
1
);
AddAttr
<
int
>
(
"rpc_prefetch_thread_num"
,
"pserver prefetch thread num."
)
.
SetDefault
(
1
);
}
};
...
...
paddle/fluid/pybind/communicator_py.cc
浏览文件 @
7fb817d4
...
...
@@ -39,19 +39,23 @@ void BindCommunicator(py::module* m) {
// Communicator is already used by nccl, change to DistCommunicator
py
::
class_
<
Communicator
,
std
::
shared_ptr
<
Communicator
>>
(
*
m
,
"DistCommunicator"
)
.
def
(
py
::
init
([](
const
ProgramDesc
&
program
,
Scope
*
param_scope
)
{
.
def
(
py
::
init
([](
const
ProgramDesc
&
program
,
Scope
*
param_scope
,
std
::
map
<
std
::
string
,
int
>&
env_flags
)
{
VLOG
(
0
)
<<
"using communicator"
;
Communicator
::
InitInstance
<
AsyncCommunicator
>
(
program
,
param_scope
);
Communicator
::
InitInstance
<
AsyncCommunicator
>
(
program
,
param_scope
,
env_flags
);
return
Communicator
::
GetInstantcePtr
();
}))
.
def
(
py
::
init
([](
const
ProgramDesc
&
program
,
Scope
*
training_scope
,
std
::
map
<
std
::
string
,
std
::
map
<
std
::
string
,
std
::
vector
<
std
::
string
>>>&
vars_info
,
int
&
trainers
,
int
&
geo_need_push_nums
)
{
int
&
trainers
,
int
&
geo_need_push_nums
,
std
::
map
<
std
::
string
,
int
>&
env_flags
)
{
VLOG
(
0
)
<<
"using geo sgd communicator"
;
Communicator
::
InitInstance
<
GeoSgdCommunicator
>
(
program
,
training_scope
,
vars_info
,
trainers
,
geo_need_push_nums
);
program
,
training_scope
,
vars_info
,
trainers
,
geo_need_push_nums
,
env_flags
);
return
Communicator
::
GetInstantcePtr
();
}))
.
def
(
"stop"
,
&
Communicator
::
Stop
)
...
...
python/paddle/fluid/communicator.py
浏览文件 @
7fb817d4
...
...
@@ -28,7 +28,8 @@ class Communicator(object):
program
,
vars_info
=
None
,
trainers
=
None
,
geo_sgd_need_push_nums
=
None
):
geo_sgd_need_push_nums
=
None
,
env_flags
=
None
):
"""
Communicator is used for async distribute training in distribute_transpiler mode.
It's a wrapper of a cpp class Communicator and should be used inside fleet API.
...
...
@@ -56,14 +57,19 @@ class Communicator(object):
if
op
.
type
==
"recv"
:
op
.
_set_attr
(
'do_not_run'
,
True
)
# Todo: Add check
if
env_flags
is
None
:
env_flags
=
{}
if
vars_info
and
trainers
and
geo_sgd_need_push_nums
:
# for geo sgd
self
.
communicator_
=
core
.
DistCommunicator
(
program
.
desc
,
global_scope
(),
vars_info
,
trainers
,
geo_sgd_need_push_nums
)
global_scope
(),
vars_info
,
trainers
,
geo_sgd_need_push_nums
,
env_flags
)
else
:
self
.
communicator_
=
core
.
DistCommunicator
(
program
.
desc
,
global_scope
())
global_scope
(),
env_flags
)
def
start
(
self
):
"""
...
...
python/paddle/fluid/incubate/fleet/parameter_server/distribute_transpiler/__init__.py
浏览文件 @
7fb817d4
...
...
@@ -11,11 +11,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
warnings
"""
Convert the fluid program to distributed data-parallelism programs.
"""
from
.distributed_strategy
import
*
import
paddle.fluid.io
as
io
from
paddle.fluid.communicator
import
Communicator
from
paddle.fluid.framework
import
default_main_program
...
...
@@ -26,8 +28,7 @@ from paddle.fluid.executor import Executor
from
paddle.fluid.parallel_executor
import
ParallelExecutor
from
paddle.fluid.optimizer
import
Optimizer
from
paddle.fluid.transpiler.distribute_transpiler
import
DistributeTranspiler
as
OriginTranspiler
from
paddle.fluid.transpiler.geo_sgd_transpiler
import
GeoSgdTranspiler
from
paddle.fluid.transpiler.distribute_transpiler
import
DistributeTranspilerConfig
from
paddle.fluid.transpiler.distribute_transpiler
import
DistributeTranspilerConfig
,
ServerRuntimeConfig
from
paddle.fluid.incubate.fleet.base.fleet_base
import
DistributedOptimizer
from
paddle.fluid.incubate.fleet.base.fleet_base
import
Fleet
...
...
@@ -66,15 +67,24 @@ class DistributedTranspiler(Fleet):
from
paddle.fluid.transpiler.details.checkport
import
wait_server_ready
wait_server_ready
(
fleet
.
server_endpoints
(
to_string
=
False
))
if
not
self
.
_transpile_config
.
sync_mode
:
if
self
.
_transpile_config
.
geo_sgd_mode
:
self
.
_communicator
=
Communicator
(
self
.
main_program
,
self
.
vars_info
,
fleet
.
worker_num
(),
self
.
_transpile_config
.
geo_sgd_need_push_nums
)
else
:
self
.
_communicator
=
Communicator
(
self
.
main_program
)
program_config
=
self
.
_transpile_config
.
get_program_config
()
trainer_communicator_config
=
self
.
_transpile_config
.
get_trainer_runtime_config
(
)
print
(
trainer_communicator_config
)
need_communicator_flag
=
False
if
isinstance
(
self
.
_transpile_config
,
GeoStrategy
):
need_communicator_flag
=
True
self
.
_communicator
=
Communicator
(
self
.
main_program
,
self
.
vars_info
,
fleet
.
worker_num
(),
program_config
.
geo_sgd_need_push_nums
,
trainer_communicator_config
.
get_communicator_flags
())
elif
isinstance
(
self
.
_transpile_config
,
AsyncStrategy
):
need_communicator_flag
=
True
self
.
_communicator
=
Communicator
(
self
.
main_program
,
env_flags
=
trainer_communicator_config
.
get_communicator_flags
())
if
need_communicator_flag
:
if
not
self
.
_communicator
.
is_running
():
self
.
_communicator
.
start
()
else
:
...
...
@@ -129,7 +139,8 @@ class DistributedTranspiler(Fleet):
Returns:
None
"""
if
not
self
.
_transpile_config
.
sync_mode
:
if
isinstance
(
self
.
_transpile_config
,
GeoStrategy
)
or
isinstance
(
self
.
_transpile_config
,
AsyncStrategy
):
self
.
_communicator
.
stop
()
self
.
_executor
.
close
()
if
isinstance
(
self
.
_role_maker
,
MPISymetricRoleMaker
):
...
...
@@ -239,36 +250,44 @@ class DistributedTranspiler(Fleet):
io
.
save_persistables
(
executor
,
dirname
,
main_program
,
None
)
def
_transpile
(
self
,
config
):
if
not
isinstance
(
config
,
DistributeTranspilerConfig
):
if
isinstance
(
config
,
DistributeTranspilerConfig
):
self
.
_transpile_config
=
DistributedStrategy
()
self
.
_transpile_config
.
set_program_config
(
config
)
elif
isinstance
(
config
,
DistributedStrategy
):
self
.
_transpile_config
=
config
else
:
raise
TypeError
(
"config must be an instance of DistributeTranspilerConfig"
)
"config must be an instance of DistributeTranspilerConfig or DistributedStrategy"
)
if
not
config
.
sync_mode
:
config
.
runtime_split_send_recv
=
True
program_config
=
self
.
_transpile_config
.
get_program_config
()
# _origin_program is a deep copy for default_main_program, for inference
self
.
_origin_program
=
default_main_program
().
clone
(
for_test
=
False
)
self
.
_transpile_config
=
config
if
config
.
geo_sgd_mode
:
self
.
_transpiler
=
GeoSgdTranspiler
(
config
)
if
program_config
.
geo_sgd_mode
:
from
paddle.fluid.transpiler.geo_sgd_transpiler
import
GeoSgdTranspiler
self
.
_transpiler
=
GeoSgdTranspiler
(
program_
config
)
else
:
self
.
_transpiler
=
OriginTranspiler
(
config
)
self
.
_transpiler
=
OriginTranspiler
(
program_config
)
self
.
_transpiler
.
_set_server_config
(
self
.
_transpile_config
.
get_server_runtime_config
())
if
self
.
is_worker
():
self
.
_transpiler
.
transpile
(
trainer_id
=
fleet
.
worker_index
(),
pservers
=
fleet
.
server_endpoints
(
to_string
=
True
),
trainers
=
fleet
.
worker_num
(),
sync_mode
=
config
.
sync_mode
)
sync_mode
=
program_
config
.
sync_mode
)
if
isinstance
(
self
.
_role_maker
,
MPISymetricRoleMaker
):
config
.
wait_port
=
False
program_config
.
wait_port
=
False
self
.
_transpile_config
.
set_program_config
(
program_config
)
self
.
main_program
=
self
.
_transpiler
.
get_trainer_program
(
wait_port
=
config
.
wait_port
)
wait_port
=
program_
config
.
wait_port
)
self
.
startup_program
=
default_startup_program
()
if
self
.
_transpile
_config
.
geo_sgd_mode
:
if
program
_config
.
geo_sgd_mode
:
self
.
vars_info
=
self
.
_transpiler
.
_get_vars_info
()
self
.
startup_program
=
self
.
_transpiler
.
trainer_startup_program
else
:
...
...
@@ -276,7 +295,7 @@ class DistributedTranspiler(Fleet):
trainer_id
=
fleet
.
worker_index
(),
pservers
=
fleet
.
server_endpoints
(
to_string
=
True
),
trainers
=
fleet
.
worker_num
(),
sync_mode
=
config
.
sync_mode
,
sync_mode
=
program_
config
.
sync_mode
,
current_endpoint
=
self
.
server_endpoints
()[
self
.
server_index
()])
self
.
main_program
,
self
.
startup_program
=
\
self
.
_transpiler
.
get_pserver_programs
(
...
...
@@ -308,14 +327,17 @@ class TranspilerOptimizer(DistributedOptimizer):
super
(
TranspilerOptimizer
,
self
).
__init__
(
optimizer
,
strategy
)
if
strategy
:
if
not
isinstance
(
strategy
,
DistributeTranspilerConfig
):
if
isinstance
(
strategy
,
DistributedStrategy
):
self
.
_strategy
=
strategy
elif
isinstance
(
strategy
,
DistributeTranspilerConfig
):
self
.
_strategy
=
DistributedStrategy
()
self
.
_strategy
.
set_program_config
(
strategy
)
else
:
raise
TypeError
(
"In {} mode, strategy must be an instance of DistributeTranspilerConfig"
.
"In {} mode, strategy must be an instance of DistributeTranspilerConfig
or DistributedStrategy
"
.
format
(
fleet
.
_mode
))
else
:
self
.
_strategy
=
strategy
else
:
self
.
_strategy
=
Distribute
TranspilerConfig
()
self
.
_strategy
=
Distribute
dStrategy
()
def
backward
(
self
,
loss
,
...
...
python/paddle/fluid/incubate/fleet/parameter_server/distribute_transpiler/distributed_strategy.py
0 → 100644
浏览文件 @
7fb817d4
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
__all__
=
[
"TrainerRuntimeConfig"
,
"DistributedStrategy"
,
"SyncStrategy"
,
"AsyncStrategy"
,
"HalfAsyncStrategy"
,
"GeoStrategy"
,
"StrategyFactory"
]
import
os
import
paddle.fluid
as
fluid
from
paddle.fluid.transpiler.distribute_transpiler
import
DistributeTranspilerConfig
,
ServerRuntimeConfig
class
TrainerRuntimeConfig
(
object
):
def
__init__
(
self
):
self
.
max_merge_var_num
=
int
(
os
.
getenv
(
"FLAGS_communicator_max_merge_var_num"
,
"20"
))
self
.
send_queue_size
=
int
(
os
.
getenv
(
"FLAGS_communicator_send_queue_size"
,
"20"
))
self
.
independent_recv_thread
=
int
(
os
.
getenv
(
"FLAGS_communicator_independent_recv_thread"
,
"1"
))
self
.
min_send_grad_num_before_recv
=
int
(
os
.
getenv
(
"FLAGS_communicator_min_send_grad_num_before_recv"
,
"20"
))
self
.
thread_pool_size
=
int
(
os
.
getenv
(
"FLAGS_communicator_thread_pool_size"
,
"5"
))
self
.
send_wait_times
=
int
(
os
.
getenv
(
"FLAGS_communicator_send_wait_times"
,
"5"
))
self
.
fake_rpc
=
int
(
os
.
getenv
(
"FLAGS_communicator_fake_rpc"
,
"0"
))
self
.
merge_sparse_grad
=
int
(
os
.
getenv
(
"FLAGS_communicator_merge_sparse_grad"
,
"1"
))
self
.
is_sgd_optimizer
=
int
(
os
.
getenv
(
"FLAGS_communicator_is_sgd_optimizer"
,
"1"
))
# not used
self
.
_rpc_deadline
=
int
(
os
.
getenv
(
"FLAGS_rpc_deadline"
,
"180000"
))
self
.
_rpc_retry_times
=
int
(
os
.
getenv
(
"FLAGS_rpc_retry_times"
,
"3"
))
def
get_communicator_flags
(
self
):
_communicator_flags
=
dict
()
_communicator_flags
[
"max_merge_var_num"
]
=
self
.
max_merge_var_num
_communicator_flags
[
"send_queue_size"
]
=
self
.
send_queue_size
_communicator_flags
[
"independent_recv_thread"
]
=
self
.
independent_recv_thread
_communicator_flags
[
"min_send_grad_num_before_recv"
]
=
self
.
min_send_grad_num_before_recv
_communicator_flags
[
"thread_pool_size"
]
=
self
.
thread_pool_size
_communicator_flags
[
"send_wait_times"
]
=
self
.
send_wait_times
_communicator_flags
[
"fake_rpc"
]
=
self
.
fake_rpc
_communicator_flags
[
"merge_sparse_grad"
]
=
self
.
merge_sparse_grad
_communicator_flags
[
"is_sgd_optimizer"
]
=
self
.
is_sgd_optimizer
return
_communicator_flags
def
__repr__
(
self
):
_str
=
"please check that TrainerRuntimeConfig is as expected:
\n
"
_communicator_flags
=
self
.
get_communicator_flags
()
for
key
in
_communicator_flags
:
_str
+=
"communicator_{}: {}
\n
"
.
format
(
key
,
_communicator_flags
[
key
])
return
_str
class
DistributedStrategy
(
object
):
def
__init__
(
self
):
self
.
_program_config
=
DistributeTranspilerConfig
()
self
.
_trainer_runtime_config
=
TrainerRuntimeConfig
()
self
.
_server_runtime_config
=
ServerRuntimeConfig
()
self
.
_execute_strategy
=
fluid
.
ExecutionStrategy
()
self
.
_build_strategy
=
fluid
.
BuildStrategy
()
num_threads
=
int
(
os
.
getenv
(
"CPU_NUM"
,
"1"
))
self
.
_execute_strategy
.
num_threads
=
num_threads
if
num_threads
>
1
:
self
.
_build_strategy
.
reduce_strategy
=
fluid
.
BuildStrategy
.
ReduceStrategy
.
Reduce
def
get_program_config
(
self
):
return
self
.
_program_config
def
set_program_config
(
self
,
config
):
if
isinstance
(
config
,
DistributeTranspilerConfig
):
self
.
_program_config
=
config
elif
isinstance
(
config
,
dict
):
for
key
in
config
:
if
hasattr
(
self
.
_program_config
,
key
):
setattr
(
self
.
_program_config
,
key
,
config
[
key
])
else
:
raise
ValueError
(
"DistributeTranspilerConfig doesn't have key: {}"
.
format
(
key
))
else
:
raise
TypeError
(
"program_config only accept input type: dict or DistributeTranspilerConfig"
)
def
get_trainer_runtime_config
(
self
):
return
self
.
_trainer_runtime_config
def
set_trainer_runtime_config
(
self
,
config
):
if
isinstance
(
config
,
TrainerRuntimeConfig
):
self
.
_trainer_runtime_config
=
config
elif
isinstance
(
config
,
dict
):
for
key
in
config
:
if
hasattr
(
self
.
_trainer_runtime_config
,
key
):
setattr
(
self
.
_trainer_runtime_config
,
key
,
config
[
key
])
else
:
raise
ValueError
(
"TrainerRuntimeConfig doesn't have key: {}"
.
format
(
key
))
else
:
raise
TypeError
(
"trainer_runtime_config only accept input type: dict or TrainerRuntimeConfig"
)
def
get_server_runtime_config
(
self
):
return
self
.
_server_runtime_config
def
set_server_runtime_config
(
self
,
config
):
if
isinstance
(
config
,
ServerRuntimeConfig
):
self
.
_server_runtime_config
=
config
elif
isinstance
(
config
,
dict
):
for
key
in
config
:
if
hasattr
(
self
.
_server_runtime_config
,
key
):
setattr
(
self
.
_server_runtime_config
,
key
,
config
[
key
])
else
:
raise
ValueError
(
"ServerRuntimeConfig doesn't have key: {}"
.
format
(
key
))
else
:
raise
TypeError
(
"server_runtime_config only accept input type: dict or ServerRuntimeConfig"
)
def
get_execute_strategy
(
self
):
return
self
.
_execute_strategy
def
set_execute_strategy
(
self
,
config
):
if
isinstance
(
config
,
fluid
.
ExecutionStrategy
):
self
.
_execute_strategy
=
config
elif
isinstance
(
config
,
dict
):
for
key
in
config
:
if
hasattr
(
self
.
_execute_strategy
,
key
):
setattr
(
self
.
_execute_strategy
,
key
,
config
[
key
])
else
:
raise
ValueError
(
"ExecutionStrategy doesn't have key: {}"
.
format
(
key
))
else
:
raise
TypeError
(
"execute_strategy only accept input type: dict or ExecutionStrategy"
)
def
get_build_strategy
(
self
):
return
self
.
_build_strategy
def
set_build_strategy
(
self
,
config
):
if
isinstance
(
config
,
fluid
.
BuildStrategy
):
self
.
_build_strategy
=
config
elif
isinstance
(
config
,
dict
):
for
key
in
config
:
if
hasattr
(
self
.
_build_strategy
,
key
):
setattr
(
self
.
_build_strategy
,
key
,
config
[
key
])
else
:
raise
ValueError
(
"BuildStrategy doesn't have key: {}"
.
format
(
key
))
else
:
raise
TypeError
(
"build_strategy only accept input type: dict or BuildStrategy"
)
class
SyncStrategy
(
DistributedStrategy
):
def
__init__
(
self
):
super
(
SyncStrategy
,
self
).
__init__
()
self
.
_program_config
.
sync_mode
=
True
self
.
_program_config
.
runtime_split_send_recv
=
False
self
.
_build_strategy
.
async_mode
=
False
class
AsyncStrategy
(
DistributedStrategy
):
def
__init__
(
self
):
super
(
AsyncStrategy
,
self
).
__init__
()
self
.
_program_config
.
sync_mode
=
False
self
.
_program_config
.
runtime_split_send_recv
=
True
self
.
_build_strategy
.
async_mode
=
True
class
HalfAsyncStrategy
(
DistributedStrategy
):
def
__init__
(
self
):
super
(
HalfAsyncStrategy
,
self
).
__init__
()
self
.
_program_config
.
sync_mode
=
False
self
.
_program_config
.
runtime_split_send_recv
=
False
self
.
_build_strategy
.
async_mode
=
False
class
GeoStrategy
(
DistributedStrategy
):
def
__init__
(
self
,
update_frequency
=
100
):
super
(
GeoStrategy
,
self
).
__init__
()
self
.
_program_config
.
sync_mode
=
False
self
.
_program_config
.
runtime_split_send_recv
=
True
self
.
_program_config
.
geo_sgd_mode
=
True
self
.
_program_config
.
geo_sgd_need_push_nums
=
update_frequency
self
.
_build_strategy
.
async_mode
=
True
class
StrategyFactory
(
object
):
def
__init_
(
self
):
pass
@
staticmethod
def
create_sync_strategy
():
return
SyncStrategy
()
@
staticmethod
def
create_half_async_strategy
():
return
HalfAsyncStrategy
()
@
staticmethod
def
create_async_strategy
():
return
AsyncStrategy
()
@
staticmethod
def
create_geo_strategy
(
update_frequency
=
100
):
return
GeoStrategy
(
update_frequency
)
python/paddle/fluid/tests/unittests/ctr_dataset_reader.py
浏览文件 @
7fb817d4
...
...
@@ -61,6 +61,24 @@ def load_lr_input_record(sent):
return
res
class
CtrReader
(
object
):
def
__init__
(
self
):
pass
def
_reader_creator
(
self
,
filelist
):
def
reader
():
for
file
in
filelist
:
with
open
(
file
,
'r'
)
as
f
:
for
line
in
f
:
fs
=
line
.
strip
().
split
(
'
\t
'
)
dnn_input
=
load_dnn_input_record
(
fs
[
0
])
lr_input
=
load_lr_input_record
(
fs
[
1
])
click
=
[
int
(
fs
[
2
])]
yield
[
dnn_input
]
+
[
lr_input
]
+
[
click
]
return
reader
class
DatasetCtrReader
(
data_generator
.
MultiSlotDataGenerator
):
def
generate_sample
(
self
,
line
):
def
get_rand
(
low
=
0.0
,
high
=
1.0
):
...
...
python/paddle/fluid/tests/unittests/dist_fleet_ctr.py
浏览文件 @
7fb817d4
...
...
@@ -21,8 +21,10 @@ import shutil
import
tempfile
import
time
import
paddle
import
paddle.fluid
as
fluid
import
os
import
numpy
as
np
import
ctr_dataset_reader
from
test_dist_fleet_base
import
runtime_main
,
FleetDistRunnerBase
...
...
@@ -131,7 +133,7 @@ class TestDistCTR2x2(FleetDistRunnerBase):
with
open
(
os
.
path
.
join
(
dirname
,
"__model__.proto"
),
"w"
)
as
wn
:
wn
.
write
(
str
(
program
))
def
do_training
(
self
,
fleet
):
def
do_
pyreader_
training
(
self
,
fleet
):
"""
do training using dataset, using fetch handler to catch variable
Args:
...
...
@@ -146,13 +148,63 @@ class TestDistCTR2x2(FleetDistRunnerBase):
exe
.
run
(
fleet
.
startup_program
)
thread_num
=
2
batch_size
=
128
filelist
=
[]
for
_
in
range
(
thread_num
):
filelist
.
append
(
train_file_path
)
train_reader
=
paddle
.
batch
(
paddle
.
reader
.
shuffle
(
ctr_dataset_reader
.
CtrReader
().
_reader_creator
(
filelist
),
buf_size
=
batch_size
*
100
),
batch_size
=
batch_size
)
self
.
reader
.
decorate_sample_list_generator
(
train_reader
)
compiled_prog
=
fluid
.
compiler
.
CompiledProgram
(
fleet
.
main_program
).
with_data_parallel
(
loss_name
=
self
.
avg_cost
.
name
,
build_strategy
=
self
.
strategy
.
get_build_strategy
(),
exec_strategy
=
self
.
strategy
.
get_execute_strategy
())
for
epoch_id
in
range
(
1
):
self
.
reader
.
start
()
try
:
pass_start
=
time
.
time
()
while
True
:
loss_val
=
exe
.
run
(
program
=
compiled_prog
,
fetch_list
=
[
self
.
avg_cost
.
name
])
loss_val
=
np
.
mean
(
loss_val
)
print
(
"TRAIN ---> pass: {} loss: {}
\n
"
.
format
(
epoch_id
,
loss_val
))
pass_time
=
time
.
time
()
-
pass_start
except
fluid
.
core
.
EOFException
:
self
.
reader
.
reset
()
model_dir
=
tempfile
.
mkdtemp
()
fleet
.
save_inference_model
(
exe
,
model_dir
,
[
feed
.
name
for
feed
in
self
.
feeds
],
self
.
avg_cost
)
self
.
check_model_right
(
model_dir
)
shutil
.
rmtree
(
model_dir
)
fleet
.
stop_worker
()
def
do_dataset_training
(
self
,
fleet
):
dnn_input_dim
,
lr_input_dim
,
train_file_path
=
ctr_dataset_reader
.
prepare_data
(
)
exe
=
fluid
.
Executor
(
fluid
.
CPUPlace
())
fleet
.
init_worker
()
exe
.
run
(
fleet
.
startup_program
)
thread_num
=
2
batch_size
=
128
filelist
=
[]
for
_
in
range
(
thread_num
):
filelist
.
append
(
train_file_path
)
# config dataset
dataset
=
fluid
.
DatasetFactory
().
create_dataset
()
dataset
.
set_batch_size
(
128
)
dataset
.
set_batch_size
(
batch_size
)
dataset
.
set_use_var
(
self
.
feeds
)
pipe_command
=
'python ctr_dataset_reader.py'
dataset
.
set_pipe_command
(
pipe_command
)
...
...
@@ -172,11 +224,14 @@ class TestDistCTR2x2(FleetDistRunnerBase):
debug
=
False
)
pass_time
=
time
.
time
()
-
pass_start
res_dict
=
dict
()
res_dict
[
'loss'
]
=
self
.
avg_cost
class
FH
(
fluid
.
executor
.
FetchHandler
):
def
handle
r
(
self
,
fetch_target_vars
):
for
i
in
range
(
len
(
fetch_target_vars
))
:
print
(
"{}:
\n
{}
\n
"
.
format
(
self
.
fetch_target_names
[
0
],
fetch_target_vars
[
0
]
))
def
handle
(
self
,
res_dict
):
for
key
in
res_dict
:
v
=
res_dict
[
key
]
print
(
"{}:
\n
{}
\n
"
.
format
(
key
,
v
))
for
epoch_id
in
range
(
1
):
pass_start
=
time
.
time
()
...
...
@@ -184,7 +239,7 @@ class TestDistCTR2x2(FleetDistRunnerBase):
exe
.
train_from_dataset
(
program
=
fleet
.
main_program
,
dataset
=
dataset
,
fetch_handler
=
FH
(
[
self
.
avg_cost
.
name
]
,
period_secs
=
2
),
fetch_handler
=
FH
(
var_dict
=
res_dict
,
period_secs
=
2
),
debug
=
False
)
pass_time
=
time
.
time
()
-
pass_start
...
...
python/paddle/fluid/tests/unittests/test_dist_fleet_base.py
浏览文件 @
7fb817d4
...
...
@@ -37,6 +37,7 @@ import paddle.fluid as fluid
import
paddle.fluid.incubate.fleet.base.role_maker
as
role_maker
from
paddle.fluid.incubate.fleet.parameter_server.distribute_transpiler
import
fleet
from
paddle.fluid.transpiler.distribute_transpiler
import
DistributeTranspilerConfig
from
paddle.fluid.incubate.fleet.parameter_server.distribute_transpiler.distributed_strategy
import
StrategyFactory
RUN_STEP
=
5
LEARNING_RATE
=
0.01
...
...
@@ -50,6 +51,19 @@ class FleetDistRunnerBase(object):
do training : exe run program
"""
def
generate_strategy
(
self
,
args
):
self
.
strategy
=
None
if
args
.
mode
==
"async"
:
self
.
strategy
=
StrategyFactory
.
create_async_strategy
()
elif
args
.
mode
==
"sync"
:
self
.
strategy
=
StrategyFactory
.
create_sync_strategy
()
elif
args
.
mode
==
"half_async"
:
self
.
strategy
=
StrategyFactory
.
create_half_async_strategy
()
elif
args
.
mode
==
"geo"
:
self
.
strategy
=
StrategyFactory
.
create_geo_strategy
(
args
.
geo_sgd_need_push_nums
)
return
self
.
strategy
def
run_pserver
(
self
,
args
):
if
args
.
role
.
upper
()
!=
"PSERVER"
:
raise
ValueError
(
"args role must be PSERVER"
)
...
...
@@ -62,10 +76,7 @@ class FleetDistRunnerBase(object):
fleet
.
init
(
role
)
strategy
=
DistributeTranspilerConfig
()
strategy
.
sync_mode
=
args
.
sync_mode
strategy
.
geo_sgd_mode
=
args
.
geo_sgd_mode
strategy
.
geo_sgd_need_push_nums
=
args
.
geo_sgd_need_push_nums
strategy
=
self
.
generate_strategy
(
args
)
avg_cost
=
self
.
net
()
...
...
@@ -76,7 +87,28 @@ class FleetDistRunnerBase(object):
fleet
.
init_server
()
fleet
.
run_server
()
def
run_trainer
(
self
,
args
):
def
run_dataset_trainer
(
self
,
args
):
if
args
.
role
.
upper
()
!=
"TRAINER"
:
raise
ValueError
(
"args role must be TRAINER"
)
role
=
role_maker
.
UserDefinedRoleMaker
(
current_id
=
args
.
current_id
,
role
=
role_maker
.
Role
.
WORKER
,
worker_num
=
args
.
trainers
,
server_endpoints
=
args
.
endpoints
.
split
(
","
))
fleet
.
init
(
role
)
strategy
=
self
.
generate_strategy
(
args
)
avg_cost
=
self
.
net
()
optimizer
=
fluid
.
optimizer
.
SGD
(
LEARNING_RATE
)
optimizer
=
fleet
.
distributed_optimizer
(
optimizer
,
strategy
)
optimizer
.
minimize
(
avg_cost
)
out
=
self
.
do_dataset_training
(
fleet
)
def
run_pyreader_trainer
(
self
,
args
):
if
args
.
role
.
upper
()
!=
"TRAINER"
:
raise
ValueError
(
"args role must be TRAINER"
)
...
...
@@ -88,26 +120,33 @@ class FleetDistRunnerBase(object):
fleet
.
init
(
role
)
strategy
=
DistributeTranspilerConfig
()
strategy
.
sync_mode
=
args
.
sync_mode
strategy
.
geo_sgd_mode
=
args
.
geo_sgd_mode
strategy
.
geo_sgd_need_push_nums
=
args
.
geo_sgd_need_push_nums
strategy
=
self
.
generate_strategy
(
args
)
avg_cost
=
self
.
net
()
self
.
reader
=
fluid
.
io
.
PyReader
(
feed_list
=
self
.
feeds
,
capacity
=
64
,
iterable
=
False
,
use_double_buffer
=
False
)
optimizer
=
fluid
.
optimizer
.
SGD
(
LEARNING_RATE
)
optimizer
=
fleet
.
distributed_optimizer
(
optimizer
,
strategy
)
optimizer
.
minimize
(
avg_cost
)
out
=
self
.
do_training
(
fleet
)
out
=
self
.
do_
pyreader_
training
(
fleet
)
def
net
(
self
,
batch_size
=
4
,
lr
=
0.01
):
raise
NotImplementedError
(
"get_model should be implemented by child classes."
)
def
do_training
(
self
,
fleet
):
def
do_
dataset_
training
(
self
,
fleet
):
raise
NotImplementedError
(
"do_training should be implemented by child classes."
)
"do_dataset_training should be implemented by child classes."
)
def
do_pyreader_training
(
self
,
fleet
):
raise
NotImplementedError
(
"do_pyreader_training should be implemented by child classes."
)
class
TestFleetBase
(
unittest
.
TestCase
):
...
...
@@ -120,7 +159,8 @@ class TestFleetBase(unittest.TestCase):
raise
NotImplementedError
(
"tests should have _setup_config implemented"
)
def
setUp
(
self
):
self
.
_sync_mode
=
True
self
.
_mode
=
"sync"
self
.
_reader
=
"pyreader"
self
.
_trainers
=
2
self
.
_pservers
=
2
self
.
_port_set
=
set
()
...
...
@@ -139,7 +179,6 @@ class TestFleetBase(unittest.TestCase):
self
.
_find_free_port
(),
self
.
_find_free_port
())
self
.
_python_interp
=
sys
.
executable
self
.
_geo_sgd
=
False
self
.
_geo_sgd_need_push_nums
=
5
self
.
_setup_config
()
...
...
@@ -203,21 +242,13 @@ class TestFleetBase(unittest.TestCase):
envs
[
'COVERAGE_FILE'
]
=
os
.
getenv
(
'COVERAGE_FILE'
,
''
)
python_path
+=
" -m coverage run --branch -p"
tr_cmd
=
"{0} {1} --role trainer --endpoints {2} --current_id {{}} --trainers {3}"
.
format
(
python_path
,
model
,
self
.
_ps_endpoints
,
self
.
_trainers
)
ps_cmd
=
"{0} {1} --role pserver --endpoints {2} --current_id {{}} --trainers {3}"
.
format
(
python_path
,
model
,
self
.
_ps_endpoints
,
self
.
_trainers
)
tr_cmd
=
"{0} {1} --role trainer --endpoints {2} --current_id {{}} --trainers {3} --mode {4} --geo_sgd_need_push_nums {5} --reader {6}"
.
format
(
python_path
,
model
,
self
.
_ps_endpoints
,
self
.
_trainers
,
self
.
_mode
,
self
.
_geo_sgd_need_push_nums
,
self
.
_reader
)
if
self
.
_sync_mode
:
tr_cmd
+=
" --sync_mode"
ps_cmd
+=
" --sync_mode"
if
self
.
_geo_sgd
:
tr_cmd
+=
" --geo_sgd_mode {0} --geo_sgd_need_push_nums {1}"
.
format
(
self
.
_geo_sgd
,
self
.
_geo_sgd_need_push_nums
)
ps_cmd
+=
" --geo_sgd_mode {0} --geo_sgd_need_push_nums {1}"
.
format
(
self
.
_geo_sgd
,
self
.
_geo_sgd_need_push_nums
)
ps_cmd
=
"{0} {1} --role pserver --endpoints {2} --current_id {{}} --trainers {3} --mode {4} --geo_sgd_need_push_nums {5} --reader {6}"
.
format
(
python_path
,
model
,
self
.
_ps_endpoints
,
self
.
_trainers
,
self
.
_mode
,
self
.
_geo_sgd_need_push_nums
,
self
.
_reader
)
# Run dist train to compare with local results
ps0
,
ps1
,
ps0_pipe
,
ps1_pipe
=
self
.
_start_pserver
(
ps_cmd
,
env
)
...
...
@@ -301,15 +332,17 @@ def runtime_main(test_class):
parser
.
add_argument
(
'--endpoints'
,
type
=
str
,
required
=
False
,
default
=
""
)
parser
.
add_argument
(
'--current_id'
,
type
=
int
,
required
=
False
,
default
=
0
)
parser
.
add_argument
(
'--trainers'
,
type
=
int
,
required
=
False
,
default
=
1
)
parser
.
add_argument
(
'--sync_mode'
,
action
=
'store_true'
)
parser
.
add_argument
(
'--geo_sgd_mode'
,
type
=
bool
,
required
=
False
,
default
=
False
)
parser
.
add_argument
(
'--mode'
,
type
=
str
,
required
=
False
,
default
=
'geo'
)
parser
.
add_argument
(
'--geo_sgd_need_push_nums'
,
type
=
int
,
required
=
False
,
default
=
2
)
parser
.
add_argument
(
'--reader'
,
type
=
str
,
required
=
False
,
default
=
'dataset'
)
args
=
parser
.
parse_args
()
model
=
test_class
()
if
args
.
role
==
"pserver"
:
model
.
run_pserver
(
args
)
else
:
model
.
run_trainer
(
args
)
if
args
.
reader
==
"dataset"
:
model
.
run_dataset_trainer
(
args
)
else
:
model
.
run_pyreader_trainer
(
args
)
python/paddle/fluid/tests/unittests/test_dist_fleet_ctr.py
浏览文件 @
7fb817d4
...
...
@@ -19,9 +19,103 @@ import unittest
from
test_dist_fleet_base
import
TestFleetBase
class
TestDistMnist2x2
(
TestFleetBase
):
class
TestDistMnist
Sync
2x2
(
TestFleetBase
):
def
_setup_config
(
self
):
self
.
_sync_mode
=
False
self
.
_mode
=
"sync"
self
.
_reader
=
"pyreader"
def
check_with_place
(
self
,
model_file
,
delta
=
1e-3
,
check_error_log
=
False
,
need_envs
=
{}):
required_envs
=
{
"PATH"
:
os
.
getenv
(
"PATH"
,
""
),
"PYTHONPATH"
:
os
.
getenv
(
"PYTHONPATH"
,
""
),
"LD_LIBRARY_PATH"
:
os
.
getenv
(
"LD_LIBRARY_PATH"
,
""
),
"FLAGS_rpc_deadline"
:
"5000"
,
# 5sec to fail fast
"http_proxy"
:
""
}
required_envs
.
update
(
need_envs
)
if
check_error_log
:
required_envs
[
"GLOG_v"
]
=
"3"
required_envs
[
"GLOG_logtostderr"
]
=
"1"
tr0_losses
,
tr1_losses
=
self
.
_run_cluster
(
model_file
,
required_envs
)
def
test_dist_train
(
self
):
self
.
check_with_place
(
"dist_fleet_ctr.py"
,
delta
=
1e-5
,
check_error_log
=
True
)
class
TestDistMnistHalfAsync2x2
(
TestFleetBase
):
def
_setup_config
(
self
):
self
.
_mode
=
"half_async"
self
.
_reader
=
"pyreader"
def
check_with_place
(
self
,
model_file
,
delta
=
1e-3
,
check_error_log
=
False
,
need_envs
=
{}):
required_envs
=
{
"PATH"
:
os
.
getenv
(
"PATH"
,
""
),
"PYTHONPATH"
:
os
.
getenv
(
"PYTHONPATH"
,
""
),
"LD_LIBRARY_PATH"
:
os
.
getenv
(
"LD_LIBRARY_PATH"
,
""
),
"FLAGS_rpc_deadline"
:
"5000"
,
# 5sec to fail fast
"http_proxy"
:
""
}
required_envs
.
update
(
need_envs
)
if
check_error_log
:
required_envs
[
"GLOG_v"
]
=
"3"
required_envs
[
"GLOG_logtostderr"
]
=
"1"
tr0_losses
,
tr1_losses
=
self
.
_run_cluster
(
model_file
,
required_envs
)
def
test_dist_train
(
self
):
self
.
check_with_place
(
"dist_fleet_ctr.py"
,
delta
=
1e-5
,
check_error_log
=
True
)
class
TestDistMnistAsync2x2
(
TestFleetBase
):
def
_setup_config
(
self
):
self
.
_mode
=
"async"
self
.
_reader
=
"pyreader"
def
check_with_place
(
self
,
model_file
,
delta
=
1e-3
,
check_error_log
=
False
,
need_envs
=
{}):
required_envs
=
{
"PATH"
:
os
.
getenv
(
"PATH"
,
""
),
"PYTHONPATH"
:
os
.
getenv
(
"PYTHONPATH"
,
""
),
"LD_LIBRARY_PATH"
:
os
.
getenv
(
"LD_LIBRARY_PATH"
,
""
),
"FLAGS_rpc_deadline"
:
"5000"
,
# 5sec to fail fast
"http_proxy"
:
""
}
required_envs
.
update
(
need_envs
)
if
check_error_log
:
required_envs
[
"GLOG_v"
]
=
"3"
required_envs
[
"GLOG_logtostderr"
]
=
"1"
tr0_losses
,
tr1_losses
=
self
.
_run_cluster
(
model_file
,
required_envs
)
def
test_dist_train
(
self
):
self
.
check_with_place
(
"dist_fleet_ctr.py"
,
delta
=
1e-5
,
check_error_log
=
True
)
class
TestDistMnistAsyncDataset2x2
(
TestFleetBase
):
def
_setup_config
(
self
):
self
.
_mode
=
"async"
self
.
_reader
=
"dataset"
def
check_with_place
(
self
,
model_file
,
...
...
python/paddle/fluid/tests/unittests/test_dist_fleet_geo.py
浏览文件 @
7fb817d4
...
...
@@ -19,15 +19,16 @@ import unittest
import
paddle.fluid
as
fluid
import
paddle.fluid.incubate.fleet.base.role_maker
as
role_maker
from
paddle.fluid.incubate.fleet.parameter_server.distribute_transpiler
import
fleet
from
paddle.fluid.transpiler.distribute_transpiler
import
DistributeTranspilerConfig
from
paddle.fluid.transpiler.distribute_transpiler
import
DistributeTranspilerConfig
,
ServerRuntimeConfig
from
paddle.fluid.transpiler.geo_sgd_transpiler
import
GeoSgdTranspiler
from
test_dist_fleet_base
import
TestFleetBase
from
dist_simnet_bow
import
train_network
class
TestDistGeoCtr_2x2
(
TestFleetBase
):
def
_setup_config
(
self
):
self
.
_
sync_mode
=
False
self
.
_
geo_sgd
=
True
self
.
_
mode
=
"geo"
self
.
_
reader
=
"dataset"
self
.
_geo_sgd_need_push_nums
=
5
def
check_with_place
(
self
,
...
...
python/paddle/fluid/tests/unittests/test_distributed_strategy.py
0 → 100644
浏览文件 @
7fb817d4
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
unittest
import
paddle.fluid
as
fluid
from
paddle.fluid.transpiler.distribute_transpiler
import
DistributeTranspilerConfig
,
ServerRuntimeConfig
from
paddle.fluid.incubate.fleet.parameter_server.distribute_transpiler.distributed_strategy
import
TrainerRuntimeConfig
,
StrategyFactory
import
os
class
TestStrategyFactor
(
unittest
.
TestCase
):
def
test_sync_strategy
(
self
):
os
.
environ
[
'CPU_NUM'
]
=
"2"
strategy
=
StrategyFactory
.
create_sync_strategy
()
self
.
assertEqual
(
strategy
.
_program_config
.
sync_mode
,
True
)
self
.
assertEqual
(
strategy
.
_program_config
.
runtime_split_send_recv
,
False
)
self
.
assertEqual
(
strategy
.
_build_strategy
.
async_mode
,
False
)
self
.
assertEqual
(
strategy
.
_execute_strategy
.
num_threads
,
2
)
# test set_program_config using DistributeTranspilerConfig()
program_config_class
=
DistributeTranspilerConfig
()
program_config_class
.
min_block_size
=
81920
strategy
.
set_program_config
(
program_config_class
)
program_config
=
strategy
.
get_program_config
()
self
.
assertEqual
(
program_config
.
min_block_size
,
81920
)
# test set_program_config using dict
program_config_dict
=
dict
()
program_config_dict
[
'min_block_size'
]
=
8192
strategy
.
set_program_config
(
program_config_dict
)
program_config
=
strategy
.
get_program_config
()
self
.
assertEqual
(
program_config
.
min_block_size
,
8192
)
# test set_program_config exception
program_config_dict
[
'unknown'
]
=
None
self
.
assertRaises
(
Exception
,
strategy
.
set_program_config
,
program_config_dict
)
program_config_illegal
=
None
self
.
assertRaises
(
Exception
,
strategy
.
set_program_config
,
program_config_illegal
)
def
test_geo_strategy
(
self
):
strategy
=
StrategyFactory
.
create_geo_strategy
(
5
)
self
.
assertEqual
(
strategy
.
_program_config
.
sync_mode
,
False
)
self
.
assertEqual
(
strategy
.
_program_config
.
runtime_split_send_recv
,
True
)
self
.
assertEqual
(
strategy
.
_program_config
.
geo_sgd_mode
,
True
)
self
.
assertEqual
(
strategy
.
_program_config
.
geo_sgd_need_push_nums
,
5
)
self
.
assertEqual
(
strategy
.
_build_strategy
.
async_mode
,
True
)
# test set_build_strategy using fluid.BuildStrategy
build_strategy_class
=
fluid
.
BuildStrategy
()
build_strategy_class
.
memory_optimize
=
False
strategy
.
set_build_strategy
(
build_strategy_class
)
build_strategy
=
strategy
.
get_build_strategy
()
self
.
assertEqual
(
build_strategy
.
memory_optimize
,
False
)
# test set_build_strategy using dict
build_strategy_dict
=
dict
()
build_strategy_dict
[
'memory_optimize'
]
=
True
strategy
.
set_build_strategy
(
build_strategy_dict
)
build_strategy
=
strategy
.
get_build_strategy
()
self
.
assertEqual
(
build_strategy
.
memory_optimize
,
True
)
# test set_build_strategy exception
build_strategy_dict
[
'unknown'
]
=
None
self
.
assertRaises
(
Exception
,
strategy
.
set_build_strategy
,
build_strategy_dict
)
build_strategy_illegal
=
None
self
.
assertRaises
(
Exception
,
strategy
.
set_build_strategy
,
build_strategy_illegal
)
def
test_async_strategy
(
self
):
strategy
=
StrategyFactory
.
create_async_strategy
()
self
.
assertEqual
(
strategy
.
_program_config
.
sync_mode
,
False
)
self
.
assertEqual
(
strategy
.
_program_config
.
runtime_split_send_recv
,
True
)
self
.
assertEqual
(
strategy
.
_build_strategy
.
async_mode
,
True
)
# test set_trainer_runtime_config using TrainerRuntimeConfig
trainer_runtime_config_class
=
TrainerRuntimeConfig
()
trainer_runtime_config_class
.
send_queue_size
=
50
print
(
trainer_runtime_config_class
)
strategy
.
set_trainer_runtime_config
(
trainer_runtime_config_class
)
trainer_runtime_config
=
strategy
.
get_trainer_runtime_config
()
self
.
assertEqual
(
trainer_runtime_config
.
send_queue_size
,
50
)
# test set_trainer_runtime_config using dict
trainer_runtime_config_dict
=
dict
()
trainer_runtime_config_dict
[
'send_queue_size'
]
=
100
strategy
.
set_trainer_runtime_config
(
trainer_runtime_config_dict
)
trainer_runtime_config
=
strategy
.
get_trainer_runtime_config
()
trainer_communicator_flags
=
trainer_runtime_config
.
get_communicator_flags
(
)
self
.
assertIn
(
'send_queue_size'
,
trainer_communicator_flags
)
self
.
assertEqual
(
trainer_communicator_flags
[
'send_queue_size'
],
100
)
# test set_trainer_runtime_config exception
trainer_runtime_config_dict
[
'unknown'
]
=
None
self
.
assertRaises
(
Exception
,
strategy
.
set_trainer_runtime_config
,
trainer_runtime_config_dict
)
trainer_runtime_config_illegal
=
None
self
.
assertRaises
(
Exception
,
strategy
.
set_trainer_runtime_config
,
trainer_runtime_config_illegal
)
# test set_execute_strategy using fluid.ExecutionStrategy
exec_strategy_class
=
fluid
.
ExecutionStrategy
()
exec_strategy_class
.
num_threads
=
4
strategy
.
set_execute_strategy
(
exec_strategy_class
)
exec_strategy
=
strategy
.
get_execute_strategy
()
self
.
assertEqual
(
exec_strategy
.
num_threads
,
4
)
# test set_execute_strategy using dict
exec_strategy_dict
=
dict
()
exec_strategy_dict
[
'num_threads'
]
=
8
strategy
.
set_execute_strategy
(
exec_strategy_dict
)
exec_strategy
=
strategy
.
get_execute_strategy
()
self
.
assertEqual
(
exec_strategy
.
num_threads
,
8
)
# test set_execute_strategy exception
exec_strategy_dict
[
'unknown'
]
=
None
self
.
assertRaises
(
Exception
,
strategy
.
set_execute_strategy
,
exec_strategy_dict
)
exec_strategy_illegal
=
None
self
.
assertRaises
(
Exception
,
strategy
.
set_execute_strategy
,
exec_strategy_illegal
)
def
test_half_async_strategy
(
self
):
strategy
=
StrategyFactory
.
create_half_async_strategy
()
self
.
assertEqual
(
strategy
.
_program_config
.
sync_mode
,
False
)
self
.
assertEqual
(
strategy
.
_program_config
.
runtime_split_send_recv
,
False
)
self
.
assertEqual
(
strategy
.
_build_strategy
.
async_mode
,
False
)
# test set_server_runtime_config using ServerRuntimeConfig
server_runtime_config_class
=
ServerRuntimeConfig
()
server_runtime_config_class
.
_rpc_send_thread_num
=
24
strategy
.
set_server_runtime_config
(
server_runtime_config_class
)
server_runtime_config
=
strategy
.
get_server_runtime_config
()
self
.
assertEqual
(
server_runtime_config
.
_rpc_send_thread_num
,
24
)
# test set_server_runtime_config using dict
server_runtime_config_dict
=
dict
()
server_runtime_config_dict
[
'_rpc_send_thread_num'
]
=
20
strategy
.
set_server_runtime_config
(
server_runtime_config_dict
)
server_runtime_config
=
strategy
.
get_server_runtime_config
()
self
.
assertEqual
(
server_runtime_config
.
_rpc_send_thread_num
,
20
)
# test set_server_runtime_config exception
server_runtime_config_dict
[
'unknown'
]
=
None
self
.
assertRaises
(
Exception
,
strategy
.
set_server_runtime_config
,
server_runtime_config_dict
)
server_runtime_config_illegal
=
None
self
.
assertRaises
(
Exception
,
strategy
.
set_server_runtime_config
,
server_runtime_config_illegal
)
if
__name__
==
'__main__'
:
unittest
.
main
()
python/paddle/fluid/transpiler/distribute_transpiler.py
浏览文件 @
7fb817d4
...
...
@@ -30,6 +30,7 @@ Steps to transpile pserver:
5. add listen_and_serv op
"""
import
os
import
sys
import
math
from
functools
import
reduce
...
...
@@ -177,8 +178,8 @@ class DistributeTranspilerConfig(object):
print_log
=
False
wait_port
=
True
# split the send recv var in runtime
_runtime_split_send_recv
=
False
_sync_mode
=
True
_
_
runtime_split_send_recv
=
False
_
_
sync_mode
=
True
# Geo-sgd algorithm
geo_sgd_mode
=
False
...
...
@@ -200,31 +201,41 @@ class DistributeTranspilerConfig(object):
@
property
def
runtime_split_send_recv
(
self
):
return
self
.
_runtime_split_send_recv
return
self
.
_
_
runtime_split_send_recv
@
runtime_split_send_recv
.
setter
def
runtime_split_send_recv
(
self
,
value
):
if
value
is
None
:
raise
ValueError
(
"runtime_split_send_recv can't be None"
)
if
value
and
self
.
_sync_mode
:
if
value
and
self
.
_
_
sync_mode
:
raise
ValueError
(
"if you want to set runtime_split_send_recv to be true, make ensure config.sync_mode is false at first"
)
self
.
_runtime_split_send_recv
=
value
self
.
_
_
runtime_split_send_recv
=
value
@
property
def
sync_mode
(
self
):
return
self
.
_sync_mode
return
self
.
_
_
sync_mode
@
sync_mode
.
setter
def
sync_mode
(
self
,
value
):
if
value
is
None
:
raise
ValueError
(
"sync_mode can't be None"
)
if
value
and
self
.
_runtime_split_send_recv
:
if
value
and
self
.
_
_
runtime_split_send_recv
:
raise
ValueError
(
"if you want to set sync_mode to be true, make ensure config.runtime_split_send_recv is false at first"
)
self
.
_sync_mode
=
value
self
.
__sync_mode
=
value
class
ServerRuntimeConfig
(
object
):
def
__init__
(
self
):
self
.
_rpc_send_thread_num
=
int
(
os
.
getenv
(
"FLAGS_rpc_send_thread_num"
,
"12"
))
self
.
_rpc_get_thread_num
=
int
(
os
.
getenv
(
"FLAGS_rpc_get_thread_num"
,
"12"
))
self
.
_rpc_prefetch_thread_num
=
int
(
os
.
getenv
(
"FLAGS_rpc_prefetch_thread_num"
,
"12"
))
class
DistributeTranspiler
(
object
):
...
...
@@ -295,6 +306,7 @@ class DistributeTranspiler(object):
self
.
config
=
config
else
:
self
.
config
=
DistributeTranspilerConfig
()
self
.
_set_server_config
()
if
self
.
config
.
split_method
is
None
:
self
.
config
.
split_method
=
RoundRobin
...
...
@@ -306,6 +318,16 @@ class DistributeTranspiler(object):
assert
(
self
.
config
.
split_method
.
__bases__
[
0
]
==
PSDispatcher
)
self
.
counter_var
=
None
def
_set_server_config
(
self
,
server_config
=
None
):
if
server_config
is
None
:
self
.
server_config
=
ServerRuntimeConfig
()
elif
isinstance
(
server_config
,
ServerRuntimeConfig
):
self
.
server_config
=
server_config
else
:
raise
TypeError
(
"In DistributeTranspiler, server_config must be an instance of ServerRuntimeConfig"
)
def
_transpile_nccl2
(
self
,
trainer_id
,
trainers
,
...
...
@@ -1313,6 +1335,10 @@ class DistributeTranspiler(object):
"grad_to_block_id"
:
grad_to_block_id
,
"sparse_grad_to_param"
:
sparse_grad_to_param
,
"lr_decay_block_id"
:
lr_decay_block_id
,
"rpc_get_thread_num"
:
self
.
server_config
.
_rpc_get_thread_num
,
"rpc_send_thread_num"
:
self
.
server_config
.
_rpc_send_thread_num
,
"rpc_prefetch_thread_num"
:
self
.
server_config
.
_rpc_prefetch_thread_num
}
if
self
.
has_distributed_lookup_table
:
...
...
python/paddle/fluid/transpiler/geo_sgd_transpiler.py
浏览文件 @
7fb817d4
...
...
@@ -38,7 +38,7 @@ from ..framework import Program, default_main_program, \
from
.details
import
wait_server_ready
,
VarsDistributed
from
.details
import
delete_ops
from
..distribute_lookup_table
import
find_distributed_lookup_table
from
.distribute_transpiler
import
DistributeTranspiler
,
DistributeTranspilerConfig
,
slice_variable
,
same_or_split_var
from
.distribute_transpiler
import
DistributeTranspiler
,
DistributeTranspilerConfig
,
slice_variable
,
same_or_split_var
,
ServerRuntimeConfig
RPC_OP_ROLE_ATTR_NAME
=
op_role_attr_name
=
core
.
op_proto_and_checker_maker
.
kOpRoleAttrName
(
)
...
...
@@ -51,6 +51,7 @@ class GeoSgdTranspiler(DistributeTranspiler):
self
.
config
=
config
else
:
self
.
config
=
DistributeTranspilerConfig
()
self
.
_set_server_config
()
if
self
.
config
.
split_method
is
None
:
self
.
config
.
split_method
=
RoundRobin
...
...
@@ -241,7 +242,11 @@ class GeoSgdTranspiler(DistributeTranspiler):
"Fanin"
:
self
.
trainer_num
,
"sync_mode"
:
self
.
sync_mode
,
"grad_to_block_id"
:
param_to_block_id
,
"sparse_grad_to_param"
:
sparse_grad_to_param
"sparse_grad_to_param"
:
sparse_grad_to_param
,
"rpc_get_thread_num"
:
self
.
server_config
.
_rpc_get_thread_num
,
"rpc_send_thread_num"
:
self
.
server_config
.
_rpc_send_thread_num
,
"rpc_prefetch_thread_num"
:
self
.
server_config
.
_rpc_prefetch_thread_num
}
# step5 append the listen_and_serv op
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录