Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
58be41fa
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
未验证
提交
58be41fa
编写于
1月 19, 2018
作者:
武
武毅
提交者:
GitHub
1月 19, 2018
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #7608 from typhoonzero/distributed_split_selectedrows
Enhance distributed train performance
上级
b7b5de7f
0aff1363
变更
7
隐藏空白更改
内联
并排
Showing
7 changed file
with
94 addition
and
83 deletion
+94
-83
paddle/operators/detail/grpc_client.cc
paddle/operators/detail/grpc_client.cc
+0
-3
paddle/operators/detail/grpc_server.cc
paddle/operators/detail/grpc_server.cc
+30
-20
paddle/operators/detail/grpc_server.h
paddle/operators/detail/grpc_server.h
+8
-7
paddle/operators/recv_op.cc
paddle/operators/recv_op.cc
+30
-44
paddle/operators/send_op.cc
paddle/operators/send_op.cc
+3
-0
python/paddle/v2/fluid/distribute_transpiler.py
python/paddle/v2/fluid/distribute_transpiler.py
+14
-1
python/paddle/v2/fluid/tests/book_distribute/notest_recognize_digits_conv_dist.py
...ests/book_distribute/notest_recognize_digits_conv_dist.py
+9
-8
未找到文件。
paddle/operators/detail/grpc_client.cc
浏览文件 @
58be41fa
...
@@ -63,9 +63,6 @@ bool RPCClient::AsyncGetVariable(const std::string& ep,
...
@@ -63,9 +63,6 @@ bool RPCClient::AsyncGetVariable(const std::string& ep,
sendrecv
::
VariableMessage
req
;
sendrecv
::
VariableMessage
req
;
req
.
set_varname
(
var_name
);
req
.
set_varname
(
var_name
);
auto
*
var
=
scope
.
FindVar
(
var_name
);
SerializeToMessage
(
var_name
,
var
,
ctx
,
&
req
);
// varhandle
// varhandle
VarHandle
var_h
;
VarHandle
var_h
;
var_h
.
ep
=
ep
;
var_h
.
ep
=
ep
;
...
...
paddle/operators/detail/grpc_server.cc
浏览文件 @
58be41fa
...
@@ -36,7 +36,10 @@ class RequestBase {
...
@@ -36,7 +36,10 @@ class RequestBase {
CallStatus
Status
()
{
return
status_
;
}
CallStatus
Status
()
{
return
status_
;
}
void
SetStatus
(
CallStatus
status
)
{
status_
=
status
;
}
void
SetStatus
(
CallStatus
status
)
{
status_
=
status
;
}
virtual
std
::
string
GetReqName
()
{
assert
(
false
);
}
virtual
std
::
string
GetReqName
()
{
assert
(
false
);
return
""
;
}
protected:
protected:
grpc
::
ServerContext
ctx_
;
grpc
::
ServerContext
ctx_
;
...
@@ -80,11 +83,13 @@ class RequestGet final : public RequestBase {
...
@@ -80,11 +83,13 @@ class RequestGet final : public RequestBase {
public:
public:
explicit
RequestGet
(
sendrecv
::
SendRecvService
::
AsyncService
*
service
,
explicit
RequestGet
(
sendrecv
::
SendRecvService
::
AsyncService
*
service
,
grpc
::
ServerCompletionQueue
*
cq
,
framework
::
Scope
*
scope
,
grpc
::
ServerCompletionQueue
*
cq
,
framework
::
Scope
*
scope
,
const
platform
::
DeviceContext
*
dev_ctx
)
const
platform
::
DeviceContext
*
dev_ctx
,
SimpleBlockQueue
<
char
>*
queue
)
:
RequestBase
(
service
,
cq
),
:
RequestBase
(
service
,
cq
),
responder_
(
&
ctx_
),
responder_
(
&
ctx_
),
scope_
(
scope
),
scope_
(
scope
),
dev_ctx_
(
dev_ctx
)
{
dev_ctx_
(
dev_ctx
),
queue_
(
queue
)
{
service_
->
RequestGetVariable
(
&
ctx_
,
&
request_
,
&
responder_
,
cq_
,
cq_
,
this
);
service_
->
RequestGetVariable
(
&
ctx_
,
&
request_
,
&
responder_
,
cq_
,
cq_
,
this
);
}
}
...
@@ -100,6 +105,7 @@ class RequestGet final : public RequestBase {
...
@@ -100,6 +105,7 @@ class RequestGet final : public RequestBase {
// TODO(gongwb): check var's info.
// TODO(gongwb): check var's info.
responder_
.
Finish
(
reply_
,
grpc
::
Status
::
OK
,
this
);
responder_
.
Finish
(
reply_
,
grpc
::
Status
::
OK
,
this
);
status_
=
FINISH
;
status_
=
FINISH
;
queue_
->
Push
(
'c'
);
}
}
protected:
protected:
...
@@ -108,8 +114,15 @@ class RequestGet final : public RequestBase {
...
@@ -108,8 +114,15 @@ class RequestGet final : public RequestBase {
ServerAsyncResponseWriter
<
sendrecv
::
VariableMessage
>
responder_
;
ServerAsyncResponseWriter
<
sendrecv
::
VariableMessage
>
responder_
;
framework
::
Scope
*
scope_
;
framework
::
Scope
*
scope_
;
const
platform
::
DeviceContext
*
dev_ctx_
;
const
platform
::
DeviceContext
*
dev_ctx_
;
SimpleBlockQueue
<
char
>*
queue_
;
};
};
void
AsyncGRPCServer
::
WaitClientGet
(
int
count
)
{
for
(
int
i
=
0
;
i
<
count
;
++
i
)
{
var_get_queue_
.
Pop
();
}
}
void
AsyncGRPCServer
::
RunSyncUpdate
()
{
void
AsyncGRPCServer
::
RunSyncUpdate
()
{
grpc
::
ServerBuilder
builder
;
grpc
::
ServerBuilder
builder
;
builder
.
AddListeningPort
(
address_
,
grpc
::
InsecureServerCredentials
());
builder
.
AddListeningPort
(
address_
,
grpc
::
InsecureServerCredentials
());
...
@@ -149,7 +162,6 @@ void AsyncGRPCServer::ShutdownQueue() {
...
@@ -149,7 +162,6 @@ void AsyncGRPCServer::ShutdownQueue() {
}
}
// This URL explains why shutdown is complicate:
// This URL explains why shutdown is complicate:
// https://stackoverflow.com/questions/35708348/grpc-what-is-the-recommended-way-to-shut-down-an-asynchronous-server-in-c
void
AsyncGRPCServer
::
ShutDown
()
{
void
AsyncGRPCServer
::
ShutDown
()
{
server_
->
Shutdown
();
server_
->
Shutdown
();
ShutdownQueue
();
ShutdownQueue
();
...
@@ -170,10 +182,12 @@ void AsyncGRPCServer::TryToRegisterNewGetOne() {
...
@@ -170,10 +182,12 @@ void AsyncGRPCServer::TryToRegisterNewGetOne() {
if
(
is_shut_down_
)
{
if
(
is_shut_down_
)
{
return
;
return
;
}
}
RequestGet
*
get
=
new
RequestGet
(
&
service_
,
cq_get_
.
get
(),
scope_
,
dev_ctx_
);
RequestGet
*
get
=
new
RequestGet
(
&
service_
,
cq_get_
.
get
(),
scope_
,
dev_ctx_
,
&
var_get_queue_
);
VLOG
(
4
)
<<
"create Requestget status:"
<<
get
->
Status
();
VLOG
(
4
)
<<
"create Requestget status:"
<<
get
->
Status
();
}
}
// FIXME(typhoonzero): remove wait argument and change cq_name to enum.
void
AsyncGRPCServer
::
HandleRequest
(
bool
wait
,
grpc
::
ServerCompletionQueue
*
cq
,
void
AsyncGRPCServer
::
HandleRequest
(
bool
wait
,
grpc
::
ServerCompletionQueue
*
cq
,
std
::
string
cq_name
,
std
::
string
cq_name
,
std
::
function
<
void
()
>
TryToRegisterNewOne
)
{
std
::
function
<
void
()
>
TryToRegisterNewOne
)
{
...
@@ -188,9 +202,9 @@ void AsyncGRPCServer::HandleRequest(bool wait, grpc::ServerCompletionQueue* cq,
...
@@ -188,9 +202,9 @@ void AsyncGRPCServer::HandleRequest(bool wait, grpc::ServerCompletionQueue* cq,
}
}
PADDLE_ENFORCE
(
tag
);
PADDLE_ENFORCE
(
tag
);
if
(
wait
&&
!
done_
)
{
// FIXME(typhoonzero): de-couple the barriers with recv_op
Wait
(
);
if
(
cq_name
==
"cq_get"
)
WaitCond
(
1
);
}
if
(
cq_name
==
"cq_send"
)
WaitCond
(
0
);
RequestBase
*
base
=
(
RequestBase
*
)
tag
;
RequestBase
*
base
=
(
RequestBase
*
)
tag
;
// reference:
// reference:
...
@@ -222,22 +236,18 @@ void AsyncGRPCServer::HandleRequest(bool wait, grpc::ServerCompletionQueue* cq,
...
@@ -222,22 +236,18 @@ void AsyncGRPCServer::HandleRequest(bool wait, grpc::ServerCompletionQueue* cq,
}
}
}
}
void
AsyncGRPCServer
::
Wait
()
{
void
AsyncGRPCServer
::
WaitCond
(
int
cond
)
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
this
->
mutex_
);
std
::
unique_lock
<
std
::
mutex
>
lock
(
this
->
barrier_mutex_
);
condition_
.
wait
(
lock
,
[
=
]
{
return
this
->
done_
==
true
;
});
barrier_condition_
.
wait
(
lock
,
}
[
=
]
{
return
this
->
barrier_cond_step_
==
cond
;
});
void
AsyncGRPCServer
::
Reset
()
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
this
->
mutex_
);
done_
=
false
;
}
}
void
AsyncGRPCServer
::
Done
(
)
{
void
AsyncGRPCServer
::
SetCond
(
int
cond
)
{
{
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
this
->
mutex_
);
std
::
lock_guard
<
std
::
mutex
>
lock
(
this
->
barrier_
mutex_
);
done_
=
true
;
barrier_cond_step_
=
cond
;
}
}
condition_
.
notify_all
();
barrier_
condition_
.
notify_all
();
}
}
}
// namespace detail
}
// namespace detail
...
...
paddle/operators/detail/grpc_server.h
浏览文件 @
58be41fa
...
@@ -41,9 +41,10 @@ class AsyncGRPCServer final : public sendrecv::SendRecvService::Service {
...
@@ -41,9 +41,10 @@ class AsyncGRPCServer final : public sendrecv::SendRecvService::Service {
void
RunSyncUpdate
();
void
RunSyncUpdate
();
void
Reset
();
// functions to sync server barrier status.
void
WaitCond
(
int
cond
);
void
Done
();
void
SetCond
(
int
cond
);
void
WaitClientGet
(
int
count
);
void
SetScope
(
framework
::
Scope
*
scope
)
{
scope_
=
scope
;
}
void
SetScope
(
framework
::
Scope
*
scope
)
{
scope_
=
scope
;
}
...
@@ -56,7 +57,6 @@ class AsyncGRPCServer final : public sendrecv::SendRecvService::Service {
...
@@ -56,7 +57,6 @@ class AsyncGRPCServer final : public sendrecv::SendRecvService::Service {
void
ShutDown
();
void
ShutDown
();
protected:
protected:
void
Wait
();
void
HandleRequest
(
bool
wait
,
grpc
::
ServerCompletionQueue
*
cq
,
void
HandleRequest
(
bool
wait
,
grpc
::
ServerCompletionQueue
*
cq
,
std
::
string
cq_name
,
std
::
string
cq_name
,
std
::
function
<
void
()
>
TryToRegisterNewOne
);
std
::
function
<
void
()
>
TryToRegisterNewOne
);
...
@@ -78,11 +78,12 @@ class AsyncGRPCServer final : public sendrecv::SendRecvService::Service {
...
@@ -78,11 +78,12 @@ class AsyncGRPCServer final : public sendrecv::SendRecvService::Service {
const
platform
::
DeviceContext
*
dev_ctx_
;
const
platform
::
DeviceContext
*
dev_ctx_
;
// received variable from RPC, operators fetch variable from this queue.
// received variable from RPC, operators fetch variable from this queue.
SimpleBlockQueue
<
MessageWithName
>
var_recv_queue_
;
SimpleBlockQueue
<
MessageWithName
>
var_recv_queue_
;
SimpleBlockQueue
<
char
>
var_get_queue_
;
// condition of the sub program
// condition of the sub program
std
::
mutex
mutex_
;
std
::
mutex
barrier_
mutex_
;
volatile
mutable
bool
done
_
;
mutable
int
barrier_cond_step
_
;
std
::
condition_variable
condition_
;
std
::
condition_variable
barrier_
condition_
;
std
::
unique_ptr
<
std
::
thread
>
t_send_
;
std
::
unique_ptr
<
std
::
thread
>
t_send_
;
std
::
unique_ptr
<
std
::
thread
>
t_get_
;
std
::
unique_ptr
<
std
::
thread
>
t_get_
;
...
...
paddle/operators/recv_op.cc
浏览文件 @
58be41fa
...
@@ -27,12 +27,17 @@ limitations under the License. */
...
@@ -27,12 +27,17 @@ limitations under the License. */
#include "paddle/operators/detail/grpc_server.h"
#include "paddle/operators/detail/grpc_server.h"
#include "paddle/operators/detail/sendrecvop_utils.h"
#include "paddle/operators/detail/sendrecvop_utils.h"
#include "paddle/operators/detail/simple_block_queue.h"
#include "paddle/operators/detail/simple_block_queue.h"
#include "paddle/string/printf.h"
#define LISTEN_TERMINATE_MESSAGE "TERMINATE@RECV"
#define LISTEN_TERMINATE_MESSAGE "TERMINATE@RECV"
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
constexpr
int
kCondStart
=
0
;
constexpr
int
kCondRunning
=
1
;
constexpr
int
kCondDone
=
2
;
void
RunServer
(
std
::
shared_ptr
<
detail
::
AsyncGRPCServer
>
service
)
{
void
RunServer
(
std
::
shared_ptr
<
detail
::
AsyncGRPCServer
>
service
)
{
service
->
RunSyncUpdate
();
service
->
RunSyncUpdate
();
VLOG
(
4
)
<<
"RunServer thread end"
;
VLOG
(
4
)
<<
"RunServer thread end"
;
...
@@ -77,42 +82,41 @@ class RecvOp : public framework::OperatorBase {
...
@@ -77,42 +82,41 @@ class RecvOp : public framework::OperatorBase {
if
(
grads_counter_
.
find
(
varname
)
==
grads_counter_
.
end
())
{
if
(
grads_counter_
.
find
(
varname
)
==
grads_counter_
.
end
())
{
grads_counter_
[
varname
]
=
0
;
grads_counter_
[
varname
]
=
0
;
}
}
char
ret
[
256
];
return
string
::
Sprintf
(
"%s.trainer_%d"
,
varname
,
grads_counter_
[
varname
]
++
);
snprintf
(
ret
,
sizeof
(
ret
),
"%s.trainer_%d"
,
varname
.
c_str
(),
grads_counter_
[
varname
]
++
);
return
std
::
string
(
ret
);
}
}
void
Run
(
const
framework
::
Scope
&
scope
,
void
Run
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
dev_place
)
const
override
{
const
platform
::
Place
&
dev_place
)
const
override
{
// FIXME(typhoonzero): no new scopes for every run.
framework
::
Scope
&
recv_scope
=
scope
.
NewScope
();
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
auto
&
dev_ctx
=
*
pool
.
Get
(
dev_place
);
auto
&
dev_ctx
=
*
pool
.
Get
(
dev_place
);
framework
::
Scope
&
recv_scope
=
scope
.
NewScope
();
// FIXME(Yancey1989): initialize rpc server with laze mode.
// FIXME(Yancey1989): initialize rpc server with laze mode.
rpc_service_
->
SetScope
(
&
recv_scope
);
rpc_service_
->
SetScope
(
&
recv_scope
);
rpc_service_
->
SetDevCtx
(
&
dev_ctx
);
rpc_service_
->
SetDevCtx
(
&
dev_ctx
);
auto
param_list
=
Attr
<
std
::
vector
<
std
::
string
>>
(
"ParamList"
);
auto
param_list
=
Attr
<
std
::
vector
<
std
::
string
>>
(
"ParamList"
);
auto
grad_list
=
Attr
<
std
::
vector
<
std
::
string
>>
(
"GradList"
);
auto
grad_list
=
Attr
<
std
::
vector
<
std
::
string
>>
(
"GradList"
);
auto
trainer_count
=
Attr
<
int
>
(
"Trainers
"
);
auto
fan_in
=
Attr
<
int
>
(
"Fanin
"
);
size_t
param_count
=
param_list
.
size
();
size_t
param_count
=
param_list
.
size
();
rpc_service_
->
Reset
();
std
::
string
program_str
=
Attr
<
std
::
string
>
(
"OptimizeProgram"
);
framework
::
proto
::
ProgramDesc
program_desc
;
program_desc
.
ParseFromString
(
program_str
);
framework
::
ProgramDesc
program
(
program_desc
);
framework
::
Executor
executor
(
dev_place
);
// TODO(typhoonzero): change this to a while_op for every cluster-batch.
// TODO(typhoonzero): change this to a while_op for every cluster-batch.
bool
exit_flag
=
false
;
bool
exit_flag
=
false
;
VLOG
(
4
)
<<
"param_count:"
<<
param_count
int64_t
barrier_size
=
param_count
*
fan_in
;
<<
" trainer_count:"
<<
trainer_count
;
while
(
!
exit_flag
)
{
while
(
!
exit_flag
)
{
// TODO(gognwb): simply this loop.
// Get from multiple trainers, we don't care about the order in which
// Get from multiple trainers, we don't care about order in which
// the gradients arrives, just add suffix 0~n and merge the gradient.
// the gradient arrives, just add suffix 0~n then average the gradient.
rpc_service_
->
SetCond
(
0
);
for
(
size_t
i
=
0
;
i
<
param_count
*
trainer_count
;
++
i
)
{
for
(
size_t
i
=
0
;
i
<
barrier_size
;
++
i
)
{
// blocking get one var from client.
const
detail
::
MessageWithName
&
v
=
rpc_service_
->
Get
();
const
detail
::
MessageWithName
&
v
=
rpc_service_
->
Get
();
auto
grad_var_name
=
v
.
first
;
auto
grad_var_name
=
v
.
first
;
if
(
grad_var_name
==
LISTEN_TERMINATE_MESSAGE
)
{
if
(
grad_var_name
==
LISTEN_TERMINATE_MESSAGE
)
{
VLOG
(
4
)
<<
"received LISTEN_TERMINATE_MESSAGE and RunOp.Run()
exit"
;
LOG
(
INFO
)
<<
"received terminate message and
exit"
;
exit_flag
=
true
;
exit_flag
=
true
;
break
;
break
;
}
}
...
@@ -121,49 +125,31 @@ class RecvOp : public framework::OperatorBase {
...
@@ -121,49 +125,31 @@ class RecvOp : public framework::OperatorBase {
if
(
it
!=
grad_list
.
end
())
{
if
(
it
!=
grad_list
.
end
())
{
param_var_name
=
param_list
[
it
-
grad_list
.
begin
()];
param_var_name
=
param_list
[
it
-
grad_list
.
begin
()];
}
else
{
}
else
{
LOG
(
ERROR
)
<<
"grad have no paired param found!
\"
"
<<
grad_var_name
LOG
(
ERROR
)
<<
"grad have no paired param:"
<<
grad_var_name
;
<<
"
\"
"
;
}
}
VLOG
(
3
)
<<
"recved grad: "
<<
grad_var_name
VLOG
(
3
)
<<
"recved grad: "
<<
grad_var_name
<<
" updating param: "
<<
param_var_name
;
<<
" updating param: "
<<
param_var_name
;
if
(
fan_in
>
1
)
{
auto
*
merged_grad
=
recv_scope
.
FindVar
(
grad_var_name
);
if
(
merged_grad
==
nullptr
)
{
auto
*
ptr
=
recv_scope
.
Var
(
grad_var_name
);
CreateTensorFromMessageType
(
ptr
,
v
.
second
.
type
());
VLOG
(
3
)
<<
"Create Variable "
<<
grad_var_name
<<
" on recv scope, which pointer is "
<<
ptr
<<
" type is "
<<
v
.
second
.
type
();
}
if
(
trainer_count
>
1
)
{
grad_var_name
=
this
->
GetGradVarNameForTrainer
(
grad_var_name
);
grad_var_name
=
this
->
GetGradVarNameForTrainer
(
grad_var_name
);
}
}
auto
*
var
=
recv_scope
.
FindVar
(
grad_var_name
);
auto
*
var
=
recv_scope
.
Var
(
grad_var_name
);
if
(
var
==
nullptr
)
{
LOG
(
ERROR
)
<<
"can not find server side var: "
<<
grad_var_name
;
PADDLE_THROW
(
"can not find server side var"
);
}
detail
::
DeserializeFromMessage
(
v
.
second
,
dev_ctx
,
var
);
detail
::
DeserializeFromMessage
(
v
.
second
,
dev_ctx
,
var
);
}
}
if
(
exit_flag
)
{
if
(
exit_flag
)
{
break
;
break
;
}
}
rpc_service_
->
Reset
();
std
::
string
program_str
=
Attr
<
std
::
string
>
(
"OptimizeProgram"
);
framework
::
proto
::
ProgramDesc
program_desc
;
program_desc
.
ParseFromString
(
program_str
);
framework
::
ProgramDesc
program
(
program_desc
);
framework
::
Executor
executor
(
dev_place
);
// Run sub graph to get optimized tensor
try
{
try
{
executor
.
Run
(
program
,
&
recv_scope
,
0
,
/*global_block*/
executor
.
Run
(
program
,
&
recv_scope
,
0
,
/*global_block*/
false
/*create_local_scope*/
,
false
/*create_vars*/
);
false
/*create_local_scope*/
,
false
/*create_vars*/
);
}
catch
(
std
::
exception
&
e
)
{
}
catch
(
std
::
exception
&
e
)
{
LOG
(
ERROR
)
<<
"run sub program error "
<<
e
.
what
();
LOG
(
ERROR
)
<<
"run sub program error "
<<
e
.
what
();
}
}
rpc_service_
->
SetCond
(
1
);
rpc_service_
->
Done
(
);
rpc_service_
->
WaitClientGet
(
barrier_size
);
grads_counter_
.
clear
();
grads_counter_
.
clear
();
}
// while(true)
}
// while(true)
}
}
...
@@ -199,7 +185,7 @@ This operator will recv tensor from send_op
...
@@ -199,7 +185,7 @@ This operator will recv tensor from send_op
"GradList"
,
"type list of string"
,
"GradList"
,
"type list of string"
,
"grad->param name mapping to find which param to optimize."
)
"grad->param name mapping to find which param to optimize."
)
.
SetDefault
({});
.
SetDefault
({});
AddAttr
<
int
>
(
"
Trainers
"
,
"type int"
,
AddAttr
<
int
>
(
"
Fanin
"
,
"type int"
,
"Number of trainers in the current cluster job"
)
"Number of trainers in the current cluster job"
)
.
SetDefault
(
1
);
.
SetDefault
(
1
);
}
}
...
...
paddle/operators/send_op.cc
浏览文件 @
58be41fa
...
@@ -41,10 +41,13 @@ class SendOp : public framework::OperatorBase {
...
@@ -41,10 +41,13 @@ class SendOp : public framework::OperatorBase {
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
auto
&
ctx
=
*
pool
.
Get
(
place
);
auto
&
ctx
=
*
pool
.
Get
(
place
);
for
(
size_t
i
=
0
;
i
<
ins
.
size
();
i
++
)
{
for
(
size_t
i
=
0
;
i
<
ins
.
size
();
i
++
)
{
VLOG
(
3
)
<<
"sending "
<<
ins
[
i
];
client_
.
AsyncSendVariable
(
epmap
[
i
],
ctx
,
scope
,
ins
[
i
]);
client_
.
AsyncSendVariable
(
epmap
[
i
],
ctx
,
scope
,
ins
[
i
]);
}
}
PADDLE_ENFORCE
(
client_
.
Wait
());
for
(
size_t
i
=
0
;
i
<
outs
.
size
();
i
++
)
{
for
(
size_t
i
=
0
;
i
<
outs
.
size
();
i
++
)
{
VLOG
(
3
)
<<
"getting "
<<
outs
[
i
];
client_
.
AsyncGetVariable
(
epmap
[
i
],
ctx
,
scope
,
outs
[
i
]);
client_
.
AsyncGetVariable
(
epmap
[
i
],
ctx
,
scope
,
outs
[
i
]);
}
}
...
...
python/paddle/v2/fluid/distribute_transpiler.py
浏览文件 @
58be41fa
...
@@ -420,6 +420,19 @@ class DistributeTranspiler:
...
@@ -420,6 +420,19 @@ class DistributeTranspiler:
pserver_program
=
Program
()
pserver_program
=
Program
()
for
v
in
self
.
param_grad_ep_mapping
[
endpoint
][
"params"
]:
for
v
in
self
.
param_grad_ep_mapping
[
endpoint
][
"params"
]:
self
.
_clone_var
(
pserver_program
.
global_block
(),
v
)
self
.
_clone_var
(
pserver_program
.
global_block
(),
v
)
for
v
in
self
.
param_grad_ep_mapping
[
endpoint
][
"grads"
]:
# create vars for each trainer in global scope, so
# we don't need to create them when grad arrives.
pserver_program
.
global_block
().
create_var
(
name
=
v
.
name
,
persistable
=
True
,
dtype
=
v
.
dtype
,
shape
=
v
.
shape
)
for
trainer_id
in
xrange
(
self
.
trainers
):
print
(
"create variable for program: %s.trainer_%d"
%
(
v
.
name
,
trainer_id
))
pserver_program
.
global_block
().
create_var
(
name
=
"%s.trainer_%d"
%
(
v
.
name
,
trainer_id
),
persistable
=
True
,
dtype
=
v
.
dtype
,
shape
=
v
.
shape
)
# step6
# step6
optimize_sub_program
=
Program
()
optimize_sub_program
=
Program
()
for
idx
,
opt_op
in
enumerate
(
self
.
optimize_ops
):
for
idx
,
opt_op
in
enumerate
(
self
.
optimize_ops
):
...
@@ -449,7 +462,7 @@ class DistributeTranspiler:
...
@@ -449,7 +462,7 @@ class DistributeTranspiler:
p
.
name
p
.
name
for
p
in
self
.
param_grad_ep_mapping
[
endpoint
][
"grads"
]
for
p
in
self
.
param_grad_ep_mapping
[
endpoint
][
"grads"
]
],
],
"
Trainers
"
:
self
.
trainers
"
Fanin
"
:
self
.
trainers
})
})
pserver_program
.
sync_with_cpp
()
pserver_program
.
sync_with_cpp
()
return
pserver_program
return
pserver_program
...
...
python/paddle/v2/fluid/tests/book_distribute/notest_recognize_digits_conv_dist.py
浏览文件 @
58be41fa
...
@@ -52,26 +52,27 @@ train_reader = paddle.batch(
...
@@ -52,26 +52,27 @@ train_reader = paddle.batch(
place
=
fluid
.
CPUPlace
()
place
=
fluid
.
CPUPlace
()
exe
=
fluid
.
Executor
(
place
)
exe
=
fluid
.
Executor
(
place
)
t
=
fluid
.
DistributeTranspiler
()
pserver_endpoints
=
os
.
getenv
(
"PSERVERS"
)
# all pserver endpoints
# all parameter server endpoints list for spliting parameters
trainers
=
int
(
os
.
getenv
(
"TRAINERS"
))
# total trainer count
pserver_endpoints
=
os
.
getenv
(
"PSERVERS"
)
current_endpoint
=
os
.
getenv
(
"SERVER_ENDPOINT"
)
# current pserver endpoint
# server endpoint for current node
current_endpoint
=
os
.
getenv
(
"SERVER_ENDPOINT"
)
# run as trainer or parameter server
training_role
=
os
.
getenv
(
"TRAINING_ROLE"
,
training_role
=
os
.
getenv
(
"TRAINING_ROLE"
,
"TRAINER"
)
# get the training role: trainer/pserver
"TRAINER"
)
# get the training role: trainer/pserver
t
.
transpile
(
optimize_ops
,
params_grads
,
pservers
=
pserver_endpoints
,
trainers
=
2
)
t
=
fluid
.
DistributeTranspiler
()
t
.
transpile
(
optimize_ops
,
params_grads
,
pservers
=
pserver_endpoints
,
trainers
=
trainers
)
if
training_role
==
"PSERVER"
:
if
training_role
==
"PSERVER"
:
if
not
current_endpoint
:
if
not
current_endpoint
:
print
(
"need env SERVER_ENDPOINT"
)
print
(
"need env SERVER_ENDPOINT"
)
exit
(
1
)
exit
(
1
)
pserver_prog
=
t
.
get_pserver_program
(
current_endpoint
)
pserver_prog
=
t
.
get_pserver_program
(
current_endpoint
)
exe
.
run
(
fluid
.
default_startup_program
())
pserver_startup
=
t
.
get_startup_program
(
current_endpoint
,
pserver_prog
)
exe
.
run
(
pserver_startup
)
exe
.
run
(
pserver_prog
)
exe
.
run
(
pserver_prog
)
elif
training_role
==
"TRAINER"
:
elif
training_role
==
"TRAINER"
:
trainer_prog
=
t
.
get_trainer_program
()
trainer_prog
=
t
.
get_trainer_program
()
feeder
=
fluid
.
DataFeeder
(
feed_list
=
[
images
,
label
],
place
=
place
)
feeder
=
fluid
.
DataFeeder
(
feed_list
=
[
images
,
label
],
place
=
place
)
# TODO(typhoonzero): change trainer startup program to fetch parameters from pserver
exe
.
run
(
fluid
.
default_startup_program
())
exe
.
run
(
fluid
.
default_startup_program
())
for
pass_id
in
range
(
PASS_NUM
):
for
pass_id
in
range
(
PASS_NUM
):
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录