Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
7fb817d4
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
7fb817d4
编写于
1月 06, 2020
作者:
1
123malin
提交者:
GitHub
1月 06, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add distributed_strategy (#21710)
* add distributed_strategy
上级
ad8a9cb8
变更
15
显示空白变更内容
内联
并排
Showing
15 changed file
with
841 addition
and
160 deletion
+841
-160
paddle/fluid/operators/distributed/communicator.cc
paddle/fluid/operators/distributed/communicator.cc
+52
-51
paddle/fluid/operators/distributed/communicator.h
paddle/fluid/operators/distributed/communicator.h
+24
-11
paddle/fluid/operators/distributed_ops/listen_and_serv_op.cc
paddle/fluid/operators/distributed_ops/listen_and_serv_op.cc
+13
-7
paddle/fluid/pybind/communicator_py.cc
paddle/fluid/pybind/communicator_py.cc
+8
-4
python/paddle/fluid/communicator.py
python/paddle/fluid/communicator.py
+9
-3
python/paddle/fluid/incubate/fleet/parameter_server/distribute_transpiler/__init__.py
.../fleet/parameter_server/distribute_transpiler/__init__.py
+52
-30
python/paddle/fluid/incubate/fleet/parameter_server/distribute_transpiler/distributed_strategy.py
...eter_server/distribute_transpiler/distributed_strategy.py
+228
-0
python/paddle/fluid/tests/unittests/ctr_dataset_reader.py
python/paddle/fluid/tests/unittests/ctr_dataset_reader.py
+18
-0
python/paddle/fluid/tests/unittests/dist_fleet_ctr.py
python/paddle/fluid/tests/unittests/dist_fleet_ctr.py
+62
-7
python/paddle/fluid/tests/unittests/test_dist_fleet_base.py
python/paddle/fluid/tests/unittests/test_dist_fleet_base.py
+65
-32
python/paddle/fluid/tests/unittests/test_dist_fleet_ctr.py
python/paddle/fluid/tests/unittests/test_dist_fleet_ctr.py
+96
-2
python/paddle/fluid/tests/unittests/test_dist_fleet_geo.py
python/paddle/fluid/tests/unittests/test_dist_fleet_geo.py
+4
-3
python/paddle/fluid/tests/unittests/test_distributed_strategy.py
...paddle/fluid/tests/unittests/test_distributed_strategy.py
+169
-0
python/paddle/fluid/transpiler/distribute_transpiler.py
python/paddle/fluid/transpiler/distribute_transpiler.py
+34
-8
python/paddle/fluid/transpiler/geo_sgd_transpiler.py
python/paddle/fluid/transpiler/geo_sgd_transpiler.py
+7
-2
未找到文件。
paddle/fluid/operators/distributed/communicator.cc
浏览文件 @
7fb817d4
...
...
@@ -63,6 +63,43 @@ inline void VSUB(int n, const T *x, const T *y, T *z) {
}
}
void
Communicator
::
SetEnvFlagsDefault
()
{
env_flags_dict
.
clear
();
env_flags_dict
.
insert
(
std
::
pair
<
std
::
string
,
int
>
(
"independent_recv_thread"
,
FLAGS_communicator_independent_recv_thread
));
env_flags_dict
.
insert
(
std
::
pair
<
std
::
string
,
int
>
(
"send_queue_size"
,
FLAGS_communicator_send_queue_size
));
env_flags_dict
.
insert
(
std
::
pair
<
std
::
string
,
int
>
(
"min_send_grad_num_before_recv"
,
FLAGS_communicator_min_send_grad_num_before_recv
));
env_flags_dict
.
insert
(
std
::
pair
<
std
::
string
,
int
>
(
"thread_pool_size"
,
FLAGS_communicator_thread_pool_size
));
env_flags_dict
.
insert
(
std
::
pair
<
std
::
string
,
int
>
(
"send_wait_times"
,
FLAGS_communicator_send_wait_times
));
env_flags_dict
.
insert
(
std
::
pair
<
std
::
string
,
int
>
(
"max_merge_var_num"
,
FLAGS_communicator_max_merge_var_num
));
env_flags_dict
.
insert
(
std
::
pair
<
std
::
string
,
int
>
(
"fake_rpc"
,
FLAGS_communicator_fake_rpc
));
env_flags_dict
.
insert
(
std
::
pair
<
std
::
string
,
int
>
(
"merge_sparse_grad"
,
FLAGS_communicator_merge_sparse_grad
));
env_flags_dict
.
insert
(
std
::
pair
<
std
::
string
,
int
>
(
"is_sgd_optimizer"
,
FLAGS_communicator_is_sgd_optimizer
));
return
;
}
Communicator
::
Communicator
()
{
SetEnvFlagsDefault
();
}
Communicator
::
Communicator
(
const
std
::
map
<
std
::
string
,
int
>
&
env_flags
)
{
SetEnvFlagsDefault
();
for
(
auto
&
iter
:
env_flags
)
{
std
::
string
flag_name
=
iter
.
first
;
int
val_
=
iter
.
second
;
env_flags_dict
.
at
(
flag_name
)
=
val_
;
}
return
;
}
std
::
once_flag
Communicator
::
init_flag_
;
std
::
shared_ptr
<
Communicator
>
Communicator
::
communicator_
(
nullptr
);
...
...
@@ -73,25 +110,6 @@ void AsyncCommunicator::InitImpl(const RpcCtxMap &send_varname_to_ctx,
recv_varname_to_ctx_
=
std
::
move
(
recv_varname_to_ctx
);
recv_scope_
=
std
::
move
(
recv_scope
);
// get all send information from graph, build vars_to_send
VLOG
(
0
)
<<
"communicator_independent_recv_thread: "
<<
FLAGS_communicator_independent_recv_thread
;
VLOG
(
0
)
<<
"communicator_send_queue_size: "
<<
FLAGS_communicator_send_queue_size
;
VLOG
(
0
)
<<
"communicator_min_send_grad_num_before_recv: "
<<
FLAGS_communicator_min_send_grad_num_before_recv
;
VLOG
(
0
)
<<
"communicator_thread_pool_size: "
<<
FLAGS_communicator_thread_pool_size
;
VLOG
(
0
)
<<
"communicator_send_wait_times: "
<<
FLAGS_communicator_send_wait_times
;
VLOG
(
0
)
<<
"communicator_max_merge_var_num: "
<<
FLAGS_communicator_max_merge_var_num
;
VLOG
(
0
)
<<
"communicator_fake_rpc: "
<<
FLAGS_communicator_fake_rpc
;
VLOG
(
0
)
<<
"communicator_merge_sparse_grad: "
<<
FLAGS_communicator_merge_sparse_grad
;
VLOG
(
0
)
<<
"communicator_is_sgd_optimizer: "
<<
FLAGS_communicator_is_sgd_optimizer
;
if
(
send_varname_to_ctx
.
size
()
==
0
)
{
VLOG
(
0
)
<<
"nothing need to be send, will not start send_thread"
;
}
else
{
...
...
@@ -99,17 +117,17 @@ void AsyncCommunicator::InitImpl(const RpcCtxMap &send_varname_to_ctx,
for
(
auto
&
iter
:
send_varname_to_ctx_
)
{
send_varname_to_queue_
[
iter
.
first
]
=
std
::
make_shared
<
BlockingQueue
<
std
::
shared_ptr
<
Variable
>>>
(
FLAGS_communicator_send_queue_size
);
env_flags_dict
[
"send_queue_size"
]
);
}
send_threadpool_
.
reset
(
new
::
ThreadPool
(
FLAGS_communicator_thread_pool_size
));
new
::
ThreadPool
(
env_flags_dict
[
"thread_pool_size"
]
));
}
if
(
recv_varname_to_ctx
.
size
()
==
0
)
{
VLOG
(
0
)
<<
"nothing need to be received, will not start recv_thread"
;
}
else
{
recv_threadpool_
.
reset
(
new
::
ThreadPool
(
FLAGS_communicator_thread_pool_size
));
new
::
ThreadPool
(
env_flags_dict
[
"thread_pool_size"
]
));
}
}
...
...
@@ -132,7 +150,7 @@ void AsyncCommunicator::InitImpl(const paddle::framework::ProgramDesc &program,
auto
trainer_id
=
boost
::
get
<
int
>
(
op
->
GetNullableAttr
(
"trainer_id"
));
auto
merge_add
=
boost
::
get
<
bool
>
(
op
->
GetNullableAttr
(
"merge_add"
));
if
(
!
merge_add
)
{
merge_add
=
FLAGS_communicator_is_sgd_optimizer
;
merge_add
=
static_cast
<
bool
>
(
env_flags_dict
[
"is_sgd_optimizer"
])
;
}
auto
use_send_handler
=
boost
::
get
<
bool
>
(
op
->
GetNullableAttr
(
"use_send_handler"
));
...
...
@@ -194,10 +212,10 @@ void AsyncCommunicator::SendThread() {
std
::
vector
<
std
::
shared_ptr
<
Variable
>>
vars
;
int
merged_var_num
=
0
;
int
wait_times
=
0
;
while
(
merged_var_num
<
FLAGS_communicator_max_merge_var_num
)
{
while
(
merged_var_num
<
env_flags_dict
[
"max_merge_var_num"
]
)
{
if
(
var_queue
->
Size
()
==
0
)
{
VLOG
(
4
)
<<
"wait_times -> "
<<
wait_times
;
if
(
wait_times
>=
FLAGS_communicator_send_wait_times
)
{
if
(
wait_times
>=
env_flags_dict
[
"send_wait_times"
]
)
{
break
;
}
std
::
this_thread
::
sleep_for
(
std
::
chrono
::
milliseconds
(
10
));
...
...
@@ -226,7 +244,7 @@ void AsyncCommunicator::SendThread() {
VLOG
(
4
)
<<
"merge "
<<
merged_var_num
<<
" "
<<
var_name
<<
" use time "
<<
after_merge
-
before_merge
;
auto
send_functor
=
distributed
::
ParameterSend
<
float
>
();
if
(
!
FLAGS_communicator_fake_rpc
)
{
if
(
!
env_flags_dict
[
"fake_rpc"
]
)
{
send_functor
(
ctx
,
*
send_scope_
,
true
,
1
);
}
auto
after_send
=
GetCurrentUS
();
...
...
@@ -255,7 +273,7 @@ void AsyncCommunicator::RecvThread() {
VLOG
(
3
)
<<
"RecvThread start!"
;
while
(
running_
)
{
int
grad_num
=
grad_num_
.
load
();
if
(
grad_num
>
FLAGS_communicator_min_send_grad_num_before_recv
)
{
if
(
grad_num
>
env_flags_dict
[
"min_send_grad_num_before_recv"
]
)
{
VLOG
(
1
)
<<
"current grad num "
<<
grad_num
;
RecvAll
();
grad_num_
.
store
(
0
);
...
...
@@ -273,10 +291,10 @@ void AsyncCommunicator::Send(const std::string &var_name,
auto
*
grad_var
=
scope
.
FindVar
(
var_name
);
PADDLE_ENFORCE
(
grad_var
->
IsInitialized
(),
"grad var should be inited"
);
if
(
grad_var
->
IsType
<
framework
::
SelectedRows
>
()
&&
!
FLAGS_communicator_merge_sparse_grad
)
{
!
env_flags_dict
[
"merge_sparse_grad"
]
)
{
auto
send_functor
=
distributed
::
ParameterSend
<
float
>
();
auto
&
ctx
=
send_varname_to_ctx_
.
at
(
var_name
);
if
(
!
FLAGS_communicator_fake_rpc
)
{
if
(
!
env_flags_dict
[
"fake_rpc"
]
)
{
send_functor
(
ctx
,
scope
,
true
,
1
);
}
}
else
{
...
...
@@ -289,7 +307,7 @@ void AsyncCommunicator::Send(const std::string &var_name,
}
void
AsyncCommunicator
::
Recv
()
{
if
(
FLAGS_communicator_independent_recv_thread
)
{
if
(
env_flags_dict
[
"independent_recv_thread"
]
)
{
return
;
}
...
...
@@ -313,7 +331,7 @@ void AsyncCommunicator::RecvAll() {
auto
&
var_name
=
iter
.
first
;
VLOG
(
4
)
<<
"recv var "
<<
var_name
;
auto
recv_functor
=
distributed
::
ParameterRecv
<
float
>
();
if
(
!
FLAGS_communicator_fake_rpc
)
{
if
(
!
env_flags_dict
[
"fake_rpc"
]
)
{
recv_functor
(
iter
.
second
,
*
recv_scope_
);
}
};
...
...
@@ -336,7 +354,7 @@ void AsyncCommunicator::Start() {
// start send and recv thread
send_thread_
.
reset
(
new
std
::
thread
(
std
::
bind
(
&
AsyncCommunicator
::
SendThread
,
this
)));
if
(
FLAGS_communicator_independent_recv_thread
)
{
if
(
env_flags_dict
[
"independent_recv_thread"
]
)
{
recv_thread_
.
reset
(
new
std
::
thread
(
std
::
bind
(
&
AsyncCommunicator
::
RecvThread
,
this
)));
}
...
...
@@ -396,25 +414,8 @@ void GeoSgdCommunicator::InitImpl(
geo_need_push_nums_
=
std
::
move
(
geo_need_push_nums
);
// get all send information from graph, build vars_to_send
VLOG
(
0
)
<<
"communicator_independent_recv_thread: "
<<
FLAGS_communicator_independent_recv_thread
;
VLOG
(
0
)
<<
"communicator_send_queue_size: "
<<
FLAGS_communicator_send_queue_size
;
VLOG
(
0
)
<<
"communicator_min_send_grad_num_before_recv: "
<<
FLAGS_communicator_min_send_grad_num_before_recv
;
VLOG
(
0
)
<<
"communicator_thread_pool_size: "
<<
FLAGS_communicator_thread_pool_size
;
VLOG
(
0
)
<<
"communicator_send_wait_times: "
<<
FLAGS_communicator_send_wait_times
;
VLOG
(
0
)
<<
"communicator_max_merge_var_num: "
<<
FLAGS_communicator_max_merge_var_num
;
VLOG
(
0
)
<<
"communicator_fake_rpc: "
<<
FLAGS_communicator_fake_rpc
;
VLOG
(
0
)
<<
"communicator_merge_sparse_grad: "
<<
FLAGS_communicator_merge_sparse_grad
;
VLOG
(
0
)
<<
"Trainer nums: "
<<
trainer_nums_
;
VLOG
(
0
)
<<
"geo_sgd_push_before_local_train_nums: "
<<
geo_need_push_nums_
;
VLOG
(
0
)
<<
"communicator_merge_sparse_bucket "
<<
FLAGS_communicator_merge_sparse_bucket
;
// process var info from transpiler
for
(
auto
&
iter
:
vars_info
)
{
...
...
@@ -461,7 +462,7 @@ void GeoSgdCommunicator::InitImpl(
LOG
(
WARNING
)
<<
"no var need to send and recv!!"
;
}
send_threadpool_
.
reset
(
new
::
ThreadPool
(
FLAGS_communicator_thread_pool_size
));
send_threadpool_
.
reset
(
new
::
ThreadPool
(
env_flags_dict
[
"thread_pool_size"
]
));
need_push_queue_
=
std
::
make_shared
<
BlockingQueue
<
std
::
shared_ptr
<
SparseIdsMap
>>>
(
geo_need_push_nums
);
...
...
@@ -570,7 +571,7 @@ void GeoSgdCommunicator::SendThread() {
VLOG
(
4
)
<<
"ids_send_vec_ pushed"
;
}
else
if
(
need_push_queue_
->
Size
()
==
0
)
{
VLOG
(
4
)
<<
"wait_times -> "
<<
wait_times
;
if
(
wait_times
>=
FLAGS_communicator_send_wait_times
)
{
if
(
wait_times
>=
env_flags_dict
[
"send_wait_times"
]
)
{
break
;
}
std
::
this_thread
::
sleep_for
(
std
::
chrono
::
milliseconds
(
10
));
...
...
paddle/fluid/operators/distributed/communicator.h
浏览文件 @
7fb817d4
...
...
@@ -174,9 +174,12 @@ using RpcCtxMap = std::unordered_map<std::string, RpcContext>;
class
Communicator
{
public:
Communicator
()
{}
Communicator
();
explicit
Communicator
(
const
std
::
map
<
std
::
string
,
int
>&
env_flags
);
virtual
~
Communicator
()
{}
virtual
void
SetEnvFlagsDefault
();
virtual
void
Start
()
=
0
;
virtual
void
Stop
()
=
0
;
virtual
bool
IsRunning
()
{
return
running_
;
}
...
...
@@ -221,9 +224,10 @@ class Communicator {
template
<
typename
T
>
static
Communicator
*
InitInstance
(
const
paddle
::
framework
::
ProgramDesc
&
program
,
Scope
*
recv_scope
)
{
const
paddle
::
framework
::
ProgramDesc
&
program
,
Scope
*
recv_scope
,
const
std
::
map
<
std
::
string
,
int
>&
env_flags
)
{
std
::
call_once
(
init_flag_
,
&
Communicator
::
InitWithProgram
<
T
>
,
program
,
recv_scope
);
recv_scope
,
std
::
ref
(
env_flags
)
);
return
communicator_
.
get
();
}
...
...
@@ -232,10 +236,12 @@ class Communicator {
const
paddle
::
framework
::
ProgramDesc
&
program
,
Scope
*
training_scope
,
std
::
map
<
std
::
string
,
std
::
map
<
std
::
string
,
std
::
vector
<
std
::
string
>>>&
vars_info
,
const
int
&
trainers
,
const
int
&
geo_need_push_nums
)
{
const
int
&
trainers
,
const
int
&
geo_need_push_nums
,
const
std
::
map
<
std
::
string
,
int
>&
env_flags
)
{
std
::
call_once
(
init_flag_
,
&
Communicator
::
InitWithTranspilerInfo
<
T
>
,
program
,
training_scope
,
std
::
ref
(
vars_info
),
std
::
ref
(
trainers
),
std
::
ref
(
geo_need_push_nums
));
std
::
ref
(
trainers
),
std
::
ref
(
geo_need_push_nums
),
std
::
ref
(
env_flags
));
return
communicator_
.
get
();
}
...
...
@@ -253,9 +259,10 @@ class Communicator {
template
<
typename
T
>
static
void
InitWithProgram
(
const
paddle
::
framework
::
ProgramDesc
&
program
,
Scope
*
recv_scope
)
{
Scope
*
recv_scope
,
const
std
::
map
<
std
::
string
,
int
>&
env_flags
)
{
if
(
communicator_
.
get
()
==
nullptr
)
{
communicator_
.
reset
(
new
T
());
communicator_
.
reset
(
new
T
(
std
::
ref
(
env_flags
)
));
communicator_
->
InitImpl
(
program
,
recv_scope
);
}
}
...
...
@@ -265,9 +272,10 @@ class Communicator {
const
paddle
::
framework
::
ProgramDesc
&
program
,
Scope
*
training_scope
,
std
::
map
<
std
::
string
,
std
::
map
<
std
::
string
,
std
::
vector
<
std
::
string
>>>&
vars_info
,
const
int
&
trainers
,
const
int
&
geo_need_push_nums
)
{
const
int
&
trainers
,
const
int
&
geo_need_push_nums
,
const
std
::
map
<
std
::
string
,
int
>&
env_flags
)
{
if
(
communicator_
.
get
()
==
nullptr
)
{
communicator_
.
reset
(
new
T
());
communicator_
.
reset
(
new
T
(
std
::
ref
(
env_flags
)
));
communicator_
->
InitImpl
(
program
,
training_scope
,
std
::
ref
(
vars_info
),
std
::
ref
(
trainers
),
std
::
ref
(
geo_need_push_nums
));
}
...
...
@@ -277,6 +285,7 @@ class Communicator {
bool
running_
=
false
;
static
std
::
shared_ptr
<
Communicator
>
communicator_
;
static
std
::
once_flag
init_flag_
;
std
::
unordered_map
<
std
::
string
,
int
>
env_flags_dict
;
};
using
SparseIdsMap
=
...
...
@@ -284,7 +293,9 @@ using SparseIdsMap =
class
AsyncCommunicator
:
public
Communicator
{
public:
AsyncCommunicator
()
{}
AsyncCommunicator
()
:
Communicator
()
{}
explicit
AsyncCommunicator
(
const
std
::
map
<
std
::
string
,
int
>&
env_flags
)
:
Communicator
(
env_flags
)
{}
~
AsyncCommunicator
();
void
Start
()
override
;
void
Stop
()
override
;
...
...
@@ -331,7 +342,9 @@ class AsyncCommunicator : public Communicator {
class
GeoSgdCommunicator
:
public
Communicator
{
public:
GeoSgdCommunicator
()
{}
GeoSgdCommunicator
()
:
Communicator
()
{}
explicit
GeoSgdCommunicator
(
const
std
::
map
<
std
::
string
,
int
>&
env_flags
)
:
Communicator
(
env_flags
)
{}
~
GeoSgdCommunicator
();
void
InitImpl
(
const
paddle
::
framework
::
ProgramDesc
&
program
,
Scope
*
training_scope
,
...
...
paddle/fluid/operators/distributed_ops/listen_and_serv_op.cc
浏览文件 @
7fb817d4
...
...
@@ -356,6 +356,10 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
rpc_service_
.
reset
(
new
RPCSERVER_T
(
endpoint
,
fan_in
));
auto
rpc_get_thread_num
=
Attr
<
int
>
(
"rpc_get_thread_num"
);
auto
rpc_send_thread_num
=
Attr
<
int
>
(
"rpc_send_thread_num"
);
auto
rpc_prefetch_thread_num
=
Attr
<
int
>
(
"rpc_prefetch_thread_num"
);
request_send_handler_
.
reset
(
new
distributed
::
RequestSendHandler
(
sync_mode
,
dc_sgd
));
request_get_handler_
.
reset
(
...
...
@@ -370,21 +374,18 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
new
distributed
::
RequestNotifyHandler
(
sync_mode
,
lr_decay_block_id
));
rpc_service_
->
RegisterRPC
(
distributed
::
kRequestSend
,
request_send_handler_
.
get
(),
FLAGS_rpc_send_thread_num
);
request_send_handler_
.
get
(),
rpc_send_thread_num
);
rpc_service_
->
RegisterRPC
(
distributed
::
kRequestGet
,
request_get_handler_
.
get
(),
FLAGS_rpc_get_thread_num
);
request_get_handler_
.
get
(),
rpc_get_thread_num
);
rpc_service_
->
RegisterRPC
(
distributed
::
kRequestPrefetch
,
request_prefetch_handler_
.
get
(),
FLAGS_
rpc_prefetch_thread_num
);
rpc_prefetch_thread_num
);
rpc_service_
->
RegisterRPC
(
distributed
::
kRequestCheckpoint
,
request_checkpoint_handler_
.
get
());
rpc_service_
->
RegisterRPC
(
distributed
::
kRequestGetNoBarrier
,
request_get_no_barrier_handler_
.
get
());
rpc_service_
->
RegisterRPC
(
distributed
::
kRequestNotify
,
request_notify_handler_
.
get
(),
FLAGS_rpc_send_thread_num
);
request_notify_handler_
.
get
(),
rpc_send_thread_num
);
auto
optimize_blocks
=
Attr
<
std
::
vector
<
framework
::
BlockDesc
*>>
(
kOptimizeBlocks
);
...
...
@@ -549,6 +550,11 @@ class ListenAndServOpMaker : public framework::OpProtoAndCheckerMaker {
.
SetDefault
(
-
1
);
AddAttr
<
int
>
(
kLRDecayBlockId
,
"BolckID to run lr decay on pserer."
)
.
SetDefault
(
-
1
);
AddAttr
<
int
>
(
"rpc_get_thread_num"
,
"pserver get thread num."
).
SetDefault
(
1
);
AddAttr
<
int
>
(
"rpc_send_thread_num"
,
"pserver send thread num."
)
.
SetDefault
(
1
);
AddAttr
<
int
>
(
"rpc_prefetch_thread_num"
,
"pserver prefetch thread num."
)
.
SetDefault
(
1
);
}
};
...
...
paddle/fluid/pybind/communicator_py.cc
浏览文件 @
7fb817d4
...
...
@@ -39,19 +39,23 @@ void BindCommunicator(py::module* m) {
// Communicator is already used by nccl, change to DistCommunicator
py
::
class_
<
Communicator
,
std
::
shared_ptr
<
Communicator
>>
(
*
m
,
"DistCommunicator"
)
.
def
(
py
::
init
([](
const
ProgramDesc
&
program
,
Scope
*
param_scope
)
{
.
def
(
py
::
init
([](
const
ProgramDesc
&
program
,
Scope
*
param_scope
,
std
::
map
<
std
::
string
,
int
>&
env_flags
)
{
VLOG
(
0
)
<<
"using communicator"
;
Communicator
::
InitInstance
<
AsyncCommunicator
>
(
program
,
param_scope
);
Communicator
::
InitInstance
<
AsyncCommunicator
>
(
program
,
param_scope
,
env_flags
);
return
Communicator
::
GetInstantcePtr
();
}))
.
def
(
py
::
init
([](
const
ProgramDesc
&
program
,
Scope
*
training_scope
,
std
::
map
<
std
::
string
,
std
::
map
<
std
::
string
,
std
::
vector
<
std
::
string
>>>&
vars_info
,
int
&
trainers
,
int
&
geo_need_push_nums
)
{
int
&
trainers
,
int
&
geo_need_push_nums
,
std
::
map
<
std
::
string
,
int
>&
env_flags
)
{
VLOG
(
0
)
<<
"using geo sgd communicator"
;
Communicator
::
InitInstance
<
GeoSgdCommunicator
>
(
program
,
training_scope
,
vars_info
,
trainers
,
geo_need_push_nums
);
program
,
training_scope
,
vars_info
,
trainers
,
geo_need_push_nums
,
env_flags
);
return
Communicator
::
GetInstantcePtr
();
}))
.
def
(
"stop"
,
&
Communicator
::
Stop
)
...
...
python/paddle/fluid/communicator.py
浏览文件 @
7fb817d4
...
...
@@ -28,7 +28,8 @@ class Communicator(object):
program
,
vars_info
=
None
,
trainers
=
None
,
geo_sgd_need_push_nums
=
None
):
geo_sgd_need_push_nums
=
None
,
env_flags
=
None
):
"""
Communicator is used for async distribute training in distribute_transpiler mode.
It's a wrapper of a cpp class Communicator and should be used inside fleet API.
...
...
@@ -56,14 +57,19 @@ class Communicator(object):
if
op
.
type
==
"recv"
:
op
.
_set_attr
(
'do_not_run'
,
True
)
# Todo: Add check
if
env_flags
is
None
:
env_flags
=
{}
if
vars_info
and
trainers
and
geo_sgd_need_push_nums
:
# for geo sgd
self
.
communicator_
=
core
.
DistCommunicator
(
program
.
desc
,
global_scope
(),
vars_info
,
trainers
,
geo_sgd_need_push_nums
)
global_scope
(),
vars_info
,
trainers
,
geo_sgd_need_push_nums
,
env_flags
)
else
:
self
.
communicator_
=
core
.
DistCommunicator
(
program
.
desc
,
global_scope
())
global_scope
(),
env_flags
)
def
start
(
self
):
"""
...
...
python/paddle/fluid/incubate/fleet/parameter_server/distribute_transpiler/__init__.py
浏览文件 @
7fb817d4
...
...
@@ -11,11 +11,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
warnings
"""
Convert the fluid program to distributed data-parallelism programs.
"""
from
.distributed_strategy
import
*
import
paddle.fluid.io
as
io
from
paddle.fluid.communicator
import
Communicator
from
paddle.fluid.framework
import
default_main_program
...
...
@@ -26,8 +28,7 @@ from paddle.fluid.executor import Executor
from
paddle.fluid.parallel_executor
import
ParallelExecutor
from
paddle.fluid.optimizer
import
Optimizer
from
paddle.fluid.transpiler.distribute_transpiler
import
DistributeTranspiler
as
OriginTranspiler
from
paddle.fluid.transpiler.geo_sgd_transpiler
import
GeoSgdTranspiler
from
paddle.fluid.transpiler.distribute_transpiler
import
DistributeTranspilerConfig
from
paddle.fluid.transpiler.distribute_transpiler
import
DistributeTranspilerConfig
,
ServerRuntimeConfig
from
paddle.fluid.incubate.fleet.base.fleet_base
import
DistributedOptimizer
from
paddle.fluid.incubate.fleet.base.fleet_base
import
Fleet
...
...
@@ -66,15 +67,24 @@ class DistributedTranspiler(Fleet):
from
paddle.fluid.transpiler.details.checkport
import
wait_server_ready
wait_server_ready
(
fleet
.
server_endpoints
(
to_string
=
False
))
if
not
self
.
_transpile_config
.
sync_mode
:
if
self
.
_transpile_config
.
geo_sgd_mode
:
program_config
=
self
.
_transpile_config
.
get_program_config
()
trainer_communicator_config
=
self
.
_transpile_config
.
get_trainer_runtime_config
(
)
print
(
trainer_communicator_config
)
need_communicator_flag
=
False
if
isinstance
(
self
.
_transpile_config
,
GeoStrategy
):
need_communicator_flag
=
True
self
.
_communicator
=
Communicator
(
self
.
main_program
,
self
.
vars_info
,
fleet
.
worker_num
(),
self
.
_transpile_config
.
geo_sgd_need_push_nums
)
else
:
self
.
_communicator
=
Communicator
(
self
.
main_program
)
fleet
.
worker_num
(),
program_config
.
geo_sgd_need_push_nums
,
trainer_communicator_config
.
get_communicator_flags
())
elif
isinstance
(
self
.
_transpile_config
,
AsyncStrategy
):
need_communicator_flag
=
True
self
.
_communicator
=
Communicator
(
self
.
main_program
,
env_flags
=
trainer_communicator_config
.
get_communicator_flags
())
if
need_communicator_flag
:
if
not
self
.
_communicator
.
is_running
():
self
.
_communicator
.
start
()
else
:
...
...
@@ -129,7 +139,8 @@ class DistributedTranspiler(Fleet):
Returns:
None
"""
if
not
self
.
_transpile_config
.
sync_mode
:
if
isinstance
(
self
.
_transpile_config
,
GeoStrategy
)
or
isinstance
(
self
.
_transpile_config
,
AsyncStrategy
):
self
.
_communicator
.
stop
()
self
.
_executor
.
close
()
if
isinstance
(
self
.
_role_maker
,
MPISymetricRoleMaker
):
...
...
@@ -239,36 +250,44 @@ class DistributedTranspiler(Fleet):
io
.
save_persistables
(
executor
,
dirname
,
main_program
,
None
)
def
_transpile
(
self
,
config
):
if
not
isinstance
(
config
,
DistributeTranspilerConfig
):
if
isinstance
(
config
,
DistributeTranspilerConfig
):
self
.
_transpile_config
=
DistributedStrategy
()
self
.
_transpile_config
.
set_program_config
(
config
)
elif
isinstance
(
config
,
DistributedStrategy
):
self
.
_transpile_config
=
config
else
:
raise
TypeError
(
"config must be an instance of DistributeTranspilerConfig"
)
"config must be an instance of DistributeTranspilerConfig or DistributedStrategy"
)
if
not
config
.
sync_mode
:
config
.
runtime_split_send_recv
=
True
program_config
=
self
.
_transpile_config
.
get_program_config
()
# _origin_program is a deep copy for default_main_program, for inference
self
.
_origin_program
=
default_main_program
().
clone
(
for_test
=
False
)
self
.
_transpile_config
=
config
if
config
.
geo_sgd_mode
:
self
.
_transpiler
=
GeoSgdTranspiler
(
config
)
if
program_config
.
geo_sgd_mode
:
from
paddle.fluid.transpiler.geo_sgd_transpiler
import
GeoSgdTranspiler
self
.
_transpiler
=
GeoSgdTranspiler
(
program_
config
)
else
:
self
.
_transpiler
=
OriginTranspiler
(
config
)
self
.
_transpiler
=
OriginTranspiler
(
program_config
)
self
.
_transpiler
.
_set_server_config
(
self
.
_transpile_config
.
get_server_runtime_config
())
if
self
.
is_worker
():
self
.
_transpiler
.
transpile
(
trainer_id
=
fleet
.
worker_index
(),
pservers
=
fleet
.
server_endpoints
(
to_string
=
True
),
trainers
=
fleet
.
worker_num
(),
sync_mode
=
config
.
sync_mode
)
sync_mode
=
program_
config
.
sync_mode
)
if
isinstance
(
self
.
_role_maker
,
MPISymetricRoleMaker
):
config
.
wait_port
=
False
program_config
.
wait_port
=
False
self
.
_transpile_config
.
set_program_config
(
program_config
)
self
.
main_program
=
self
.
_transpiler
.
get_trainer_program
(
wait_port
=
config
.
wait_port
)
wait_port
=
program_
config
.
wait_port
)
self
.
startup_program
=
default_startup_program
()
if
self
.
_transpile
_config
.
geo_sgd_mode
:
if
program
_config
.
geo_sgd_mode
:
self
.
vars_info
=
self
.
_transpiler
.
_get_vars_info
()
self
.
startup_program
=
self
.
_transpiler
.
trainer_startup_program
else
:
...
...
@@ -276,7 +295,7 @@ class DistributedTranspiler(Fleet):
trainer_id
=
fleet
.
worker_index
(),
pservers
=
fleet
.
server_endpoints
(
to_string
=
True
),
trainers
=
fleet
.
worker_num
(),
sync_mode
=
config
.
sync_mode
,
sync_mode
=
program_
config
.
sync_mode
,
current_endpoint
=
self
.
server_endpoints
()[
self
.
server_index
()])
self
.
main_program
,
self
.
startup_program
=
\
self
.
_transpiler
.
get_pserver_programs
(
...
...
@@ -308,14 +327,17 @@ class TranspilerOptimizer(DistributedOptimizer):
super
(
TranspilerOptimizer
,
self
).
__init__
(
optimizer
,
strategy
)
if
strategy
:
if
not
isinstance
(
strategy
,
DistributeTranspilerConfig
):
if
isinstance
(
strategy
,
DistributedStrategy
):
self
.
_strategy
=
strategy
elif
isinstance
(
strategy
,
DistributeTranspilerConfig
):
self
.
_strategy
=
DistributedStrategy
()
self
.
_strategy
.
set_program_config
(
strategy
)
else
:
raise
TypeError
(
"In {} mode, strategy must be an instance of DistributeTranspilerConfig"
.
"In {} mode, strategy must be an instance of DistributeTranspilerConfig
or DistributedStrategy
"
.
format
(
fleet
.
_mode
))
else
:
self
.
_strategy
=
strategy
else
:
self
.
_strategy
=
DistributeTranspilerConfig
()
self
.
_strategy
=
DistributedStrategy
()
def
backward
(
self
,
loss
,
...
...
python/paddle/fluid/incubate/fleet/parameter_server/distribute_transpiler/distributed_strategy.py
0 → 100644
浏览文件 @
7fb817d4
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
__all__
=
[
"TrainerRuntimeConfig"
,
"DistributedStrategy"
,
"SyncStrategy"
,
"AsyncStrategy"
,
"HalfAsyncStrategy"
,
"GeoStrategy"
,
"StrategyFactory"
]
import
os
import
paddle.fluid
as
fluid
from
paddle.fluid.transpiler.distribute_transpiler
import
DistributeTranspilerConfig
,
ServerRuntimeConfig
class
TrainerRuntimeConfig
(
object
):
def
__init__
(
self
):
self
.
max_merge_var_num
=
int
(
os
.
getenv
(
"FLAGS_communicator_max_merge_var_num"
,
"20"
))
self
.
send_queue_size
=
int
(
os
.
getenv
(
"FLAGS_communicator_send_queue_size"
,
"20"
))
self
.
independent_recv_thread
=
int
(
os
.
getenv
(
"FLAGS_communicator_independent_recv_thread"
,
"1"
))
self
.
min_send_grad_num_before_recv
=
int
(
os
.
getenv
(
"FLAGS_communicator_min_send_grad_num_before_recv"
,
"20"
))
self
.
thread_pool_size
=
int
(
os
.
getenv
(
"FLAGS_communicator_thread_pool_size"
,
"5"
))
self
.
send_wait_times
=
int
(
os
.
getenv
(
"FLAGS_communicator_send_wait_times"
,
"5"
))
self
.
fake_rpc
=
int
(
os
.
getenv
(
"FLAGS_communicator_fake_rpc"
,
"0"
))
self
.
merge_sparse_grad
=
int
(
os
.
getenv
(
"FLAGS_communicator_merge_sparse_grad"
,
"1"
))
self
.
is_sgd_optimizer
=
int
(
os
.
getenv
(
"FLAGS_communicator_is_sgd_optimizer"
,
"1"
))
# not used
self
.
_rpc_deadline
=
int
(
os
.
getenv
(
"FLAGS_rpc_deadline"
,
"180000"
))
self
.
_rpc_retry_times
=
int
(
os
.
getenv
(
"FLAGS_rpc_retry_times"
,
"3"
))
def
get_communicator_flags
(
self
):
_communicator_flags
=
dict
()
_communicator_flags
[
"max_merge_var_num"
]
=
self
.
max_merge_var_num
_communicator_flags
[
"send_queue_size"
]
=
self
.
send_queue_size
_communicator_flags
[
"independent_recv_thread"
]
=
self
.
independent_recv_thread
_communicator_flags
[
"min_send_grad_num_before_recv"
]
=
self
.
min_send_grad_num_before_recv
_communicator_flags
[
"thread_pool_size"
]
=
self
.
thread_pool_size
_communicator_flags
[
"send_wait_times"
]
=
self
.
send_wait_times
_communicator_flags
[
"fake_rpc"
]
=
self
.
fake_rpc
_communicator_flags
[
"merge_sparse_grad"
]
=
self
.
merge_sparse_grad
_communicator_flags
[
"is_sgd_optimizer"
]
=
self
.
is_sgd_optimizer
return
_communicator_flags
def
__repr__
(
self
):
_str
=
"please check that TrainerRuntimeConfig is as expected:
\n
"
_communicator_flags
=
self
.
get_communicator_flags
()
for
key
in
_communicator_flags
:
_str
+=
"communicator_{}: {}
\n
"
.
format
(
key
,
_communicator_flags
[
key
])
return
_str
class
DistributedStrategy
(
object
):
def
__init__
(
self
):
self
.
_program_config
=
DistributeTranspilerConfig
()
self
.
_trainer_runtime_config
=
TrainerRuntimeConfig
()
self
.
_server_runtime_config
=
ServerRuntimeConfig
()
self
.
_execute_strategy
=
fluid
.
ExecutionStrategy
()
self
.
_build_strategy
=
fluid
.
BuildStrategy
()
num_threads
=
int
(
os
.
getenv
(
"CPU_NUM"
,
"1"
))
self
.
_execute_strategy
.
num_threads
=
num_threads
if
num_threads
>
1
:
self
.
_build_strategy
.
reduce_strategy
=
fluid
.
BuildStrategy
.
ReduceStrategy
.
Reduce
def
get_program_config
(
self
):
return
self
.
_program_config
def
set_program_config
(
self
,
config
):
if
isinstance
(
config
,
DistributeTranspilerConfig
):
self
.
_program_config
=
config
elif
isinstance
(
config
,
dict
):
for
key
in
config
:
if
hasattr
(
self
.
_program_config
,
key
):
setattr
(
self
.
_program_config
,
key
,
config
[
key
])
else
:
raise
ValueError
(
"DistributeTranspilerConfig doesn't have key: {}"
.
format
(
key
))
else
:
raise
TypeError
(
"program_config only accept input type: dict or DistributeTranspilerConfig"
)
def
get_trainer_runtime_config
(
self
):
return
self
.
_trainer_runtime_config
def
set_trainer_runtime_config
(
self
,
config
):
if
isinstance
(
config
,
TrainerRuntimeConfig
):
self
.
_trainer_runtime_config
=
config
elif
isinstance
(
config
,
dict
):
for
key
in
config
:
if
hasattr
(
self
.
_trainer_runtime_config
,
key
):
setattr
(
self
.
_trainer_runtime_config
,
key
,
config
[
key
])
else
:
raise
ValueError
(
"TrainerRuntimeConfig doesn't have key: {}"
.
format
(
key
))
else
:
raise
TypeError
(
"trainer_runtime_config only accept input type: dict or TrainerRuntimeConfig"
)
def
get_server_runtime_config
(
self
):
return
self
.
_server_runtime_config
def
set_server_runtime_config
(
self
,
config
):
if
isinstance
(
config
,
ServerRuntimeConfig
):
self
.
_server_runtime_config
=
config
elif
isinstance
(
config
,
dict
):
for
key
in
config
:
if
hasattr
(
self
.
_server_runtime_config
,
key
):
setattr
(
self
.
_server_runtime_config
,
key
,
config
[
key
])
else
:
raise
ValueError
(
"ServerRuntimeConfig doesn't have key: {}"
.
format
(
key
))
else
:
raise
TypeError
(
"server_runtime_config only accept input type: dict or ServerRuntimeConfig"
)
def
get_execute_strategy
(
self
):
return
self
.
_execute_strategy
def
set_execute_strategy
(
self
,
config
):
if
isinstance
(
config
,
fluid
.
ExecutionStrategy
):
self
.
_execute_strategy
=
config
elif
isinstance
(
config
,
dict
):
for
key
in
config
:
if
hasattr
(
self
.
_execute_strategy
,
key
):
setattr
(
self
.
_execute_strategy
,
key
,
config
[
key
])
else
:
raise
ValueError
(
"ExecutionStrategy doesn't have key: {}"
.
format
(
key
))
else
:
raise
TypeError
(
"execute_strategy only accept input type: dict or ExecutionStrategy"
)
def
get_build_strategy
(
self
):
return
self
.
_build_strategy
def
set_build_strategy
(
self
,
config
):
if
isinstance
(
config
,
fluid
.
BuildStrategy
):
self
.
_build_strategy
=
config
elif
isinstance
(
config
,
dict
):
for
key
in
config
:
if
hasattr
(
self
.
_build_strategy
,
key
):
setattr
(
self
.
_build_strategy
,
key
,
config
[
key
])
else
:
raise
ValueError
(
"BuildStrategy doesn't have key: {}"
.
format
(
key
))
else
:
raise
TypeError
(
"build_strategy only accept input type: dict or BuildStrategy"
)
class
SyncStrategy
(
DistributedStrategy
):
def
__init__
(
self
):
super
(
SyncStrategy
,
self
).
__init__
()
self
.
_program_config
.
sync_mode
=
True
self
.
_program_config
.
runtime_split_send_recv
=
False
self
.
_build_strategy
.
async_mode
=
False
class
AsyncStrategy
(
DistributedStrategy
):
def
__init__
(
self
):
super
(
AsyncStrategy
,
self
).
__init__
()
self
.
_program_config
.
sync_mode
=
False
self
.
_program_config
.
runtime_split_send_recv
=
True
self
.
_build_strategy
.
async_mode
=
True
class
HalfAsyncStrategy
(
DistributedStrategy
):
def
__init__
(
self
):
super
(
HalfAsyncStrategy
,
self
).
__init__
()
self
.
_program_config
.
sync_mode
=
False
self
.
_program_config
.
runtime_split_send_recv
=
False
self
.
_build_strategy
.
async_mode
=
False
class
GeoStrategy
(
DistributedStrategy
):
def
__init__
(
self
,
update_frequency
=
100
):
super
(
GeoStrategy
,
self
).
__init__
()
self
.
_program_config
.
sync_mode
=
False
self
.
_program_config
.
runtime_split_send_recv
=
True
self
.
_program_config
.
geo_sgd_mode
=
True
self
.
_program_config
.
geo_sgd_need_push_nums
=
update_frequency
self
.
_build_strategy
.
async_mode
=
True
class
StrategyFactory
(
object
):
def
__init_
(
self
):
pass
@
staticmethod
def
create_sync_strategy
():
return
SyncStrategy
()
@
staticmethod
def
create_half_async_strategy
():
return
HalfAsyncStrategy
()
@
staticmethod
def
create_async_strategy
():
return
AsyncStrategy
()
@
staticmethod
def
create_geo_strategy
(
update_frequency
=
100
):
return
GeoStrategy
(
update_frequency
)
python/paddle/fluid/tests/unittests/ctr_dataset_reader.py
浏览文件 @
7fb817d4
...
...
@@ -61,6 +61,24 @@ def load_lr_input_record(sent):
return
res
class
CtrReader
(
object
):
def
__init__
(
self
):
pass
def
_reader_creator
(
self
,
filelist
):
def
reader
():
for
file
in
filelist
:
with
open
(
file
,
'r'
)
as
f
:
for
line
in
f
:
fs
=
line
.
strip
().
split
(
'
\t
'
)
dnn_input
=
load_dnn_input_record
(
fs
[
0
])
lr_input
=
load_lr_input_record
(
fs
[
1
])
click
=
[
int
(
fs
[
2
])]
yield
[
dnn_input
]
+
[
lr_input
]
+
[
click
]
return
reader
class
DatasetCtrReader
(
data_generator
.
MultiSlotDataGenerator
):
def
generate_sample
(
self
,
line
):
def
get_rand
(
low
=
0.0
,
high
=
1.0
):
...
...
python/paddle/fluid/tests/unittests/dist_fleet_ctr.py
浏览文件 @
7fb817d4
...
...
@@ -21,8 +21,10 @@ import shutil
import
tempfile
import
time
import
paddle
import
paddle.fluid
as
fluid
import
os
import
numpy
as
np
import
ctr_dataset_reader
from
test_dist_fleet_base
import
runtime_main
,
FleetDistRunnerBase
...
...
@@ -131,7 +133,7 @@ class TestDistCTR2x2(FleetDistRunnerBase):
with
open
(
os
.
path
.
join
(
dirname
,
"__model__.proto"
),
"w"
)
as
wn
:
wn
.
write
(
str
(
program
))
def
do_training
(
self
,
fleet
):
def
do_
pyreader_
training
(
self
,
fleet
):
"""
do training using dataset, using fetch handler to catch variable
Args:
...
...
@@ -146,13 +148,63 @@ class TestDistCTR2x2(FleetDistRunnerBase):
exe
.
run
(
fleet
.
startup_program
)
thread_num
=
2
batch_size
=
128
filelist
=
[]
for
_
in
range
(
thread_num
):
filelist
.
append
(
train_file_path
)
train_reader
=
paddle
.
batch
(
paddle
.
reader
.
shuffle
(
ctr_dataset_reader
.
CtrReader
().
_reader_creator
(
filelist
),
buf_size
=
batch_size
*
100
),
batch_size
=
batch_size
)
self
.
reader
.
decorate_sample_list_generator
(
train_reader
)
compiled_prog
=
fluid
.
compiler
.
CompiledProgram
(
fleet
.
main_program
).
with_data_parallel
(
loss_name
=
self
.
avg_cost
.
name
,
build_strategy
=
self
.
strategy
.
get_build_strategy
(),
exec_strategy
=
self
.
strategy
.
get_execute_strategy
())
for
epoch_id
in
range
(
1
):
self
.
reader
.
start
()
try
:
pass_start
=
time
.
time
()
while
True
:
loss_val
=
exe
.
run
(
program
=
compiled_prog
,
fetch_list
=
[
self
.
avg_cost
.
name
])
loss_val
=
np
.
mean
(
loss_val
)
print
(
"TRAIN ---> pass: {} loss: {}
\n
"
.
format
(
epoch_id
,
loss_val
))
pass_time
=
time
.
time
()
-
pass_start
except
fluid
.
core
.
EOFException
:
self
.
reader
.
reset
()
model_dir
=
tempfile
.
mkdtemp
()
fleet
.
save_inference_model
(
exe
,
model_dir
,
[
feed
.
name
for
feed
in
self
.
feeds
],
self
.
avg_cost
)
self
.
check_model_right
(
model_dir
)
shutil
.
rmtree
(
model_dir
)
fleet
.
stop_worker
()
def
do_dataset_training
(
self
,
fleet
):
dnn_input_dim
,
lr_input_dim
,
train_file_path
=
ctr_dataset_reader
.
prepare_data
(
)
exe
=
fluid
.
Executor
(
fluid
.
CPUPlace
())
fleet
.
init_worker
()
exe
.
run
(
fleet
.
startup_program
)
thread_num
=
2
batch_size
=
128
filelist
=
[]
for
_
in
range
(
thread_num
):
filelist
.
append
(
train_file_path
)
# config dataset
dataset
=
fluid
.
DatasetFactory
().
create_dataset
()
dataset
.
set_batch_size
(
128
)
dataset
.
set_batch_size
(
batch_size
)
dataset
.
set_use_var
(
self
.
feeds
)
pipe_command
=
'python ctr_dataset_reader.py'
dataset
.
set_pipe_command
(
pipe_command
)
...
...
@@ -172,11 +224,14 @@ class TestDistCTR2x2(FleetDistRunnerBase):
debug
=
False
)
pass_time
=
time
.
time
()
-
pass_start
res_dict
=
dict
()
res_dict
[
'loss'
]
=
self
.
avg_cost
class
FH
(
fluid
.
executor
.
FetchHandler
):
def
handle
r
(
self
,
fetch_target_vars
):
for
i
in
range
(
len
(
fetch_target_vars
))
:
print
(
"{}:
\n
{}
\n
"
.
format
(
self
.
fetch_target_names
[
0
],
fetch_target_vars
[
0
]
))
def
handle
(
self
,
res_dict
):
for
key
in
res_dict
:
v
=
res_dict
[
key
]
print
(
"{}:
\n
{}
\n
"
.
format
(
key
,
v
))
for
epoch_id
in
range
(
1
):
pass_start
=
time
.
time
()
...
...
@@ -184,7 +239,7 @@ class TestDistCTR2x2(FleetDistRunnerBase):
exe
.
train_from_dataset
(
program
=
fleet
.
main_program
,
dataset
=
dataset
,
fetch_handler
=
FH
(
[
self
.
avg_cost
.
name
]
,
period_secs
=
2
),
fetch_handler
=
FH
(
var_dict
=
res_dict
,
period_secs
=
2
),
debug
=
False
)
pass_time
=
time
.
time
()
-
pass_start
...
...
python/paddle/fluid/tests/unittests/test_dist_fleet_base.py
浏览文件 @
7fb817d4
...
...
@@ -37,6 +37,7 @@ import paddle.fluid as fluid
import
paddle.fluid.incubate.fleet.base.role_maker
as
role_maker
from
paddle.fluid.incubate.fleet.parameter_server.distribute_transpiler
import
fleet
from
paddle.fluid.transpiler.distribute_transpiler
import
DistributeTranspilerConfig
from
paddle.fluid.incubate.fleet.parameter_server.distribute_transpiler.distributed_strategy
import
StrategyFactory
RUN_STEP
=
5
LEARNING_RATE
=
0.01
...
...
@@ -50,6 +51,19 @@ class FleetDistRunnerBase(object):
do training : exe run program
"""
def
generate_strategy
(
self
,
args
):
self
.
strategy
=
None
if
args
.
mode
==
"async"
:
self
.
strategy
=
StrategyFactory
.
create_async_strategy
()
elif
args
.
mode
==
"sync"
:
self
.
strategy
=
StrategyFactory
.
create_sync_strategy
()
elif
args
.
mode
==
"half_async"
:
self
.
strategy
=
StrategyFactory
.
create_half_async_strategy
()
elif
args
.
mode
==
"geo"
:
self
.
strategy
=
StrategyFactory
.
create_geo_strategy
(
args
.
geo_sgd_need_push_nums
)
return
self
.
strategy
def
run_pserver
(
self
,
args
):
if
args
.
role
.
upper
()
!=
"PSERVER"
:
raise
ValueError
(
"args role must be PSERVER"
)
...
...
@@ -62,10 +76,7 @@ class FleetDistRunnerBase(object):
fleet
.
init
(
role
)
strategy
=
DistributeTranspilerConfig
()
strategy
.
sync_mode
=
args
.
sync_mode
strategy
.
geo_sgd_mode
=
args
.
geo_sgd_mode
strategy
.
geo_sgd_need_push_nums
=
args
.
geo_sgd_need_push_nums
strategy
=
self
.
generate_strategy
(
args
)
avg_cost
=
self
.
net
()
...
...
@@ -76,7 +87,28 @@ class FleetDistRunnerBase(object):
fleet
.
init_server
()
fleet
.
run_server
()
def
run_trainer
(
self
,
args
):
def
run_dataset_trainer
(
self
,
args
):
if
args
.
role
.
upper
()
!=
"TRAINER"
:
raise
ValueError
(
"args role must be TRAINER"
)
role
=
role_maker
.
UserDefinedRoleMaker
(
current_id
=
args
.
current_id
,
role
=
role_maker
.
Role
.
WORKER
,
worker_num
=
args
.
trainers
,
server_endpoints
=
args
.
endpoints
.
split
(
","
))
fleet
.
init
(
role
)
strategy
=
self
.
generate_strategy
(
args
)
avg_cost
=
self
.
net
()
optimizer
=
fluid
.
optimizer
.
SGD
(
LEARNING_RATE
)
optimizer
=
fleet
.
distributed_optimizer
(
optimizer
,
strategy
)
optimizer
.
minimize
(
avg_cost
)
out
=
self
.
do_dataset_training
(
fleet
)
def
run_pyreader_trainer
(
self
,
args
):
if
args
.
role
.
upper
()
!=
"TRAINER"
:
raise
ValueError
(
"args role must be TRAINER"
)
...
...
@@ -88,26 +120,33 @@ class FleetDistRunnerBase(object):
fleet
.
init
(
role
)
strategy
=
DistributeTranspilerConfig
()
strategy
.
sync_mode
=
args
.
sync_mode
strategy
.
geo_sgd_mode
=
args
.
geo_sgd_mode
strategy
.
geo_sgd_need_push_nums
=
args
.
geo_sgd_need_push_nums
strategy
=
self
.
generate_strategy
(
args
)
avg_cost
=
self
.
net
()
self
.
reader
=
fluid
.
io
.
PyReader
(
feed_list
=
self
.
feeds
,
capacity
=
64
,
iterable
=
False
,
use_double_buffer
=
False
)
optimizer
=
fluid
.
optimizer
.
SGD
(
LEARNING_RATE
)
optimizer
=
fleet
.
distributed_optimizer
(
optimizer
,
strategy
)
optimizer
.
minimize
(
avg_cost
)
out
=
self
.
do_training
(
fleet
)
out
=
self
.
do_
pyreader_
training
(
fleet
)
def
net
(
self
,
batch_size
=
4
,
lr
=
0.01
):
raise
NotImplementedError
(
"get_model should be implemented by child classes."
)
def
do_training
(
self
,
fleet
):
def
do_
dataset_
training
(
self
,
fleet
):
raise
NotImplementedError
(
"do_training should be implemented by child classes."
)
"do_dataset_training should be implemented by child classes."
)
def
do_pyreader_training
(
self
,
fleet
):
raise
NotImplementedError
(
"do_pyreader_training should be implemented by child classes."
)
class
TestFleetBase
(
unittest
.
TestCase
):
...
...
@@ -120,7 +159,8 @@ class TestFleetBase(unittest.TestCase):
raise
NotImplementedError
(
"tests should have _setup_config implemented"
)
def
setUp
(
self
):
self
.
_sync_mode
=
True
self
.
_mode
=
"sync"
self
.
_reader
=
"pyreader"
self
.
_trainers
=
2
self
.
_pservers
=
2
self
.
_port_set
=
set
()
...
...
@@ -139,7 +179,6 @@ class TestFleetBase(unittest.TestCase):
self
.
_find_free_port
(),
self
.
_find_free_port
())
self
.
_python_interp
=
sys
.
executable
self
.
_geo_sgd
=
False
self
.
_geo_sgd_need_push_nums
=
5
self
.
_setup_config
()
...
...
@@ -203,21 +242,13 @@ class TestFleetBase(unittest.TestCase):
envs
[
'COVERAGE_FILE'
]
=
os
.
getenv
(
'COVERAGE_FILE'
,
''
)
python_path
+=
" -m coverage run --branch -p"
tr_cmd
=
"{0} {1} --role trainer --endpoints {2} --current_id {{}} --trainers {3}"
.
format
(
python_path
,
model
,
self
.
_ps_endpoints
,
self
.
_trainers
)
ps_cmd
=
"{0} {1} --role pserver --endpoints {2} --current_id {{}} --trainers {3}"
.
format
(
python_path
,
model
,
self
.
_ps_endpoints
,
self
.
_trainers
)
tr_cmd
=
"{0} {1} --role trainer --endpoints {2} --current_id {{}} --trainers {3} --mode {4} --geo_sgd_need_push_nums {5} --reader {6}"
.
format
(
python_path
,
model
,
self
.
_ps_endpoints
,
self
.
_trainers
,
self
.
_mode
,
self
.
_geo_sgd_need_push_nums
,
self
.
_reader
)
if
self
.
_sync_mode
:
tr_cmd
+=
" --sync_mode"
ps_cmd
+=
" --sync_mode"
if
self
.
_geo_sgd
:
tr_cmd
+=
" --geo_sgd_mode {0} --geo_sgd_need_push_nums {1}"
.
format
(
self
.
_geo_sgd
,
self
.
_geo_sgd_need_push_nums
)
ps_cmd
+=
" --geo_sgd_mode {0} --geo_sgd_need_push_nums {1}"
.
format
(
self
.
_geo_sgd
,
self
.
_geo_sgd_need_push_nums
)
ps_cmd
=
"{0} {1} --role pserver --endpoints {2} --current_id {{}} --trainers {3} --mode {4} --geo_sgd_need_push_nums {5} --reader {6}"
.
format
(
python_path
,
model
,
self
.
_ps_endpoints
,
self
.
_trainers
,
self
.
_mode
,
self
.
_geo_sgd_need_push_nums
,
self
.
_reader
)
# Run dist train to compare with local results
ps0
,
ps1
,
ps0_pipe
,
ps1_pipe
=
self
.
_start_pserver
(
ps_cmd
,
env
)
...
...
@@ -301,15 +332,17 @@ def runtime_main(test_class):
parser
.
add_argument
(
'--endpoints'
,
type
=
str
,
required
=
False
,
default
=
""
)
parser
.
add_argument
(
'--current_id'
,
type
=
int
,
required
=
False
,
default
=
0
)
parser
.
add_argument
(
'--trainers'
,
type
=
int
,
required
=
False
,
default
=
1
)
parser
.
add_argument
(
'--sync_mode'
,
action
=
'store_true'
)
parser
.
add_argument
(
'--geo_sgd_mode'
,
type
=
bool
,
required
=
False
,
default
=
False
)
parser
.
add_argument
(
'--mode'
,
type
=
str
,
required
=
False
,
default
=
'geo'
)
parser
.
add_argument
(
'--geo_sgd_need_push_nums'
,
type
=
int
,
required
=
False
,
default
=
2
)
parser
.
add_argument
(
'--reader'
,
type
=
str
,
required
=
False
,
default
=
'dataset'
)
args
=
parser
.
parse_args
()
model
=
test_class
()
if
args
.
role
==
"pserver"
:
model
.
run_pserver
(
args
)
else
:
model
.
run_trainer
(
args
)
if
args
.
reader
==
"dataset"
:
model
.
run_dataset_trainer
(
args
)
else
:
model
.
run_pyreader_trainer
(
args
)
python/paddle/fluid/tests/unittests/test_dist_fleet_ctr.py
浏览文件 @
7fb817d4
...
...
@@ -19,9 +19,103 @@ import unittest
from
test_dist_fleet_base
import
TestFleetBase
class
TestDistMnist2x2
(
TestFleetBase
):
class
TestDistMnist
Sync
2x2
(
TestFleetBase
):
def
_setup_config
(
self
):
self
.
_sync_mode
=
False
self
.
_mode
=
"sync"
self
.
_reader
=
"pyreader"
def
check_with_place
(
self
,
model_file
,
delta
=
1e-3
,
check_error_log
=
False
,
need_envs
=
{}):
required_envs
=
{
"PATH"
:
os
.
getenv
(
"PATH"
,
""
),
"PYTHONPATH"
:
os
.
getenv
(
"PYTHONPATH"
,
""
),
"LD_LIBRARY_PATH"
:
os
.
getenv
(
"LD_LIBRARY_PATH"
,
""
),
"FLAGS_rpc_deadline"
:
"5000"
,
# 5sec to fail fast
"http_proxy"
:
""
}
required_envs
.
update
(
need_envs
)
if
check_error_log
:
required_envs
[
"GLOG_v"
]
=
"3"
required_envs
[
"GLOG_logtostderr"
]
=
"1"
tr0_losses
,
tr1_losses
=
self
.
_run_cluster
(
model_file
,
required_envs
)
def
test_dist_train
(
self
):
self
.
check_with_place
(
"dist_fleet_ctr.py"
,
delta
=
1e-5
,
check_error_log
=
True
)
class
TestDistMnistHalfAsync2x2
(
TestFleetBase
):
def
_setup_config
(
self
):
self
.
_mode
=
"half_async"
self
.
_reader
=
"pyreader"
def
check_with_place
(
self
,
model_file
,
delta
=
1e-3
,
check_error_log
=
False
,
need_envs
=
{}):
required_envs
=
{
"PATH"
:
os
.
getenv
(
"PATH"
,
""
),
"PYTHONPATH"
:
os
.
getenv
(
"PYTHONPATH"
,
""
),
"LD_LIBRARY_PATH"
:
os
.
getenv
(
"LD_LIBRARY_PATH"
,
""
),
"FLAGS_rpc_deadline"
:
"5000"
,
# 5sec to fail fast
"http_proxy"
:
""
}
required_envs
.
update
(
need_envs
)
if
check_error_log
:
required_envs
[
"GLOG_v"
]
=
"3"
required_envs
[
"GLOG_logtostderr"
]
=
"1"
tr0_losses
,
tr1_losses
=
self
.
_run_cluster
(
model_file
,
required_envs
)
def
test_dist_train
(
self
):
self
.
check_with_place
(
"dist_fleet_ctr.py"
,
delta
=
1e-5
,
check_error_log
=
True
)
class
TestDistMnistAsync2x2
(
TestFleetBase
):
def
_setup_config
(
self
):
self
.
_mode
=
"async"
self
.
_reader
=
"pyreader"
def
check_with_place
(
self
,
model_file
,
delta
=
1e-3
,
check_error_log
=
False
,
need_envs
=
{}):
required_envs
=
{
"PATH"
:
os
.
getenv
(
"PATH"
,
""
),
"PYTHONPATH"
:
os
.
getenv
(
"PYTHONPATH"
,
""
),
"LD_LIBRARY_PATH"
:
os
.
getenv
(
"LD_LIBRARY_PATH"
,
""
),
"FLAGS_rpc_deadline"
:
"5000"
,
# 5sec to fail fast
"http_proxy"
:
""
}
required_envs
.
update
(
need_envs
)
if
check_error_log
:
required_envs
[
"GLOG_v"
]
=
"3"
required_envs
[
"GLOG_logtostderr"
]
=
"1"
tr0_losses
,
tr1_losses
=
self
.
_run_cluster
(
model_file
,
required_envs
)
def
test_dist_train
(
self
):
self
.
check_with_place
(
"dist_fleet_ctr.py"
,
delta
=
1e-5
,
check_error_log
=
True
)
class
TestDistMnistAsyncDataset2x2
(
TestFleetBase
):
def
_setup_config
(
self
):
self
.
_mode
=
"async"
self
.
_reader
=
"dataset"
def
check_with_place
(
self
,
model_file
,
...
...
python/paddle/fluid/tests/unittests/test_dist_fleet_geo.py
浏览文件 @
7fb817d4
...
...
@@ -19,15 +19,16 @@ import unittest
import
paddle.fluid
as
fluid
import
paddle.fluid.incubate.fleet.base.role_maker
as
role_maker
from
paddle.fluid.incubate.fleet.parameter_server.distribute_transpiler
import
fleet
from
paddle.fluid.transpiler.distribute_transpiler
import
DistributeTranspilerConfig
from
paddle.fluid.transpiler.distribute_transpiler
import
DistributeTranspilerConfig
,
ServerRuntimeConfig
from
paddle.fluid.transpiler.geo_sgd_transpiler
import
GeoSgdTranspiler
from
test_dist_fleet_base
import
TestFleetBase
from
dist_simnet_bow
import
train_network
class
TestDistGeoCtr_2x2
(
TestFleetBase
):
def
_setup_config
(
self
):
self
.
_
sync_mode
=
False
self
.
_
geo_sgd
=
True
self
.
_
mode
=
"geo"
self
.
_
reader
=
"dataset"
self
.
_geo_sgd_need_push_nums
=
5
def
check_with_place
(
self
,
...
...
python/paddle/fluid/tests/unittests/test_distributed_strategy.py
0 → 100644
浏览文件 @
7fb817d4
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
unittest
import
paddle.fluid
as
fluid
from
paddle.fluid.transpiler.distribute_transpiler
import
DistributeTranspilerConfig
,
ServerRuntimeConfig
from
paddle.fluid.incubate.fleet.parameter_server.distribute_transpiler.distributed_strategy
import
TrainerRuntimeConfig
,
StrategyFactory
import
os
class
TestStrategyFactor
(
unittest
.
TestCase
):
def
test_sync_strategy
(
self
):
os
.
environ
[
'CPU_NUM'
]
=
"2"
strategy
=
StrategyFactory
.
create_sync_strategy
()
self
.
assertEqual
(
strategy
.
_program_config
.
sync_mode
,
True
)
self
.
assertEqual
(
strategy
.
_program_config
.
runtime_split_send_recv
,
False
)
self
.
assertEqual
(
strategy
.
_build_strategy
.
async_mode
,
False
)
self
.
assertEqual
(
strategy
.
_execute_strategy
.
num_threads
,
2
)
# test set_program_config using DistributeTranspilerConfig()
program_config_class
=
DistributeTranspilerConfig
()
program_config_class
.
min_block_size
=
81920
strategy
.
set_program_config
(
program_config_class
)
program_config
=
strategy
.
get_program_config
()
self
.
assertEqual
(
program_config
.
min_block_size
,
81920
)
# test set_program_config using dict
program_config_dict
=
dict
()
program_config_dict
[
'min_block_size'
]
=
8192
strategy
.
set_program_config
(
program_config_dict
)
program_config
=
strategy
.
get_program_config
()
self
.
assertEqual
(
program_config
.
min_block_size
,
8192
)
# test set_program_config exception
program_config_dict
[
'unknown'
]
=
None
self
.
assertRaises
(
Exception
,
strategy
.
set_program_config
,
program_config_dict
)
program_config_illegal
=
None
self
.
assertRaises
(
Exception
,
strategy
.
set_program_config
,
program_config_illegal
)
def
test_geo_strategy
(
self
):
strategy
=
StrategyFactory
.
create_geo_strategy
(
5
)
self
.
assertEqual
(
strategy
.
_program_config
.
sync_mode
,
False
)
self
.
assertEqual
(
strategy
.
_program_config
.
runtime_split_send_recv
,
True
)
self
.
assertEqual
(
strategy
.
_program_config
.
geo_sgd_mode
,
True
)
self
.
assertEqual
(
strategy
.
_program_config
.
geo_sgd_need_push_nums
,
5
)
self
.
assertEqual
(
strategy
.
_build_strategy
.
async_mode
,
True
)
# test set_build_strategy using fluid.BuildStrategy
build_strategy_class
=
fluid
.
BuildStrategy
()
build_strategy_class
.
memory_optimize
=
False
strategy
.
set_build_strategy
(
build_strategy_class
)
build_strategy
=
strategy
.
get_build_strategy
()
self
.
assertEqual
(
build_strategy
.
memory_optimize
,
False
)
# test set_build_strategy using dict
build_strategy_dict
=
dict
()
build_strategy_dict
[
'memory_optimize'
]
=
True
strategy
.
set_build_strategy
(
build_strategy_dict
)
build_strategy
=
strategy
.
get_build_strategy
()
self
.
assertEqual
(
build_strategy
.
memory_optimize
,
True
)
# test set_build_strategy exception
build_strategy_dict
[
'unknown'
]
=
None
self
.
assertRaises
(
Exception
,
strategy
.
set_build_strategy
,
build_strategy_dict
)
build_strategy_illegal
=
None
self
.
assertRaises
(
Exception
,
strategy
.
set_build_strategy
,
build_strategy_illegal
)
def
test_async_strategy
(
self
):
strategy
=
StrategyFactory
.
create_async_strategy
()
self
.
assertEqual
(
strategy
.
_program_config
.
sync_mode
,
False
)
self
.
assertEqual
(
strategy
.
_program_config
.
runtime_split_send_recv
,
True
)
self
.
assertEqual
(
strategy
.
_build_strategy
.
async_mode
,
True
)
# test set_trainer_runtime_config using TrainerRuntimeConfig
trainer_runtime_config_class
=
TrainerRuntimeConfig
()
trainer_runtime_config_class
.
send_queue_size
=
50
print
(
trainer_runtime_config_class
)
strategy
.
set_trainer_runtime_config
(
trainer_runtime_config_class
)
trainer_runtime_config
=
strategy
.
get_trainer_runtime_config
()
self
.
assertEqual
(
trainer_runtime_config
.
send_queue_size
,
50
)
# test set_trainer_runtime_config using dict
trainer_runtime_config_dict
=
dict
()
trainer_runtime_config_dict
[
'send_queue_size'
]
=
100
strategy
.
set_trainer_runtime_config
(
trainer_runtime_config_dict
)
trainer_runtime_config
=
strategy
.
get_trainer_runtime_config
()
trainer_communicator_flags
=
trainer_runtime_config
.
get_communicator_flags
(
)
self
.
assertIn
(
'send_queue_size'
,
trainer_communicator_flags
)
self
.
assertEqual
(
trainer_communicator_flags
[
'send_queue_size'
],
100
)
# test set_trainer_runtime_config exception
trainer_runtime_config_dict
[
'unknown'
]
=
None
self
.
assertRaises
(
Exception
,
strategy
.
set_trainer_runtime_config
,
trainer_runtime_config_dict
)
trainer_runtime_config_illegal
=
None
self
.
assertRaises
(
Exception
,
strategy
.
set_trainer_runtime_config
,
trainer_runtime_config_illegal
)
# test set_execute_strategy using fluid.ExecutionStrategy
exec_strategy_class
=
fluid
.
ExecutionStrategy
()
exec_strategy_class
.
num_threads
=
4
strategy
.
set_execute_strategy
(
exec_strategy_class
)
exec_strategy
=
strategy
.
get_execute_strategy
()
self
.
assertEqual
(
exec_strategy
.
num_threads
,
4
)
# test set_execute_strategy using dict
exec_strategy_dict
=
dict
()
exec_strategy_dict
[
'num_threads'
]
=
8
strategy
.
set_execute_strategy
(
exec_strategy_dict
)
exec_strategy
=
strategy
.
get_execute_strategy
()
self
.
assertEqual
(
exec_strategy
.
num_threads
,
8
)
# test set_execute_strategy exception
exec_strategy_dict
[
'unknown'
]
=
None
self
.
assertRaises
(
Exception
,
strategy
.
set_execute_strategy
,
exec_strategy_dict
)
exec_strategy_illegal
=
None
self
.
assertRaises
(
Exception
,
strategy
.
set_execute_strategy
,
exec_strategy_illegal
)
def
test_half_async_strategy
(
self
):
strategy
=
StrategyFactory
.
create_half_async_strategy
()
self
.
assertEqual
(
strategy
.
_program_config
.
sync_mode
,
False
)
self
.
assertEqual
(
strategy
.
_program_config
.
runtime_split_send_recv
,
False
)
self
.
assertEqual
(
strategy
.
_build_strategy
.
async_mode
,
False
)
# test set_server_runtime_config using ServerRuntimeConfig
server_runtime_config_class
=
ServerRuntimeConfig
()
server_runtime_config_class
.
_rpc_send_thread_num
=
24
strategy
.
set_server_runtime_config
(
server_runtime_config_class
)
server_runtime_config
=
strategy
.
get_server_runtime_config
()
self
.
assertEqual
(
server_runtime_config
.
_rpc_send_thread_num
,
24
)
# test set_server_runtime_config using dict
server_runtime_config_dict
=
dict
()
server_runtime_config_dict
[
'_rpc_send_thread_num'
]
=
20
strategy
.
set_server_runtime_config
(
server_runtime_config_dict
)
server_runtime_config
=
strategy
.
get_server_runtime_config
()
self
.
assertEqual
(
server_runtime_config
.
_rpc_send_thread_num
,
20
)
# test set_server_runtime_config exception
server_runtime_config_dict
[
'unknown'
]
=
None
self
.
assertRaises
(
Exception
,
strategy
.
set_server_runtime_config
,
server_runtime_config_dict
)
server_runtime_config_illegal
=
None
self
.
assertRaises
(
Exception
,
strategy
.
set_server_runtime_config
,
server_runtime_config_illegal
)
if
__name__
==
'__main__'
:
unittest
.
main
()
python/paddle/fluid/transpiler/distribute_transpiler.py
浏览文件 @
7fb817d4
...
...
@@ -30,6 +30,7 @@ Steps to transpile pserver:
5. add listen_and_serv op
"""
import
os
import
sys
import
math
from
functools
import
reduce
...
...
@@ -177,8 +178,8 @@ class DistributeTranspilerConfig(object):
print_log
=
False
wait_port
=
True
# split the send recv var in runtime
_runtime_split_send_recv
=
False
_sync_mode
=
True
_
_
runtime_split_send_recv
=
False
_
_
sync_mode
=
True
# Geo-sgd algorithm
geo_sgd_mode
=
False
...
...
@@ -200,31 +201,41 @@ class DistributeTranspilerConfig(object):
@
property
def
runtime_split_send_recv
(
self
):
return
self
.
_runtime_split_send_recv
return
self
.
_
_
runtime_split_send_recv
@
runtime_split_send_recv
.
setter
def
runtime_split_send_recv
(
self
,
value
):
if
value
is
None
:
raise
ValueError
(
"runtime_split_send_recv can't be None"
)
if
value
and
self
.
_sync_mode
:
if
value
and
self
.
_
_
sync_mode
:
raise
ValueError
(
"if you want to set runtime_split_send_recv to be true, make ensure config.sync_mode is false at first"
)
self
.
_runtime_split_send_recv
=
value
self
.
_
_
runtime_split_send_recv
=
value
@
property
def
sync_mode
(
self
):
return
self
.
_sync_mode
return
self
.
_
_
sync_mode
@
sync_mode
.
setter
def
sync_mode
(
self
,
value
):
if
value
is
None
:
raise
ValueError
(
"sync_mode can't be None"
)
if
value
and
self
.
_runtime_split_send_recv
:
if
value
and
self
.
_
_
runtime_split_send_recv
:
raise
ValueError
(
"if you want to set sync_mode to be true, make ensure config.runtime_split_send_recv is false at first"
)
self
.
_sync_mode
=
value
self
.
__sync_mode
=
value
class
ServerRuntimeConfig
(
object
):
def
__init__
(
self
):
self
.
_rpc_send_thread_num
=
int
(
os
.
getenv
(
"FLAGS_rpc_send_thread_num"
,
"12"
))
self
.
_rpc_get_thread_num
=
int
(
os
.
getenv
(
"FLAGS_rpc_get_thread_num"
,
"12"
))
self
.
_rpc_prefetch_thread_num
=
int
(
os
.
getenv
(
"FLAGS_rpc_prefetch_thread_num"
,
"12"
))
class
DistributeTranspiler
(
object
):
...
...
@@ -295,6 +306,7 @@ class DistributeTranspiler(object):
self
.
config
=
config
else
:
self
.
config
=
DistributeTranspilerConfig
()
self
.
_set_server_config
()
if
self
.
config
.
split_method
is
None
:
self
.
config
.
split_method
=
RoundRobin
...
...
@@ -306,6 +318,16 @@ class DistributeTranspiler(object):
assert
(
self
.
config
.
split_method
.
__bases__
[
0
]
==
PSDispatcher
)
self
.
counter_var
=
None
def
_set_server_config
(
self
,
server_config
=
None
):
if
server_config
is
None
:
self
.
server_config
=
ServerRuntimeConfig
()
elif
isinstance
(
server_config
,
ServerRuntimeConfig
):
self
.
server_config
=
server_config
else
:
raise
TypeError
(
"In DistributeTranspiler, server_config must be an instance of ServerRuntimeConfig"
)
def
_transpile_nccl2
(
self
,
trainer_id
,
trainers
,
...
...
@@ -1313,6 +1335,10 @@ class DistributeTranspiler(object):
"grad_to_block_id"
:
grad_to_block_id
,
"sparse_grad_to_param"
:
sparse_grad_to_param
,
"lr_decay_block_id"
:
lr_decay_block_id
,
"rpc_get_thread_num"
:
self
.
server_config
.
_rpc_get_thread_num
,
"rpc_send_thread_num"
:
self
.
server_config
.
_rpc_send_thread_num
,
"rpc_prefetch_thread_num"
:
self
.
server_config
.
_rpc_prefetch_thread_num
}
if
self
.
has_distributed_lookup_table
:
...
...
python/paddle/fluid/transpiler/geo_sgd_transpiler.py
浏览文件 @
7fb817d4
...
...
@@ -38,7 +38,7 @@ from ..framework import Program, default_main_program, \
from
.details
import
wait_server_ready
,
VarsDistributed
from
.details
import
delete_ops
from
..distribute_lookup_table
import
find_distributed_lookup_table
from
.distribute_transpiler
import
DistributeTranspiler
,
DistributeTranspilerConfig
,
slice_variable
,
same_or_split_var
from
.distribute_transpiler
import
DistributeTranspiler
,
DistributeTranspilerConfig
,
slice_variable
,
same_or_split_var
,
ServerRuntimeConfig
RPC_OP_ROLE_ATTR_NAME
=
op_role_attr_name
=
core
.
op_proto_and_checker_maker
.
kOpRoleAttrName
(
)
...
...
@@ -51,6 +51,7 @@ class GeoSgdTranspiler(DistributeTranspiler):
self
.
config
=
config
else
:
self
.
config
=
DistributeTranspilerConfig
()
self
.
_set_server_config
()
if
self
.
config
.
split_method
is
None
:
self
.
config
.
split_method
=
RoundRobin
...
...
@@ -241,7 +242,11 @@ class GeoSgdTranspiler(DistributeTranspiler):
"Fanin"
:
self
.
trainer_num
,
"sync_mode"
:
self
.
sync_mode
,
"grad_to_block_id"
:
param_to_block_id
,
"sparse_grad_to_param"
:
sparse_grad_to_param
"sparse_grad_to_param"
:
sparse_grad_to_param
,
"rpc_get_thread_num"
:
self
.
server_config
.
_rpc_get_thread_num
,
"rpc_send_thread_num"
:
self
.
server_config
.
_rpc_send_thread_num
,
"rpc_prefetch_thread_num"
:
self
.
server_config
.
_rpc_prefetch_thread_num
}
# step5 append the listen_and_serv op
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录