Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
82bc814a
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
82bc814a
编写于
1月 17, 2020
作者:
T
tangwei12
提交者:
GitHub
1月 17, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
integrated HALF_ASYNC to communicator (#21869)
* add half_async in the communicator * fix DistributedStrategy
上级
1e932ecc
变更
30
隐藏空白更改
内联
并排
Showing
30 changed file
with
1029 addition
and
562 deletion
+1029
-562
paddle/fluid/framework/CMakeLists.txt
paddle/fluid/framework/CMakeLists.txt
+1
-1
paddle/fluid/framework/details/async_ssa_graph_executor.cc
paddle/fluid/framework/details/async_ssa_graph_executor.cc
+8
-19
paddle/fluid/framework/device_worker.h
paddle/fluid/framework/device_worker.h
+1
-0
paddle/fluid/framework/hogwild_worker.cc
paddle/fluid/framework/hogwild_worker.cc
+14
-0
paddle/fluid/framework/multi_trainer.cc
paddle/fluid/framework/multi_trainer.cc
+9
-0
paddle/fluid/framework/trainer_desc.proto
paddle/fluid/framework/trainer_desc.proto
+1
-0
paddle/fluid/operators/distributed/communicator.cc
paddle/fluid/operators/distributed/communicator.cc
+321
-168
paddle/fluid/operators/distributed/communicator.h
paddle/fluid/operators/distributed/communicator.h
+112
-102
paddle/fluid/operators/distributed/distributed.h
paddle/fluid/operators/distributed/distributed.h
+1
-0
paddle/fluid/operators/distributed_ops/send_barrier_op.cc
paddle/fluid/operators/distributed_ops/send_barrier_op.cc
+13
-0
paddle/fluid/operators/distributed_ops/send_op.cc
paddle/fluid/operators/distributed_ops/send_op.cc
+1
-6
paddle/fluid/pybind/communicator_py.cc
paddle/fluid/pybind/communicator_py.cc
+19
-20
python/paddle/fluid/__init__.py
python/paddle/fluid/__init__.py
+0
-11
python/paddle/fluid/communicator.py
python/paddle/fluid/communicator.py
+33
-20
python/paddle/fluid/executor.py
python/paddle/fluid/executor.py
+1
-0
python/paddle/fluid/incubate/fleet/parameter_server/distribute_transpiler/__init__.py
.../fleet/parameter_server/distribute_transpiler/__init__.py
+49
-27
python/paddle/fluid/incubate/fleet/parameter_server/distribute_transpiler/distributed_strategy.py
...eter_server/distribute_transpiler/distributed_strategy.py
+36
-33
python/paddle/fluid/tests/CMakeLists.txt
python/paddle/fluid/tests/CMakeLists.txt
+0
-4
python/paddle/fluid/tests/unittests/CMakeLists.txt
python/paddle/fluid/tests/unittests/CMakeLists.txt
+7
-0
python/paddle/fluid/tests/unittests/dist_fleet_ctr.py
python/paddle/fluid/tests/unittests/dist_fleet_ctr.py
+9
-5
python/paddle/fluid/tests/unittests/test_communicator_async.py
...n/paddle/fluid/tests/unittests/test_communicator_async.py
+8
-13
python/paddle/fluid/tests/unittests/test_communicator_geo.py
python/paddle/fluid/tests/unittests/test_communicator_geo.py
+83
-0
python/paddle/fluid/tests/unittests/test_communicator_half_async.py
...dle/fluid/tests/unittests/test_communicator_half_async.py
+177
-0
python/paddle/fluid/tests/unittests/test_dist_ctr.py
python/paddle/fluid/tests/unittests/test_dist_ctr.py
+3
-1
python/paddle/fluid/tests/unittests/test_dist_fleet_base.py
python/paddle/fluid/tests/unittests/test_dist_fleet_base.py
+30
-90
python/paddle/fluid/tests/unittests/test_dist_fleet_ctr.py
python/paddle/fluid/tests/unittests/test_dist_fleet_ctr.py
+15
-10
python/paddle/fluid/tests/unittests/test_distributed_strategy.py
...paddle/fluid/tests/unittests/test_distributed_strategy.py
+6
-5
python/paddle/fluid/tests/unittests/test_fleet_api_input.py
python/paddle/fluid/tests/unittests/test_fleet_api_input.py
+25
-3
python/paddle/fluid/trainer_desc.py
python/paddle/fluid/trainer_desc.py
+3
-0
python/paddle/fluid/transpiler/distribute_transpiler.py
python/paddle/fluid/transpiler/distribute_transpiler.py
+43
-24
未找到文件。
paddle/fluid/framework/CMakeLists.txt
浏览文件 @
82bc814a
...
...
@@ -192,7 +192,7 @@ if(WITH_DISTRIBUTE)
data_feed.cc device_worker.cc hogwild_worker.cc downpour_worker.cc
pull_dense_worker.cc section_worker.cc device_worker_factory.cc data_set.cc DEPS op_registry
device_context scope framework_proto trainer_desc_proto glog fs shell fleet_wrapper lodtensor_printer
lod_rank_table feed_fetch_method sendrecvop_rpc collective_helper
${
GLOB_DISTRIBUTE_DEPS
}
lod_rank_table feed_fetch_method sendrecvop_rpc co
mmunicator co
llective_helper
${
GLOB_DISTRIBUTE_DEPS
}
graph_to_program_pass variable_helper data_feed_proto
${
NGRAPH_EXE_DEPS
}
timer
)
set
(
DISTRIBUTE_COMPILE_FLAGS
"-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor"
)
set_source_files_properties
(
executor.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
...
...
paddle/fluid/framework/details/async_ssa_graph_executor.cc
浏览文件 @
82bc814a
...
...
@@ -48,7 +48,7 @@ void ProcessGraph(std::vector<ir::Graph *> graphs, Scope *scope) {
using
RpcCtxMap
=
operators
::
distributed
::
RpcCtxMap
;
VLOG
(
3
)
<<
"ProcessGraph"
;
RpcCtxMap
send_varname_to_ctx
;
RpcCtxMap
recv_varname_to_ctx
;
for
(
auto
&
node
:
graphs
[
0
]
->
Nodes
())
{
VLOG
(
3
)
<<
"node name "
<<
node
->
Name
();
if
(
node
&&
node
->
IsOp
())
{
...
...
@@ -74,30 +74,19 @@ void ProcessGraph(std::vector<ir::Graph *> graphs, Scope *scope) {
merge_add
,
use_send_handler
);
VLOG
(
3
)
<<
"find and init an send op: "
<<
send_varname_to_ctx
[
send_var_name
];
}
else
if
(
node
->
Name
()
==
"recv"
)
{
auto
recv_var_name
=
node
->
Op
()
->
Output
(
"Out"
)[
0
];
auto
recv_varnames
=
boost
::
get
<
std
::
vector
<
std
::
string
>>
(
node
->
Op
()
->
GetNullableAttr
(
"recv_varnames"
));
auto
epmap
=
boost
::
get
<
std
::
vector
<
std
::
string
>>
(
node
->
Op
()
->
GetNullableAttr
(
"epmap"
));
auto
trainer_id
=
boost
::
get
<
int
>
(
node
->
Op
()
->
GetNullableAttr
(
"trainer_id"
));
recv_varname_to_ctx
[
recv_var_name
]
=
operators
::
distributed
::
RpcContext
(
recv_var_name
,
recv_varnames
,
epmap
,
{},
trainer_id
);
VLOG
(
3
)
<<
"find and remove an recv op: "
<<
recv_varname_to_ctx
[
recv_var_name
];
}
}
}
// init communicator here
if
(
send_varname_to_ctx
.
size
()
>
0
)
{
VLOG
(
3
)
<<
"this is distribute mode, will use communicator"
;
auto
*
instance
=
operators
::
distributed
::
Communicator
::
InitInstance
<
operators
::
distributed
::
AsyncCommunicator
>
(
send_varname_to_ctx
,
recv_varname_to_ctx
,
scope
);
if
(
!
instance
->
IsRunning
())
instance
->
Start
();
auto
*
instance
=
operators
::
distributed
::
Communicator
::
GetInstance
();
auto
initialized
=
instance
?
true
:
false
;
PADDLE_ENFORCE_EQ
(
initialized
,
true
,
platform
::
errors
::
InvalidArgument
(
"Communicator is not Initialized, you may use "
"FleetAPI(https://github.com/PaddlePaddle/Fleet/tree/"
"develop/markdown_doc/transpiler)"
));
}
#endif
}
...
...
paddle/fluid/framework/device_worker.h
浏览文件 @
82bc814a
...
...
@@ -179,6 +179,7 @@ class HogwildWorker : public CPUWorkerBase {
void
CreateThreadScope
(
const
ProgramDesc
&
program
);
std
::
vector
<
std
::
string
>
op_names_
;
std
::
vector
<
OperatorBase
*>
ops_
;
bool
thread_barrier_
;
// Scope* thread_scope_;
HogwildWorkerParameter
param_
;
std
::
vector
<
std
::
string
>
skip_ops_
;
...
...
paddle/fluid/framework/hogwild_worker.cc
浏览文件 @
82bc814a
...
...
@@ -15,6 +15,7 @@ limitations under the License. */
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/device_worker.h"
#include "paddle/fluid/framework/device_worker_factory.h"
#include "paddle/fluid/operators/distributed/distributed.h"
#include "paddle/fluid/platform/cpu_helper.h"
#include "paddle/fluid/platform/lodtensor_printer.h"
...
...
@@ -29,6 +30,7 @@ void HogwildWorker::Initialize(const TrainerDesc &desc) {
skip_ops_
[
i
]
=
param_
.
skip_ops
(
i
);
}
use_cvm_
=
desc
.
use_cvm
();
thread_barrier_
=
desc
.
thread_barrier
();
}
void
HogwildWorker
::
CreateThreadOperators
(
const
ProgramDesc
&
program
)
{
...
...
@@ -158,6 +160,12 @@ void HogwildWorker::TrainFilesWithProfiler() {
thread_scope_
->
DropKids
();
timeline
.
Start
();
}
#ifdef PADDLE_WITH_DISTRIBUTE
if
(
thread_barrier_
)
{
operators
::
distributed
::
Communicator
::
GetInstance
()
->
BarrierTriggerDecrement
();
}
#endif
}
void
HogwildWorker
::
TrainFiles
()
{
...
...
@@ -183,6 +191,12 @@ void HogwildWorker::TrainFiles() {
PrintFetchVars
();
thread_scope_
->
DropKids
();
}
#ifdef PADDLE_WITH_DISTRIBUTE
if
(
thread_barrier_
)
{
operators
::
distributed
::
Communicator
::
GetInstance
()
->
BarrierTriggerDecrement
();
}
#endif
}
void
HogwildWorker
::
PrintFetchVars
()
{
...
...
paddle/fluid/framework/multi_trainer.cc
浏览文件 @
82bc814a
...
...
@@ -17,6 +17,7 @@ limitations under the License. */
#include "paddle/fluid/framework/data_feed_factory.h"
#include "paddle/fluid/framework/device_worker_factory.h"
#include "paddle/fluid/framework/trainer.h"
#include "paddle/fluid/operators/distributed/distributed.h"
namespace
paddle
{
namespace
framework
{
...
...
@@ -38,6 +39,14 @@ void MultiTrainer::Initialize(const TrainerDesc& trainer_desc,
thread_num_
=
readers
.
size
();
VLOG
(
3
)
<<
"worker thread num: "
<<
thread_num_
;
workers_
.
resize
(
thread_num_
);
#ifdef PADDLE_WITH_DISTRIBUTE
if
(
trainer_desc
.
thread_barrier
())
{
operators
::
distributed
::
Communicator
::
GetInstance
()
->
BarrierTriggerReset
(
thread_num_
);
}
#endif
for
(
int
i
=
0
;
i
<
thread_num_
;
++
i
)
{
workers_
[
i
]
=
DeviceWorkerFactory
::
CreateDeviceWorker
(
trainer_desc
.
device_worker_name
());
...
...
paddle/fluid/framework/trainer_desc.proto
浏览文件 @
82bc814a
...
...
@@ -47,6 +47,7 @@ message TrainerDesc {
// adjust ins weight
optional
AdjustInsWeightConfig
adjust_ins_weight_config
=
20
;
optional
bool
no_cvm
=
21
[
default
=
false
];
optional
bool
thread_barrier
=
22
;
// device worker parameters
optional
HogwildWorkerParameter
hogwild_param
=
101
;
...
...
paddle/fluid/operators/distributed/communicator.cc
浏览文件 @
82bc814a
...
...
@@ -27,29 +27,16 @@ limitations under the License. */
#include "paddle/fluid/operators/distributed/distributed.h"
#include "paddle/fluid/operators/distributed/parameter_recv.h"
#include "paddle/fluid/operators/distributed/parameter_send.h"
DECLARE_int32
(
communicator_max_merge_var_num
);
DECLARE_int32
(
communicator_send_queue_size
);
DEFINE_bool
(
communicator_independent_recv_thread
,
true
,
"use an independent to recv vars from parameter server"
);
DEFINE_int32
(
communicator_min_send_grad_num_before_recv
,
20
,
"max grad num to send before recv parameters"
);
DEFINE_int32
(
communicator_thread_pool_size
,
5
,
"thread num to do send or recv"
);
DEFINE_int32
(
communicator_send_wait_times
,
5
,
"times that send thread will wait if merge num does not reach "
"max_merge_var_num"
);
DEFINE_bool
(
communicator_fake_rpc
,
false
,
"fake mode does not really send any thing"
);
DEFINE_bool
(
communicator_merge_sparse_grad
,
true
,
"merge sparse gradient before sending"
);
DEFINE_int32
(
communicator_merge_sparse_bucket
,
2000
,
"number of threads for sparse var"
);
#include "paddle/fluid/string/split.h"
namespace
paddle
{
namespace
operators
{
namespace
distributed
{
using
Tree
=
std
::
map
<
std
::
string
,
std
::
map
<
std
::
string
,
std
::
vector
<
std
::
string
>>>
;
using
RpcCtxMap
=
operators
::
distributed
::
RpcCtxMap
;
inline
double
GetCurrentUS
()
{
struct
timeval
time
;
gettimeofday
(
&
time
,
NULL
);
...
...
@@ -63,41 +50,12 @@ inline void VSUB(int n, const T *x, const T *y, T *z) {
}
}
void
Communicator
::
SetEnvFlagsDefault
()
{
env_flags_dict
.
clear
();
env_flags_dict
.
insert
(
std
::
pair
<
std
::
string
,
int
>
(
"independent_recv_thread"
,
FLAGS_communicator_independent_recv_thread
));
env_flags_dict
.
insert
(
std
::
pair
<
std
::
string
,
int
>
(
"send_queue_size"
,
FLAGS_communicator_send_queue_size
));
env_flags_dict
.
insert
(
std
::
pair
<
std
::
string
,
int
>
(
"min_send_grad_num_before_recv"
,
FLAGS_communicator_min_send_grad_num_before_recv
));
env_flags_dict
.
insert
(
std
::
pair
<
std
::
string
,
int
>
(
"thread_pool_size"
,
FLAGS_communicator_thread_pool_size
));
env_flags_dict
.
insert
(
std
::
pair
<
std
::
string
,
int
>
(
"send_wait_times"
,
FLAGS_communicator_send_wait_times
));
env_flags_dict
.
insert
(
std
::
pair
<
std
::
string
,
int
>
(
"max_merge_var_num"
,
FLAGS_communicator_max_merge_var_num
));
env_flags_dict
.
insert
(
std
::
pair
<
std
::
string
,
int
>
(
"fake_rpc"
,
FLAGS_communicator_fake_rpc
));
env_flags_dict
.
insert
(
std
::
pair
<
std
::
string
,
int
>
(
"merge_sparse_grad"
,
FLAGS_communicator_merge_sparse_grad
));
env_flags_dict
.
insert
(
std
::
pair
<
std
::
string
,
int
>
(
"is_sgd_optimizer"
,
FLAGS_communicator_is_sgd_optimizer
));
return
;
}
Communicator
::
Communicator
()
{
SetEnvFlagsDefault
();
}
Communicator
::
Communicator
()
{}
Communicator
::
Communicator
(
const
std
::
map
<
std
::
string
,
int
>
&
env_flags
)
{
SetEnvFlagsDefault
();
for
(
auto
&
iter
:
env_flags
)
{
std
::
string
flag_name
=
iter
.
first
;
int
val_
=
iter
.
second
;
env_flags_dict
.
at
(
flag_name
)
=
val_
;
Communicator
::
Communicator
(
const
std
::
map
<
std
::
string
,
std
::
string
>
&
envs_
)
{
for
(
auto
&
iter
:
envs_
)
{
envs
[
iter
.
first
]
=
iter
.
second
;
}
return
;
}
std
::
once_flag
Communicator
::
init_flag_
;
...
...
@@ -106,6 +64,7 @@ std::shared_ptr<Communicator> Communicator::communicator_(nullptr);
void
AsyncCommunicator
::
InitImpl
(
const
RpcCtxMap
&
send_varname_to_ctx
,
const
RpcCtxMap
&
recv_varname_to_ctx
,
Scope
*
recv_scope
)
{
VLOG
(
0
)
<<
"AsyncCommunicator Initializing"
;
send_varname_to_ctx_
=
std
::
move
(
send_varname_to_ctx
);
recv_varname_to_ctx_
=
std
::
move
(
recv_varname_to_ctx
);
recv_scope_
=
std
::
move
(
recv_scope
);
...
...
@@ -117,24 +76,21 @@ void AsyncCommunicator::InitImpl(const RpcCtxMap &send_varname_to_ctx,
for
(
auto
&
iter
:
send_varname_to_ctx_
)
{
send_varname_to_queue_
[
iter
.
first
]
=
std
::
make_shared
<
BlockingQueue
<
std
::
shared_ptr
<
Variable
>>>
(
env_flags_dict
[
"send_queue_size"
]
);
send_queue_size_
);
}
send_threadpool_
.
reset
(
new
::
ThreadPool
(
env_flags_dict
[
"thread_pool_size"
]));
send_threadpool_
.
reset
(
new
::
ThreadPool
(
thread_pool_size_
));
}
if
(
recv_varname_to_ctx
.
size
()
==
0
)
{
VLOG
(
0
)
<<
"nothing need to be received, will not start recv_thread"
;
}
else
{
recv_threadpool_
.
reset
(
new
::
ThreadPool
(
env_flags_dict
[
"thread_pool_size"
]));
recv_threadpool_
.
reset
(
new
::
ThreadPool
(
thread_pool_size_
));
}
}
void
AsyncCommunicator
::
InitImpl
(
const
paddle
::
framework
::
ProgramDesc
&
program
,
Scope
*
param_scope
)
{
using
RpcCtxMap
=
operators
::
distributed
::
RpcCtxMap
;
VLOG
(
3
)
<<
"ProcessGraph"
;
VLOG
(
0
)
<<
"AsyncCommunicator Initializing"
;
RpcCtxMap
send_varname_to_ctx
;
RpcCtxMap
recv_varname_to_ctx
;
for
(
auto
*
op
:
program
.
Block
(
0
).
AllOps
())
{
...
...
@@ -150,7 +106,7 @@ void AsyncCommunicator::InitImpl(const paddle::framework::ProgramDesc &program,
auto
trainer_id
=
boost
::
get
<
int
>
(
op
->
GetNullableAttr
(
"trainer_id"
));
auto
merge_add
=
boost
::
get
<
bool
>
(
op
->
GetNullableAttr
(
"merge_add"
));
if
(
!
merge_add
)
{
merge_add
=
static_cast
<
bool
>
(
env_flags_dict
[
"is_sgd_optimizer"
])
;
merge_add
=
is_sgd_optimizer_
;
}
auto
use_send_handler
=
boost
::
get
<
bool
>
(
op
->
GetNullableAttr
(
"use_send_handler"
));
...
...
@@ -161,7 +117,9 @@ void AsyncCommunicator::InitImpl(const paddle::framework::ProgramDesc &program,
<<
send_varname_to_ctx
[
send_var_name
];
}
else
if
(
op
->
Type
()
==
"recv"
)
{
auto
do_not_run
=
boost
::
get
<
int
>
(
op
->
GetNullableAttr
(
"do_not_run"
));
PADDLE_ENFORCE_GT
(
do_not_run
,
0
,
"recv should not run!"
);
PADDLE_ENFORCE_GT
(
do_not_run
,
0
,
platform
::
errors
::
InvalidArgument
(
"recv op's attr `do_not_run` must be True!"
));
auto
recv_var_name
=
op
->
Output
(
"Out"
)[
0
];
auto
recv_varnames
=
boost
::
get
<
std
::
vector
<
std
::
string
>>
(
op
->
GetNullableAttr
(
"recv_varnames"
));
...
...
@@ -183,17 +141,9 @@ void AsyncCommunicator::InitImpl(const paddle::framework::ProgramDesc &program,
}
AsyncCommunicator
::~
AsyncCommunicator
()
{
if
(
FLAGS_v
>=
3
)
{
std
::
string
msg
(
"~Communicator"
);
fwrite
(
msg
.
c_str
(),
msg
.
length
(),
1
,
stdout
);
}
running_
=
false
;
if
(
send_thread_
)
send_thread_
->
join
();
if
(
recv_thread_
)
recv_thread_
->
join
();
if
(
FLAGS_v
>=
3
)
{
std
::
string
msg
(
"~Communicator done"
);
fwrite
(
msg
.
c_str
(),
msg
.
length
(),
1
,
stdout
);
}
}
void
AsyncCommunicator
::
SendThread
()
{
...
...
@@ -212,10 +162,10 @@ void AsyncCommunicator::SendThread() {
std
::
vector
<
std
::
shared_ptr
<
Variable
>>
vars
;
int
merged_var_num
=
0
;
int
wait_times
=
0
;
while
(
merged_var_num
<
env_flags_dict
[
"max_merge_var_num"
]
)
{
while
(
merged_var_num
<
max_merge_var_num_
)
{
if
(
var_queue
->
Size
()
==
0
)
{
VLOG
(
4
)
<<
"wait_times -> "
<<
wait_times
;
if
(
wait_times
>=
env_flags_dict
[
"send_wait_times"
]
)
{
if
(
wait_times
>=
send_wait_times_
)
{
break
;
}
std
::
this_thread
::
sleep_for
(
std
::
chrono
::
milliseconds
(
10
));
...
...
@@ -244,9 +194,7 @@ void AsyncCommunicator::SendThread() {
VLOG
(
4
)
<<
"merge "
<<
merged_var_num
<<
" "
<<
var_name
<<
" use time "
<<
after_merge
-
before_merge
;
auto
send_functor
=
distributed
::
ParameterSend
<
float
>
();
if
(
!
env_flags_dict
[
"fake_rpc"
])
{
send_functor
(
ctx
,
*
send_scope_
,
true
,
1
);
}
send_functor
(
ctx
,
*
send_scope_
,
true
,
1
);
auto
after_send
=
GetCurrentUS
();
VLOG
(
4
)
<<
"send "
<<
var_name
<<
" use time "
<<
after_send
-
after_merge
;
...
...
@@ -273,7 +221,7 @@ void AsyncCommunicator::RecvThread() {
VLOG
(
3
)
<<
"RecvThread start!"
;
while
(
running_
)
{
int
grad_num
=
grad_num_
.
load
();
if
(
grad_num
>
env_flags_dict
[
"min_send_grad_num_before_recv"
]
)
{
if
(
grad_num
>
min_send_grad_num_before_recv_
)
{
VLOG
(
1
)
<<
"current grad num "
<<
grad_num
;
RecvAll
();
grad_num_
.
store
(
0
);
...
...
@@ -284,30 +232,8 @@ void AsyncCommunicator::RecvThread() {
VLOG
(
0
)
<<
"communicator stopped, recv thread exit"
;
}
void
AsyncCommunicator
::
Send
(
const
std
::
string
&
var_name
,
const
framework
::
Scope
&
scope
)
{
VLOG
(
3
)
<<
"communicator send "
<<
var_name
;
// push var into send queue by var_name
auto
*
grad_var
=
scope
.
FindVar
(
var_name
);
PADDLE_ENFORCE
(
grad_var
->
IsInitialized
(),
"grad var should be inited"
);
if
(
grad_var
->
IsType
<
framework
::
SelectedRows
>
()
&&
!
env_flags_dict
[
"merge_sparse_grad"
])
{
auto
send_functor
=
distributed
::
ParameterSend
<
float
>
();
auto
&
ctx
=
send_varname_to_ctx_
.
at
(
var_name
);
if
(
!
env_flags_dict
[
"fake_rpc"
])
{
send_functor
(
ctx
,
scope
,
true
,
1
);
}
}
else
{
auto
tmp_grad_var
=
std
::
make_shared
<
Variable
>
();
framework
::
CopyVariable
(
*
grad_var
,
tmp_grad_var
.
get
());
auto
&
queue
=
send_varname_to_queue_
.
at
(
var_name
);
VLOG
(
3
)
<<
"send "
<<
var_name
<<
" queue size "
<<
queue
->
Size
();
queue
->
Push
(
tmp_grad_var
);
}
}
void
AsyncCommunicator
::
Recv
()
{
if
(
env_flags_dict
[
"independent_recv_thread"
]
)
{
if
(
independent_recv_thread_
)
{
return
;
}
...
...
@@ -331,9 +257,7 @@ void AsyncCommunicator::RecvAll() {
auto
&
var_name
=
iter
.
first
;
VLOG
(
4
)
<<
"recv var "
<<
var_name
;
auto
recv_functor
=
distributed
::
ParameterRecv
<
float
>
();
if
(
!
env_flags_dict
[
"fake_rpc"
])
{
recv_functor
(
iter
.
second
,
*
recv_scope_
);
}
recv_functor
(
iter
.
second
,
*
recv_scope_
);
};
task_futures
.
emplace_back
(
recv_threadpool_
->
enqueue
(
std
::
move
(
recv_task
)));
}
...
...
@@ -354,7 +278,7 @@ void AsyncCommunicator::Start() {
// start send and recv thread
send_thread_
.
reset
(
new
std
::
thread
(
std
::
bind
(
&
AsyncCommunicator
::
SendThread
,
this
)));
if
(
env_flags_dict
[
"independent_recv_thread"
]
)
{
if
(
independent_recv_thread_
)
{
recv_thread_
.
reset
(
new
std
::
thread
(
std
::
bind
(
&
AsyncCommunicator
::
RecvThread
,
this
)));
}
...
...
@@ -381,71 +305,65 @@ void AsyncCommunicator::Stop() {
VLOG
(
0
)
<<
"Communicator stop done"
;
}
void
AsyncCommunicator
::
Send
(
const
std
::
vector
<
std
::
string
>
&
sparse_var_names
,
const
std
::
vector
<
std
::
string
>
&
sparse_var_tables
,
const
framework
::
Scope
&
scope
)
{}
void
AsyncCommunicator
::
InitImpl
(
const
paddle
::
framework
::
ProgramDesc
&
program
,
Scope
*
param_scope
,
std
::
map
<
std
::
string
,
std
::
map
<
std
::
string
,
std
::
vector
<
std
::
string
>>>
&
vars_info
,
const
int
&
trainers
,
const
int
&
geo_need_push_nums
)
{}
void
AsyncCommunicator
::
Send
(
const
std
::
vector
<
std
::
string
>
&
var_names
,
const
std
::
vector
<
std
::
string
>
&
var_tables
,
const
framework
::
Scope
&
scope
)
{
PADDLE_ENFORCE_EQ
(
var_names
.
size
(),
1
,
platform
::
errors
::
InvalidArgument
(
"var_names.size() == 1 is permitted"
));
auto
var_name
=
var_names
[
0
];
// push var into send queue by var_name
auto
*
grad_var
=
scope
.
FindVar
(
var_name
);
PADDLE_ENFORCE_EQ
(
grad_var
->
IsInitialized
(),
true
,
platform
::
errors
::
InvalidArgument
(
"grad var should be inited"
));
auto
tmp_grad_var
=
std
::
make_shared
<
Variable
>
();
framework
::
CopyVariable
(
*
grad_var
,
tmp_grad_var
.
get
());
auto
&
queue
=
send_varname_to_queue_
.
at
(
var_name
);
VLOG
(
3
)
<<
"send "
<<
var_name
<<
" queue size "
<<
queue
->
Size
();
queue
->
Push
(
tmp_grad_var
);
}
GeoSgdCommunicator
::~
GeoSgdCommunicator
()
{
if
(
FLAGS_v
>=
3
)
{
std
::
string
msg
(
"~Geo Sgd Communicator"
);
fwrite
(
msg
.
c_str
(),
msg
.
length
(),
1
,
stdout
);
}
running_
=
false
;
if
(
send_thread_
)
send_thread_
->
join
();
if
(
FLAGS_v
>=
3
)
{
std
::
string
msg
(
"~Geo Sgd Communicator done"
);
fwrite
(
msg
.
c_str
(),
msg
.
length
(),
1
,
stdout
);
}
}
void
GeoSgdCommunicator
::
InitImpl
(
const
paddle
::
framework
::
ProgramDesc
&
program
,
Scope
*
training_scope
,
std
::
map
<
std
::
string
,
std
::
map
<
std
::
string
,
std
::
vector
<
std
::
string
>>>
&
vars_info
,
const
int
&
trainers
,
const
int
&
geo_need_push_nums
)
{
training_scope_
=
std
::
move
(
training_scope
);
trainer_nums_
=
std
::
move
(
trainers
)
;
geo_need_push_nums_
=
std
::
move
(
geo_need_push_nums
);
// get all send information from graph, build vars_to_send
VLOG
(
0
)
<<
"Trainer nums: "
<<
trainer_nums_
;
VLOG
(
0
)
<<
"geo_sgd_push_before_local_train_nums: "
<<
geo_need_push_nums_
;
// process var info from transpiler
for
(
auto
&
iter
:
vars_info
)
{
// change var name in delta scope: "var" -> "var.delta"
std
::
string
var_name
=
iter
.
first
;
void
GeoSgdCommunicator
::
InitImpl
(
const
paddle
::
framework
::
ProgramDesc
&
program
,
Scope
*
recv_scope
)
{
VLOG
(
0
)
<<
"GeoCommunicator Initializing"
;
training_scope_
=
std
::
move
(
recv_scope
);
auto
geo_send_varnames
=
envs
[
"geo_send_varnames"
]
;
auto
varnames
=
paddle
::
string
::
Split
(
geo_send_varnames
,
'#'
);
for
(
auto
&
var_name
:
varnames
)
{
auto
var_attr_str
=
envs
.
at
(
var_name
)
;
auto
var_attrs
=
paddle
::
string
::
Split
(
var_attr_str
,
'#'
)
;
auto
split_varnames
=
paddle
::
string
::
Split
(
var_attrs
[
0
],
'&'
);
auto
sections
=
paddle
::
string
::
Split
(
var_attrs
[
1
],
'&'
);
auto
endpoints
=
paddle
::
string
::
Split
(
var_attrs
[
2
],
'&'
);
bool
is_sparse
=
static_cast
<
bool
>
(
std
::
stoi
(
var_attrs
[
3
]));
std
::
string
send_var_name
=
VarToDeltaVar
(
var_name
);
std
::
vector
<
std
::
string
>
vars_names
=
iter
.
second
[
"var_names"
];
std
::
vector
<
std
::
string
>
send_var_names
;
for
(
auto
origin_var_name
:
vars_
names
)
{
for
(
auto
origin_var_name
:
split_var
names
)
{
send_var_names
.
push_back
(
VarToDeltaVar
(
origin_var_name
));
}
// get vars section for split
std
::
vector
<
std
::
string
>
vars_sections_str
=
iter
.
second
[
"sections"
];
std
::
vector
<
int64_t
>
vars_sections_int
=
{};
for
(
std
::
string
str
:
vars_sections_str
)
{
for
(
std
::
string
str
:
sections
)
{
int64_t
str2i
=
std
::
stol
(
str
.
c_str
());
vars_sections_int
.
push_back
(
str2i
);
}
std
::
vector
<
std
::
string
>
vars_epmap
=
iter
.
second
[
"epmap"
];
// record var is sparse or not
bool
is_sparse
=
iter
.
second
[
"is_sparse"
].
front
()
==
std
::
string
(
"True"
);
var_list_
[
var_name
]
=
is_sparse
;
send_varname_to_ctx_
[
send_var_name
]
=
operators
::
distributed
::
RpcContext
(
send_var_name
,
send_var_names
,
vars_epmap
,
vars_sections_int
,
0
);
send_var_name
,
send_var_names
,
endpoints
,
vars_sections_int
,
0
);
recv_varname_to_ctx_
[
var_name
]
=
operators
::
distributed
::
RpcContext
(
var_name
,
vars_names
,
vars_epmap
,
vars_sections_int
,
0
);
var_name
,
split_varnames
,
endpoints
,
vars_sections_int
,
0
);
absolute_section_
[
var_name
]
=
operators
::
ToAbsoluteSection
(
send_varname_to_ctx_
[
send_var_name
].
height_sections
);
...
...
@@ -454,18 +372,17 @@ void GeoSgdCommunicator::InitImpl(
for
(
int64_t
section
:
vars_sections_int
)
{
vars_first_dimension_
[
var_name
]
+=
section
;
}
send_var_nums_
+=
vars_names
.
size
();
send_var_nums_
+=
split_varnames
.
size
();
}
if
(
send_varname_to_ctx_
.
size
()
==
0
&&
recv_varname_to_ctx_
.
size
()
==
0
)
{
LOG
(
WARNING
)
<<
"no var need to send and recv!!"
;
}
send_threadpool_
.
reset
(
new
::
ThreadPool
(
env_flags_dict
[
"thread_pool_size"
]
));
send_threadpool_
.
reset
(
new
::
ThreadPool
(
thread_pool_size_
));
need_push_queue_
=
std
::
make_shared
<
BlockingQueue
<
std
::
shared_ptr
<
SparseIdsMap
>>>
(
geo_need_push_nums
);
geo_need_push_nums
_
);
delta_scope_
.
reset
(
new
Scope
());
old_scope_
.
reset
(
new
Scope
());
pserver_scope_
.
reset
(
new
Scope
());
...
...
@@ -499,11 +416,10 @@ void GeoSgdCommunicator::Stop() {
VLOG
(
0
)
<<
"Geo Sgd Communicator stop done"
;
}
void
GeoSgdCommunicator
::
Send
(
const
std
::
string
&
var_name
,
void
GeoSgdCommunicator
::
Send
(
const
std
::
vector
<
std
::
string
>
&
sparse_var_names
,
const
std
::
vector
<
std
::
string
>
&
sparse_var_tables
,
const
framework
::
Scope
&
scope
)
{
// when execute trainer startup program, recv parameter from pserver
// training_scope & pserver_scope param will copy it
if
(
var_name
==
"param_init"
)
{
if
(
sparse_var_names
.
size
()
==
1
&&
sparse_var_names
[
0
]
==
"param_init"
)
{
for
(
auto
&
iter
:
var_list_
)
{
// For sparse param, old_scope store LoDTensor,
// pserver_scope store SelectedRows.
...
...
@@ -518,12 +434,7 @@ void GeoSgdCommunicator::Send(const std::string &var_name,
GeoSgdDenseParamInit
(
training_scope_
,
old_scope_
.
get
(),
local_var_name
);
}
}
}
void
GeoSgdCommunicator
::
Send
(
const
std
::
vector
<
std
::
string
>
&
sparse_var_names
,
const
std
::
vector
<
std
::
string
>
&
sparse_var_tables
,
const
framework
::
Scope
&
scope
)
{
// SparseIdsMap = std::unordered_map<std::string,std::unordered_set<int64_t>>
std
::
shared_ptr
<
SparseIdsMap
>
ids_table
=
std
::
make_shared
<
SparseIdsMap
>
();
auto
before_run_send
=
GetCurrentUS
();
for
(
size_t
i
=
0
;
i
<
sparse_var_tables
.
size
();
i
++
)
{
...
...
@@ -563,7 +474,7 @@ void GeoSgdCommunicator::SendThread() {
task_futures
.
reserve
(
send_var_nums_
);
int
wait_times
=
0
;
while
(
ids_send_vec_
.
size
()
<
geo_need_push_nums_
)
{
while
(
ids_send_vec_
.
size
()
<
static_cast
<
size_t
>
(
geo_need_push_nums_
)
)
{
VLOG
(
4
)
<<
"ids_send_vec_ Size: "
<<
ids_send_vec_
.
size
();
if
(
need_push_queue_
->
Size
()
>
0
)
{
wait_times
=
0
;
...
...
@@ -571,7 +482,7 @@ void GeoSgdCommunicator::SendThread() {
VLOG
(
4
)
<<
"ids_send_vec_ pushed"
;
}
else
if
(
need_push_queue_
->
Size
()
==
0
)
{
VLOG
(
4
)
<<
"wait_times -> "
<<
wait_times
;
if
(
wait_times
>=
env_flags_dict
[
"send_wait_times"
]
)
{
if
(
wait_times
>=
send_wait_times_
)
{
break
;
}
std
::
this_thread
::
sleep_for
(
std
::
chrono
::
milliseconds
(
10
));
...
...
@@ -580,7 +491,7 @@ void GeoSgdCommunicator::SendThread() {
}
}
if
(
ids_send_vec_
.
size
()
>=
geo_need_push_nums_
)
{
if
(
ids_send_vec_
.
size
()
>=
static_cast
<
size_t
>
(
geo_need_push_nums_
)
)
{
auto
after_run_training
=
GetCurrentUS
();
VLOG
(
4
)
<<
"run Training use time "
<<
after_run_training
-
before_run_training
;
...
...
@@ -1039,12 +950,254 @@ void GeoSgdCommunicator::RpcRecv(const std::string &var_name,
void
GeoSgdCommunicator
::
Recv
()
{}
void
GeoSgdCommunicator
::
InitImpl
(
const
RpcCtxMap
&
send_varname_to_ctx
,
const
RpcCtxMap
&
recv_varname_to_ctx
,
Scope
*
recv_scope
)
{}
void
HalfAsyncCommunicator
::
InitImpl
(
const
RpcCtxMap
&
send_varname_to_ctx
,
const
RpcCtxMap
&
recv_varname_to_ctx
,
Scope
*
recv_scope
)
{
VLOG
(
0
)
<<
"HalfAsyncCommunicator Initializing"
;
send_varname_to_ctx_
=
std
::
move
(
send_varname_to_ctx
);
recv_varname_to_ctx_
=
std
::
move
(
recv_varname_to_ctx
);
recv_scope_
=
std
::
move
(
recv_scope
);
if
(
send_varname_to_ctx
.
size
()
==
0
)
{
VLOG
(
0
)
<<
"nothing need to be send, will not start send_thread"
;
}
else
{
send_scope_
.
reset
(
new
Scope
());
for
(
auto
&
iter
:
send_varname_to_ctx_
)
{
send_varname_to_queue_
[
iter
.
first
]
=
std
::
make_shared
<
BlockingQueue
<
std
::
shared_ptr
<
Variable
>>>
(
send_queue_size_
);
}
void
GeoSgdCommunicator
::
InitImpl
(
const
paddle
::
framework
::
ProgramDesc
&
program
,
Scope
*
recv_scope
)
{}
consume_threadpool_
.
reset
(
new
::
ThreadPool
(
thread_pool_size_
));
}
if
(
recv_varname_to_ctx
.
size
()
==
0
)
{
VLOG
(
0
)
<<
"nothing need to be received, will not start recv_thread"
;
}
else
{
recv_threadpool_
.
reset
(
new
::
ThreadPool
(
thread_pool_size_
));
}
}
void
HalfAsyncCommunicator
::
InitImpl
(
const
paddle
::
framework
::
ProgramDesc
&
program
,
Scope
*
param_scope
)
{
RpcCtxMap
send_varname_to_ctx
;
RpcCtxMap
recv_varname_to_ctx
;
for
(
auto
*
op
:
program
.
Block
(
0
).
AllOps
())
{
VLOG
(
3
)
<<
"node name "
<<
op
->
Type
();
if
(
op
->
Type
()
==
"send"
)
{
auto
send_var_name
=
op
->
Input
(
"X"
)[
0
];
auto
send_varnames
=
boost
::
get
<
std
::
vector
<
std
::
string
>>
(
op
->
GetNullableAttr
(
"send_varnames"
));
auto
epmap
=
boost
::
get
<
std
::
vector
<
std
::
string
>>
(
op
->
GetNullableAttr
(
"epmap"
));
auto
height_section
=
boost
::
get
<
std
::
vector
<
int64_t
>>
(
op
->
GetNullableAttr
(
"sections"
));
auto
trainer_id
=
boost
::
get
<
int
>
(
op
->
GetNullableAttr
(
"trainer_id"
));
send_varname_to_ctx
[
send_var_name
]
=
operators
::
distributed
::
RpcContext
(
send_var_name
,
send_varnames
,
epmap
,
height_section
,
trainer_id
);
VLOG
(
3
)
<<
"find and init an send op: "
<<
send_varname_to_ctx
[
send_var_name
];
}
else
if
(
op
->
Type
()
==
"recv"
)
{
auto
do_not_run
=
boost
::
get
<
int
>
(
op
->
GetNullableAttr
(
"do_not_run"
));
PADDLE_ENFORCE_GT
(
do_not_run
,
0
,
platform
::
errors
::
InvalidArgument
(
"recv op's attr `do_not_run` must be True!"
));
auto
recv_var_name
=
op
->
Output
(
"Out"
)[
0
];
auto
recv_varnames
=
boost
::
get
<
std
::
vector
<
std
::
string
>>
(
op
->
GetNullableAttr
(
"recv_varnames"
));
auto
epmap
=
boost
::
get
<
std
::
vector
<
std
::
string
>>
(
op
->
GetNullableAttr
(
"epmap"
));
auto
trainer_id
=
boost
::
get
<
int
>
(
op
->
GetNullableAttr
(
"trainer_id"
));
recv_varname_to_ctx
[
recv_var_name
]
=
operators
::
distributed
::
RpcContext
(
recv_var_name
,
recv_varnames
,
epmap
,
{},
trainer_id
);
}
}
// init communicator here
if
(
send_varname_to_ctx
.
size
()
==
0
&&
recv_varname_to_ctx
.
size
()
==
0
)
{
LOG
(
WARNING
)
<<
"no var need to send and recv!!"
;
}
operators
::
distributed
::
HalfAsyncCommunicator
::
InitImpl
(
send_varname_to_ctx
,
recv_varname_to_ctx
,
param_scope
);
}
HalfAsyncCommunicator
::~
HalfAsyncCommunicator
()
{
running_
=
false
;
if
(
consume_thread_
)
consume_thread_
->
join
();
}
void
HalfAsyncCommunicator
::
ConsumeThread
()
{
VLOG
(
3
)
<<
"ConsumeThread start!"
;
while
(
running_
)
{
while
(
running_
)
{
if
(
barrier_counter_
.
load
()
>=
barrier_trigger_
.
load
())
{
break
;
}
else
{
std
::
this_thread
::
sleep_for
(
std
::
chrono
::
milliseconds
(
10
));
}
}
std
::
vector
<
std
::
future
<
void
>>
task_futures
;
task_futures
.
reserve
(
send_varname_to_ctx_
.
size
());
VLOG
(
3
)
<<
"run send graph"
;
auto
before_run_send_graph
=
GetCurrentUS
();
for
(
auto
&
iter
:
send_varname_to_queue_
)
{
auto
&
var_name
=
iter
.
first
;
auto
&
var_queue
=
iter
.
second
;
if
(
var_queue
->
Size
()
>
0
)
{
auto
send_task
=
[
this
,
&
var_name
,
&
var_queue
]
{
VLOG
(
3
)
<<
var_name
<<
" merge and send"
;
std
::
vector
<
std
::
shared_ptr
<
Variable
>>
vars
;
size_t
merged_var_num
=
0
;
size_t
wait_times
=
0
;
while
(
merged_var_num
<
static_cast
<
size_t
>
(
max_merge_var_num_
))
{
if
(
var_queue
->
Size
()
==
0
)
{
VLOG
(
3
)
<<
"wait_times -> "
<<
wait_times
;
if
(
wait_times
>=
static_cast
<
size_t
>
(
send_wait_times_
))
{
break
;
}
std
::
this_thread
::
sleep_for
(
std
::
chrono
::
milliseconds
(
10
));
wait_times
++
;
continue
;
}
else
{
wait_times
=
0
;
vars
.
push_back
(
var_queue
->
Pop
());
merged_var_num
++
;
}
}
auto
before_merge
=
GetCurrentUS
();
MergeVars
<
float
>
(
var_name
,
vars
,
send_scope_
.
get
(),
false
);
auto
after_merge
=
GetCurrentUS
();
VLOG
(
3
)
<<
"merge "
<<
merged_var_num
<<
" "
<<
var_name
<<
" use time "
<<
after_merge
-
before_merge
;
auto
send_functor
=
distributed
::
ParameterSend
<
float
>
();
auto
&
ctx
=
send_varname_to_ctx_
.
at
(
var_name
);
send_functor
(
ctx
,
*
send_scope_
,
true
,
1
);
auto
after_send
=
GetCurrentUS
();
VLOG
(
3
)
<<
"send "
<<
var_name
<<
" use time "
<<
after_send
-
after_merge
;
};
task_futures
.
emplace_back
(
consume_threadpool_
->
enqueue
(
std
::
move
(
send_task
)));
}
else
{
VLOG
(
4
)
<<
var_name
<<
" queue empty"
;
}
}
for
(
auto
&
task_f
:
task_futures
)
{
task_f
.
wait
();
}
auto
after_run_send_graph
=
GetCurrentUS
();
VLOG
(
3
)
<<
"run send graph use time "
<<
after_run_send_graph
-
before_run_send_graph
;
Recv
();
BarrierWeakUp
();
}
VLOG
(
0
)
<<
"communicator stopped, send thread exit"
;
}
void
HalfAsyncCommunicator
::
Send
(
const
std
::
vector
<
std
::
string
>
&
var_names
,
const
std
::
vector
<
std
::
string
>
&
var_tables
,
const
framework
::
Scope
&
scope
)
{
PADDLE_ENFORCE_EQ
(
var_names
.
size
(),
1
,
platform
::
errors
::
InvalidArgument
(
"var_names.size() == 1 is permitted"
));
auto
var_name
=
var_names
[
0
];
VLOG
(
3
)
<<
"communicator send "
<<
var_name
;
// push var into send queue by var_name
auto
*
grad_var
=
scope
.
FindVar
(
var_name
);
PADDLE_ENFORCE_EQ
(
grad_var
->
IsInitialized
(),
true
,
platform
::
errors
::
InvalidArgument
(
"grad var should is not initialized."
));
auto
tmp_grad_var
=
std
::
make_shared
<
Variable
>
();
framework
::
CopyVariable
(
*
grad_var
,
tmp_grad_var
.
get
());
auto
&
queue
=
send_varname_to_queue_
.
at
(
var_name
);
VLOG
(
3
)
<<
"send "
<<
var_name
<<
" queue size "
<<
queue
->
Size
();
queue
->
Push
(
tmp_grad_var
);
}
void
HalfAsyncCommunicator
::
Recv
()
{
VLOG
(
3
)
<<
"parallel run recv graph"
;
if
(
!
running_
)
return
;
auto
before_send
=
GetCurrentUS
();
std
::
vector
<
std
::
future
<
void
>>
task_futures
;
task_futures
.
reserve
(
recv_varname_to_ctx_
.
size
());
for
(
auto
&
iter
:
recv_varname_to_ctx_
)
{
auto
recv_task
=
[
this
,
&
iter
]
{
auto
&
var_name
=
iter
.
first
;
VLOG
(
4
)
<<
"recv var "
<<
var_name
;
auto
recv_functor
=
distributed
::
ParameterRecv
<
float
>
();
recv_functor
(
iter
.
second
,
*
recv_scope_
);
};
task_futures
.
emplace_back
(
recv_threadpool_
->
enqueue
(
std
::
move
(
recv_task
)));
}
for
(
auto
&
task
:
task_futures
)
{
task
.
wait
();
}
auto
after_recv
=
GetCurrentUS
();
VLOG
(
3
)
<<
"run recv graph use time "
<<
after_recv
-
before_send
;
}
void
HalfAsyncCommunicator
::
Barrier
()
{
barrier_counter_
++
;
{
std
::
unique_lock
<
std
::
mutex
>
lk
(
barrier_mutex_
);
barrier_cond_
.
wait
(
lk
,
[
this
]
{
return
(
barrier_counter_
==
0
);
});
}
}
void
HalfAsyncCommunicator
::
BarrierTriggerDecrement
()
{
barrier_trigger_
--
;
VLOG
(
3
)
<<
"BarrierTriggerDecrement decrement barrier trigger to "
<<
barrier_trigger_
.
load
();
}
void
HalfAsyncCommunicator
::
BarrierTriggerReset
(
int
initial_val
)
{
barrier_trigger_
.
store
(
initial_val
);
VLOG
(
3
)
<<
"BarrierTriggerReset reset barrier trigger to "
<<
barrier_trigger_
.
load
();
}
void
HalfAsyncCommunicator
::
BarrierWeakUp
()
{
barrier_counter_
.
store
(
0
);
barrier_cond_
.
notify_all
();
}
void
HalfAsyncCommunicator
::
Start
()
{
VLOG
(
0
)
<<
"Communicator start"
;
if
(
!
communicator_
)
{
VLOG
(
0
)
<<
"Communicator is not inited, do nothing"
;
}
else
{
VLOG
(
1
)
<<
"start send thread and recv thread"
;
BarrierTriggerReset
(
max_merge_var_num_
);
running_
=
true
;
consume_thread_
.
reset
(
new
std
::
thread
(
std
::
bind
(
&
HalfAsyncCommunicator
::
ConsumeThread
,
this
)));
}
}
void
HalfAsyncCommunicator
::
Stop
()
{
VLOG
(
0
)
<<
"Communicator stop"
;
running_
=
false
;
if
(
!
communicator_
)
{
VLOG
(
0
)
<<
"Communicator is not inited, do nothing"
;
}
else
{
if
(
consume_thread_
)
{
VLOG
(
4
)
<<
"stop send thread"
;
consume_thread_
->
join
();
consume_thread_
.
reset
(
nullptr
);
}
}
VLOG
(
0
)
<<
"Communicator stop done"
;
}
}
// namespace distributed
}
// namespace operators
...
...
paddle/fluid/operators/distributed/communicator.h
浏览文件 @
82bc814a
...
...
@@ -175,117 +175,57 @@ using RpcCtxMap = std::unordered_map<std::string, RpcContext>;
class
Communicator
{
public:
Communicator
();
explicit
Communicator
(
const
std
::
map
<
std
::
string
,
int
>&
env_flag
s
);
explicit
Communicator
(
const
std
::
map
<
std
::
string
,
std
::
string
>&
env
s
);
virtual
~
Communicator
()
{}
virtual
void
SetEnvFlagsDefault
();
virtual
void
Start
()
=
0
;
virtual
void
Stop
()
=
0
;
virtual
bool
IsRunning
()
{
return
running_
;
}
virtual
void
Send
(
const
std
::
string
&
var_name
,
const
framework
::
Scope
&
scope
)
=
0
;
virtual
void
Send
(
const
std
::
vector
<
std
::
string
>&
sparse_var_names
,
const
std
::
vector
<
std
::
string
>&
sparse_var_tables
,
virtual
void
Send
(
const
std
::
vector
<
std
::
string
>&
var_names
,
const
std
::
vector
<
std
::
string
>&
var_tables
,
const
framework
::
Scope
&
scope
)
=
0
;
virtual
void
Recv
()
=
0
;
virtual
void
Barrier
()
{}
virtual
void
BarrierTriggerDecrement
()
{}
virtual
void
BarrierTriggerReset
(
int
init_counter
)
{}
virtual
void
InitImpl
(
const
RpcCtxMap
&
send_varname_to_ctx
,
const
RpcCtxMap
&
recv_varname_to_ctx
,
Scope
*
recv_scope
)
=
0
;
Scope
*
recv_scope
)
{}
virtual
void
InitImpl
(
const
paddle
::
framework
::
ProgramDesc
&
program
,
Scope
*
recv_scope
)
=
0
;
// for geo-sgd
virtual
void
InitImpl
(
const
paddle
::
framework
::
ProgramDesc
&
program
,
Scope
*
param_scope
,
std
::
map
<
std
::
string
,
std
::
map
<
std
::
string
,
std
::
vector
<
std
::
string
>>>&
vars_info
,
const
int
&
trainers
,
const
int
&
geo_need_push_nums
)
=
0
;
static
Communicator
*
GetInstance
()
{
return
communicator_
.
get
();
}
static
std
::
shared_ptr
<
Communicator
>
GetInstantcePtr
()
{
return
communicator_
;
}
template
<
typename
T
>
static
Communicator
*
InitInstance
(
const
RpcCtxMap
&
send_varname_to_ctx
,
const
RpcCtxMap
&
recv_varname_to_ctx
,
Scope
*
recv_scope
)
{
std
::
call_once
(
init_flag_
,
&
Communicator
::
InitWithRpcCtx
<
T
>
,
send_varname_to_ctx
,
recv_varname_to_ctx
,
recv_scope
);
return
communicator_
.
get
();
}
template
<
typename
T
>
static
Communicator
*
InitInstance
(
const
paddle
::
framework
::
ProgramDesc
&
program
,
Scope
*
recv_scope
,
const
std
::
map
<
std
::
string
,
int
>&
env_flag
s
)
{
const
std
::
map
<
std
::
string
,
std
::
string
>&
env
s
)
{
std
::
call_once
(
init_flag_
,
&
Communicator
::
InitWithProgram
<
T
>
,
program
,
recv_scope
,
std
::
ref
(
env
_flag
s
));
recv_scope
,
std
::
ref
(
envs
));
return
communicator_
.
get
();
}
template
<
typename
T
>
static
Communicator
*
InitInstance
(
const
paddle
::
framework
::
ProgramDesc
&
program
,
Scope
*
training_scope
,
std
::
map
<
std
::
string
,
std
::
map
<
std
::
string
,
std
::
vector
<
std
::
string
>>>&
vars_info
,
const
int
&
trainers
,
const
int
&
geo_need_push_nums
,
const
std
::
map
<
std
::
string
,
int
>&
env_flags
)
{
std
::
call_once
(
init_flag_
,
&
Communicator
::
InitWithTranspilerInfo
<
T
>
,
program
,
training_scope
,
std
::
ref
(
vars_info
),
std
::
ref
(
trainers
),
std
::
ref
(
geo_need_push_nums
),
std
::
ref
(
env_flags
));
return
communicator_
.
get
();
}
// Init is called by InitInstance.
template
<
typename
T
>
static
void
InitWithRpcCtx
(
const
RpcCtxMap
&
send_varname_to_ctx
,
const
RpcCtxMap
&
recv_varname_to_ctx
,
Scope
*
recv_scope
)
{
if
(
communicator_
.
get
()
==
nullptr
)
{
communicator_
.
reset
(
new
T
());
communicator_
->
InitImpl
(
send_varname_to_ctx
,
recv_varname_to_ctx
,
recv_scope
);
}
}
template
<
typename
T
>
static
void
InitWithProgram
(
const
paddle
::
framework
::
ProgramDesc
&
program
,
Scope
*
recv_scope
,
const
std
::
map
<
std
::
string
,
int
>&
env_flag
s
)
{
const
std
::
map
<
std
::
string
,
std
::
string
>&
env
s
)
{
if
(
communicator_
.
get
()
==
nullptr
)
{
communicator_
.
reset
(
new
T
(
std
::
ref
(
env
_flag
s
)));
communicator_
.
reset
(
new
T
(
std
::
ref
(
envs
)));
communicator_
->
InitImpl
(
program
,
recv_scope
);
}
}
template
<
typename
T
>
static
void
InitWithTranspilerInfo
(
const
paddle
::
framework
::
ProgramDesc
&
program
,
Scope
*
training_scope
,
std
::
map
<
std
::
string
,
std
::
map
<
std
::
string
,
std
::
vector
<
std
::
string
>>>&
vars_info
,
const
int
&
trainers
,
const
int
&
geo_need_push_nums
,
const
std
::
map
<
std
::
string
,
int
>&
env_flags
)
{
if
(
communicator_
.
get
()
==
nullptr
)
{
communicator_
.
reset
(
new
T
(
std
::
ref
(
env_flags
)));
communicator_
->
InitImpl
(
program
,
training_scope
,
std
::
ref
(
vars_info
),
std
::
ref
(
trainers
),
std
::
ref
(
geo_need_push_nums
));
}
}
protected:
bool
running_
=
false
;
static
std
::
shared_ptr
<
Communicator
>
communicator_
;
static
std
::
once_flag
init_flag_
;
std
::
unordered_map
<
std
::
string
,
int
>
env_flags_dict
;
std
::
unordered_map
<
std
::
string
,
std
::
string
>
envs
;
};
using
SparseIdsMap
=
...
...
@@ -294,14 +234,23 @@ using SparseIdsMap =
class
AsyncCommunicator
:
public
Communicator
{
public:
AsyncCommunicator
()
:
Communicator
()
{}
explicit
AsyncCommunicator
(
const
std
::
map
<
std
::
string
,
int
>&
env_flags
)
:
Communicator
(
env_flags
)
{}
explicit
AsyncCommunicator
(
const
std
::
map
<
std
::
string
,
std
::
string
>&
envs
)
:
Communicator
(
envs
)
{
independent_recv_thread_
=
static_cast
<
bool
>
(
std
::
stoi
(
envs
.
at
(
"communicator_independent_recv_thread"
)));
min_send_grad_num_before_recv_
=
std
::
stoi
(
envs
.
at
(
"communicator_min_send_grad_num_before_recv"
));
thread_pool_size_
=
std
::
stoi
(
envs
.
at
(
"communicator_thread_pool_size"
));
max_merge_var_num_
=
std
::
stoi
(
envs
.
at
(
"communicator_max_merge_var_num"
));
send_wait_times_
=
std
::
stoi
(
envs
.
at
(
"communicator_send_wait_times"
));
send_queue_size_
=
std
::
stoi
(
envs
.
at
(
"communicator_send_queue_size"
));
is_sgd_optimizer_
=
static_cast
<
bool
>
(
std
::
stoi
(
envs
.
at
(
"communicator_is_sgd_optimizer"
)));
}
~
AsyncCommunicator
();
void
Start
()
override
;
void
Stop
()
override
;
void
Send
(
const
std
::
string
&
var_name
,
const
framework
::
Scope
&
scope
)
override
;
void
Recv
()
override
;
void
RecvAll
();
...
...
@@ -315,15 +264,18 @@ class AsyncCommunicator : public Communicator {
void
SendThread
();
void
RecvThread
();
void
Send
(
const
std
::
vector
<
std
::
string
>&
sparse_
var_names
,
const
std
::
vector
<
std
::
string
>&
sparse_
var_tables
,
void
Send
(
const
std
::
vector
<
std
::
string
>&
var_names
,
const
std
::
vector
<
std
::
string
>&
var_tables
,
const
framework
::
Scope
&
scope
)
override
;
void
InitImpl
(
const
paddle
::
framework
::
ProgramDesc
&
program
,
Scope
*
param_scope
,
std
::
map
<
std
::
string
,
std
::
map
<
std
::
string
,
std
::
vector
<
std
::
string
>>>&
vars_info
,
const
int
&
trainers
,
const
int
&
geo_need_push_nums
)
override
;
private:
int
min_send_grad_num_before_recv_
;
int
thread_pool_size_
;
int
max_merge_var_num_
;
int
send_wait_times_
;
int
send_queue_size_
;
bool
independent_recv_thread_
;
bool
is_sgd_optimizer_
;
private:
std
::
unordered_map
<
std
::
string
,
...
...
@@ -340,30 +292,32 @@ class AsyncCommunicator : public Communicator {
std
::
atomic_uint
grad_num_
{
0
};
// the num of gradient sent since last recv
};
class
GeoSgd
Communicator
:
public
Communicator
{
class
HalfAsync
Communicator
:
public
Communicator
{
public:
GeoSgdCommunicator
()
:
Communicator
()
{}
explicit
GeoSgdCommunicator
(
const
std
::
map
<
std
::
string
,
int
>&
env_flags
)
:
Communicator
(
env_flags
)
{}
~
GeoSgdCommunicator
();
void
InitImpl
(
const
paddle
::
framework
::
ProgramDesc
&
program
,
Scope
*
training_scope
,
std
::
map
<
std
::
string
,
std
::
map
<
std
::
string
,
std
::
vector
<
std
::
string
>>>&
vars_info
,
const
int
&
trainers
,
const
int
&
geo_need_push_nums
)
override
;
HalfAsyncCommunicator
()
{}
explicit
HalfAsyncCommunicator
(
const
std
::
map
<
std
::
string
,
std
::
string
>&
envs
)
:
Communicator
(
envs
)
{
max_merge_var_num_
=
std
::
stoi
(
envs
.
at
(
"communicator_max_merge_var_num"
));
send_wait_times_
=
std
::
stoi
(
envs
.
at
(
"communicator_send_wait_times"
));
thread_pool_size_
=
std
::
stoi
(
envs
.
at
(
"communicator_thread_pool_size"
));
send_queue_size_
=
std
::
stoi
(
envs
.
at
(
"communicator_send_queue_size"
));
}
~
HalfAsyncCommunicator
();
void
Start
()
override
;
void
Stop
()
override
;
void
Send
(
const
std
::
string
&
var_name
,
const
framework
::
Scope
&
scope
)
override
;
void
Send
(
const
std
::
vector
<
std
::
string
>&
sparse_var_names
,
const
std
::
vector
<
std
::
string
>&
sparse_var_tables
,
void
Send
(
const
std
::
vector
<
std
::
string
>&
var_names
,
const
std
::
vector
<
std
::
string
>&
var_tables
,
const
framework
::
Scope
&
scope
)
override
;
void
Recv
()
override
;
void
Barrier
()
override
;
void
BarrierWeakUp
();
void
BarrierTriggerDecrement
()
override
;
void
BarrierTriggerReset
(
int
initial_val
)
override
;
void
InitImpl
(
const
RpcCtxMap
&
send_varname_to_ctx
,
const
RpcCtxMap
&
recv_varname_to_ctx
,
Scope
*
recv_scope
)
override
;
...
...
@@ -371,6 +325,58 @@ class GeoSgdCommunicator : public Communicator {
void
InitImpl
(
const
paddle
::
framework
::
ProgramDesc
&
program
,
Scope
*
recv_scope
)
override
;
void
ConsumeThread
();
private:
int
max_merge_var_num_
;
int
send_wait_times_
;
int
thread_pool_size_
;
int
send_queue_size_
;
private:
std
::
unordered_map
<
std
::
string
,
std
::
shared_ptr
<
BlockingQueue
<
std
::
shared_ptr
<
Variable
>>>>
send_varname_to_queue_
;
RpcCtxMap
send_varname_to_ctx_
;
RpcCtxMap
recv_varname_to_ctx_
;
std
::
unique_ptr
<
std
::
thread
>
consume_thread_
{
nullptr
};
Scope
*
recv_scope_
;
// should be global scope
std
::
unique_ptr
<
Scope
>
send_scope_
;
// an independent scope
std
::
unique_ptr
<::
ThreadPool
>
consume_threadpool_
{
nullptr
};
std
::
unique_ptr
<::
ThreadPool
>
recv_threadpool_
{
nullptr
};
// mutex for Wait for barrier
std
::
mutex
barrier_mutex_
;
std
::
condition_variable
barrier_cond_
;
std
::
atomic
<
int64_t
>
barrier_trigger_
{
0
};
std
::
atomic
<
int64_t
>
barrier_counter_
{
0
};
};
class
GeoSgdCommunicator
:
public
Communicator
{
public:
GeoSgdCommunicator
()
:
Communicator
()
{}
explicit
GeoSgdCommunicator
(
const
std
::
map
<
std
::
string
,
std
::
string
>&
envs
)
:
Communicator
(
envs
)
{
geo_need_push_nums_
=
std
::
stoi
(
envs
.
at
(
"geo_need_push_nums"
));
trainer_nums_
=
std
::
stoi
(
envs
.
at
(
"geo_trainer_nums"
));
thread_pool_size_
=
std
::
stoi
(
envs
.
at
(
"communicator_thread_pool_size"
));
send_wait_times_
=
std
::
stoi
(
envs
.
at
(
"communicator_send_wait_times"
));
}
~
GeoSgdCommunicator
();
void
Start
()
override
;
void
Stop
()
override
;
void
Send
(
const
std
::
vector
<
std
::
string
>&
var_names
,
const
std
::
vector
<
std
::
string
>&
var_tables
,
const
framework
::
Scope
&
scope
)
override
;
void
Recv
()
override
;
void
InitImpl
(
const
paddle
::
framework
::
ProgramDesc
&
program
,
Scope
*
recv_scope
)
override
;
private:
void
SendThread
();
std
::
unordered_set
<
int64_t
>
SparseIdsMerge
(
...
...
@@ -379,6 +385,7 @@ class GeoSgdCommunicator : public Communicator {
void
SendUpdateDenseVars
(
const
std
::
string
&
var_name
,
const
std
::
string
&
splited_var_name
);
void
SendUpdateSparseVars
(
const
std
::
string
&
var_name
,
const
std
::
string
&
splited_var_name
,
const
std
::
unordered_set
<
int64_t
>&
ids_table
);
...
...
@@ -433,8 +440,11 @@ class GeoSgdCommunicator : public Communicator {
private:
int
trainer_nums_
=
1
;
size_t
geo_need_push_nums_
=
100
;
bool
is_geo_sgd_
=
false
;
int
geo_need_push_nums_
=
100
;
int
thread_pool_size_
;
int
send_wait_times_
;
private:
int
send_var_nums_
=
0
;
RpcCtxMap
send_varname_to_ctx_
;
...
...
paddle/fluid/operators/distributed/distributed.h
浏览文件 @
82bc814a
...
...
@@ -17,6 +17,7 @@
#ifdef PADDLE_WITH_DISTRIBUTE
#ifdef PADDLE_WITH_GRPC
#include "paddle/fluid/operators/distributed/communicator.h"
#include "paddle/fluid/operators/distributed/grpc/grpc_client.h"
#include "paddle/fluid/operators/distributed/grpc/grpc_server.h"
...
...
paddle/fluid/operators/distributed_ops/send_barrier_op.cc
浏览文件 @
82bc814a
...
...
@@ -36,6 +36,13 @@ class SendBarrierOp : public framework::OperatorBase {
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{
auto
is_half_async
=
Attr
<
bool
>
(
"half_async"
);
if
(
is_half_async
)
{
distributed
::
Communicator
::
GetInstance
()
->
Barrier
();
return
;
}
std
::
vector
<
std
::
string
>
eps
=
Attr
<
std
::
vector
<
std
::
string
>>
(
"endpoints"
);
distributed
::
RPCClient
*
rpc_client
=
...
...
@@ -76,6 +83,12 @@ the Parameter Server would knew all variables have been sent.
"(string vector, default 127.0.0.1:6164)"
"Server endpoints to send variables to."
)
.
SetDefault
({
"127.0.0.1:6164"
});
AddAttr
<
bool
>
(
"half_async"
,
"(bool, default false)"
"half_async=True is for half_async mode, this will send signal "
"to HalfAsyncCommunicator Instance"
)
.
SetDefault
(
false
);
}
};
...
...
paddle/fluid/operators/distributed_ops/send_op.cc
浏览文件 @
82bc814a
...
...
@@ -48,12 +48,7 @@ class SendOp : public framework::OperatorBase {
auto
use_send_handler
=
Attr
<
bool
>
(
"use_send_handler"
);
if
(
send_varnames
.
size
()
>
0
)
{
if
(
ins
.
size
()
>
1
)
{
distributed
::
Communicator
::
GetInstance
()
->
Send
(
ins
,
send_varnames
,
scope
);
}
else
{
distributed
::
Communicator
::
GetInstance
()
->
Send
(
ins
[
0
],
scope
);
}
distributed
::
Communicator
::
GetInstance
()
->
Send
(
ins
,
send_varnames
,
scope
);
}
else
{
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
...
...
paddle/fluid/pybind/communicator_py.cc
浏览文件 @
82bc814a
...
...
@@ -27,10 +27,11 @@ limitations under the License. */
namespace
py
=
pybind11
;
using
paddle
::
framework
::
ProgramDesc
;
using
paddle
::
operators
::
distributed
::
Communicator
;
using
paddle
::
framework
::
Scope
;
using
paddle
::
operators
::
distributed
::
AsyncCommunicator
;
using
paddle
::
operators
::
distributed
::
Communicator
;
using
paddle
::
operators
::
distributed
::
GeoSgdCommunicator
;
using
paddle
::
framework
::
Scope
;
using
paddle
::
operators
::
distributed
::
HalfAsyncCommunicator
;
namespace
paddle
{
namespace
pybind
{
...
...
@@ -39,29 +40,27 @@ void BindCommunicator(py::module* m) {
// Communicator is already used by nccl, change to DistCommunicator
py
::
class_
<
Communicator
,
std
::
shared_ptr
<
Communicator
>>
(
*
m
,
"DistCommunicator"
)
.
def
(
py
::
init
([](
const
ProgramDesc
&
program
,
Scope
*
param_scope
,
std
::
map
<
std
::
string
,
int
>&
env_flags
)
{
VLOG
(
0
)
<<
"using communicator"
;
Communicator
::
InitInstance
<
AsyncCommunicator
>
(
program
,
param_scope
,
env_flags
);
return
Communicator
::
GetInstantcePtr
();
}))
.
def
(
py
::
init
([](
const
ProgramDesc
&
program
,
Scope
*
training_scope
,
std
::
map
<
std
::
string
,
std
::
map
<
std
::
string
,
std
::
vector
<
std
::
string
>>>&
vars_info
,
int
&
trainers
,
int
&
geo_need_push_nums
,
std
::
map
<
std
::
string
,
int
>&
env_flags
)
{
VLOG
(
0
)
<<
"using geo sgd communicator"
;
Communicator
::
InitInstance
<
GeoSgdCommunicator
>
(
program
,
training_scope
,
vars_info
,
trainers
,
geo_need_push_nums
,
env_flags
);
.
def
(
py
::
init
([](
const
std
::
string
&
mode
,
const
ProgramDesc
&
program
,
Scope
*
param_scope
,
std
::
map
<
std
::
string
,
std
::
string
>&
envs
)
{
if
(
mode
==
"HALF_ASYNC"
)
{
Communicator
::
InitInstance
<
HalfAsyncCommunicator
>
(
program
,
param_scope
,
envs
);
}
else
if
(
mode
==
"ASYNC"
)
{
Communicator
::
InitInstance
<
AsyncCommunicator
>
(
program
,
param_scope
,
envs
);
}
else
if
(
mode
==
"GEO"
)
{
Communicator
::
InitInstance
<
GeoSgdCommunicator
>
(
program
,
param_scope
,
envs
);
}
else
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"unsuported communicator MODE"
));
}
return
Communicator
::
GetInstantcePtr
();
}))
.
def
(
"stop"
,
&
Communicator
::
Stop
)
.
def
(
"start"
,
&
Communicator
::
Start
)
.
def
(
"is_running"
,
&
Communicator
::
IsRunning
);
}
}
// namespace pybind
}
// namespace paddle
python/paddle/fluid/__init__.py
浏览文件 @
82bc814a
...
...
@@ -199,17 +199,6 @@ def __bootstrap__():
read_env_flags
.
append
(
'worker_update_interval_secs'
)
# env for communicator
read_env_flags
.
append
(
'communicator_independent_recv_thread'
)
read_env_flags
.
append
(
'communicator_send_queue_size'
)
read_env_flags
.
append
(
'communicator_min_send_grad_num_before_recv'
)
read_env_flags
.
append
(
'communicator_thread_pool_size'
)
read_env_flags
.
append
(
'communicator_max_merge_var_num'
)
read_env_flags
.
append
(
'communicator_merge_sparse_bucket'
)
read_env_flags
.
append
(
'communicator_fake_rpc'
)
read_env_flags
.
append
(
'communicator_send_wait_times'
)
read_env_flags
.
append
(
'communicator_merge_sparse_grad'
)
read_env_flags
.
append
(
'communicator_is_sgd_optimizer'
)
if
core
.
is_compiled_with_brpc
():
read_env_flags
.
append
(
'max_body_size'
)
#set brpc max body size
...
...
python/paddle/fluid/communicator.py
浏览文件 @
82bc814a
...
...
@@ -19,17 +19,13 @@ It's a wrapper of a cpp class Communicator and should be used inside fleet API.
"""
from
.
import
core
from
.framework
import
Program
from
.transpiler.distribute_transpiler
import
DistributedMode
__all__
=
[
'Communicator'
]
class
Communicator
(
object
):
def
__init__
(
self
,
program
,
vars_info
=
None
,
trainers
=
None
,
geo_sgd_need_push_nums
=
None
,
env_flags
=
None
):
def
__init__
(
self
,
program
,
mode
,
kwargs
=
None
,
envs
=
{}):
"""
Communicator is used for async distribute training in distribute_transpiler mode.
It's a wrapper of a cpp class Communicator and should be used inside fleet API.
...
...
@@ -56,20 +52,37 @@ class Communicator(object):
for
op
in
program
.
block
(
0
).
ops
:
if
op
.
type
==
"recv"
:
op
.
_set_attr
(
'do_not_run'
,
True
)
# Todo: Add check
if
env_flags
is
None
:
env_flags
=
{}
if
vars_info
and
trainers
and
geo_sgd_need_push_nums
:
# for geo sgd
self
.
communicator_
=
core
.
DistCommunicator
(
program
.
desc
,
global_scope
(),
vars_info
,
trainers
,
geo_sgd_need_push_nums
,
env_flags
)
else
:
self
.
communicator_
=
core
.
DistCommunicator
(
program
.
desc
,
global_scope
(),
env_flags
)
if
mode
==
DistributedMode
.
GEO
:
push_vars
=
kwargs
[
"push_vars"
]
push_var_names
=
[]
for
k
,
vs
in
push_vars
.
items
():
varnames
=
"&"
.
join
(
vs
[
"var_names"
])
sections
=
"&"
.
join
([
str
(
v
)
for
v
in
vs
[
"sections"
]])
endpoints
=
"&"
.
join
(
vs
[
"epmap"
])
is_sparse
=
"1"
if
vs
[
"is_sparse"
]
else
"0"
push_var_names
.
append
(
k
)
envs
[
k
]
=
"#"
.
join
([
varnames
,
sections
,
endpoints
,
is_sparse
])
envs
[
"geo_trainer_nums"
]
=
str
(
kwargs
[
"trainers"
])
envs
[
"geo_need_push_nums"
]
=
str
(
kwargs
[
"push_nums"
])
envs
[
"geo_send_varnames"
]
=
'#'
.
join
(
push_var_names
)
mode_str
=
None
if
mode
==
DistributedMode
.
SYNC
:
mode_str
=
"SYNC"
elif
mode
==
DistributedMode
.
ASYNC
:
mode_str
=
"ASYNC"
elif
mode
==
DistributedMode
.
HALF_ASYNC
:
mode_str
=
"HALF_ASYNC"
elif
mode
==
DistributedMode
.
GEO
:
mode_str
=
"GEO"
self
.
communicator_
=
core
.
DistCommunicator
(
mode_str
,
program
.
desc
,
global_scope
(),
envs
)
def
start
(
self
):
"""
...
...
python/paddle/fluid/executor.py
浏览文件 @
82bc814a
...
...
@@ -963,6 +963,7 @@ class Executor(object):
program
.
_pipeline_opt
)
else
:
trainer
=
TrainerFactory
().
_create_trainer
(
program
.
_fleet_opt
)
trainer
.
_set_thread_barrier
(
program
.
_is_distributed
)
trainer
.
_set_program
(
program
)
else
:
if
program
.
_pipeline_opt
:
...
...
python/paddle/fluid/incubate/fleet/parameter_server/distribute_transpiler/__init__.py
浏览文件 @
82bc814a
...
...
@@ -17,7 +17,6 @@ import warnings
"""
Convert the fluid program to distributed data-parallelism programs.
"""
from
.distributed_strategy
import
*
import
paddle.fluid.io
as
io
from
paddle.fluid.communicator
import
Communicator
from
paddle.fluid.framework
import
default_main_program
...
...
@@ -27,8 +26,11 @@ from paddle.fluid.compiler import CompiledProgram
from
paddle.fluid.executor
import
Executor
from
paddle.fluid.parallel_executor
import
ParallelExecutor
from
paddle.fluid.optimizer
import
Optimizer
from
paddle.fluid.incubate.fleet.parameter_server.distribute_transpiler.distributed_strategy
import
TrainerRuntimeConfig
,
DistributedStrategy
,
SyncStrategy
,
AsyncStrategy
,
HalfAsyncStrategy
,
GeoStrategy
,
StrategyFactory
from
paddle.fluid.transpiler.distribute_transpiler
import
DistributeTranspiler
as
OriginTranspiler
from
paddle.fluid.transpiler.distribute_transpiler
import
DistributeTranspilerConfig
,
ServerRuntimeConfig
from
paddle.fluid.transpiler.distribute_transpiler
import
DistributeTranspilerConfig
,
ServerRuntimeConfig
,
DistributedMode
from
paddle.fluid.incubate.fleet.base.fleet_base
import
DistributedOptimizer
from
paddle.fluid.incubate.fleet.base.fleet_base
import
Fleet
...
...
@@ -70,25 +72,39 @@ class DistributedTranspiler(Fleet):
program_config
=
self
.
_transpile_config
.
get_program_config
()
trainer_communicator_config
=
self
.
_transpile_config
.
get_trainer_runtime_config
(
)
if
isinstance
(
self
.
_transpile_config
,
SyncStrategy
):
return
print
(
trainer_communicator_config
)
need_communicator_flag
=
False
if
isinstance
(
self
.
_transpile_config
,
GeoStrategy
):
need_communicator_flag
=
True
kwargs
=
{}
kwargs
[
"push_vars"
]
=
self
.
vars_info
kwargs
[
"trainers"
]
=
fleet
.
worker_num
()
kwargs
[
"push_nums"
]
=
self
.
_transpile_config
.
get_program_config
(
).
geo_sgd_need_push_nums
self
.
_communicator
=
Communicator
(
self
.
main_program
,
self
.
vars_info
,
fleet
.
worker_num
(),
program_config
.
geo_sgd_need_push_nums
,
self
.
main_program
,
DistributedMode
.
GEO
,
kwargs
,
trainer_communicator_config
.
get_communicator_flags
())
elif
isinstance
(
self
.
_transpile_config
,
AsyncStrategy
):
need_communicator_flag
=
True
self
.
_communicator
=
Communicator
(
self
.
main_program
,
env_flags
=
trainer_communicator_config
.
get_communicator_flags
())
if
need_communicator_flag
:
if
not
self
.
_communicator
.
is_running
():
self
.
_communicator
.
start
()
else
:
warnings
.
warn
(
"communicator has been initialized, skip"
)
self
.
main_program
,
DistributedMode
.
ASYNC
,
None
,
trainer_communicator_config
.
get_communicator_flags
())
elif
isinstance
(
self
.
_transpile_config
,
HalfAsyncStrategy
):
self
.
_communicator
=
Communicator
(
self
.
main_program
,
DistributedMode
.
HALF_ASYNC
,
None
,
trainer_communicator_config
.
get_communicator_flags
())
else
:
raise
TypeError
(
"Training MODE do not supported"
)
if
not
self
.
_communicator
.
is_running
():
self
.
_communicator
.
start
()
else
:
warnings
.
warn
(
"communicator has been initialized, skip"
)
def
init_server
(
self
,
model_dir
=
None
):
"""
...
...
@@ -139,12 +155,12 @@ class DistributedTranspiler(Fleet):
Returns:
None
"""
if
isinstance
(
self
.
_transpile_config
,
GeoStrategy
)
or
isinstance
(
self
.
_transpile_config
,
As
yncStrategy
):
if
not
isinstance
(
self
.
_transpile_config
,
S
yncStrategy
):
self
.
_communicator
.
stop
()
self
.
_executor
.
close
()
if
isinstance
(
self
.
_role_maker
,
MPISymetricRoleMaker
):
self
.
_role_maker
.
_finalize
()
self
.
_executor
.
close
()
def
distributed_optimizer
(
self
,
optimizer
,
strategy
=
None
):
"""
...
...
@@ -250,14 +266,22 @@ class DistributedTranspiler(Fleet):
io
.
save_persistables
(
executor
,
dirname
,
main_program
,
None
)
def
_transpile
(
self
,
config
):
if
isinstance
(
config
,
DistributeTranspilerConfig
):
self
.
_transpile_config
=
DistributedStrategy
()
self
.
_transpile_config
.
set_program_config
(
config
)
elif
isinstance
(
config
,
DistributedStrategy
):
if
isinstance
(
config
,
DistributedStrategy
):
self
.
_transpile_config
=
config
elif
isinstance
(
config
,
DistributeTranspilerConfig
):
if
config
.
sync_mode
:
self
.
_transpile_config
=
SyncStrategy
()
elif
config
.
geo_sgd_mode
:
self
.
_transpile_config
=
GeoStrategy
(
config
.
geo_sgd_need_push_nums
)
elif
config
.
runtime_split_send_recv
and
config
.
half_async
:
self
.
_transpile_config
=
HalfAsyncStrategy
()
else
:
self
.
_transpile_config
=
AsyncStrategy
()
self
.
_transpile_config
.
set_program_config
(
config
)
else
:
raise
TypeError
(
"config must be an instance of DistributeTranspilerConfig
or DistributedStrategy
"
"config must be an instance of DistributeTranspilerConfig
, SyncStrategy, HalfAsyncStrategy, AsyncStrategy or GeoStratey.
"
)
program_config
=
self
.
_transpile_config
.
get_program_config
()
...
...
@@ -327,14 +351,12 @@ class TranspilerOptimizer(DistributedOptimizer):
super
(
TranspilerOptimizer
,
self
).
__init__
(
optimizer
,
strategy
)
if
strategy
:
if
isinstance
(
strategy
,
DistributedStrategy
):
if
isinstance
(
strategy
,
DistributeTranspilerConfig
)
or
isinstance
(
strategy
,
DistributedStrategy
):
self
.
_strategy
=
strategy
elif
isinstance
(
strategy
,
DistributeTranspilerConfig
):
self
.
_strategy
=
DistributedStrategy
()
self
.
_strategy
.
set_program_config
(
strategy
)
else
:
raise
TypeError
(
"In {} mode, strategy must be an instance of DistributeTranspilerConfig
or Distributed
Strategy"
.
"In {} mode, strategy must be an instance of DistributeTranspilerConfig
, SyncStrategy, HalfAsyncStrategy, AsyncStrategy, or Geo
Strategy"
.
format
(
fleet
.
_mode
))
else
:
self
.
_strategy
=
DistributedStrategy
()
...
...
python/paddle/fluid/incubate/fleet/parameter_server/distribute_transpiler/distributed_strategy.py
浏览文件 @
82bc814a
...
...
@@ -24,49 +24,51 @@ from paddle.fluid.transpiler.distribute_transpiler import DistributeTranspilerCo
class
TrainerRuntimeConfig
(
object
):
def
__init__
(
self
):
self
.
max_merge_var_num
=
int
(
os
.
getenv
(
"FLAGS_communicator_max_merge_var_num"
,
"20"
)
)
self
.
send_queue_size
=
int
(
os
.
getenv
(
"FLAGS_communicator_send_queue_size"
,
"20"
)
)
self
.
independent_recv_thread
=
int
(
os
.
getenv
(
"FLAGS_communicator_independent_recv_thread"
,
"1"
)
)
self
.
min_send_grad_num_before_recv
=
int
(
os
.
getenv
(
"FLAGS_communicator_min_send_grad_num_before_recv"
,
"20"
)
)
self
.
thread_pool_size
=
int
(
os
.
getenv
(
"FLAGS_communicator_thread_pool_size"
,
"5"
)
)
self
.
send_wait_times
=
int
(
os
.
getenv
(
"FLAGS_communicator_send_wait_times"
,
"5"
)
)
self
.
fake_rpc
=
int
(
os
.
getenv
(
"FLAGS_communicator_fake_rpc"
,
"0"
)
)
self
.
merge_sparse_grad
=
int
(
os
.
getenv
(
"FLAGS_communicator_merge_sparse_grad"
,
"1"
)
)
self
.
is_sgd_optimizer
=
int
(
os
.
getenv
(
"FLAGS_communicator_is_sgd_optimizer"
,
"1"
)
)
self
.
max_merge_var_num
=
os
.
getenv
(
"FLAGS_communicator_max_merge_var_num"
,
"20"
)
self
.
send_queue_size
=
os
.
getenv
(
"FLAGS_communicator_send_queue_size"
,
"20"
)
self
.
independent_recv_thread
=
os
.
getenv
(
"FLAGS_communicator_independent_recv_thread"
,
"1"
)
self
.
min_send_grad_num_before_recv
=
os
.
getenv
(
"FLAGS_communicator_min_send_grad_num_before_recv"
,
"20"
)
self
.
thread_pool_size
=
os
.
getenv
(
"FLAGS_communicator_thread_pool_size"
,
"5"
)
self
.
send_wait_times
=
os
.
getenv
(
"FLAGS_communicator_send_wait_times"
,
"5"
)
self
.
fake_rpc
=
os
.
getenv
(
"FLAGS_communicator_fake_rpc"
,
"0"
)
self
.
merge_sparse_grad
=
os
.
getenv
(
"FLAGS_communicator_merge_sparse_grad"
,
"1"
)
self
.
is_sgd_optimizer
=
os
.
getenv
(
"FLAGS_communicator_is_sgd_optimizer"
,
"1"
)
# not used
self
.
_rpc_deadline
=
int
(
os
.
getenv
(
"FLAGS_rpc_deadline"
,
"180000"
)
)
self
.
_rpc_retry_times
=
int
(
os
.
getenv
(
"FLAGS_rpc_retry_times"
,
"3"
)
)
self
.
_rpc_deadline
=
os
.
getenv
(
"FLAGS_rpc_deadline"
,
"180000"
)
self
.
_rpc_retry_times
=
os
.
getenv
(
"FLAGS_rpc_retry_times"
,
"3"
)
def
get_communicator_flags
(
self
):
_communicator_flags
=
dict
()
_communicator_flags
[
"max_merge_var_num"
]
=
self
.
max_merge_var_num
_communicator_flags
[
"send_queue_size"
]
=
self
.
send_queue_size
_communicator_flags
[
"
independent_recv_thread"
]
=
self
.
independent_recv_thread
"
communicator_max_merge_var_num"
]
=
self
.
max_merge_var_num
_communicator_flags
[
"min_send_grad_num_before_recv"
]
=
self
.
min_send_grad_num_before_recv
_communicator_flags
[
"thread_pool_size"
]
=
self
.
thread_pool_size
_communicator_flags
[
"send_wait_times"
]
=
self
.
send_wait_times
_communicator_flags
[
"fake_rpc"
]
=
self
.
fake_rpc
_communicator_flags
[
"merge_sparse_grad"
]
=
self
.
merge_sparse_grad
_communicator_flags
[
"is_sgd_optimizer"
]
=
self
.
is_sgd_optimizer
"communicator_send_queue_size"
]
=
self
.
send_queue_size
_communicator_flags
[
"communicator_independent_recv_thread"
]
=
self
.
independent_recv_thread
_communicator_flags
[
"communicator_min_send_grad_num_before_recv"
]
=
self
.
min_send_grad_num_before_recv
_communicator_flags
[
"communicator_thread_pool_size"
]
=
self
.
thread_pool_size
_communicator_flags
[
"communicator_send_wait_times"
]
=
self
.
send_wait_times
_communicator_flags
[
"communicator_is_sgd_optimizer"
]
=
self
.
is_sgd_optimizer
return
_communicator_flags
def
__repr__
(
self
):
_str
=
"please check that TrainerRuntimeConfig is as expected:
\n
"
_communicator_flags
=
self
.
get_communicator_flags
()
for
key
in
_communicator_flags
:
_str
+=
"communicator_{}: {}
\n
"
.
format
(
key
,
_communicator_flags
[
key
])
_str
+=
"{}: {}
\n
"
.
format
(
key
,
_communicator_flags
[
key
])
return
_str
...
...
@@ -193,8 +195,9 @@ class HalfAsyncStrategy(DistributedStrategy):
def
__init__
(
self
):
super
(
HalfAsyncStrategy
,
self
).
__init__
()
self
.
_program_config
.
sync_mode
=
False
self
.
_program_config
.
runtime_split_send_recv
=
False
self
.
_build_strategy
.
async_mode
=
False
self
.
_program_config
.
runtime_split_send_recv
=
True
self
.
_build_strategy
.
async_mode
=
True
self
.
_program_config
.
half_async
=
True
class
GeoStrategy
(
DistributedStrategy
):
...
...
@@ -202,9 +205,9 @@ class GeoStrategy(DistributedStrategy):
super
(
GeoStrategy
,
self
).
__init__
()
self
.
_program_config
.
sync_mode
=
False
self
.
_program_config
.
runtime_split_send_recv
=
True
self
.
_build_strategy
.
async_mode
=
True
self
.
_program_config
.
geo_sgd_mode
=
True
self
.
_program_config
.
geo_sgd_need_push_nums
=
update_frequency
self
.
_build_strategy
.
async_mode
=
True
class
StrategyFactory
(
object
):
...
...
python/paddle/fluid/tests/CMakeLists.txt
浏览文件 @
82bc814a
file
(
GLOB TEST_OPS RELATIVE
"
${
CMAKE_CURRENT_SOURCE_DIR
}
"
"test_*.py"
)
string
(
REPLACE
".py"
""
TEST_OPS
"
${
TEST_OPS
}
"
)
if
(
NOT WITH_DISTRIBUTE
)
list
(
REMOVE_ITEM TEST_OPS test_communicator
)
endif
(
NOT WITH_DISTRIBUTE
)
foreach
(
src
${
TEST_OPS
}
)
py_test
(
${
src
}
SRCS
${
src
}
.py
)
endforeach
()
...
...
python/paddle/fluid/tests/unittests/CMakeLists.txt
浏览文件 @
82bc814a
...
...
@@ -20,6 +20,10 @@ list(APPEND MIXED_DIST_TEST_OPS test_transpiler_ops)
list
(
APPEND MIXED_DIST_TEST_OPS test_lookup_remote_table_op
)
list
(
APPEND MIXED_DIST_TEST_OPS test_launch
)
list
(
APPEND MIXED_DIST_TEST_OPS test_launch_ps
)
list
(
APPEND MIXED_DIST_TEST_OPS test_communicator_async
)
list
(
APPEND MIXED_DIST_TEST_OPS test_communicator_geo
)
list
(
APPEND MIXED_DIST_TEST_OPS test_communicator_half_async
)
list
(
APPEND MIXED_DIST_TEST_OPS test_fleet_api_input
)
foreach
(
TEST_OP
${
MIXED_DIST_TEST_OPS
}
)
list
(
REMOVE_ITEM TEST_OPS
${
TEST_OP
}
)
endforeach
()
...
...
@@ -269,6 +273,9 @@ if(WITH_DISTRIBUTE)
py_test_modules
(
test_nce_remote_table_op MODULES test_nce_remote_table_op ENVS
${
dist_ENVS
}
)
py_test_modules
(
test_recv_save_op MODULES test_recv_save_op ENVS
${
dist_ENVS
}
)
py_test_modules
(
test_transpiler_ops MODULES test_transpiler_ops ENVS
${
dist_ENVS
}
)
py_test_modules
(
test_communicator_async MODULES test_communicator_async ENVS
${
dist_ENVS
}
)
py_test_modules
(
test_communicator_geo MODULES test_communicator_geo ENVS
${
dist_ENVS
}
)
py_test_modules
(
test_communicator_half_async MODULES test_communicator_half_async ENVS
${
dist_ENVS
}
FLAGS_communicator_send_queue_size=1 FLAGS_communicator_max_merge_var_num=1
)
if
(
WITH_DGC
)
# if with dgc, test all dgc tests.
# NOTE. dist dgc tests is already in DIST_TEST_OPS
...
...
python/paddle/fluid/tests/unittests/dist_fleet_ctr.py
浏览文件 @
82bc814a
...
...
@@ -110,8 +110,10 @@ class TestDistCTR2x2(FleetDistRunnerBase):
predict
=
fluid
.
layers
.
fc
(
input
=
merge_layer
,
size
=
2
,
act
=
'softmax'
)
acc
=
fluid
.
layers
.
accuracy
(
input
=
predict
,
label
=
label
)
auc_var
,
batch_auc_var
,
auc_states
=
fluid
.
layers
.
auc
(
input
=
predict
,
label
=
label
)
cost
=
fluid
.
layers
.
cross_entropy
(
input
=
predict
,
label
=
label
)
avg_cost
=
fluid
.
layers
.
mean
(
x
=
cost
)
...
...
@@ -242,11 +244,13 @@ class TestDistCTR2x2(FleetDistRunnerBase):
debug
=
False
)
pass_time
=
time
.
time
()
-
pass_start
model_dir
=
tempfile
.
mkdtemp
()
fleet
.
save_inference_model
(
exe
,
model_dir
,
[
feed
.
name
for
feed
in
self
.
feeds
],
self
.
avg_cost
)
self
.
check_model_right
(
model_dir
)
shutil
.
rmtree
(
model_dir
)
if
os
.
getenv
(
"SAVE_MODEL"
)
==
"1"
:
model_dir
=
tempfile
.
mkdtemp
()
fleet
.
save_inference_model
(
exe
,
model_dir
,
[
feed
.
name
for
feed
in
self
.
feeds
],
self
.
avg_cost
)
self
.
check_model_right
(
model_dir
)
shutil
.
rmtree
(
model_dir
)
fleet
.
stop_worker
()
...
...
python/paddle/fluid/tests/
test_communicator
.py
→
python/paddle/fluid/tests/
unittests/test_communicator_async
.py
浏览文件 @
82bc814a
...
...
@@ -16,7 +16,10 @@ from __future__ import print_function
import
unittest
import
time
import
threading
import
numpy
import
paddle
import
paddle.fluid
as
fluid
from
paddle.fluid.communicator
import
Communicator
...
...
@@ -35,7 +38,7 @@ class TestCommunicator(unittest.TestCase):
avg_cost
=
fluid
.
layers
.
mean
(
cost
)
return
avg_cost
def
test_communicator_
init_and_start
(
self
):
def
test_communicator_
async
(
self
):
role
=
role_maker
.
UserDefinedRoleMaker
(
current_id
=
0
,
role
=
role_maker
.
Role
.
WORKER
,
...
...
@@ -48,23 +51,15 @@ class TestCommunicator(unittest.TestCase):
optimizer
=
fluid
.
optimizer
.
SGD
(
0.01
)
strategy
=
DistributeTranspilerConfig
()
strategy
.
sync_mode
=
True
strategy
.
sync_mode
=
False
strategy
.
runtime_split_send_recv
=
True
strategy
.
wait_port
=
False
optimizer
=
fleet
.
distributed_optimizer
(
optimizer
,
strategy
)
optimizer
.
minimize
(
avg_cost
)
comm
=
Communicator
(
fleet
.
main_program
)
comm
.
start
()
fleet
.
init_worker
()
time
.
sleep
(
10
)
comm
.
stop
()
class
TestCommunicator2
(
unittest
.
TestCase
):
def
test_communicator_init_and_start
(
self
):
prog
=
fluid
.
Program
()
comm
=
Communicator
(
prog
)
comm
.
start
()
comm
.
stop
()
fleet
.
stop_worker
()
if
__name__
==
'__main__'
:
...
...
python/paddle/fluid/tests/unittests/test_communicator_geo.py
0 → 100644
浏览文件 @
82bc814a
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
print_function
import
unittest
import
time
import
threading
import
numpy
import
paddle
import
paddle.fluid
as
fluid
from
paddle.fluid.communicator
import
Communicator
from
paddle.fluid.transpiler.distribute_transpiler
import
DistributedMode
import
paddle.fluid.incubate.fleet.base.role_maker
as
role_maker
from
paddle.fluid.incubate.fleet.parameter_server.distribute_transpiler
import
fleet
from
paddle.fluid.transpiler.distribute_transpiler
import
DistributeTranspilerConfig
class
TestCommunicator
(
unittest
.
TestCase
):
def
net
(
self
):
x
=
fluid
.
layers
.
data
(
name
=
'x'
,
shape
=
[
13
],
dtype
=
'float32'
)
y_predict
=
fluid
.
layers
.
fc
(
input
=
x
,
size
=
1
,
act
=
None
)
y
=
fluid
.
layers
.
data
(
name
=
'y'
,
shape
=
[
1
],
dtype
=
'float32'
)
cost
=
fluid
.
layers
.
square_error_cost
(
input
=
y_predict
,
label
=
y
)
avg_cost
=
fluid
.
layers
.
mean
(
cost
)
return
avg_cost
def
test_communicator_geo
(
self
):
role
=
role_maker
.
UserDefinedRoleMaker
(
current_id
=
0
,
role
=
role_maker
.
Role
.
WORKER
,
worker_num
=
2
,
server_endpoints
=
[
"127.0.0.1:6001"
,
"127.0.0.1:6002"
])
fleet
.
init
(
role
)
avg_cost
=
self
.
net
()
optimizer
=
fluid
.
optimizer
.
SGD
(
0.01
)
strategy
=
DistributeTranspilerConfig
()
strategy
.
sync_mode
=
False
strategy
.
runtime_split_send_recv
=
True
strategy
.
geo_sgd_mode
=
True
strategy
.
wait_port
=
False
optimizer
=
fleet
.
distributed_optimizer
(
optimizer
,
strategy
)
optimizer
.
minimize
(
avg_cost
)
fleet
.
init_worker
()
time
.
sleep
(
10
)
fleet
.
stop_worker
()
# class TestCommunicatorGEO(unittest.TestCase):
# def test_communicator_init_and_start(self):
# prog = fluid.Program()
# envs = {}
# envs["communicator_thread_pool_size"] = "5"
# envs["communicator_send_wait_times"] = "5"
# kwargs = {}
# kwargs["push_vars"] = {}
# kwargs["trainers"] = 10
# kwargs["push_nums"] = 10
# comm = Communicator(prog, DistributedMode.GEO, kwargs, envs)
if
__name__
==
'__main__'
:
unittest
.
main
()
python/paddle/fluid/tests/unittests/test_communicator_half_async.py
0 → 100644
浏览文件 @
82bc814a
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
print_function
import
os
import
sys
import
time
import
threading
import
subprocess
import
unittest
import
numpy
import
paddle
import
paddle.fluid
as
fluid
from
paddle.fluid.communicator
import
Communicator
import
paddle.fluid.incubate.fleet.base.role_maker
as
role_maker
from
paddle.fluid.transpiler.distribute_transpiler
import
DistributedMode
from
paddle.fluid.transpiler.distribute_transpiler
import
DistributeTranspilerConfig
from
paddle.fluid.incubate.fleet.parameter_server.distribute_transpiler
import
fleet
class
TestCommunicatorHalfAsyncEnd2End
(
unittest
.
TestCase
):
def
net
(
self
):
x
=
fluid
.
layers
.
data
(
name
=
'x'
,
shape
=
[
13
],
dtype
=
'float32'
)
y_predict
=
fluid
.
layers
.
fc
(
input
=
x
,
size
=
1
,
act
=
None
)
y
=
fluid
.
layers
.
data
(
name
=
'y'
,
shape
=
[
1
],
dtype
=
'float32'
)
cost
=
fluid
.
layers
.
square_error_cost
(
input
=
y_predict
,
label
=
y
)
avg_cost
=
fluid
.
layers
.
mean
(
cost
)
return
avg_cost
,
x
,
y
def
fake_reader
(
self
):
def
reader
():
for
i
in
range
(
10000
):
x
=
numpy
.
random
.
random
((
1
,
13
)).
astype
(
'float32'
)
y
=
numpy
.
random
.
randint
(
0
,
2
,
(
1
,
1
)).
astype
(
'int64'
)
yield
x
,
y
return
reader
def
run_pserver
(
self
,
role
,
strategy
):
fleet
.
init
(
role
)
avg_cost
,
x
,
y
=
self
.
net
()
optimizer
=
fluid
.
optimizer
.
SGD
(
0.01
)
optimizer
=
fleet
.
distributed_optimizer
(
optimizer
,
strategy
)
optimizer
.
minimize
(
avg_cost
)
fleet
.
init_server
()
fleet
.
run_server
()
def
run_trainer
(
self
,
role
,
strategy
):
place
=
fluid
.
core
.
CPUPlace
()
exe
=
fluid
.
Executor
(
place
)
fleet
.
init
(
role
)
avg_cost
,
x
,
y
=
self
.
net
()
optimizer
=
fluid
.
optimizer
.
SGD
(
0.01
)
optimizer
=
fleet
.
distributed_optimizer
(
optimizer
,
strategy
)
optimizer
.
minimize
(
avg_cost
)
exe
.
run
(
fleet
.
startup_program
)
fleet
.
init_worker
()
train_reader
=
paddle
.
batch
(
self
.
fake_reader
(),
batch_size
=
24
)
feeder
=
fluid
.
DataFeeder
(
place
=
place
,
feed_list
=
[
x
,
y
])
for
batch_id
,
data
in
enumerate
(
train_reader
()):
exe
.
run
(
fleet
.
main_program
,
feed
=
feeder
.
feed
(
data
),
fetch_list
=
[])
fleet
.
stop_worker
()
def
run_ut
(
self
):
strategy
=
DistributeTranspilerConfig
()
strategy
.
sync_mode
=
False
strategy
.
runtime_split_send_recv
=
True
strategy
.
half_async
=
True
training_role
=
os
.
getenv
(
"TRAINING_ROLE"
,
"TRAINER"
)
role
=
role_maker
.
UserDefinedRoleMaker
(
current_id
=
0
,
role
=
role_maker
.
Role
.
WORKER
if
training_role
==
"TRAINER"
else
role_maker
.
Role
.
SERVER
,
worker_num
=
2
,
server_endpoints
=
[
"127.0.0.1:6002"
])
if
training_role
==
"TRAINER"
:
self
.
run_trainer
(
role
,
strategy
)
else
:
self
.
run_pserver
(
role
,
strategy
)
def
test_communicator
(
self
):
run_server_cmd
=
"""
from __future__ import print_function
import sys
import os
import time
import threading
import subprocess
import unittest
import numpy
import paddle
import paddle.fluid as fluid
from paddle.fluid.communicator import Communicator
from paddle.fluid.communicator import DistributedMode
import paddle.fluid.incubate.fleet.base.role_maker as role_maker
from test_communicator_half_async import TestCommunicatorHalfAsyncEnd2End
from paddle.fluid.transpiler.distribute_transpiler import DistributeTranspilerConfig
from paddle.fluid.incubate.fleet.parameter_server.distribute_transpiler import fleet
class RunServer(TestCommunicatorHalfAsyncEnd2End):
def runTest(self):
pass
os.environ["TRAINING_ROLE"] = "PSERVER"
half_run_server = RunServer()
half_run_server.run_ut()
"""
server_file
=
"run_server_for_communicator_haflaysnc.py"
with
open
(
server_file
,
"w"
)
as
wb
:
wb
.
write
(
run_server_cmd
)
os
.
environ
[
"TRAINING_ROLE"
]
=
"PSERVER"
_python
=
sys
.
executable
ps_cmd
=
"{} {}"
.
format
(
_python
,
server_file
)
ps_proc
=
subprocess
.
Popen
(
ps_cmd
.
strip
().
split
(
" "
),
stdout
=
subprocess
.
PIPE
,
stderr
=
subprocess
.
PIPE
)
os
.
environ
[
"TRAINING_ROLE"
]
=
"TRAINER"
os
.
environ
[
"FLAGS_communicator_send_queue_size"
]
=
"1"
os
.
environ
[
"FLAGS_communicator_max_merge_var_num"
]
=
"1"
self
.
run_ut
()
ps_proc
.
kill
()
if
os
.
path
.
exists
(
server_file
):
os
.
remove
(
server_file
)
# class TestCommunicatorHalfAsync2(unittest.TestCase):
# def test_communicator_init_and_start(self):
# prog = fluid.Program()
# envs = {}
# envs["communicator_send_queue_size"] = "12"
# envs["communicator_max_merge_var_num"] = "12"
# envs["communicator_thread_pool_size"] = "5"
# envs["communicator_send_wait_times"] = "5"
# comm = Communicator(prog, DistributedMode.HALF_ASYNC, None, envs)
# comm.start()
# time.sleep(10)
# comm.stop()
if
__name__
==
'__main__'
:
unittest
.
main
()
python/paddle/fluid/tests/unittests/test_dist_ctr.py
浏览文件 @
82bc814a
...
...
@@ -32,7 +32,6 @@ class TestDistCTR2x2(TestDistBase):
"dist_ctr.py"
,
delta
=
1e-2
,
check_error_log
=
True
,
log_name
=
flag_name
)
@
unittest
.
skip
(
reason
=
"Skip unstable ci"
)
class
TestDistCTRWithL2Decay2x2
(
TestDistBase
):
def
_setup_config
(
self
):
self
.
_sync_mode
=
True
...
...
@@ -48,6 +47,7 @@ class TestDistCTRWithL2Decay2x2(TestDistBase):
log_name
=
flag_name
)
@
unittest
.
skip
(
reason
=
"Skip unstable ci"
)
class
TestDistCTR2x2_ASYNC
(
TestDistBase
):
def
_setup_config
(
self
):
self
.
_sync_mode
=
False
...
...
@@ -69,6 +69,7 @@ class TestDistCTR2x2_ASYNC(TestDistBase):
log_name
=
flag_name
)
@
unittest
.
skip
(
reason
=
"Skip unstable ci"
)
class
TestDistCTR2x2_ASYNCWithLRDecay2x2
(
TestDistBase
):
def
_setup_config
(
self
):
self
.
_sync_mode
=
False
...
...
@@ -91,6 +92,7 @@ class TestDistCTR2x2_ASYNCWithLRDecay2x2(TestDistBase):
log_name
=
flag_name
)
@
unittest
.
skip
(
reason
=
"Skip unstable ci"
)
class
TestDistCTR2x2_ASYNC2
(
TestDistBase
):
def
_setup_config
(
self
):
self
.
_sync_mode
=
False
...
...
python/paddle/fluid/tests/unittests/test_dist_fleet_base.py
浏览文件 @
82bc814a
...
...
@@ -53,7 +53,22 @@ class FleetDistRunnerBase(object):
do training : exe run program
"""
def
generate_strategy
(
self
,
args
):
def
build_role
(
self
,
args
):
if
args
.
role
.
upper
()
==
"PSERVER"
:
role
=
role_maker
.
UserDefinedRoleMaker
(
current_id
=
args
.
current_id
,
role
=
role_maker
.
Role
.
SERVER
,
worker_num
=
args
.
trainers
,
server_endpoints
=
args
.
endpoints
.
split
(
","
))
else
:
role
=
role_maker
.
UserDefinedRoleMaker
(
current_id
=
args
.
current_id
,
role
=
role_maker
.
Role
.
WORKER
,
worker_num
=
args
.
trainers
,
server_endpoints
=
args
.
endpoints
.
split
(
","
))
return
role
def
build_strategy
(
self
,
args
):
self
.
strategy
=
None
if
args
.
mode
==
"async"
:
self
.
strategy
=
StrategyFactory
.
create_async_strategy
()
...
...
@@ -66,22 +81,7 @@ class FleetDistRunnerBase(object):
args
.
geo_sgd_need_push_nums
)
return
self
.
strategy
def
run_pserver
(
self
,
args
):
if
args
.
role
.
upper
()
!=
"PSERVER"
:
raise
ValueError
(
"args role must be PSERVER"
)
role
=
role_maker
.
UserDefinedRoleMaker
(
current_id
=
args
.
current_id
,
role
=
role_maker
.
Role
.
SERVER
,
worker_num
=
args
.
trainers
,
server_endpoints
=
args
.
endpoints
.
split
(
","
))
fleet
.
init
(
role
)
strategy
=
self
.
generate_strategy
(
args
)
avg_cost
=
self
.
net
()
def
build_optimizer
(
self
,
avg_cost
,
strategy
):
use_grad_clip
=
int
(
os
.
getenv
(
'GRAD_CLIP'
,
0
))
if
use_grad_clip
:
# 1: clip_by_value; 2: clip_by_norm; 3:clip_by_global_norm
...
...
@@ -99,70 +99,33 @@ class FleetDistRunnerBase(object):
optimizer
=
fleet
.
distributed_optimizer
(
optimizer
,
strategy
)
optimizer
.
minimize
(
avg_cost
)
def
run_pserver
(
self
,
args
):
fleet
.
init
(
self
.
build_role
(
args
))
strategy
=
self
.
build_strategy
(
args
)
avg_cost
=
self
.
net
()
self
.
build_optimizer
(
avg_cost
,
strategy
)
fleet
.
init_server
()
fleet
.
run_server
()
def
run_dataset_trainer
(
self
,
args
):
if
args
.
role
.
upper
()
!=
"TRAINER"
:
raise
ValueError
(
"args role must be TRAINER"
)
role
=
role_maker
.
UserDefinedRoleMaker
(
current_id
=
args
.
current_id
,
role
=
role_maker
.
Role
.
WORKER
,
worker_num
=
args
.
trainers
,
server_endpoints
=
args
.
endpoints
.
split
(
","
))
fleet
.
init
(
role
)
strategy
=
self
.
generate_strategy
(
args
)
fleet
.
init
(
self
.
build_role
(
args
))
strategy
=
self
.
build_strategy
(
args
)
avg_cost
=
self
.
net
()
use_grad_clip
=
int
(
os
.
getenv
(
'GRAD_CLIP'
,
0
))
if
use_grad_clip
:
# 1: clip_by_value; 2: clip_by_norm; 3:clip_by_global_norm
if
use_grad_clip
==
1
:
fluid
.
clip
.
set_gradient_clip
(
clip
=
fluid
.
clip
.
GradientClipByValue
(
2.0
))
elif
use_grad_clip
==
2
:
fluid
.
clip
.
set_gradient_clip
(
clip
=
fluid
.
clip
.
GradientClipByNorm
(
2.0
))
elif
use_grad_clip
==
3
:
fluid
.
clip
.
set_gradient_clip
(
clip
=
fluid
.
clip
.
GradientClipByGlobalNorm
(
2.0
))
optimizer
=
fluid
.
optimizer
.
SGD
(
LEARNING_RATE
)
optimizer
=
fleet
.
distributed_optimizer
(
optimizer
,
strategy
)
optimizer
.
minimize
(
avg_cost
)
self
.
build_optimizer
(
avg_cost
,
strategy
)
out
=
self
.
do_dataset_training
(
fleet
)
def
run_pyreader_trainer
(
self
,
args
):
if
args
.
role
.
upper
()
!=
"TRAINER"
:
raise
ValueError
(
"args role must be TRAINER"
)
role
=
role_maker
.
UserDefinedRoleMaker
(
current_id
=
args
.
current_id
,
role
=
role_maker
.
Role
.
WORKER
,
worker_num
=
args
.
trainers
,
server_endpoints
=
args
.
endpoints
.
split
(
","
))
fleet
.
init
(
role
)
strategy
=
self
.
generate_strategy
(
args
)
fleet
.
init
(
self
.
build_role
(
args
))
strategy
=
self
.
build_strategy
(
args
)
avg_cost
=
self
.
net
()
self
.
reader
=
fluid
.
io
.
PyReader
(
feed_list
=
self
.
feeds
,
capacity
=
64
,
iterable
=
False
,
use_double_buffer
=
False
)
optimizer
=
fluid
.
optimizer
.
SGD
(
LEARNING_RATE
)
optimizer
=
fleet
.
distributed_optimizer
(
optimizer
,
strategy
)
optimizer
.
minimize
(
avg_cost
)
self
.
build_optimizer
(
avg_cost
,
strategy
)
out
=
self
.
do_pyreader_training
(
fleet
)
def
net
(
self
,
batch_size
=
4
,
lr
=
0.01
):
...
...
@@ -263,7 +226,7 @@ class TestFleetBase(unittest.TestCase):
return
tr0_proc
,
tr1_proc
,
tr0_pipe
,
tr1_pipe
def
_run_cluster
(
self
,
model
,
envs
):
env
=
{
'
CPU_NUM'
:
'1'
,
'
GRAD_CLIP'
:
str
(
self
.
_grad_clip_mode
)}
env
=
{
'GRAD_CLIP'
:
str
(
self
.
_grad_clip_mode
)}
env
.
update
(
envs
)
python_path
=
self
.
_python_interp
...
...
@@ -307,29 +270,6 @@ class TestFleetBase(unittest.TestCase):
ps0
.
terminate
()
ps1
.
terminate
()
'''
with open("/tmp/tr0_out.log", "wb+") as wn:
wn.write(tr0_out)
with open("/tmp/tr1_out.log", "wb+") as wn:
wn.write(tr1_out)
# print server log
'''
# print server log
'''
with open("/tmp/ps0_err.log", "r") as fn:
sys.stderr.write("ps0 stderr: %s
\n
" % fn.read())
with open("/tmp/ps1_err.log", "r") as fn:
sys.stderr.write("ps1 stderr: %s
\n
" % fn.read())
'''
# print log
'''
with open("/tmp/tr0_err.log", "r") as fn:
sys.stderr.write('trainer 0 stderr: %s
\n
' % fn.read())
with open("/tmp/tr1_err.log", "r") as fn:
sys.stderr.write('trainer 1 stderr: %s
\n
' % fn.read())
'''
return
0
,
0
...
...
python/paddle/fluid/tests/unittests/test_dist_fleet_ctr.py
浏览文件 @
82bc814a
...
...
@@ -50,9 +50,9 @@ class TestDistMnistSync2x2(TestFleetBase):
"dist_fleet_ctr.py"
,
delta
=
1e-5
,
check_error_log
=
True
)
class
TestDistMnist
Half
Async2x2
(
TestFleetBase
):
class
TestDistMnistAsync2x2
(
TestFleetBase
):
def
_setup_config
(
self
):
self
.
_mode
=
"
half_
async"
self
.
_mode
=
"async"
self
.
_reader
=
"pyreader"
def
check_with_place
(
self
,
...
...
@@ -81,10 +81,10 @@ class TestDistMnistHalfAsync2x2(TestFleetBase):
"dist_fleet_ctr.py"
,
delta
=
1e-5
,
check_error_log
=
True
)
class
TestDistMnistAsync2x2
(
TestFleetBase
):
class
TestDistMnistAsync
Dataset
2x2
(
TestFleetBase
):
def
_setup_config
(
self
):
self
.
_mode
=
"async"
self
.
_reader
=
"
pyreader
"
self
.
_reader
=
"
dataset
"
def
check_with_place
(
self
,
model_file
,
...
...
@@ -96,7 +96,8 @@ class TestDistMnistAsync2x2(TestFleetBase):
"PYTHONPATH"
:
os
.
getenv
(
"PYTHONPATH"
,
""
),
"LD_LIBRARY_PATH"
:
os
.
getenv
(
"LD_LIBRARY_PATH"
,
""
),
"FLAGS_rpc_deadline"
:
"5000"
,
# 5sec to fail fast
"http_proxy"
:
""
"http_proxy"
:
""
,
"SAVE_MODEL"
:
"1"
}
required_envs
.
update
(
need_envs
)
...
...
@@ -112,10 +113,10 @@ class TestDistMnistAsync2x2(TestFleetBase):
"dist_fleet_ctr.py"
,
delta
=
1e-5
,
check_error_log
=
True
)
class
TestDist
MnistAsyncDataset
2x2
(
TestFleetBase
):
class
TestDist
CtrHalfAsync
2x2
(
TestFleetBase
):
def
_setup_config
(
self
):
self
.
_mode
=
"async"
self
.
_reader
=
"
dataset
"
self
.
_mode
=
"
half_
async"
self
.
_reader
=
"
pyreader
"
def
check_with_place
(
self
,
model_file
,
...
...
@@ -126,8 +127,12 @@ class TestDistMnistAsyncDataset2x2(TestFleetBase):
"PATH"
:
os
.
getenv
(
"PATH"
,
""
),
"PYTHONPATH"
:
os
.
getenv
(
"PYTHONPATH"
,
""
),
"LD_LIBRARY_PATH"
:
os
.
getenv
(
"LD_LIBRARY_PATH"
,
""
),
"FLAGS_rpc_deadline"
:
"5000"
,
# 5sec to fail fast
"http_proxy"
:
""
"FLAGS_rpc_deadline"
:
"30000"
,
# 5sec to fail fast
"http_proxy"
:
""
,
"FLAGS_communicator_send_queue_size"
:
"1"
,
"FLAGS_communicator_max_merge_var_num"
:
"1"
,
"CPU_NUM"
:
"1"
,
"SAVE_MODEL"
:
"0"
}
required_envs
.
update
(
need_envs
)
...
...
python/paddle/fluid/tests/unittests/test_distributed_strategy.py
浏览文件 @
82bc814a
...
...
@@ -102,8 +102,10 @@ class TestStrategyFactor(unittest.TestCase):
trainer_runtime_config
=
strategy
.
get_trainer_runtime_config
()
trainer_communicator_flags
=
trainer_runtime_config
.
get_communicator_flags
(
)
self
.
assertIn
(
'send_queue_size'
,
trainer_communicator_flags
)
self
.
assertEqual
(
trainer_communicator_flags
[
'send_queue_size'
],
100
)
self
.
assertIn
(
'communicator_send_queue_size'
,
trainer_communicator_flags
)
self
.
assertEqual
(
trainer_communicator_flags
[
'communicator_send_queue_size'
],
100
)
# test set_trainer_runtime_config exception
trainer_runtime_config_dict
[
'unknown'
]
=
None
...
...
@@ -138,9 +140,8 @@ class TestStrategyFactor(unittest.TestCase):
def
test_half_async_strategy
(
self
):
strategy
=
StrategyFactory
.
create_half_async_strategy
()
self
.
assertEqual
(
strategy
.
_program_config
.
sync_mode
,
False
)
self
.
assertEqual
(
strategy
.
_program_config
.
runtime_split_send_recv
,
False
)
self
.
assertEqual
(
strategy
.
_build_strategy
.
async_mode
,
False
)
self
.
assertEqual
(
strategy
.
_program_config
.
runtime_split_send_recv
,
True
)
self
.
assertEqual
(
strategy
.
_build_strategy
.
async_mode
,
True
)
# test set_server_runtime_config using ServerRuntimeConfig
server_runtime_config_class
=
ServerRuntimeConfig
()
...
...
python/paddle/fluid/tests/unittests/test_fleet_api_input.py
浏览文件 @
82bc814a
...
...
@@ -100,9 +100,10 @@ class FleetTest(unittest.TestCase):
self
.
assertRaises
(
Exception
,
fleet
.
_transpile
,
"config"
)
def
set_program
(
self
,
avg_cost
,
strategy
):
optimizer
=
fluid
.
optimizer
.
SGD
(
0.1
)
optimizer
=
fleet
.
distributed_optimizer
(
optimizer
,
strategy
)
optimizer
.
minimize
(
avg_cost
)
with
fluid
.
scope_guard
(
fluid
.
Scope
()):
optimizer
=
fluid
.
optimizer
.
SGD
(
0.1
)
optimizer
=
fleet
.
distributed_optimizer
(
optimizer
,
strategy
)
optimizer
.
minimize
(
avg_cost
)
def
test_init_role
(
self
):
role
=
role_maker
.
UserDefinedRoleMaker
(
...
...
@@ -123,6 +124,27 @@ class FleetTest(unittest.TestCase):
self
.
assertRaises
(
Exception
,
self
.
set_program
,
avg_cost
,
strategy
)
def
test_transpile
(
self
):
role
=
role_maker
.
UserDefinedRoleMaker
(
current_id
=
0
,
role
=
role_maker
.
Role
.
SERVER
,
worker_num
=
2
,
server_endpoints
=
[
"127.0.0.1:36011"
,
"127.0.0.1:36012"
])
# for test optimizer without init(role)
fleet
.
init
(
role
)
batch_size
=
128
is_sparse
=
True
is_distribute
=
False
strategy
=
DistributeTranspilerConfig
()
strategy
.
sync_mode
=
False
strategy
.
runtime_split_send_recv
=
True
avg_cost
,
_
,
_
=
train_network
(
batch_size
,
is_distribute
,
is_sparse
)
self
.
set_program
(
avg_cost
,
strategy
)
strategy
.
runtime_split_send_recv
=
False
self
.
set_program
(
avg_cost
,
strategy
)
class
TranspilerOptimizerTest
(
unittest
.
TestCase
):
def
testInvalidInputs
(
self
):
...
...
python/paddle/fluid/trainer_desc.py
浏览文件 @
82bc814a
...
...
@@ -108,6 +108,9 @@ class TrainerDesc(object):
for
param
in
dump_param
:
self
.
proto_desc
.
dump_param
.
append
(
param
)
def
_set_thread_barrier
(
self
,
thread_barrier
):
self
.
proto_desc
.
thread_barrier
=
thread_barrier
def
_set_check_nan_var_names
(
self
,
check_nan_var_names
):
for
var
in
check_nan_var_names
:
self
.
proto_desc
.
check_nan_var_names
.
append
(
var
)
...
...
python/paddle/fluid/transpiler/distribute_transpiler.py
浏览文件 @
82bc814a
...
...
@@ -190,6 +190,9 @@ class DistributeTranspilerConfig(object):
__runtime_split_send_recv
=
False
__sync_mode
=
True
# half_async
half_async
=
False
# Geo-sgd algorithm
geo_sgd_mode
=
False
geo_sgd_need_push_nums
=
100
...
...
@@ -744,27 +747,15 @@ class DistributeTranspiler(object):
for
_
,
var
in
enumerate
(
splited_vars
):
send_vars
.
append
(
var
)
if
self
.
sync_mode
:
fetch_barrier_input
=
[]
send_barrier_out
=
program
.
global_block
().
create_var
(
name
=
framework
.
generate_control_dev_var_name
())
if
self
.
has_distributed_lookup_table
:
self
.
grad_name_to_send_dummy_out
[
self
.
table_name
]
=
program
.
global_block
().
create_var
(
name
=
framework
.
generate_control_dev_var_name
())
input_deps
=
list
(
self
.
grad_name_to_send_dummy_out
.
values
())
send_barrier_out
=
program
.
global_block
().
create_var
(
name
=
framework
.
generate_control_dev_var_name
())
if
self
.
has_distributed_lookup_table
:
self
.
grad_name_to_send_dummy_out
[
self
.
table_name
]
=
program
.
global_block
().
create_var
(
name
=
framework
.
generate_control_dev_var_name
())
input_deps
=
list
(
self
.
grad_name_to_send_dummy_out
.
values
())
program
.
global_block
().
append_op
(
type
=
"send_barrier"
,
inputs
=
{
"X"
:
list
(
input_deps
)},
outputs
=
{
"Out"
:
send_barrier_out
},
attrs
=
{
"endpoints"
:
pserver_endpoints
,
"trainer_id"
:
self
.
trainer_id
,
RPC_OP_ROLE_ATTR_NAME
:
RPC_OP_ROLE_ATTR_VALUE
})
fetch_barrier_input
.
append
(
send_barrier_out
)
else
:
if
not
self
.
sync_mode
:
lr_ops
=
self
.
_get_lr_ops
()
if
len
(
lr_ops
)
>
0
and
self
.
counter_var
:
decay_dummy_output
=
program
.
global_block
().
create_var
(
...
...
@@ -789,6 +780,35 @@ class DistributeTranspiler(object):
OP_ROLE_VAR_ATTR_NAME
:
[
self
.
counter_var
.
name
,
self
.
counter_var
.
name
]
})
input_deps
.
append
(
decay_dummy_output
)
if
self
.
sync_mode
:
fetch_barrier_input
=
[]
program
.
global_block
().
append_op
(
type
=
"send_barrier"
,
inputs
=
{
"X"
:
list
(
input_deps
)},
outputs
=
{
"Out"
:
send_barrier_out
},
attrs
=
{
"endpoints"
:
pserver_endpoints
,
"trainer_id"
:
self
.
trainer_id
,
"half_async"
:
False
,
RPC_OP_ROLE_ATTR_NAME
:
RPC_OP_ROLE_ATTR_VALUE
})
fetch_barrier_input
.
append
(
send_barrier_out
)
else
:
if
self
.
config
.
runtime_split_send_recv
and
self
.
config
.
half_async
:
program
.
global_block
().
append_op
(
type
=
"send_barrier"
,
inputs
=
{
"X"
:
list
(
input_deps
)},
outputs
=
{
"Out"
:
send_barrier_out
},
attrs
=
{
"endpoints"
:
pserver_endpoints
,
"trainer_id"
:
self
.
trainer_id
,
"half_async"
:
True
,
RPC_OP_ROLE_ATTR_NAME
:
RPC_OP_ROLE_ATTR_VALUE
})
# step 3: insert recv op to receive parameters from parameter server
recv_vars
=
[]
...
...
@@ -859,8 +879,6 @@ class DistributeTranspiler(object):
OP_ROLE_VAR_ATTR_NAME
:
[
param_varname
,
recv_op_role_var_name
]
})
if
self
.
sync_mode
:
fetch_barrier_input
.
extend
(
splited_var
)
self
.
_update_remote_sparse_update_op
(
program
,
need_sparse_update_params
)
...
...
@@ -877,10 +895,11 @@ class DistributeTranspiler(object):
})
for
param_varname
,
splited_var
in
six
.
iteritems
(
self
.
param_var_mapping
):
if
len
(
splited_var
)
<=
1
:
continue
orig_param
=
program
.
global_block
().
vars
[
param_varname
]
if
param_varname
not
in
self
.
sparse_param_to_height_sections
:
if
len
(
splited_var
)
>
1
and
not
self
.
config
.
runtime_split_send_recv
:
if
not
self
.
config
.
runtime_split_send_recv
:
program
.
global_block
().
append_op
(
type
=
"concat"
,
inputs
=
{
"X"
:
splited_var
},
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录