Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleDetection
提交
58f7695a
P
PaddleDetection
项目概览
PaddlePaddle
/
PaddleDetection
大约 1 年 前同步成功
通知
695
Star
11112
Fork
2696
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
184
列表
看板
标记
里程碑
合并请求
40
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
184
Issue
184
列表
看板
标记
里程碑
合并请求
40
合并请求
40
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
58f7695a
编写于
5月 23, 2019
作者:
Q
Qiao Longfei
提交者:
GitHub
5月 23, 2019
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Async exe support communicator (#17386)
Async exe support communicator
上级
38da1030
变更
23
显示空白变更内容
内联
并排
Showing
23 changed file
with
805 addition
and
149 deletion
+805
-149
paddle/fluid/framework/details/async_ssa_graph_executor.cc
paddle/fluid/framework/details/async_ssa_graph_executor.cc
+30
-36
paddle/fluid/framework/ir/multi_devices_graph_pass/multi_devices_graph_pass.h
...rk/ir/multi_devices_graph_pass/multi_devices_graph_pass.h
+1
-1
paddle/fluid/framework/operator.cc
paddle/fluid/framework/operator.cc
+1
-1
paddle/fluid/operators/distributed/communicator.cc
paddle/fluid/operators/distributed/communicator.cc
+95
-13
paddle/fluid/operators/distributed/communicator.h
paddle/fluid/operators/distributed/communicator.h
+11
-14
paddle/fluid/operators/distributed_ops/recv_op.cc
paddle/fluid/operators/distributed_ops/recv_op.cc
+2
-2
paddle/fluid/pybind/CMakeLists.txt
paddle/fluid/pybind/CMakeLists.txt
+23
-1
paddle/fluid/pybind/communicator_py.cc
paddle/fluid/pybind/communicator_py.cc
+47
-0
paddle/fluid/pybind/communicator_py.h
paddle/fluid/pybind/communicator_py.h
+27
-0
paddle/fluid/pybind/pybind.cc
paddle/fluid/pybind/pybind.cc
+7
-0
python/paddle/fluid/communicator.py
python/paddle/fluid/communicator.py
+88
-0
python/paddle/fluid/incubate/fleet/base/fleet_base.py
python/paddle/fluid/incubate/fleet/base/fleet_base.py
+10
-16
python/paddle/fluid/incubate/fleet/base/role_maker.py
python/paddle/fluid/incubate/fleet/base/role_maker.py
+38
-4
python/paddle/fluid/incubate/fleet/collective/__init__.py
python/paddle/fluid/incubate/fleet/collective/__init__.py
+2
-7
python/paddle/fluid/incubate/fleet/parameter_server/distributed_transpiler/__init__.py
...fleet/parameter_server/distributed_transpiler/__init__.py
+37
-29
python/paddle/fluid/incubate/fleet/parameter_server/pslib/__init__.py
...e/fluid/incubate/fleet/parameter_server/pslib/__init__.py
+4
-15
python/paddle/fluid/incubate/fleet/tests/cluster_train.sh
python/paddle/fluid/incubate/fleet/tests/cluster_train.sh
+33
-0
python/paddle/fluid/incubate/fleet/tests/ctr_dataset_reader.py
...n/paddle/fluid/incubate/fleet/tests/ctr_dataset_reader.py
+100
-0
python/paddle/fluid/incubate/fleet/tests/fleet_deep_ctr.py
python/paddle/fluid/incubate/fleet/tests/fleet_deep_ctr.py
+204
-0
python/paddle/fluid/optimizer.py
python/paddle/fluid/optimizer.py
+7
-8
python/paddle/fluid/tests/CMakeLists.txt
python/paddle/fluid/tests/CMakeLists.txt
+4
-0
python/paddle/fluid/tests/test_communicator.py
python/paddle/fluid/tests/test_communicator.py
+32
-0
python/paddle/fluid/transpiler/distribute_transpiler.py
python/paddle/fluid/transpiler/distribute_transpiler.py
+2
-2
未找到文件。
paddle/fluid/framework/details/async_ssa_graph_executor.cc
浏览文件 @
58f7695a
...
@@ -51,9 +51,7 @@ void ProcessGraph(std::vector<ir::Graph *> graphs, Scope *scope) {
...
@@ -51,9 +51,7 @@ void ProcessGraph(std::vector<ir::Graph *> graphs, Scope *scope) {
VLOG
(
3
)
<<
"ProcessGraph"
;
VLOG
(
3
)
<<
"ProcessGraph"
;
RpcCtxMap
send_varname_to_ctx
;
RpcCtxMap
send_varname_to_ctx
;
RpcCtxMap
recv_varname_to_ctx
;
RpcCtxMap
recv_varname_to_ctx
;
for
(
auto
i
=
0
;
i
<
graphs
.
size
();
++
i
)
{
for
(
auto
&
node
:
graphs
[
0
]
->
Nodes
())
{
std
::
vector
<
ir
::
Node
*>
nodes_to_delete
;
for
(
auto
&
node
:
graphs
[
i
]
->
Nodes
())
{
VLOG
(
3
)
<<
"node name "
<<
node
->
Name
();
VLOG
(
3
)
<<
"node name "
<<
node
->
Name
();
if
(
node
&&
node
->
IsOp
())
{
if
(
node
&&
node
->
IsOp
())
{
if
(
node
->
Name
()
==
"send"
)
{
if
(
node
->
Name
()
==
"send"
)
{
...
@@ -66,10 +64,8 @@ void ProcessGraph(std::vector<ir::Graph *> graphs, Scope *scope) {
...
@@ -66,10 +64,8 @@ void ProcessGraph(std::vector<ir::Graph *> graphs, Scope *scope) {
node
->
Op
()
->
GetNullableAttr
(
"sections"
));
node
->
Op
()
->
GetNullableAttr
(
"sections"
));
auto
trainer_id
=
auto
trainer_id
=
boost
::
get
<
int
>
(
node
->
Op
()
->
GetNullableAttr
(
"trainer_id"
));
boost
::
get
<
int
>
(
node
->
Op
()
->
GetNullableAttr
(
"trainer_id"
));
send_varname_to_ctx
[
send_var_name
]
=
send_varname_to_ctx
[
send_var_name
]
=
operators
::
distributed
::
RpcContext
(
operators
::
distributed
::
RpcContext
(
send_var_name
,
send_varnames
,
send_var_name
,
send_varnames
,
epmap
,
height_section
,
trainer_id
);
epmap
,
height_section
,
trainer_id
);
VLOG
(
3
)
<<
"find and init an send op: "
VLOG
(
3
)
<<
"find and init an send op: "
<<
send_varname_to_ctx
[
send_var_name
];
<<
send_varname_to_ctx
[
send_var_name
];
}
else
if
(
node
->
Name
()
==
"recv"
)
{
}
else
if
(
node
->
Name
()
==
"recv"
)
{
...
@@ -80,16 +76,14 @@ void ProcessGraph(std::vector<ir::Graph *> graphs, Scope *scope) {
...
@@ -80,16 +76,14 @@ void ProcessGraph(std::vector<ir::Graph *> graphs, Scope *scope) {
node
->
Op
()
->
GetNullableAttr
(
"epmap"
));
node
->
Op
()
->
GetNullableAttr
(
"epmap"
));
auto
trainer_id
=
auto
trainer_id
=
boost
::
get
<
int
>
(
node
->
Op
()
->
GetNullableAttr
(
"trainer_id"
));
boost
::
get
<
int
>
(
node
->
Op
()
->
GetNullableAttr
(
"trainer_id"
));
recv_varname_to_ctx
[
recv_var_name
]
=
recv_varname_to_ctx
[
recv_var_name
]
=
operators
::
distributed
::
RpcContext
(
operators
::
distributed
::
RpcContext
(
recv_var_name
,
recv_varnames
,
recv_var_name
,
recv_varnames
,
epmap
,
{},
trainer_id
);
epmap
,
{},
trainer_id
);
nodes_to_delete
.
push_back
(
node
);
VLOG
(
3
)
<<
"find and remove an recv op: "
VLOG
(
3
)
<<
"find and remove an recv op: "
<<
recv_varname_to_ctx
[
recv_var_name
];
<<
recv_varname_to_ctx
[
recv_var_name
];
}
}
}
}
}
}
}
// init communicator here
// init communicator here
if
(
send_varname_to_ctx
.
size
()
>
0
)
{
if
(
send_varname_to_ctx
.
size
()
>
0
)
{
VLOG
(
3
)
<<
"this is distribute mode, will use communicator"
;
VLOG
(
3
)
<<
"this is distribute mode, will use communicator"
;
...
...
paddle/fluid/framework/ir/multi_devices_graph_pass/multi_devices_graph_pass.h
浏览文件 @
58f7695a
...
@@ -130,7 +130,7 @@ class AsyncSSAGraphBuilder : public MultiDevSSAGraphBuilderBase {
...
@@ -130,7 +130,7 @@ class AsyncSSAGraphBuilder : public MultiDevSSAGraphBuilderBase {
bool
DealWithSpecialOp
(
ir
::
Graph
*
result
,
ir
::
Node
*
node
)
const
override
{
bool
DealWithSpecialOp
(
ir
::
Graph
*
result
,
ir
::
Node
*
node
)
const
override
{
if
(
node
->
Op
()
->
Type
()
==
"recv"
)
{
if
(
node
->
Op
()
->
Type
()
==
"recv"
)
{
VLOG
(
1
)
<<
"set recv op do_not_run to true"
;
VLOG
(
1
)
<<
"set recv op do_not_run to true"
;
node
->
Op
()
->
SetAttr
(
"do_not_run"
,
true
);
node
->
Op
()
->
SetAttr
(
"do_not_run"
,
1
);
node
->
Op
()
->
Flush
();
node
->
Op
()
->
Flush
();
}
else
if
(
node
->
Name
()
==
"lookup_table"
||
node
->
Name
()
==
"nce"
||
}
else
if
(
node
->
Name
()
==
"lookup_table"
||
node
->
Name
()
==
"nce"
||
node
->
Name
()
==
"hierarchical_sigmoid"
)
{
node
->
Name
()
==
"hierarchical_sigmoid"
)
{
...
...
paddle/fluid/framework/operator.cc
浏览文件 @
58f7695a
...
@@ -1142,7 +1142,7 @@ proto::VarType::Type OperatorWithKernel::IndicateDataType(
...
@@ -1142,7 +1142,7 @@ proto::VarType::Type OperatorWithKernel::IndicateDataType(
t
=
&
(
var
->
Get
<
SelectedRows
>
().
value
());
t
=
&
(
var
->
Get
<
SelectedRows
>
().
value
());
}
}
if
(
t
!=
nullptr
)
{
if
(
t
!=
nullptr
)
{
PADDLE_ENFORCE
(
t
->
IsInitialized
(),
"Input %s(%lu)is not initialized"
,
PADDLE_ENFORCE
(
t
->
IsInitialized
(),
"Input %s(%lu)
is not initialized"
,
input
.
first
,
i
);
input
.
first
,
i
);
proto
::
VarType
::
Type
tmp
=
t
->
type
();
proto
::
VarType
::
Type
tmp
=
t
->
type
();
PADDLE_ENFORCE
(
PADDLE_ENFORCE
(
...
...
paddle/fluid/operators/distributed/communicator.cc
浏览文件 @
58f7695a
...
@@ -15,6 +15,7 @@ limitations under the License. */
...
@@ -15,6 +15,7 @@ limitations under the License. */
#include "paddle/fluid/operators/distributed/communicator.h"
#include "paddle/fluid/operators/distributed/communicator.h"
#include <gflags/gflags.h>
#include <gflags/gflags.h>
#include <paddle/fluid/framework/program_desc.h>
#include <chrono> // NOLINT
#include <chrono> // NOLINT
#include <thread> // NOLINT
#include <thread> // NOLINT
...
@@ -50,8 +51,7 @@ inline double GetCurrentUS() {
...
@@ -50,8 +51,7 @@ inline double GetCurrentUS() {
return
1e+6
*
time
.
tv_sec
+
time
.
tv_usec
;
return
1e+6
*
time
.
tv_sec
+
time
.
tv_usec
;
}
}
std
::
unique_ptr
<
Communicator
>
Communicator
::
communicator_
(
nullptr
);
std
::
shared_ptr
<
Communicator
>
Communicator
::
communicator_
(
nullptr
);
std
::
once_flag
Communicator
::
init_flag_
;
Communicator
::
Communicator
(
const
RpcCtxMap
&
send_varname_to_ctx
,
Communicator
::
Communicator
(
const
RpcCtxMap
&
send_varname_to_ctx
,
const
RpcCtxMap
&
recv_varname_to_ctx
,
const
RpcCtxMap
&
recv_varname_to_ctx
,
...
@@ -84,11 +84,17 @@ Communicator::Communicator(const RpcCtxMap &send_varname_to_ctx,
...
@@ -84,11 +84,17 @@ Communicator::Communicator(const RpcCtxMap &send_varname_to_ctx,
}
}
Communicator
::~
Communicator
()
{
Communicator
::~
Communicator
()
{
VLOG
(
3
)
<<
"~Communicator"
;
if
(
FLAGS_v
>=
3
)
{
std
::
string
msg
(
"~Communicator"
);
fwrite
(
msg
.
c_str
(),
msg
.
length
(),
1
,
stdout
);
}
running_
=
false
;
running_
=
false
;
if
(
send_thread_
)
send_thread_
->
join
();
if
(
send_thread_
)
send_thread_
->
join
();
if
(
recv_thread_
)
recv_thread_
->
join
();
if
(
recv_thread_
)
recv_thread_
->
join
();
VLOG
(
3
)
<<
"~Communicator done"
;
if
(
FLAGS_v
>=
3
)
{
std
::
string
msg
(
"~Communicator done"
);
fwrite
(
msg
.
c_str
(),
msg
.
length
(),
1
,
stdout
);
}
}
}
void
Communicator
::
SendThread
()
{
void
Communicator
::
SendThread
()
{
...
@@ -144,7 +150,7 @@ void Communicator::SendThread() {
...
@@ -144,7 +150,7 @@ void Communicator::SendThread() {
task_futures
.
emplace_back
(
task_futures
.
emplace_back
(
send_threadpool_
->
enqueue
(
std
::
move
(
send_task
)));
send_threadpool_
->
enqueue
(
std
::
move
(
send_task
)));
}
else
{
}
else
{
VLOG
(
3
)
<<
var_name
<<
" queue empty"
;
VLOG
(
4
)
<<
var_name
<<
" queue empty"
;
}
}
}
}
for
(
auto
&
task_f
:
task_futures
)
{
for
(
auto
&
task_f
:
task_futures
)
{
...
@@ -160,17 +166,19 @@ void Communicator::SendThread() {
...
@@ -160,17 +166,19 @@ void Communicator::SendThread() {
RecvAll
();
RecvAll
();
}
}
}
}
VLOG
(
0
)
<<
"communicator stopped, send thread exit"
;
}
}
void
Communicator
::
RecvAll
()
{
void
Communicator
::
RecvAll
()
{
VLOG
(
3
)
<<
"parallel run recv graph"
;
VLOG
(
3
)
<<
"parallel run recv graph"
;
if
(
!
running_
)
return
;
auto
before_send
=
GetCurrentUS
();
auto
before_send
=
GetCurrentUS
();
std
::
vector
<
std
::
future
<
void
>>
task_futures
;
std
::
vector
<
std
::
future
<
void
>>
task_futures
;
task_futures
.
reserve
(
recv_varname_to_ctx_
.
size
());
task_futures
.
reserve
(
recv_varname_to_ctx_
.
size
());
for
(
auto
&
iter
:
recv_varname_to_ctx_
)
{
for
(
auto
&
iter
:
recv_varname_to_ctx_
)
{
auto
recv_task
=
[
this
,
&
iter
]
{
auto
recv_task
=
[
this
,
&
iter
]
{
auto
&
var_name
=
iter
.
first
;
auto
&
var_name
=
iter
.
first
;
VLOG
(
3
)
<<
"recv var "
<<
var_name
;
VLOG
(
4
)
<<
"recv var "
<<
var_name
;
auto
recv_functor
=
distributed
::
ParameterRecv
<
float
>
();
auto
recv_functor
=
distributed
::
ParameterRecv
<
float
>
();
if
(
!
FLAGS_communicator_fake_rpc
)
{
if
(
!
FLAGS_communicator_fake_rpc
)
{
recv_functor
(
iter
.
second
,
*
recv_scope_
);
recv_functor
(
iter
.
second
,
*
recv_scope_
);
...
@@ -197,6 +205,7 @@ void Communicator::RecvThread() {
...
@@ -197,6 +205,7 @@ void Communicator::RecvThread() {
std
::
this_thread
::
sleep_for
(
std
::
chrono
::
milliseconds
(
10
));
std
::
this_thread
::
sleep_for
(
std
::
chrono
::
milliseconds
(
10
));
}
}
}
}
VLOG
(
0
)
<<
"communicator stopped, recv thread exit"
;
}
}
void
Communicator
::
Send
(
const
std
::
string
&
var_name
,
void
Communicator
::
Send
(
const
std
::
string
&
var_name
,
...
@@ -212,9 +221,61 @@ void Communicator::Send(const std::string &var_name,
...
@@ -212,9 +221,61 @@ void Communicator::Send(const std::string &var_name,
queue
->
Push
(
tmp_grad_var
);
queue
->
Push
(
tmp_grad_var
);
}
}
void
Communicator
::
Init
(
const
paddle
::
framework
::
ProgramDesc
&
program
,
Scope
*
param_scope
)
{
using
RpcCtxMap
=
operators
::
distributed
::
RpcCtxMap
;
VLOG
(
3
)
<<
"ProcessGraph"
;
RpcCtxMap
send_varname_to_ctx
;
RpcCtxMap
recv_varname_to_ctx
;
for
(
auto
*
op
:
program
.
Block
(
0
).
AllOps
())
{
VLOG
(
3
)
<<
"node name "
<<
op
->
Type
();
if
(
op
->
Type
()
==
"send"
)
{
auto
send_var_name
=
op
->
Input
(
"X"
)[
0
];
auto
send_varnames
=
boost
::
get
<
std
::
vector
<
std
::
string
>>
(
op
->
GetNullableAttr
(
"send_varnames"
));
auto
epmap
=
boost
::
get
<
std
::
vector
<
std
::
string
>>
(
op
->
GetNullableAttr
(
"epmap"
));
auto
height_section
=
boost
::
get
<
std
::
vector
<
int64_t
>>
(
op
->
GetNullableAttr
(
"sections"
));
auto
trainer_id
=
boost
::
get
<
int
>
(
op
->
GetNullableAttr
(
"trainer_id"
));
send_varname_to_ctx
[
send_var_name
]
=
operators
::
distributed
::
RpcContext
(
send_var_name
,
send_varnames
,
epmap
,
height_section
,
trainer_id
);
VLOG
(
3
)
<<
"find and init an send op: "
<<
send_varname_to_ctx
[
send_var_name
];
}
else
if
(
op
->
Type
()
==
"recv"
)
{
auto
do_not_run
=
boost
::
get
<
int
>
(
op
->
GetNullableAttr
(
"do_not_run"
));
PADDLE_ENFORCE_GT
(
do_not_run
,
0
,
"recv should not run!"
);
auto
recv_var_name
=
op
->
Output
(
"Out"
)[
0
];
auto
recv_varnames
=
boost
::
get
<
std
::
vector
<
std
::
string
>>
(
op
->
GetNullableAttr
(
"recv_varnames"
));
auto
epmap
=
boost
::
get
<
std
::
vector
<
std
::
string
>>
(
op
->
GetNullableAttr
(
"epmap"
));
auto
trainer_id
=
boost
::
get
<
int
>
(
op
->
GetNullableAttr
(
"trainer_id"
));
recv_varname_to_ctx
[
recv_var_name
]
=
operators
::
distributed
::
RpcContext
(
recv_var_name
,
recv_varnames
,
epmap
,
{},
trainer_id
);
}
}
// init communicator here
if
(
send_varname_to_ctx
.
size
()
==
0
&&
recv_varname_to_ctx
.
size
()
==
0
)
{
LOG
(
WARNING
)
<<
"no var need to send and recv!!"
;
}
operators
::
distributed
::
Communicator
::
Init
(
send_varname_to_ctx
,
recv_varname_to_ctx
,
param_scope
);
}
Communicator
*
Communicator
::
GetInstance
()
{
return
communicator_
.
get
();
}
Communicator
*
Communicator
::
GetInstance
()
{
return
communicator_
.
get
();
}
std
::
shared_ptr
<
Communicator
>
Communicator
::
GetInstantcePtr
()
{
return
communicator_
;
}
void
Communicator
::
Start
()
{
void
Communicator
::
Start
()
{
VLOG
(
0
)
<<
"Communicator start"
;
if
(
!
communicator_
)
{
VLOG
(
0
)
<<
"Communicator is not inited, do nothing"
;
}
else
{
VLOG
(
1
)
<<
"start send thread and recv thread"
;
running_
=
true
;
running_
=
true
;
// start send and recv thread
// start send and recv thread
send_thread_
.
reset
(
send_thread_
.
reset
(
...
@@ -223,6 +284,27 @@ void Communicator::Start() {
...
@@ -223,6 +284,27 @@ void Communicator::Start() {
recv_thread_
.
reset
(
recv_thread_
.
reset
(
new
std
::
thread
(
std
::
bind
(
&
Communicator
::
RecvThread
,
this
)));
new
std
::
thread
(
std
::
bind
(
&
Communicator
::
RecvThread
,
this
)));
}
}
}
}
void
Communicator
::
Stop
()
{
VLOG
(
0
)
<<
"Communicator stop"
;
running_
=
false
;
if
(
!
communicator_
)
{
VLOG
(
0
)
<<
"Communicator is not inited, do nothing"
;
}
else
{
if
(
send_thread_
)
{
VLOG
(
1
)
<<
"stop send thread"
;
send_thread_
->
join
();
send_thread_
.
reset
(
nullptr
);
}
if
(
recv_thread_
)
{
VLOG
(
1
)
<<
"stop recv thread"
;
recv_thread_
->
join
();
recv_thread_
.
reset
(
nullptr
);
}
}
VLOG
(
0
)
<<
"Communicator stop done"
;
}
}
}
// namespace distributed
}
// namespace distributed
...
...
paddle/fluid/operators/distributed/communicator.h
浏览文件 @
58f7695a
...
@@ -165,6 +165,7 @@ class Communicator {
...
@@ -165,6 +165,7 @@ class Communicator {
~
Communicator
();
~
Communicator
();
void
Start
();
void
Start
();
void
Stop
();
// send grad
// send grad
void
Send
(
const
std
::
string
&
var_name
,
const
framework
::
Scope
&
scope
);
void
Send
(
const
std
::
string
&
var_name
,
const
framework
::
Scope
&
scope
);
...
@@ -181,8 +182,8 @@ class Communicator {
...
@@ -181,8 +182,8 @@ class Communicator {
send_varname_to_queue_
;
send_varname_to_queue_
;
RpcCtxMap
send_varname_to_ctx_
;
RpcCtxMap
send_varname_to_ctx_
;
RpcCtxMap
recv_varname_to_ctx_
;
RpcCtxMap
recv_varname_to_ctx_
;
std
::
unique_ptr
<
std
::
thread
>
send_thread_
;
std
::
unique_ptr
<
std
::
thread
>
send_thread_
{
nullptr
}
;
std
::
unique_ptr
<
std
::
thread
>
recv_thread_
;
std
::
unique_ptr
<
std
::
thread
>
recv_thread_
{
nullptr
}
;
Scope
*
recv_scope_
;
// should be global scope
Scope
*
recv_scope_
;
// should be global scope
std
::
unique_ptr
<
Scope
>
send_scope_
;
// an independent scope
std
::
unique_ptr
<
Scope
>
send_scope_
;
// an independent scope
std
::
unique_ptr
<::
ThreadPool
>
send_threadpool_
{
nullptr
};
std
::
unique_ptr
<::
ThreadPool
>
send_threadpool_
{
nullptr
};
...
@@ -193,25 +194,21 @@ class Communicator {
...
@@ -193,25 +194,21 @@ class Communicator {
public:
public:
static
void
Init
(
const
RpcCtxMap
&
send_varname_to_ctx
,
static
void
Init
(
const
RpcCtxMap
&
send_varname_to_ctx
,
const
RpcCtxMap
&
recv_varname_to_ctx
,
Scope
*
recv_scope
)
{
const
RpcCtxMap
&
recv_varname_to_ctx
,
Scope
*
recv_scope
)
{
InitImpl
(
send_varname_to_ctx
,
recv_varname_to_ctx
,
recv_scope
);
}
static
Communicator
*
GetInstance
();
private:
// Init is called by GetInstance.
static
void
InitImpl
(
const
RpcCtxMap
&
send_varname_to_ctx
,
const
RpcCtxMap
&
recv_varname_to_ctx
,
Scope
*
recv_scope
)
{
if
(
communicator_
==
nullptr
)
{
if
(
communicator_
==
nullptr
)
{
communicator_
.
reset
(
new
Communicator
(
send_varname_to_ctx
,
communicator_
.
reset
(
new
Communicator
(
send_varname_to_ctx
,
recv_varname_to_ctx
,
recv_scope
));
recv_varname_to_ctx
,
recv_scope
));
}
}
}
}
static
void
Init
(
const
paddle
::
framework
::
ProgramDesc
&
program
,
Scope
*
param_scope
);
static
Communicator
*
GetInstance
();
static
std
::
shared_ptr
<
Communicator
>
GetInstantcePtr
();
private:
private:
static
std
::
once_flag
init_flag_
;
static
std
::
shared_ptr
<
Communicator
>
communicator_
;
static
std
::
unique_ptr
<
Communicator
>
communicator_
;
};
};
}
// namespace distributed
}
// namespace distributed
...
...
paddle/fluid/operators/distributed_ops/recv_op.cc
浏览文件 @
58f7695a
...
@@ -36,7 +36,7 @@ class RecvOp : public framework::OperatorBase {
...
@@ -36,7 +36,7 @@ class RecvOp : public framework::OperatorBase {
void
RunImpl
(
const
framework
::
Scope
&
scope
,
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{
const
platform
::
Place
&
place
)
const
override
{
bool
do_not_run
=
Attr
<
bool
>
(
"do_not_run"
);
int
do_not_run
=
Attr
<
int
>
(
"do_not_run"
);
if
(
do_not_run
)
{
if
(
do_not_run
)
{
VLOG
(
3
)
<<
"recv do not run!"
;
VLOG
(
3
)
<<
"recv do not run!"
;
return
;
return
;
...
@@ -132,7 +132,7 @@ This operator can get variables from server side.
...
@@ -132,7 +132,7 @@ This operator can get variables from server side.
"(vector<string>) "
"(vector<string>) "
"the splited parameter varnames to be recved from pserver"
)
"the splited parameter varnames to be recved from pserver"
)
.
SetDefault
(
std
::
vector
<
std
::
string
>
{});
.
SetDefault
(
std
::
vector
<
std
::
string
>
{});
AddAttr
<
bool
>
(
"do_not_run"
,
"if recv need to really run"
).
SetDefault
(
false
);
AddAttr
<
int
>
(
"do_not_run"
,
"if recv need to really run"
).
SetDefault
(
0
);
}
}
};
};
...
...
paddle/fluid/pybind/CMakeLists.txt
浏览文件 @
58f7695a
...
@@ -5,7 +5,29 @@ set(PYBIND_DEPS pybind python proto_desc memory executor async_executor fleet_wr
...
@@ -5,7 +5,29 @@ set(PYBIND_DEPS pybind python proto_desc memory executor async_executor fleet_wr
if
(
WITH_PYTHON
)
if
(
WITH_PYTHON
)
list
(
APPEND PYBIND_DEPS py_func_op
)
list
(
APPEND PYBIND_DEPS py_func_op
)
endif
()
endif
()
set
(
PYBIND_SRCS pybind.cc exception.cc protobuf.cc const_value.cc recordio.cc reader_py.cc async_executor_py.cc fleet_wrapper_py.cc nccl_wrapper_py.cc data_set_py.cc imperative.cc ir.cc inference_api.cc
)
if
(
WITH_DISTRIBUTE
)
list
(
APPEND PYBIND_DEPS communicator
)
endif
()
set
(
PYBIND_SRCS
pybind.cc
exception.cc
protobuf.cc
const_value.cc
recordio.cc
reader_py.cc
async_executor_py.cc
fleet_wrapper_py.cc
nccl_wrapper_py.cc
data_set_py.cc
imperative.cc
ir.cc
inference_api.cc
)
if
(
WITH_DISTRIBUTE
)
list
(
APPEND PYBIND_SRCS communicator_py.cc
)
endif
()
if
(
WITH_PYTHON
)
if
(
WITH_PYTHON
)
if
(
WITH_AMD_GPU
)
if
(
WITH_AMD_GPU
)
...
...
paddle/fluid/pybind/communicator_py.cc
0 → 100644
浏览文件 @
58f7695a
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/pybind/communicator_py.h"
#include <Python.h>
#include <memory>
#include "paddle/fluid/framework/program_desc.h"
#include "pybind11/pybind11.h"
#include "paddle/fluid/operators/distributed/communicator.h"
namespace
py
=
pybind11
;
using
paddle
::
framework
::
ProgramDesc
;
using
paddle
::
operators
::
distributed
::
Communicator
;
using
paddle
::
framework
::
Scope
;
namespace
paddle
{
namespace
pybind
{
void
BindCommunicator
(
py
::
module
*
m
)
{
// Communicator is already used by nccl, change to DistCommunicator
py
::
class_
<
Communicator
,
std
::
shared_ptr
<
Communicator
>>
(
*
m
,
"DistCommunicator"
)
.
def
(
py
::
init
([](
const
ProgramDesc
&
program
,
Scope
*
param_scope
)
{
Communicator
::
Init
(
program
,
param_scope
);
return
Communicator
::
GetInstantcePtr
();
}))
.
def
(
"stop"
,
&
Communicator
::
Stop
)
.
def
(
"start"
,
&
Communicator
::
Start
);
}
}
// namespace pybind
}
// namespace paddle
paddle/fluid/pybind/communicator_py.h
0 → 100644
浏览文件 @
58f7695a
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <Python.h>
#include "pybind11/pybind11.h"
namespace
paddle
{
namespace
pybind
{
void
BindCommunicator
(
pybind11
::
module
*
m
);
}
// namespace pybind
}
// namespace paddle
paddle/fluid/pybind/pybind.cc
浏览文件 @
58f7695a
...
@@ -77,6 +77,10 @@ limitations under the License. */
...
@@ -77,6 +77,10 @@ limitations under the License. */
#include "paddle/fluid/platform/gpu_info.h"
#include "paddle/fluid/platform/gpu_info.h"
#endif
#endif
#ifdef PADDLE_WITH_DISTRIBUTE
#include "paddle/fluid/pybind/communicator_py.h"
#endif
#include "pybind11/stl.h"
#include "pybind11/stl.h"
DEFINE_bool
(
reader_queue_speed_test_mode
,
false
,
DEFINE_bool
(
reader_queue_speed_test_mode
,
false
,
...
@@ -1547,6 +1551,9 @@ All parameter, weight, gradient are variables in Paddle.
...
@@ -1547,6 +1551,9 @@ All parameter, weight, gradient are variables in Paddle.
BindNode
(
&
m
);
BindNode
(
&
m
);
BindInferenceApi
(
&
m
);
BindInferenceApi
(
&
m
);
BindDataset
(
&
m
);
BindDataset
(
&
m
);
#ifdef PADDLE_WITH_DISTRIBUTE
BindCommunicator
(
&
m
);
#endif
}
}
}
// namespace pybind
}
// namespace pybind
}
// namespace paddle
}
// namespace paddle
python/paddle/fluid/communicator.py
0 → 100644
浏览文件 @
58f7695a
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
.executor
import
global_scope
from
.
import
core
from
.framework
import
Program
__all__
=
[
'Communicator'
]
class
Communicator
(
object
):
def
__init__
(
self
,
program
):
"""
Communicator is used for async distribute training in distribute_transpiler mode.
It's a wrapper of a cpp class Communicator and should be used inside fleet API.
Args:
program(Program): the trainers program after transpile of distribute_transpiler.
It's used by communicator to extract the information to do communication.
Returns:
None
Examples:
.. code-block:: python
import paddle.fluid as fluid
prog = fluid.Program()
comm = fluid.communicator.Communicator(prog)
comm.start()
comm.stop()
"""
# set all recv op to not_run mode
assert
isinstance
(
program
,
Program
)
for
op
in
program
.
block
(
0
).
ops
:
if
op
.
type
==
"recv"
:
op
.
_set_attr
(
'do_not_run'
,
True
)
self
.
communicator_
=
core
.
DistCommunicator
(
program
.
desc
,
global_scope
())
def
start
(
self
):
"""
Start communicator. Should call before training process.
Returns:
None
Examples:
.. code-block:: python
import paddle.fluid as fluid
prog = fluid.Program()
comm = fluid.communicator.Communicator(prog)
comm.start()
comm.stop()
"""
self
.
communicator_
.
start
()
def
stop
(
self
):
"""
Stop communicator. Should call after training process.
Returns:
None
Examples:
.. code-block:: python
import paddle.fluid as fluid
prog = fluid.Program()
comm = fluid.communicator.Communicator(prog)
comm.start()
comm.stop()
"""
self
.
communicator_
.
stop
()
python/paddle/fluid/incubate/fleet/base/fleet_base.py
浏览文件 @
58f7695a
...
@@ -15,14 +15,14 @@
...
@@ -15,14 +15,14 @@
from
__future__
import
print_function
from
__future__
import
print_function
import
abc
import
abc
from
enum
import
Enum
from
enum
import
Enum
from
paddle.fluid.optimizer
import
SGD
import
paddle.fluid
as
fluid
from
paddle.fluid.executor
import
Executor
from
paddle.fluid.executor
import
Executor
from
paddle.fluid.optimizer
import
SGD
from
role_maker
import
RoleMakerBase
from
role_maker
import
MPISymetricRoleMaker
from
role_maker
import
MPISymetricRoleMaker
from
role_maker
import
RoleMakerBase
from
role_maker
import
UserDefinedRoleMaker
from
role_maker
import
UserDefinedRoleMaker
...
@@ -48,7 +48,6 @@ class Fleet(object):
...
@@ -48,7 +48,6 @@ class Fleet(object):
__metaclass__
=
abc
.
ABCMeta
__metaclass__
=
abc
.
ABCMeta
def
__init__
(
self
,
mode
):
def
__init__
(
self
,
mode
):
assert
isinstance
(
mode
,
Mode
)
self
.
_is_initialized
=
False
self
.
_is_initialized
=
False
self
.
_mode
=
mode
self
.
_mode
=
mode
self
.
_optimizer
=
None
self
.
_optimizer
=
None
...
@@ -79,9 +78,9 @@ class Fleet(object):
...
@@ -79,9 +78,9 @@ class Fleet(object):
Get current total worker number.
Get current total worker number.
Returns:
Returns:
int: worker number
int: worker number
s
"""
"""
return
len
(
self
.
_role_maker
.
get_trainer_endpoints
()
)
return
self
.
_role_maker
.
worker_num
(
)
def
is_worker
(
self
):
def
is_worker
(
self
):
"""
"""
...
@@ -173,21 +172,19 @@ class Fleet(object):
...
@@ -173,21 +172,19 @@ class Fleet(object):
end
+=
length
end
+=
length
return
files
[
start
:
end
]
return
files
[
start
:
end
]
def
init
(
self
,
executor
,
role_maker
=
None
):
def
init
(
self
,
role_maker
=
None
):
"""
"""
should be called only once in user's python scripts,
should be called only once in user's python scripts,
init() will initialize RoleMaker which is used for identifying
init() will initialize RoleMaker which is used for identifying
current node's role, e.g. worker, server, etc.
current node's role, e.g. worker, server, etc.
Args:
Args:
executor(Executor): The executor to run fleet.
role_maker(RoleMakerBase): subclass of RoleMakerBase.
role_maker(RoleMakerBase): subclass of RoleMakerBase.
Returns:
Returns:
None
None
"""
"""
if
not
isinstance
(
executor
,
Executor
):
self
.
_executor
=
Executor
(
fluid
.
CPUPlace
())
raise
ValueError
(
"executor must be an instance of Executor"
)
if
role_maker
and
not
isinstance
(
role_maker
,
RoleMakerBase
):
if
role_maker
and
not
isinstance
(
role_maker
,
RoleMakerBase
):
raise
ValueError
(
"role_maker must be an instance of RoleMakerBase"
)
raise
ValueError
(
"role_maker must be an instance of RoleMakerBase"
)
...
@@ -215,23 +212,20 @@ class Fleet(object):
...
@@ -215,23 +212,20 @@ class Fleet(object):
pass
pass
@
abc
.
abstractmethod
@
abc
.
abstractmethod
def
run_server
(
self
,
):
def
run_server
(
self
):
pass
pass
@
abc
.
abstractmethod
@
abc
.
abstractmethod
def
stop_worker
(
self
):
def
stop_worker
(
self
):
pass
pass
@
abc
.
abstractmethod
def
stop
(
self
):
pass
@
abc
.
abstractmethod
@
abc
.
abstractmethod
def
distributed_optimizer
(
self
,
optimizer
,
strategy
=
None
):
def
distributed_optimizer
(
self
,
optimizer
,
strategy
=
None
):
pass
pass
@
abc
.
abstractmethod
@
abc
.
abstractmethod
def
save_inference_model
(
self
,
def
save_inference_model
(
self
,
executor
,
dirname
,
dirname
,
feeded_var_names
,
feeded_var_names
,
target_vars
,
target_vars
,
...
@@ -240,7 +234,7 @@ class Fleet(object):
...
@@ -240,7 +234,7 @@ class Fleet(object):
pass
pass
@
abc
.
abstractmethod
@
abc
.
abstractmethod
def
save_persistables
(
self
,
dirname
,
main_program
=
None
):
def
save_persistables
(
self
,
executor
,
dirname
,
main_program
=
None
):
pass
pass
...
...
python/paddle/fluid/incubate/fleet/base/role_maker.py
浏览文件 @
58f7695a
...
@@ -61,6 +61,15 @@ class RoleMakerBase(object):
...
@@ -61,6 +61,15 @@ class RoleMakerBase(object):
"""
"""
raise
NotImplementedError
(
"Please implement this method in child class"
)
raise
NotImplementedError
(
"Please implement this method in child class"
)
def
worker_num
(
self
):
"""
Get current total worker number.
Returns:
int: worker number
"""
raise
NotImplementedError
(
"Please implement this method in child class"
)
def
worker_index
(
self
):
def
worker_index
(
self
):
"""
"""
Get current worker id.
Get current worker id.
...
@@ -197,6 +206,9 @@ class MPISymetricRoleMaker(MPIRoleMaker):
...
@@ -197,6 +206,9 @@ class MPISymetricRoleMaker(MPIRoleMaker):
return
self
.
is_worker
()
and
0
==
self
.
worker_index
()
return
self
.
is_worker
()
and
0
==
self
.
worker_index
()
return
False
return
False
def
worker_num
(
self
):
return
self
.
_worker_num
()
def
is_worker
(
self
):
def
is_worker
(
self
):
"""
"""
return whether current process is worker assigned by role maker
return whether current process is worker assigned by role maker
...
@@ -293,9 +305,28 @@ class UserDefinedRoleMaker(RoleMakerBase):
...
@@ -293,9 +305,28 @@ class UserDefinedRoleMaker(RoleMakerBase):
"""
"""
super
(
UserDefinedRoleMaker
,
self
).
__init__
()
super
(
UserDefinedRoleMaker
,
self
).
__init__
()
if
not
isinstance
(
current_id
,
int
):
raise
TypeError
(
"current_id must be as int"
)
else
:
if
current_id
<
0
:
raise
ValueError
(
"current_id must be gather or equal 0"
)
self
.
_current_id
=
current_id
self
.
_current_id
=
current_id
if
not
isinstance
(
role
,
Role
):
raise
TypeError
(
"role must be as Role"
)
else
:
self
.
_role
=
role
self
.
_role
=
role
if
not
isinstance
(
worker_num
,
int
):
raise
TypeError
(
"worker_num must be as int"
)
else
:
if
worker_num
<
0
:
raise
ValueError
(
"worker_num must be gather or equal 0"
)
self
.
_worker_num
=
worker_num
self
.
_worker_num
=
worker_num
if
not
isinstance
(
server_endpoints
,
list
):
raise
TypeError
(
"server_endpoints must be as string list"
)
else
:
self
.
_server_endpoints
=
server_endpoints
self
.
_server_endpoints
=
server_endpoints
def
is_worker
(
self
):
def
is_worker
(
self
):
...
@@ -312,3 +343,6 @@ class UserDefinedRoleMaker(RoleMakerBase):
...
@@ -312,3 +343,6 @@ class UserDefinedRoleMaker(RoleMakerBase):
def
server_index
(
self
):
def
server_index
(
self
):
return
self
.
_current_id
return
self
.
_current_id
def
worker_num
(
self
):
return
self
.
_worker_num
python/paddle/fluid/incubate/fleet/collective/__init__.py
浏览文件 @
58f7695a
...
@@ -47,17 +47,12 @@ class Collective(Fleet):
...
@@ -47,17 +47,12 @@ class Collective(Fleet):
logging
.
warn
(
logging
.
warn
(
"You should not call 'stop_worker' method for collective mode."
)
"You should not call 'stop_worker' method for collective mode."
)
def
stop
(
self
):
"""
stop(): will be called after a user finishes his/her training task.
"""
logging
.
warn
(
"You should not call 'stop' method for collective mode."
)
def
distributed_optimizer
(
self
,
optimizer
,
strategy
=
None
):
def
distributed_optimizer
(
self
,
optimizer
,
strategy
=
None
):
self
.
_optimizer
=
CollectiveOptimizer
(
optimizer
,
strategy
)
self
.
_optimizer
=
CollectiveOptimizer
(
optimizer
,
strategy
)
return
self
.
_optimizer
return
self
.
_optimizer
def
save_inference_model
(
self
,
def
save_inference_model
(
self
,
executor
,
dirname
,
dirname
,
feeded_var_names
=
None
,
feeded_var_names
=
None
,
target_vars
=
None
,
target_vars
=
None
,
...
@@ -67,7 +62,7 @@ class Collective(Fleet):
...
@@ -67,7 +62,7 @@ class Collective(Fleet):
self
.
_executor
,
main_program
,
None
,
None
,
self
.
_executor
,
main_program
,
None
,
None
,
export_for_deployment
)
export_for_deployment
)
def
save_persistables
(
self
,
dirname
,
main_program
=
None
):
def
save_persistables
(
self
,
executor
,
dirname
,
main_program
=
None
):
io
.
save_persistables
(
self
.
_executor
,
dirname
,
main_program
,
None
)
io
.
save_persistables
(
self
.
_executor
,
dirname
,
main_program
,
None
)
...
...
python/paddle/fluid/incubate/fleet/parameter_server/distributed_transpiler/__init__.py
浏览文件 @
58f7695a
...
@@ -13,18 +13,16 @@
...
@@ -13,18 +13,16 @@
# limitations under the License.
# limitations under the License.
import
os
import
os
import
paddle.fluid.io
as
io
from
paddle.fluid.communicator
import
Communicator
from
paddle.fluid.framework
import
default_startup_program
from
paddle.fluid.framework
import
default_startup_program
from
paddle.fluid.optimizer
import
Optimizer
from
paddle.fluid.optimizer
import
Optimizer
import
paddle.fluid.io
as
io
from
paddle.fluid.transpiler.distribute_transpiler
import
DistributeTranspilerConfig
from
paddle.fluid.transpiler.distribute_transpiler
import
DistributeTranspiler
as
OriginTranspiler
from
paddle.fluid.transpiler.distribute_transpiler
import
DistributeTranspiler
as
OriginTranspiler
from
paddle.fluid.transpiler.distribute_transpiler
import
DistributeTranspilerConfig
from
...base.fleet_base
import
DistributedOptimizer
from
...base.fleet_base
import
Fleet
from
...base.fleet_base
import
Fleet
from
...base.fleet_base
import
Mode
from
...base.fleet_base
import
Mode
from
...base.fleet_base
import
DistributedOptimizer
class
DistributedTranspiler
(
Fleet
):
class
DistributedTranspiler
(
Fleet
):
...
@@ -34,9 +32,11 @@ class DistributedTranspiler(Fleet):
...
@@ -34,9 +32,11 @@ class DistributedTranspiler(Fleet):
def
__init__
(
self
):
def
__init__
(
self
):
super
(
DistributedTranspiler
,
self
).
__init__
(
Mode
.
TRANSPILER
)
super
(
DistributedTranspiler
,
self
).
__init__
(
Mode
.
TRANSPILER
)
self
.
_transpiler
=
OriginTranspiler
()
self
.
_transpile_config
=
None
self
.
_startup_program
=
None
self
.
_transpiler
=
None
self
.
_main_program
=
None
self
.
startup_program
=
None
self
.
main_program
=
None
self
.
_communicator
=
None
def
init_worker
(
self
):
def
init_worker
(
self
):
"""
"""
...
@@ -48,10 +48,9 @@ class DistributedTranspiler(Fleet):
...
@@ -48,10 +48,9 @@ class DistributedTranspiler(Fleet):
Returns:
Returns:
None
None
"""
"""
pass
if
not
self
.
_transpile_config
.
sync_mode
:
self
.
_communicator
=
Communicator
(
self
.
main_program
)
def
run_worker
(
self
,
main_programs
=
None
,
scopes
=
None
):
self
.
_communicator
.
start
()
pass
def
init_server
(
self
,
model_dir
=
None
):
def
init_server
(
self
,
model_dir
=
None
):
"""
"""
...
@@ -65,19 +64,19 @@ class DistributedTranspiler(Fleet):
...
@@ -65,19 +64,19 @@ class DistributedTranspiler(Fleet):
Returns:
Returns:
None
None
"""
"""
if
not
self
.
_
startup_program
:
if
not
self
.
startup_program
:
raise
ValueError
(
raise
ValueError
(
"startup_program is None, need invoke DistributedOptimizer.minimize first"
"startup_program is None, need invoke DistributedOptimizer.minimize first"
)
)
self
.
_executor
.
run
(
self
.
_
startup_program
)
self
.
_executor
.
run
(
self
.
startup_program
)
if
model_dir
:
if
model_dir
:
if
not
os
.
path
.
isdir
(
model_dir
):
if
not
os
.
path
.
isdir
(
model_dir
):
raise
ValueError
(
"There is no directory named '%s'"
,
model_dir
)
raise
ValueError
(
"There is no directory named '%s'"
,
model_dir
)
io
.
load_persistables
(
self
.
_executor
,
model_dir
,
io
.
load_persistables
(
self
.
_executor
,
model_dir
,
self
.
_
startup_program
)
self
.
startup_program
)
def
run_server
(
self
):
def
run_server
(
self
):
"""
"""
...
@@ -86,17 +85,14 @@ class DistributedTranspiler(Fleet):
...
@@ -86,17 +85,14 @@ class DistributedTranspiler(Fleet):
Returns:
Returns:
None
None
"""
"""
if
not
self
.
_
main_program
:
if
not
self
.
main_program
:
raise
ValueError
(
raise
ValueError
(
"main_program is None, need invoke DistributedOptimizer.minimize first"
"main_program is None, need invoke DistributedOptimizer.minimize first"
)
)
self
.
_executor
.
run
(
self
.
_
main_program
)
self
.
_executor
.
run
(
self
.
main_program
)
def
stop_worker
(
self
):
def
stop_worker
(
self
):
pass
def
stop
(
self
):
"""
"""
Close this executor.
Close this executor.
...
@@ -106,6 +102,8 @@ class DistributedTranspiler(Fleet):
...
@@ -106,6 +102,8 @@ class DistributedTranspiler(Fleet):
Returns:
Returns:
None
None
"""
"""
if
not
self
.
_transpile_config
.
sync_mode
:
self
.
_communicator
.
stop
()
self
.
_executor
.
close
()
self
.
_executor
.
close
()
def
distributed_optimizer
(
self
,
optimizer
,
strategy
=
None
):
def
distributed_optimizer
(
self
,
optimizer
,
strategy
=
None
):
...
@@ -129,6 +127,7 @@ class DistributedTranspiler(Fleet):
...
@@ -129,6 +127,7 @@ class DistributedTranspiler(Fleet):
return
self
.
_optimizer
return
self
.
_optimizer
def
save_inference_model
(
self
,
def
save_inference_model
(
self
,
executor
,
dirname
,
dirname
,
feeded_var_names
,
feeded_var_names
,
target_vars
,
target_vars
,
...
@@ -139,10 +138,10 @@ class DistributedTranspiler(Fleet):
...
@@ -139,10 +138,10 @@ class DistributedTranspiler(Fleet):
and then save it and all related parameters to given `dirname` by the `executor`.
and then save it and all related parameters to given `dirname` by the `executor`.
"""
"""
io
.
save_inference_model
(
dirname
,
feeded_var_names
,
target_vars
,
io
.
save_inference_model
(
dirname
,
feeded_var_names
,
target_vars
,
self
.
_
executor
,
main_program
,
None
,
None
,
executor
,
main_program
,
None
,
None
,
export_for_deployment
)
export_for_deployment
)
def
save_persistables
(
self
,
dirname
,
main_program
=
None
):
def
save_persistables
(
self
,
executor
,
dirname
,
main_program
=
None
):
"""
"""
This function filters out all variables with `persistable==True` from the
This function filters out all variables with `persistable==True` from the
give `main_program` and then saves these variables to the folder `dirname`
give `main_program` and then saves these variables to the folder `dirname`
...
@@ -153,21 +152,30 @@ class DistributedTranspiler(Fleet):
...
@@ -153,21 +152,30 @@ class DistributedTranspiler(Fleet):
files, set `filename` None; if you would like to save all variables in a
files, set `filename` None; if you would like to save all variables in a
single file, use `filename` to specify the file name.
single file, use `filename` to specify the file name.
"""
"""
io
.
save_persistables
(
self
.
_
executor
,
dirname
,
main_program
,
None
)
io
.
save_persistables
(
executor
,
dirname
,
main_program
,
None
)
def
_transpile
(
self
,
config
):
def
_transpile
(
self
,
config
):
if
not
isinstance
(
config
,
DistributeTranspilerConfig
):
raise
ValueError
(
"config must be an instance of DistributeTranspilerConfig"
)
if
not
config
.
sync_mode
:
config
.
runtime_split_send_recv
=
True
self
.
_transpile_config
=
config
self
.
_transpiler
=
OriginTranspiler
(
config
)
self
.
_transpiler
=
OriginTranspiler
(
config
)
self
.
_transpiler
.
transpile
(
self
.
_transpiler
.
transpile
(
trainer_id
=
fleet
.
worker_index
(),
trainer_id
=
fleet
.
worker_index
(),
pservers
=
fleet
.
server_endpoints
(
to_string
=
True
),
pservers
=
fleet
.
server_endpoints
(
to_string
=
True
),
trainers
=
fleet
.
worker_num
())
trainers
=
fleet
.
worker_num
(),
sync_mode
=
config
.
sync_mode
)
if
self
.
is_worker
():
if
self
.
is_worker
():
self
.
_
main_program
=
self
.
_transpiler
.
get_trainer_program
()
self
.
main_program
=
self
.
_transpiler
.
get_trainer_program
()
self
.
_
startup_program
=
default_startup_program
()
self
.
startup_program
=
default_startup_program
()
else
:
else
:
self
.
_main_program
,
self
.
_
startup_program
=
\
self
.
main_program
,
self
.
startup_program
=
\
self
.
_transpiler
.
get_pserver_programs
(
self
.
server_endpoints
(
self
.
server_index
())
)
self
.
_transpiler
.
get_pserver_programs
(
self
.
server_endpoints
(
)[
self
.
server_index
()]
)
fleet
=
DistributedTranspiler
()
fleet
=
DistributedTranspiler
()
...
...
python/paddle/fluid/incubate/fleet/parameter_server/pslib/__init__.py
浏览文件 @
58f7695a
...
@@ -33,8 +33,8 @@ class PSLib(Fleet):
...
@@ -33,8 +33,8 @@ class PSLib(Fleet):
self
.
_main_programs
=
[]
self
.
_main_programs
=
[]
self
.
_scopes
=
[]
self
.
_scopes
=
[]
def
init
(
self
,
executor
,
role_maker
=
None
):
def
init
(
self
,
role_maker
=
None
):
super
(
PSLib
,
self
).
init
(
executor
,
MPISymetricRoleMaker
())
super
(
PSLib
,
self
).
init
(
MPISymetricRoleMaker
())
self
.
_fleet_ptr
=
fluid
.
core
.
Fleet
()
self
.
_fleet_ptr
=
fluid
.
core
.
Fleet
()
def
init_worker
(
self
):
def
init_worker
(
self
):
...
@@ -169,23 +169,12 @@ class PSLib(Fleet):
...
@@ -169,23 +169,12 @@ class PSLib(Fleet):
self
.
_role_maker
.
_barrier_all
()
self
.
_role_maker
.
_barrier_all
()
self
.
_role_maker
.
_finalize
()
self
.
_role_maker
.
_finalize
()
def
stop
(
self
):
"""
stop(): will be called after a user finishes his/her training task. Fleet instance will be
destroyed when stop() is called.
"""
self
.
_role_maker
.
_barrier_worker
()
if
self
.
_role_maker
.
is_first_worker
():
self
.
_fleet_ptr
.
stop_server
()
self
.
_role_maker
.
_barrier_worker
()
self
.
_role_maker
.
_barrier_all
()
self
.
_role_maker
.
_finalize
()
def
distributed_optimizer
(
self
,
optimizer
,
strategy
=
{}):
def
distributed_optimizer
(
self
,
optimizer
,
strategy
=
{}):
self
.
_optimizer
=
DownpourOptimizer
(
optimizer
,
strategy
)
self
.
_optimizer
=
DownpourOptimizer
(
optimizer
,
strategy
)
return
self
.
_optimizer
return
self
.
_optimizer
def
save_inference_model
(
self
,
def
save_inference_model
(
self
,
executor
,
dirname
,
dirname
,
feeded_var_names
=
None
,
feeded_var_names
=
None
,
target_vars
=
None
,
target_vars
=
None
,
...
@@ -196,7 +185,7 @@ class PSLib(Fleet):
...
@@ -196,7 +185,7 @@ class PSLib(Fleet):
"""
"""
self
.
_fleet_ptr
.
save_model
(
dirname
)
self
.
_fleet_ptr
.
save_model
(
dirname
)
def
save_persistables
(
self
,
dirname
,
main_program
=
None
,
**
kwargs
):
def
save_persistables
(
self
,
executor
,
dirname
,
main_program
=
None
,
**
kwargs
):
"""
"""
save presistable parameters,
save presistable parameters,
when using fleet, it will save sparse and dense feature
when using fleet, it will save sparse and dense feature
...
...
python/paddle/fluid/incubate/fleet/tests/cluster_train.sh
0 → 100644
浏览文件 @
58f7695a
#!/bin/bash
# start pserver0
python fleet_deep_ctr.py
\
--role
pserver
\
--endpoints
127.0.0.1:7000,127.0.0.1:7001
\
--current_endpoint
127.0.0.1:7000
\
--trainers
2
\
>
pserver0.log 2>&1 &
# start pserver1
python fleet_deep_ctr.py
\
--role
pserver
\
--endpoints
127.0.0.1:7000,127.0.0.1:7001
\
--current_endpoint
127.0.0.1:7001
\
--trainers
2
\
>
pserver1.log 2>&1 &
# start trainer0
python fleet_deep_ctr.py
\
--role
trainer
\
--endpoints
127.0.0.1:7000,127.0.0.1:7001
\
--trainers
2
\
--trainer_id
0
\
>
trainer0.log 2>&1 &
# start trainer1
python fleet_deep_ctr.py
\
--role
trainer
\
--endpoints
127.0.0.1:7000,127.0.0.1:7001
\
--trainers
2
\
--trainer_id
1
\
>
trainer1.log 2>&1 &
python/paddle/fluid/incubate/fleet/tests/ctr_dataset_reader.py
0 → 100644
浏览文件 @
58f7695a
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
print_function
import
logging
import
tarfile
import
os
import
paddle
import
paddle.fluid.incubate.data_generator
as
data_generator
logging
.
basicConfig
()
logger
=
logging
.
getLogger
(
"paddle"
)
logger
.
setLevel
(
logging
.
INFO
)
DATA_URL
=
"http://paddle-ctr-data.bj.bcebos.com/avazu_ctr_data.tgz"
DATA_MD5
=
"c11df99fbd14e53cd4bfa6567344b26e"
"""
avazu_ctr_data/train.txt
avazu_ctr_data/infer.txt
avazu_ctr_data/test.txt
avazu_ctr_data/data.meta.txt
"""
def
download_file
():
file_name
=
"avazu_ctr_data"
path
=
paddle
.
dataset
.
common
.
download
(
DATA_URL
,
file_name
,
DATA_MD5
)
dir_name
=
os
.
path
.
dirname
(
path
)
text_file_dir_name
=
os
.
path
.
join
(
dir_name
,
file_name
)
if
not
os
.
path
.
exists
(
text_file_dir_name
):
tar
=
tarfile
.
open
(
path
,
"r:gz"
)
tar
.
extractall
(
dir_name
)
return
text_file_dir_name
def
load_dnn_input_record
(
sent
):
return
list
(
map
(
int
,
sent
.
split
()))
def
load_lr_input_record
(
sent
):
res
=
[]
for
_
in
[
x
.
split
(
':'
)
for
x
in
sent
.
split
()]:
res
.
append
(
int
(
_
[
0
]))
return
res
class
DatasetCtrReader
(
data_generator
.
MultiSlotDataGenerator
):
def
generate_sample
(
self
,
line
):
def
iter
():
fs
=
line
.
strip
().
split
(
'
\t
'
)
dnn_input
=
load_dnn_input_record
(
fs
[
0
])
lr_input
=
load_lr_input_record
(
fs
[
1
])
click
=
[
int
(
fs
[
2
])]
yield
(
"dnn_data"
,
dnn_input
),
\
(
"lr_data"
,
lr_input
),
\
(
"click"
,
click
)
return
iter
def
prepare_data
():
"""
load data meta info from path, return (dnn_input_dim, lr_input_dim)
"""
file_dir_name
=
download_file
()
meta_file_path
=
os
.
path
.
join
(
file_dir_name
,
'data.meta.txt'
)
train_file_path
=
os
.
path
.
join
(
file_dir_name
,
'train.txt'
)
with
open
(
meta_file_path
,
"r"
)
as
f
:
lines
=
f
.
readlines
()
err_info
=
"wrong meta format"
assert
len
(
lines
)
==
2
,
err_info
assert
'dnn_input_dim:'
in
lines
[
0
]
and
'lr_input_dim:'
in
lines
[
1
],
err_info
res
=
map
(
int
,
[
_
.
split
(
':'
)[
1
]
for
_
in
lines
])
res
=
list
(
res
)
dnn_input_dim
=
res
[
0
]
lr_input_dim
=
res
[
1
]
logger
.
info
(
'dnn input dim: %d'
%
dnn_input_dim
)
logger
.
info
(
'lr input dim: %d'
%
lr_input_dim
)
return
dnn_input_dim
,
lr_input_dim
,
train_file_path
if
__name__
==
"__main__"
:
pairwise_reader
=
DatasetCtrReader
()
pairwise_reader
.
run_from_stdin
()
python/paddle/fluid/incubate/fleet/tests/fleet_deep_ctr.py
0 → 100644
浏览文件 @
58f7695a
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
argparse
import
logging
import
time
import
paddle.fluid
as
fluid
import
paddle.fluid.incubate.fleet.base.role_maker
as
role_maker
from
paddle.fluid.incubate.fleet.parameter_server.distributed_transpiler
import
fleet
from
paddle.fluid.transpiler.distribute_transpiler
import
DistributeTranspilerConfig
import
ctr_dataset_reader
logging
.
basicConfig
(
format
=
'%(asctime)s - %(levelname)s - %(message)s'
)
logger
=
logging
.
getLogger
(
"fluid"
)
logger
.
setLevel
(
logging
.
INFO
)
def
parse_args
():
parser
=
argparse
.
ArgumentParser
(
description
=
"PaddlePaddle Fleet ctr"
)
# the following arguments is used for distributed train, if is_local == false, then you should set them
parser
.
add_argument
(
'--role'
,
type
=
str
,
default
=
'pserver'
,
# trainer or pserver
help
=
'The path for model to store (default: models)'
)
parser
.
add_argument
(
'--endpoints'
,
type
=
str
,
default
=
'127.0.0.1:6000'
,
help
=
'The pserver endpoints, like: 127.0.0.1:6000,127.0.0.1:6001'
)
parser
.
add_argument
(
'--current_endpoint'
,
type
=
str
,
default
=
'127.0.0.1:6000'
,
help
=
'The path for model to store (default: 127.0.0.1:6000)'
)
parser
.
add_argument
(
'--trainer_id'
,
type
=
int
,
default
=
0
,
help
=
'The path for model to store (default: models)'
)
parser
.
add_argument
(
'--trainers'
,
type
=
int
,
default
=
1
,
help
=
'The num of trainers, (default: 1)'
)
return
parser
.
parse_args
()
def
model
():
dnn_input_dim
,
lr_input_dim
,
train_file_path
=
ctr_dataset_reader
.
prepare_data
(
)
""" network definition """
dnn_data
=
fluid
.
layers
.
data
(
name
=
"dnn_data"
,
shape
=
[
-
1
,
1
],
dtype
=
"int64"
,
lod_level
=
1
,
append_batch_size
=
False
)
lr_data
=
fluid
.
layers
.
data
(
name
=
"lr_data"
,
shape
=
[
-
1
,
1
],
dtype
=
"int64"
,
lod_level
=
1
,
append_batch_size
=
False
)
label
=
fluid
.
layers
.
data
(
name
=
"click"
,
shape
=
[
-
1
,
1
],
dtype
=
"int64"
,
lod_level
=
0
,
append_batch_size
=
False
)
datas
=
[
dnn_data
,
lr_data
,
label
]
# build dnn model
dnn_layer_dims
=
[
128
,
64
,
32
,
1
]
dnn_embedding
=
fluid
.
layers
.
embedding
(
is_distributed
=
False
,
input
=
dnn_data
,
size
=
[
dnn_input_dim
,
dnn_layer_dims
[
0
]],
param_attr
=
fluid
.
ParamAttr
(
name
=
"deep_embedding"
,
initializer
=
fluid
.
initializer
.
Constant
(
value
=
0.01
)),
is_sparse
=
True
)
dnn_pool
=
fluid
.
layers
.
sequence_pool
(
input
=
dnn_embedding
,
pool_type
=
"sum"
)
dnn_out
=
dnn_pool
for
i
,
dim
in
enumerate
(
dnn_layer_dims
[
1
:]):
fc
=
fluid
.
layers
.
fc
(
input
=
dnn_out
,
size
=
dim
,
act
=
"relu"
,
param_attr
=
fluid
.
ParamAttr
(
initializer
=
fluid
.
initializer
.
Constant
(
value
=
0.01
)),
name
=
'dnn-fc-%d'
%
i
)
dnn_out
=
fc
# build lr model
lr_embbding
=
fluid
.
layers
.
embedding
(
is_distributed
=
False
,
input
=
lr_data
,
size
=
[
lr_input_dim
,
1
],
param_attr
=
fluid
.
ParamAttr
(
name
=
"wide_embedding"
,
initializer
=
fluid
.
initializer
.
Constant
(
value
=
0.01
)),
is_sparse
=
True
)
lr_pool
=
fluid
.
layers
.
sequence_pool
(
input
=
lr_embbding
,
pool_type
=
"sum"
)
merge_layer
=
fluid
.
layers
.
concat
(
input
=
[
dnn_out
,
lr_pool
],
axis
=
1
)
predict
=
fluid
.
layers
.
fc
(
input
=
merge_layer
,
size
=
2
,
act
=
'softmax'
)
acc
=
fluid
.
layers
.
accuracy
(
input
=
predict
,
label
=
label
)
auc_var
,
batch_auc_var
,
auc_states
=
fluid
.
layers
.
auc
(
input
=
predict
,
label
=
label
)
cost
=
fluid
.
layers
.
cross_entropy
(
input
=
predict
,
label
=
label
)
avg_cost
=
fluid
.
layers
.
mean
(
x
=
cost
)
return
datas
,
avg_cost
,
predict
,
train_file_path
def
train
(
args
):
datas
,
avg_cost
,
predict
,
train_file_path
=
model
()
endpoints
=
args
.
endpoints
.
split
(
","
)
if
args
.
role
.
upper
()
==
"PSERVER"
:
current_id
=
endpoints
.
index
(
args
.
current_endpoint
)
else
:
current_id
=
0
role
=
role_maker
.
UserDefinedRoleMaker
(
current_id
=
current_id
,
role
=
role_maker
.
Role
.
WORKER
if
args
.
role
.
upper
()
==
"TRAINER"
else
role_maker
.
Role
.
SERVER
,
worker_num
=
args
.
trainers
,
server_endpoints
=
endpoints
)
exe
=
fluid
.
Executor
(
fluid
.
CPUPlace
())
fleet
.
init
(
role
)
strategy
=
DistributeTranspilerConfig
()
strategy
.
sync_mode
=
False
optimizer
=
fluid
.
optimizer
.
SGD
(
learning_rate
=
0.0001
)
optimizer
=
fleet
.
distributed_optimizer
(
optimizer
,
strategy
)
optimizer
.
minimize
(
avg_cost
)
if
fleet
.
is_server
():
logger
.
info
(
"run pserver"
)
fleet
.
init_server
()
fleet
.
run_server
()
elif
fleet
.
is_worker
():
logger
.
info
(
"run trainer"
)
fleet
.
init_worker
()
exe
.
run
(
fleet
.
startup_program
)
thread_num
=
2
filelist
=
[]
for
_
in
range
(
thread_num
):
filelist
.
append
(
train_file_path
)
# config dataset
dataset
=
fluid
.
DatasetFactory
().
create_dataset
()
dataset
.
set_batch_size
(
128
)
dataset
.
set_use_var
(
datas
)
pipe_command
=
'python ctr_dataset_reader.py'
dataset
.
set_pipe_command
(
pipe_command
)
dataset
.
set_filelist
(
filelist
)
dataset
.
set_thread
(
thread_num
)
for
epoch_id
in
range
(
10
):
logger
.
info
(
"epoch {} start"
.
format
(
epoch_id
))
pass_start
=
time
.
time
()
dataset
.
set_filelist
(
filelist
)
exe
.
train_from_dataset
(
program
=
fleet
.
main_program
,
dataset
=
dataset
,
fetch_list
=
[
avg_cost
],
fetch_info
=
[
"cost"
],
print_period
=
100
,
debug
=
False
)
pass_time
=
time
.
time
()
-
pass_start
logger
.
info
(
"epoch {} finished, pass_time {}"
.
format
(
epoch_id
,
pass_time
))
fleet
.
stop_worker
()
if
__name__
==
"__main__"
:
args
=
parse_args
()
train
(
args
)
python/paddle/fluid/optimizer.py
浏览文件 @
58f7695a
...
@@ -15,27 +15,26 @@
...
@@ -15,27 +15,26 @@
from
__future__
import
print_function
from
__future__
import
print_function
from
collections
import
defaultdict
from
collections
import
defaultdict
from
.wrapped_decorator
import
signature_safe_contextmanager
from
functools
import
reduce
from
paddle.fluid
.framework
import
Program
,
Variable
,
name_scope
,
default_main_program
,
default_startup_program
from
paddle.fluid
import
core
from
paddle.fluid.distribute_lookup_table
import
find_distributed_lookup_table
from
paddle.fluid.distribute_lookup_table
import
find_distributed_lookup_table
from
paddle.fluid.framework
import
Program
,
Variable
,
name_scope
,
default_main_program
,
default_startup_program
from
paddle.fluid.layers
import
tensor
from
.
import
framework
from
.
import
framework
from
.
import
layers
from
.
import
layers
from
.
import
unique_name
from
.
import
unique_name
from
.backward
import
append_backward
from
.backward
import
append_backward
from
.clip
import
append_gradient_clip_ops
,
error_clip_callback
from
.clip
import
append_gradient_clip_ops
,
error_clip_callback
from
.dygraph
import
base
as
imperative_base
from
.dygraph.learning_rate_scheduler
import
LearningRateDecay
from
.framework
import
program_guard
from
.framework
import
program_guard
from
.initializer
import
Constant
from
.initializer
import
Constant
from
.layer_helper
import
LayerHelper
from
.layer_helper
import
LayerHelper
from
.layers
import
ops
from
.layers
import
ops
from
.regularizer
import
append_regularization_ops
from
.regularizer
import
append_regularization_ops
from
.dygraph
import
base
as
imperative_base
from
.wrapped_decorator
import
signature_safe_contextmanager
from
.dygraph.learning_rate_scheduler
import
LearningRateDecay
from
paddle.fluid
import
core
from
paddle.fluid.layers
import
tensor
from
functools
import
reduce
import
copy
__all__
=
[
__all__
=
[
'SGD'
,
'Momentum'
,
'Adagrad'
,
'Adam'
,
'Adamax'
,
'DecayedAdagrad'
,
'Ftrl'
,
'SGD'
,
'Momentum'
,
'Adagrad'
,
'Adam'
,
'Adamax'
,
'DecayedAdagrad'
,
'Ftrl'
,
...
...
python/paddle/fluid/tests/CMakeLists.txt
浏览文件 @
58f7695a
file
(
GLOB TEST_OPS RELATIVE
"
${
CMAKE_CURRENT_SOURCE_DIR
}
"
"test_*.py"
)
file
(
GLOB TEST_OPS RELATIVE
"
${
CMAKE_CURRENT_SOURCE_DIR
}
"
"test_*.py"
)
string
(
REPLACE
".py"
""
TEST_OPS
"
${
TEST_OPS
}
"
)
string
(
REPLACE
".py"
""
TEST_OPS
"
${
TEST_OPS
}
"
)
if
(
NOT WITH_DISTRIBUTE
)
list
(
REMOVE_ITEM TEST_OPS test_communicator
)
endif
(
NOT WITH_DISTRIBUTE
)
foreach
(
src
${
TEST_OPS
}
)
foreach
(
src
${
TEST_OPS
}
)
py_test
(
${
src
}
SRCS
${
src
}
.py
)
py_test
(
${
src
}
SRCS
${
src
}
.py
)
endforeach
()
endforeach
()
...
...
python/paddle/fluid/tests/test_communicator.py
0 → 100644
浏览文件 @
58f7695a
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
print_function
import
unittest
import
paddle.fluid
as
fluid
from
paddle.fluid.communicator
import
Communicator
class
TestCommunicator
(
unittest
.
TestCase
):
def
test_communicator_init_and_start
(
self
):
prog
=
fluid
.
Program
()
comm
=
Communicator
(
prog
)
comm
.
start
()
comm
.
stop
()
if
__name__
==
'__main__'
:
unittest
.
main
()
python/paddle/fluid/transpiler/distribute_transpiler.py
浏览文件 @
58f7695a
...
@@ -158,7 +158,7 @@ class DistributeTranspilerConfig(object):
...
@@ -158,7 +158,7 @@ class DistributeTranspilerConfig(object):
wait_port
=
True
wait_port
=
True
# split the send recv var in runtime
# split the send recv var in runtime
runtime_split_send_recv
=
False
runtime_split_send_recv
=
False
sync_mode
=
Non
e
sync_mode
=
Tru
e
class
DistributeTranspiler
(
object
):
class
DistributeTranspiler
(
object
):
...
@@ -330,7 +330,7 @@ class DistributeTranspiler(object):
...
@@ -330,7 +330,7 @@ class DistributeTranspiler(object):
return
return
self
.
trainer_num
=
trainers
self
.
trainer_num
=
trainers
self
.
sync_mode
=
s
elf
.
config
.
sync_mode
if
self
.
config
.
sync_mode
else
s
ync_mode
self
.
sync_mode
=
sync_mode
self
.
trainer_id
=
trainer_id
self
.
trainer_id
=
trainer_id
pserver_endpoints
=
pservers
.
split
(
","
)
pserver_endpoints
=
pservers
.
split
(
","
)
self
.
pserver_endpoints
=
pserver_endpoints
self
.
pserver_endpoints
=
pserver_endpoints
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录