Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleDetection
提交
02ea3491
P
PaddleDetection
项目概览
PaddlePaddle
/
PaddleDetection
1 年多 前同步成功
通知
699
Star
11112
Fork
2696
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
184
列表
看板
标记
里程碑
合并请求
40
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
184
Issue
184
列表
看板
标记
里程碑
合并请求
40
合并请求
40
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
02ea3491
编写于
1月 17, 2018
作者:
T
typhoonzero
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
enhance dist train performance
上级
bcfb82d3
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
55 addition
and
56 deletion
+55
-56
paddle/operators/detail/grpc_client.cc
paddle/operators/detail/grpc_client.cc
+1
-4
paddle/operators/detail/grpc_client.h
paddle/operators/detail/grpc_client.h
+1
-1
paddle/operators/recv_op.cc
paddle/operators/recv_op.cc
+26
-40
paddle/operators/send_op.cc
paddle/operators/send_op.cc
+4
-2
python/paddle/v2/fluid/distribute_transpiler.py
python/paddle/v2/fluid/distribute_transpiler.py
+14
-1
python/paddle/v2/fluid/tests/book_distribute/notest_recognize_digits_conv_dist.py
...ests/book_distribute/notest_recognize_digits_conv_dist.py
+9
-8
未找到文件。
paddle/operators/detail/grpc_client.cc
浏览文件 @
02ea3491
...
@@ -63,9 +63,6 @@ bool RPCClient::AsyncGetVariable(const std::string& ep,
...
@@ -63,9 +63,6 @@ bool RPCClient::AsyncGetVariable(const std::string& ep,
sendrecv
::
VariableMessage
req
;
sendrecv
::
VariableMessage
req
;
req
.
set_varname
(
var_name
);
req
.
set_varname
(
var_name
);
auto
*
var
=
scope
.
FindVar
(
var_name
);
SerializeToMessage
(
var_name
,
var
,
ctx
,
&
req
);
// varhandle
// varhandle
VarHandle
var_h
;
VarHandle
var_h
;
var_h
.
ep
=
ep
;
var_h
.
ep
=
ep
;
...
@@ -87,7 +84,7 @@ bool RPCClient::AsyncGetVariable(const std::string& ep,
...
@@ -87,7 +84,7 @@ bool RPCClient::AsyncGetVariable(const std::string& ep,
return
true
;
return
true
;
}
}
bool
RPCClient
::
w
ait
()
{
bool
RPCClient
::
W
ait
()
{
bool
ok
=
true
;
bool
ok
=
true
;
while
(
true
)
{
while
(
true
)
{
...
...
paddle/operators/detail/grpc_client.h
浏览文件 @
02ea3491
...
@@ -130,7 +130,7 @@ class RPCClient {
...
@@ -130,7 +130,7 @@ class RPCClient {
const
framework
::
Scope
&
scope
,
const
framework
::
Scope
&
scope
,
const
std
::
string
&
var_name
,
const
std
::
string
&
var_name
,
int64_t
time_out
=
600
*
1000
);
int64_t
time_out
=
600
*
1000
);
bool
w
ait
();
bool
W
ait
();
private:
private:
bool
Proceed
();
bool
Proceed
();
...
...
paddle/operators/recv_op.cc
浏览文件 @
02ea3491
...
@@ -27,6 +27,7 @@ limitations under the License. */
...
@@ -27,6 +27,7 @@ limitations under the License. */
#include "paddle/operators/detail/grpc_server.h"
#include "paddle/operators/detail/grpc_server.h"
#include "paddle/operators/detail/sendrecvop_utils.h"
#include "paddle/operators/detail/sendrecvop_utils.h"
#include "paddle/operators/detail/simple_block_queue.h"
#include "paddle/operators/detail/simple_block_queue.h"
#include "paddle/string/printf.h"
#define LISTEN_TERMINATE_MESSAGE "TERMINATE@RECV"
#define LISTEN_TERMINATE_MESSAGE "TERMINATE@RECV"
...
@@ -77,35 +78,37 @@ class RecvOp : public framework::OperatorBase {
...
@@ -77,35 +78,37 @@ class RecvOp : public framework::OperatorBase {
if
(
grads_counter_
.
find
(
varname
)
==
grads_counter_
.
end
())
{
if
(
grads_counter_
.
find
(
varname
)
==
grads_counter_
.
end
())
{
grads_counter_
[
varname
]
=
0
;
grads_counter_
[
varname
]
=
0
;
}
}
char
ret
[
256
];
return
string
::
Sprintf
(
"%s.trainer_%d"
,
varname
,
grads_counter_
[
varname
]
++
);
snprintf
(
ret
,
sizeof
(
ret
),
"%s.trainer_%d"
,
varname
.
c_str
(),
grads_counter_
[
varname
]
++
);
return
std
::
string
(
ret
);
}
}
void
Run
(
const
framework
::
Scope
&
scope
,
void
Run
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
dev_place
)
const
override
{
const
platform
::
Place
&
dev_place
)
const
override
{
// FIXME(typhoonzero): no new scopes for every run.
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
auto
&
dev_ctx
=
*
pool
.
Get
(
dev_place
);
framework
::
Scope
&
recv_scope
=
scope
.
NewScope
();
framework
::
Scope
&
recv_scope
=
scope
.
NewScope
();
rpc_service_
->
SetScope
(
&
recv_scope
);
rpc_service_
->
SetScope
(
&
recv_scope
);
auto
param_list
=
Attr
<
std
::
vector
<
std
::
string
>>
(
"ParamList"
);
auto
param_list
=
Attr
<
std
::
vector
<
std
::
string
>>
(
"ParamList"
);
auto
grad_list
=
Attr
<
std
::
vector
<
std
::
string
>>
(
"GradList"
);
auto
grad_list
=
Attr
<
std
::
vector
<
std
::
string
>>
(
"GradList"
);
auto
trainer_count
=
Attr
<
int
>
(
"Trainers
"
);
auto
fan_in
=
Attr
<
int
>
(
"Fanin
"
);
size_t
param_count
=
param_list
.
size
();
size_t
param_count
=
param_list
.
size
();
std
::
string
program_str
=
Attr
<
std
::
string
>
(
"OptimizeProgram"
);
framework
::
proto
::
ProgramDesc
program_desc
;
program_desc
.
ParseFromString
(
program_str
);
framework
::
ProgramDesc
program
(
program_desc
);
framework
::
Executor
executor
(
dev_place
);
rpc_service_
->
Reset
();
rpc_service_
->
Reset
();
// TODO(typhoonzero): change this to a while_op for every cluster-batch.
// TODO(typhoonzero): change this to a while_op for every cluster-batch.
bool
exit_flag
=
false
;
bool
exit_flag
=
false
;
while
(
!
exit_flag
)
{
while
(
!
exit_flag
)
{
// TODO(gognwb): simply this loop.
// Get from multiple trainers, we don't care about the order in which
// Get from multiple trainers, we don't care about order in which
// the gradients arrives, just add suffix 0~n and merge the gradient.
// the gradient arrives, just add suffix 0~n then average the gradient.
for
(
size_t
i
=
0
;
i
<
param_count
*
fan_in
;
++
i
)
{
for
(
size_t
i
=
0
;
i
<
param_count
*
trainer_count
;
++
i
)
{
// blocking get one var from client.
const
detail
::
MessageWithName
&
v
=
rpc_service_
->
Get
();
const
detail
::
MessageWithName
&
v
=
rpc_service_
->
Get
();
auto
grad_var_name
=
v
.
first
;
auto
grad_var_name
=
v
.
first
;
if
(
grad_var_name
==
LISTEN_TERMINATE_MESSAGE
)
{
if
(
grad_var_name
==
LISTEN_TERMINATE_MESSAGE
)
{
VLOG
(
4
)
<<
"received LISTEN_TERMINATE_MESSAGE and RunOp.Run()
exit"
;
LOG
(
INFO
)
<<
"received terminate message and
exit"
;
exit_flag
=
true
;
exit_flag
=
true
;
break
;
break
;
}
}
...
@@ -114,44 +117,27 @@ class RecvOp : public framework::OperatorBase {
...
@@ -114,44 +117,27 @@ class RecvOp : public framework::OperatorBase {
if
(
it
!=
grad_list
.
end
())
{
if
(
it
!=
grad_list
.
end
())
{
param_var_name
=
param_list
[
it
-
grad_list
.
begin
()];
param_var_name
=
param_list
[
it
-
grad_list
.
begin
()];
}
else
{
}
else
{
LOG
(
ERROR
)
<<
"grad have no paired param found!
\"
"
<<
grad_var_name
LOG
(
ERROR
)
<<
"grad have no paired param:"
<<
grad_var_name
;
<<
"
\"
"
;
}
}
VLOG
(
3
)
<<
"recved grad: "
<<
grad_var_name
VLOG
(
3
)
<<
"recved grad: "
<<
grad_var_name
<<
" updating param: "
<<
param_var_name
;
<<
" updating param: "
<<
param_var_name
;
// Assume grad_var_name must appear in global scope.
auto
*
merged_grad
=
recv_scope
.
FindVar
(
grad_var_name
);
std
::
string
grad_var_name_trainer
;
if
(
merged_grad
==
nullptr
)
{
if
(
fan_in
>
1
)
{
auto
*
ptr
=
recv_scope
.
Var
(
grad_var_name
);
grad_var_name_trainer
=
this
->
GetGradVarNameForTrainer
(
grad_var_name
);
CreateTensorFromMessageType
(
ptr
,
v
.
second
.
type
());
VLOG
(
3
)
<<
"Create Variable "
<<
grad_var_name
<<
" on recv scope, which pointer is "
<<
ptr
<<
" type is "
<<
v
.
second
.
type
();
}
}
auto
*
var
=
recv_scope
.
FindVar
(
grad_var_name_trainer
);
if
(
trainer_count
>
1
)
{
if
(
var
==
nullptr
)
{
grad_var_name
=
this
->
GetGradVarNameForTrainer
(
grad_var_name
);
LOG
(
ERROR
)
<<
"can not find server side var: "
<<
grad_var_name_trainer
;
PADDLE_THROW
(
"can not find server side var"
);
}
}
auto
*
var
=
recv_scope
.
Var
(
grad_var_name
);
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
auto
&
dev_ctx
=
*
pool
.
Get
(
dev_place
);
detail
::
DeserializeFromMessage
(
v
.
second
,
dev_ctx
,
var
);
detail
::
DeserializeFromMessage
(
v
.
second
,
dev_ctx
,
var
);
}
}
if
(
exit_flag
)
{
if
(
exit_flag
)
{
break
;
break
;
}
}
rpc_service_
->
Reset
();
rpc_service_
->
Reset
();
std
::
string
program_str
=
Attr
<
std
::
string
>
(
"OptimizeProgram"
);
framework
::
proto
::
ProgramDesc
program_desc
;
program_desc
.
ParseFromString
(
program_str
);
framework
::
ProgramDesc
program
(
program_desc
);
framework
::
Executor
executor
(
dev_place
);
// Run sub graph to get optimized tensor
try
{
try
{
executor
.
Run
(
program
,
&
recv_scope
,
0
,
/*global_block*/
executor
.
Run
(
program
,
&
recv_scope
,
0
,
/*global_block*/
false
/*create_local_scope*/
,
false
/*create_vars*/
);
false
/*create_local_scope*/
,
false
/*create_vars*/
);
...
@@ -195,7 +181,7 @@ This operator will recv tensor from send_op
...
@@ -195,7 +181,7 @@ This operator will recv tensor from send_op
"GradList"
,
"type list of string"
,
"GradList"
,
"type list of string"
,
"grad->param name mapping to find which param to optimize."
)
"grad->param name mapping to find which param to optimize."
)
.
SetDefault
({});
.
SetDefault
({});
AddAttr
<
int
>
(
"
Trainers
"
,
"type int"
,
AddAttr
<
int
>
(
"
Fanin
"
,
"type int"
,
"Number of trainers in the current cluster job"
)
"Number of trainers in the current cluster job"
)
.
SetDefault
(
1
);
.
SetDefault
(
1
);
}
}
...
...
paddle/operators/send_op.cc
浏览文件 @
02ea3491
...
@@ -41,14 +41,16 @@ class SendOp : public framework::OperatorBase {
...
@@ -41,14 +41,16 @@ class SendOp : public framework::OperatorBase {
// FIXME(gongwb): DeviceContext?
// FIXME(gongwb): DeviceContext?
auto
ctx
=
platform
::
CPUDeviceContext
();
auto
ctx
=
platform
::
CPUDeviceContext
();
for
(
size_t
i
=
0
;
i
<
ins
.
size
();
i
++
)
{
for
(
size_t
i
=
0
;
i
<
ins
.
size
();
i
++
)
{
VLOG
(
3
)
<<
"sending "
<<
ins
[
i
];
client_
.
AsyncSendVariable
(
epmap
[
i
],
ctx
,
scope
,
ins
[
i
]);
client_
.
AsyncSendVariable
(
epmap
[
i
],
ctx
,
scope
,
ins
[
i
]);
}
}
client_
.
Wait
();
for
(
size_t
i
=
0
;
i
<
outs
.
size
();
i
++
)
{
for
(
size_t
i
=
0
;
i
<
outs
.
size
();
i
++
)
{
VLOG
(
3
)
<<
"getting "
<<
outs
[
i
];
client_
.
AsyncGetVariable
(
epmap
[
i
],
ctx
,
scope
,
outs
[
i
]);
client_
.
AsyncGetVariable
(
epmap
[
i
],
ctx
,
scope
,
outs
[
i
]);
}
}
client_
.
Wait
();
client_
.
wait
();
}
}
private:
private:
...
...
python/paddle/v2/fluid/distribute_transpiler.py
浏览文件 @
02ea3491
...
@@ -452,6 +452,19 @@ class DistributeTranspiler:
...
@@ -452,6 +452,19 @@ class DistributeTranspiler:
pserver_program
=
Program
()
pserver_program
=
Program
()
for
v
in
self
.
param_grad_ep_mapping
[
endpoint
][
"params"
]:
for
v
in
self
.
param_grad_ep_mapping
[
endpoint
][
"params"
]:
self
.
_clone_var
(
pserver_program
.
global_block
(),
v
)
self
.
_clone_var
(
pserver_program
.
global_block
(),
v
)
for
v
in
self
.
param_grad_ep_mapping
[
endpoint
][
"grads"
]:
# create vars for each trainer in global scope, so
# we don't need to create them when grad arrives.
pserver_program
.
global_block
().
create_var
(
name
=
v
.
name
,
persistable
=
True
,
dtype
=
v
.
dtype
,
shape
=
v
.
shape
)
for
trainer_id
in
xrange
(
self
.
trainers
):
print
(
"create variable for program: %s.trainer_%d"
%
(
v
.
name
,
trainer_id
))
pserver_program
.
global_block
().
create_var
(
name
=
"%s.trainer_%d"
%
(
v
.
name
,
trainer_id
),
persistable
=
True
,
dtype
=
v
.
dtype
,
shape
=
v
.
shape
)
# step6
# step6
optimize_sub_program
=
Program
()
optimize_sub_program
=
Program
()
for
idx
,
opt_op
in
enumerate
(
optimize_ops
):
for
idx
,
opt_op
in
enumerate
(
optimize_ops
):
...
@@ -481,7 +494,7 @@ class DistributeTranspiler:
...
@@ -481,7 +494,7 @@ class DistributeTranspiler:
p
.
name
p
.
name
for
p
in
self
.
param_grad_ep_mapping
[
endpoint
][
"grads"
]
for
p
in
self
.
param_grad_ep_mapping
[
endpoint
][
"grads"
]
],
],
"
Trainers
"
:
self
.
trainers
"
Fanin
"
:
self
.
trainers
})
})
pserver_program
.
sync_with_cpp
()
pserver_program
.
sync_with_cpp
()
return
pserver_program
return
pserver_program
...
...
python/paddle/v2/fluid/tests/book_distribute/notest_recognize_digits_conv_dist.py
浏览文件 @
02ea3491
...
@@ -39,26 +39,27 @@ train_reader = paddle.batch(
...
@@ -39,26 +39,27 @@ train_reader = paddle.batch(
place
=
fluid
.
CPUPlace
()
place
=
fluid
.
CPUPlace
()
exe
=
fluid
.
Executor
(
place
)
exe
=
fluid
.
Executor
(
place
)
t
=
fluid
.
DistributeTranspiler
()
pserver_endpoints
=
os
.
getenv
(
"PSERVERS"
)
# all pserver endpoints
# all parameter server endpoints list for spliting parameters
trainers
=
int
(
os
.
getenv
(
"TRAINERS"
))
# total trainer count
pserver_endpoints
=
os
.
getenv
(
"PSERVERS"
)
current_endpoint
=
os
.
getenv
(
"SERVER_ENDPOINT"
)
# current pserver endpoint
# server endpoint for current node
current_endpoint
=
os
.
getenv
(
"SERVER_ENDPOINT"
)
# run as trainer or parameter server
training_role
=
os
.
getenv
(
"TRAINING_ROLE"
,
training_role
=
os
.
getenv
(
"TRAINING_ROLE"
,
"TRAINER"
)
# get the training role: trainer/pserver
"TRAINER"
)
# get the training role: trainer/pserver
t
.
transpile
(
optimize_ops
,
params_grads
,
pservers
=
pserver_endpoints
,
trainers
=
2
)
t
=
fluid
.
DistributeTranspiler
()
t
.
transpile
(
optimize_ops
,
params_grads
,
pservers
=
pserver_endpoints
,
trainers
=
trainers
)
if
training_role
==
"PSERVER"
:
if
training_role
==
"PSERVER"
:
if
not
current_endpoint
:
if
not
current_endpoint
:
print
(
"need env SERVER_ENDPOINT"
)
print
(
"need env SERVER_ENDPOINT"
)
exit
(
1
)
exit
(
1
)
pserver_prog
=
t
.
get_pserver_program
(
current_endpoint
,
optimize_ops
)
pserver_prog
=
t
.
get_pserver_program
(
current_endpoint
,
optimize_ops
)
exe
.
run
(
fluid
.
default_startup_program
())
pserver_startup
=
t
.
get_startup_program
(
current_endpoint
,
pserver_prog
)
exe
.
run
(
pserver_startup
)
exe
.
run
(
pserver_prog
)
exe
.
run
(
pserver_prog
)
elif
training_role
==
"TRAINER"
:
elif
training_role
==
"TRAINER"
:
trainer_prog
=
t
.
get_trainer_program
()
trainer_prog
=
t
.
get_trainer_program
()
feeder
=
fluid
.
DataFeeder
(
feed_list
=
[
images
,
label
],
place
=
place
)
feeder
=
fluid
.
DataFeeder
(
feed_list
=
[
images
,
label
],
place
=
place
)
# TODO(typhoonzero): change trainer startup program to fetch parameters from pserver
exe
.
run
(
fluid
.
default_startup_program
())
exe
.
run
(
fluid
.
default_startup_program
())
for
pass_id
in
range
(
PASS_NUM
):
for
pass_id
in
range
(
PASS_NUM
):
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录