Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
02ea3491
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
02ea3491
编写于
1月 17, 2018
作者:
T
typhoonzero
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
enhance dist train performance
上级
bcfb82d3
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
55 addition
and
56 deletion
+55
-56
paddle/operators/detail/grpc_client.cc
paddle/operators/detail/grpc_client.cc
+1
-4
paddle/operators/detail/grpc_client.h
paddle/operators/detail/grpc_client.h
+1
-1
paddle/operators/recv_op.cc
paddle/operators/recv_op.cc
+26
-40
paddle/operators/send_op.cc
paddle/operators/send_op.cc
+4
-2
python/paddle/v2/fluid/distribute_transpiler.py
python/paddle/v2/fluid/distribute_transpiler.py
+14
-1
python/paddle/v2/fluid/tests/book_distribute/notest_recognize_digits_conv_dist.py
...ests/book_distribute/notest_recognize_digits_conv_dist.py
+9
-8
未找到文件。
paddle/operators/detail/grpc_client.cc
浏览文件 @
02ea3491
...
...
@@ -63,9 +63,6 @@ bool RPCClient::AsyncGetVariable(const std::string& ep,
sendrecv
::
VariableMessage
req
;
req
.
set_varname
(
var_name
);
auto
*
var
=
scope
.
FindVar
(
var_name
);
SerializeToMessage
(
var_name
,
var
,
ctx
,
&
req
);
// varhandle
VarHandle
var_h
;
var_h
.
ep
=
ep
;
...
...
@@ -87,7 +84,7 @@ bool RPCClient::AsyncGetVariable(const std::string& ep,
return
true
;
}
bool
RPCClient
::
w
ait
()
{
bool
RPCClient
::
W
ait
()
{
bool
ok
=
true
;
while
(
true
)
{
...
...
paddle/operators/detail/grpc_client.h
浏览文件 @
02ea3491
...
...
@@ -130,7 +130,7 @@ class RPCClient {
const
framework
::
Scope
&
scope
,
const
std
::
string
&
var_name
,
int64_t
time_out
=
600
*
1000
);
bool
w
ait
();
bool
W
ait
();
private:
bool
Proceed
();
...
...
paddle/operators/recv_op.cc
浏览文件 @
02ea3491
...
...
@@ -27,6 +27,7 @@ limitations under the License. */
#include "paddle/operators/detail/grpc_server.h"
#include "paddle/operators/detail/sendrecvop_utils.h"
#include "paddle/operators/detail/simple_block_queue.h"
#include "paddle/string/printf.h"
#define LISTEN_TERMINATE_MESSAGE "TERMINATE@RECV"
...
...
@@ -77,35 +78,37 @@ class RecvOp : public framework::OperatorBase {
if
(
grads_counter_
.
find
(
varname
)
==
grads_counter_
.
end
())
{
grads_counter_
[
varname
]
=
0
;
}
char
ret
[
256
];
snprintf
(
ret
,
sizeof
(
ret
),
"%s.trainer_%d"
,
varname
.
c_str
(),
grads_counter_
[
varname
]
++
);
return
std
::
string
(
ret
);
return
string
::
Sprintf
(
"%s.trainer_%d"
,
varname
,
grads_counter_
[
varname
]
++
);
}
void
Run
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
dev_place
)
const
override
{
// FIXME(typhoonzero): no new scopes for every run.
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
auto
&
dev_ctx
=
*
pool
.
Get
(
dev_place
);
framework
::
Scope
&
recv_scope
=
scope
.
NewScope
();
rpc_service_
->
SetScope
(
&
recv_scope
);
auto
param_list
=
Attr
<
std
::
vector
<
std
::
string
>>
(
"ParamList"
);
auto
grad_list
=
Attr
<
std
::
vector
<
std
::
string
>>
(
"GradList"
);
auto
trainer_count
=
Attr
<
int
>
(
"Trainers
"
);
auto
fan_in
=
Attr
<
int
>
(
"Fanin
"
);
size_t
param_count
=
param_list
.
size
();
std
::
string
program_str
=
Attr
<
std
::
string
>
(
"OptimizeProgram"
);
framework
::
proto
::
ProgramDesc
program_desc
;
program_desc
.
ParseFromString
(
program_str
);
framework
::
ProgramDesc
program
(
program_desc
);
framework
::
Executor
executor
(
dev_place
);
rpc_service_
->
Reset
();
// TODO(typhoonzero): change this to a while_op for every cluster-batch.
bool
exit_flag
=
false
;
while
(
!
exit_flag
)
{
// TODO(gognwb): simply this loop.
// Get from multiple trainers, we don't care about order in which
// the gradient arrives, just add suffix 0~n then average the gradient.
for
(
size_t
i
=
0
;
i
<
param_count
*
trainer_count
;
++
i
)
{
// blocking get one var from client.
// Get from multiple trainers, we don't care about the order in which
// the gradients arrives, just add suffix 0~n and merge the gradient.
for
(
size_t
i
=
0
;
i
<
param_count
*
fan_in
;
++
i
)
{
const
detail
::
MessageWithName
&
v
=
rpc_service_
->
Get
();
auto
grad_var_name
=
v
.
first
;
if
(
grad_var_name
==
LISTEN_TERMINATE_MESSAGE
)
{
VLOG
(
4
)
<<
"received LISTEN_TERMINATE_MESSAGE and RunOp.Run()
exit"
;
LOG
(
INFO
)
<<
"received terminate message and
exit"
;
exit_flag
=
true
;
break
;
}
...
...
@@ -114,44 +117,27 @@ class RecvOp : public framework::OperatorBase {
if
(
it
!=
grad_list
.
end
())
{
param_var_name
=
param_list
[
it
-
grad_list
.
begin
()];
}
else
{
LOG
(
ERROR
)
<<
"grad have no paired param found!
\"
"
<<
grad_var_name
<<
"
\"
"
;
LOG
(
ERROR
)
<<
"grad have no paired param:"
<<
grad_var_name
;
}
VLOG
(
3
)
<<
"recved grad: "
<<
grad_var_name
<<
" updating param: "
<<
param_var_name
;
auto
*
merged_grad
=
recv_scope
.
FindVar
(
grad_var_name
);
if
(
merged_grad
==
nullptr
)
{
auto
*
ptr
=
recv_scope
.
Var
(
grad_var_name
);
CreateTensorFromMessageType
(
ptr
,
v
.
second
.
type
());
VLOG
(
3
)
<<
"Create Variable "
<<
grad_var_name
<<
" on recv scope, which pointer is "
<<
ptr
<<
" type is "
<<
v
.
second
.
type
();
// Assume grad_var_name must appear in global scope.
std
::
string
grad_var_name_trainer
;
if
(
fan_in
>
1
)
{
grad_var_name_trainer
=
this
->
GetGradVarNameForTrainer
(
grad_var_name
);
}
if
(
trainer_count
>
1
)
{
grad_var_name
=
this
->
GetGradVarNameForTrainer
(
grad_var_name
);
auto
*
var
=
recv_scope
.
FindVar
(
grad_var_name_trainer
);
if
(
var
==
nullptr
)
{
LOG
(
ERROR
)
<<
"can not find server side var: "
<<
grad_var_name_trainer
;
PADDLE_THROW
(
"can not find server side var"
);
}
auto
*
var
=
recv_scope
.
Var
(
grad_var_name
);
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
auto
&
dev_ctx
=
*
pool
.
Get
(
dev_place
);
detail
::
DeserializeFromMessage
(
v
.
second
,
dev_ctx
,
var
);
}
if
(
exit_flag
)
{
break
;
}
rpc_service_
->
Reset
();
std
::
string
program_str
=
Attr
<
std
::
string
>
(
"OptimizeProgram"
);
framework
::
proto
::
ProgramDesc
program_desc
;
program_desc
.
ParseFromString
(
program_str
);
framework
::
ProgramDesc
program
(
program_desc
);
framework
::
Executor
executor
(
dev_place
);
// Run sub graph to get optimized tensor
try
{
executor
.
Run
(
program
,
&
recv_scope
,
0
,
/*global_block*/
false
/*create_local_scope*/
,
false
/*create_vars*/
);
...
...
@@ -195,7 +181,7 @@ This operator will recv tensor from send_op
"GradList"
,
"type list of string"
,
"grad->param name mapping to find which param to optimize."
)
.
SetDefault
({});
AddAttr
<
int
>
(
"
Trainers
"
,
"type int"
,
AddAttr
<
int
>
(
"
Fanin
"
,
"type int"
,
"Number of trainers in the current cluster job"
)
.
SetDefault
(
1
);
}
...
...
paddle/operators/send_op.cc
浏览文件 @
02ea3491
...
...
@@ -41,14 +41,16 @@ class SendOp : public framework::OperatorBase {
// FIXME(gongwb): DeviceContext?
auto
ctx
=
platform
::
CPUDeviceContext
();
for
(
size_t
i
=
0
;
i
<
ins
.
size
();
i
++
)
{
VLOG
(
3
)
<<
"sending "
<<
ins
[
i
];
client_
.
AsyncSendVariable
(
epmap
[
i
],
ctx
,
scope
,
ins
[
i
]);
}
client_
.
Wait
();
for
(
size_t
i
=
0
;
i
<
outs
.
size
();
i
++
)
{
VLOG
(
3
)
<<
"getting "
<<
outs
[
i
];
client_
.
AsyncGetVariable
(
epmap
[
i
],
ctx
,
scope
,
outs
[
i
]);
}
client_
.
wait
();
client_
.
Wait
();
}
private:
...
...
python/paddle/v2/fluid/distribute_transpiler.py
浏览文件 @
02ea3491
...
...
@@ -452,6 +452,19 @@ class DistributeTranspiler:
pserver_program
=
Program
()
for
v
in
self
.
param_grad_ep_mapping
[
endpoint
][
"params"
]:
self
.
_clone_var
(
pserver_program
.
global_block
(),
v
)
for
v
in
self
.
param_grad_ep_mapping
[
endpoint
][
"grads"
]:
# create vars for each trainer in global scope, so
# we don't need to create them when grad arrives.
pserver_program
.
global_block
().
create_var
(
name
=
v
.
name
,
persistable
=
True
,
dtype
=
v
.
dtype
,
shape
=
v
.
shape
)
for
trainer_id
in
xrange
(
self
.
trainers
):
print
(
"create variable for program: %s.trainer_%d"
%
(
v
.
name
,
trainer_id
))
pserver_program
.
global_block
().
create_var
(
name
=
"%s.trainer_%d"
%
(
v
.
name
,
trainer_id
),
persistable
=
True
,
dtype
=
v
.
dtype
,
shape
=
v
.
shape
)
# step6
optimize_sub_program
=
Program
()
for
idx
,
opt_op
in
enumerate
(
optimize_ops
):
...
...
@@ -481,7 +494,7 @@ class DistributeTranspiler:
p
.
name
for
p
in
self
.
param_grad_ep_mapping
[
endpoint
][
"grads"
]
],
"
Trainers
"
:
self
.
trainers
"
Fanin
"
:
self
.
trainers
})
pserver_program
.
sync_with_cpp
()
return
pserver_program
...
...
python/paddle/v2/fluid/tests/book_distribute/notest_recognize_digits_conv_dist.py
浏览文件 @
02ea3491
...
...
@@ -39,26 +39,27 @@ train_reader = paddle.batch(
place
=
fluid
.
CPUPlace
()
exe
=
fluid
.
Executor
(
place
)
t
=
fluid
.
DistributeTranspiler
()
# all parameter server endpoints list for spliting parameters
pserver_endpoints
=
os
.
getenv
(
"PSERVERS"
)
# server endpoint for current node
current_endpoint
=
os
.
getenv
(
"SERVER_ENDPOINT"
)
# run as trainer or parameter server
pserver_endpoints
=
os
.
getenv
(
"PSERVERS"
)
# all pserver endpoints
trainers
=
int
(
os
.
getenv
(
"TRAINERS"
))
# total trainer count
current_endpoint
=
os
.
getenv
(
"SERVER_ENDPOINT"
)
# current pserver endpoint
training_role
=
os
.
getenv
(
"TRAINING_ROLE"
,
"TRAINER"
)
# get the training role: trainer/pserver
t
.
transpile
(
optimize_ops
,
params_grads
,
pservers
=
pserver_endpoints
,
trainers
=
2
)
t
=
fluid
.
DistributeTranspiler
()
t
.
transpile
(
optimize_ops
,
params_grads
,
pservers
=
pserver_endpoints
,
trainers
=
trainers
)
if
training_role
==
"PSERVER"
:
if
not
current_endpoint
:
print
(
"need env SERVER_ENDPOINT"
)
exit
(
1
)
pserver_prog
=
t
.
get_pserver_program
(
current_endpoint
,
optimize_ops
)
exe
.
run
(
fluid
.
default_startup_program
())
pserver_startup
=
t
.
get_startup_program
(
current_endpoint
,
pserver_prog
)
exe
.
run
(
pserver_startup
)
exe
.
run
(
pserver_prog
)
elif
training_role
==
"TRAINER"
:
trainer_prog
=
t
.
get_trainer_program
()
feeder
=
fluid
.
DataFeeder
(
feed_list
=
[
images
,
label
],
place
=
place
)
# TODO(typhoonzero): change trainer startup program to fetch parameters from pserver
exe
.
run
(
fluid
.
default_startup_program
())
for
pass_id
in
range
(
PASS_NUM
):
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录