Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
b4a3b750
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
b4a3b750
编写于
10月 15, 2019
作者:
1
123malin
提交者:
GitHub
10月 15, 2019
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
bug fix: invalid learning rate decay in pserver async mode (#20325)
* bug fix: invalid learning rate decay in pserver async mode
上级
cadc6a97
变更
15
隐藏空白更改
内联
并排
Showing
15 changed file
with
296 addition
and
8 deletion
+296
-8
paddle/fluid/operators/distributed/grpc/grpc_client.cc
paddle/fluid/operators/distributed/grpc/grpc_client.cc
+29
-0
paddle/fluid/operators/distributed/grpc/grpc_client.h
paddle/fluid/operators/distributed/grpc/grpc_client.h
+18
-0
paddle/fluid/operators/distributed/grpc/grpc_server.cc
paddle/fluid/operators/distributed/grpc/grpc_server.cc
+39
-0
paddle/fluid/operators/distributed/grpc/grpc_service.h
paddle/fluid/operators/distributed/grpc/grpc_service.h
+4
-1
paddle/fluid/operators/distributed/request_handler.h
paddle/fluid/operators/distributed/request_handler.h
+9
-0
paddle/fluid/operators/distributed/request_handler_impl.cc
paddle/fluid/operators/distributed/request_handler_impl.cc
+17
-0
paddle/fluid/operators/distributed/request_handler_impl.h
paddle/fluid/operators/distributed/request_handler_impl.h
+17
-0
paddle/fluid/operators/distributed/rpc_client.h
paddle/fluid/operators/distributed/rpc_client.h
+4
-0
paddle/fluid/operators/distributed/send_recv.proto.in
paddle/fluid/operators/distributed/send_recv.proto.in
+1
-0
paddle/fluid/operators/distributed_ops/distributed_notify_op.cc
.../fluid/operators/distributed_ops/distributed_notify_op.cc
+84
-0
paddle/fluid/operators/distributed_ops/listen_and_serv_op.cc
paddle/fluid/operators/distributed_ops/listen_and_serv_op.cc
+24
-5
paddle/fluid/operators/distributed_ops/listen_and_serv_op.h
paddle/fluid/operators/distributed_ops/listen_and_serv_op.h
+2
-0
python/paddle/fluid/tests/unittests/dist_ctr.py
python/paddle/fluid/tests/unittests/dist_ctr.py
+10
-2
python/paddle/fluid/tests/unittests/test_dist_ctr.py
python/paddle/fluid/tests/unittests/test_dist_ctr.py
+22
-0
python/paddle/fluid/transpiler/distribute_transpiler.py
python/paddle/fluid/transpiler/distribute_transpiler.py
+16
-0
未找到文件。
paddle/fluid/operators/distributed/grpc/grpc_client.cc
浏览文件 @
b4a3b750
...
@@ -438,6 +438,35 @@ VarHandlePtr GRPCClient::AsyncCheckpointNotify(const std::string& ep,
...
@@ -438,6 +438,35 @@ VarHandlePtr GRPCClient::AsyncCheckpointNotify(const std::string& ep,
return
h
;
return
h
;
}
}
VarHandlePtr
GRPCClient
::
AsyncDistributeNotify
(
const
std
::
string
&
ep
,
const
std
::
string
&
type
,
int64_t
time_out
)
{
const
auto
ch
=
GetChannel
(
ep
);
DistributeNotifyProcessor
*
s
=
new
DistributeNotifyProcessor
(
ch
);
const
std
::
string
method
=
kRequestNotify
;
VarHandlePtr
h
(
new
VarHandle
(
ep
,
method
,
LEARNING_RATE_DECAY_MESSAGE
,
nullptr
,
nullptr
));
s
->
Prepare
(
h
,
time_out
);
sendrecv
::
VariableMessage
req
;
req
.
set_varname
(
type
);
platform
::
RecordRPCEvent
record_event
(
method
);
auto
rpc
=
s
->
stub_
->
AsyncDistributeNotify
(
s
->
context_
.
get
(),
req
,
&
cq_
);
rpc
->
Finish
(
&
s
->
reply_
,
&
s
->
status_
,
reinterpret_cast
<
void
*>
(
s
));
req_count_
++
;
if
(
UNLIKELY
(
platform
::
IsProfileEnabled
()))
{
h
->
Wait
();
}
return
h
;
}
bool
GRPCClient
::
Wait
()
{
bool
GRPCClient
::
Wait
()
{
std
::
unique_lock
<
std
::
mutex
>
lk
(
sync_mutex_
);
std
::
unique_lock
<
std
::
mutex
>
lk
(
sync_mutex_
);
sync_cond_
.
wait
(
lk
,
[
this
]
{
return
(
req_count_
==
0
||
ok_
==
false
);
});
sync_cond_
.
wait
(
lk
,
[
this
]
{
return
(
req_count_
==
0
||
ok_
==
false
);
});
...
...
paddle/fluid/operators/distributed/grpc/grpc_client.h
浏览文件 @
b4a3b750
...
@@ -173,6 +173,20 @@ class CheckpointNotifyProcessor : public BaseProcessor {
...
@@ -173,6 +173,20 @@ class CheckpointNotifyProcessor : public BaseProcessor {
std
::
unique_ptr
<
sendrecv
::
SendRecvService
::
Stub
>
stub_
;
std
::
unique_ptr
<
sendrecv
::
SendRecvService
::
Stub
>
stub_
;
};
};
class
DistributeNotifyProcessor
:
public
BaseProcessor
{
public:
explicit
DistributeNotifyProcessor
(
std
::
shared_ptr
<
grpc
::
Channel
>
ch
)
:
BaseProcessor
()
{
stub_
=
sendrecv
::
SendRecvService
::
NewStub
(
ch
);
}
virtual
~
DistributeNotifyProcessor
()
{}
void
ProcessImpl
()
override
{}
sendrecv
::
VoidMessage
reply_
;
std
::
unique_ptr
<
sendrecv
::
SendRecvService
::
Stub
>
stub_
;
};
class
GRPCClient
:
public
RPCClient
{
class
GRPCClient
:
public
RPCClient
{
public:
public:
GRPCClient
()
:
ok_
(
true
),
completed_
(
false
),
stopped_
(
false
)
{}
GRPCClient
()
:
ok_
(
true
),
completed_
(
false
),
stopped_
(
false
)
{}
...
@@ -225,6 +239,10 @@ class GRPCClient : public RPCClient {
...
@@ -225,6 +239,10 @@ class GRPCClient : public RPCClient {
const
std
::
string
&
ep
,
const
std
::
string
&
dir
,
const
std
::
string
&
ep
,
const
std
::
string
&
dir
,
int64_t
time_out
=
FLAGS_rpc_deadline
)
override
;
int64_t
time_out
=
FLAGS_rpc_deadline
)
override
;
VarHandlePtr
AsyncDistributeNotify
(
const
std
::
string
&
ep
,
const
std
::
string
&
type
,
int64_t
time_out
=
FLAGS_rpc_deadline
)
override
;
VarHandlePtr
AsyncSendComplete
(
VarHandlePtr
AsyncSendComplete
(
const
std
::
string
&
ep
,
int64_t
time_out
=
FLAGS_rpc_deadline
)
override
;
const
std
::
string
&
ep
,
int64_t
time_out
=
FLAGS_rpc_deadline
)
override
;
...
...
paddle/fluid/operators/distributed/grpc/grpc_server.cc
浏览文件 @
b4a3b750
...
@@ -393,6 +393,43 @@ class RequestCheckpointNotify final : public RequestBase {
...
@@ -393,6 +393,43 @@ class RequestCheckpointNotify final : public RequestBase {
ServerAsyncResponseWriter
<
sendrecv
::
VoidMessage
>
responder_
;
ServerAsyncResponseWriter
<
sendrecv
::
VoidMessage
>
responder_
;
};
};
class
RequestNotify
final
:
public
RequestBase
{
public:
explicit
RequestNotify
(
GrpcService
::
AsyncService
*
service
,
::
grpc
::
ServerCompletionQueue
*
cq
,
RequestHandler
*
request_handler
,
int
req_id
)
:
RequestBase
(
service
,
cq
,
request_handler
,
req_id
),
responder_
(
&
ctx_
)
{
request_
.
reset
(
new
GRPCVariableResponse
(
request_handler
->
scope
(),
request_handler
->
dev_ctx
()));
int
method_id
=
static_cast
<
int
>
(
distributed
::
GrpcMethod
::
kRequestNotify
);
service_
->
RequestAsyncUnary
(
method_id
,
&
ctx_
,
request_
.
get
(),
&
responder_
,
cq_
,
cq_
,
reinterpret_cast
<
void
*>
(
static_cast
<
intptr_t
>
(
req_id
)));
}
virtual
~
RequestNotify
()
{}
std
::
string
GetReqName
()
override
{
return
request_
->
Varname
();
}
void
Process
()
override
{
auto
scope
=
request_
->
GetMutableLocalScope
();
std
::
string
varname
=
request_
->
Varname
();
int
trainer_id
=
request_
->
GetTrainerId
();
VLOG
(
4
)
<<
"RequestNotify notify: "
<<
varname
<<
", trainer id: "
<<
trainer_id
;
request_handler_
->
Handle
(
varname
,
scope
,
nullptr
,
nullptr
,
trainer_id
);
Finish
(
reply_
,
&
responder_
);
}
protected:
std
::
shared_ptr
<
GRPCVariableResponse
>
request_
;
sendrecv
::
VoidMessage
reply_
;
ServerAsyncResponseWriter
<
sendrecv
::
VoidMessage
>
responder_
;
};
void
AsyncGRPCServer
::
WaitServerReady
()
{
void
AsyncGRPCServer
::
WaitServerReady
()
{
VLOG
(
4
)
<<
"AsyncGRPCServer is waiting server ready"
;
VLOG
(
4
)
<<
"AsyncGRPCServer is waiting server ready"
;
std
::
unique_lock
<
std
::
mutex
>
lock
(
this
->
mutex_ready_
);
std
::
unique_lock
<
std
::
mutex
>
lock
(
this
->
mutex_ready_
);
...
@@ -526,6 +563,8 @@ void AsyncGRPCServer::TryToRegisterNewOne(const std::string& rpc_name,
...
@@ -526,6 +563,8 @@ void AsyncGRPCServer::TryToRegisterNewOne(const std::string& rpc_name,
b
=
new
RequestPrefetch
(
&
service_
,
cq
.
get
(),
handler
,
req_id
);
b
=
new
RequestPrefetch
(
&
service_
,
cq
.
get
(),
handler
,
req_id
);
}
else
if
(
rpc_name
==
kRequestCheckpoint
)
{
}
else
if
(
rpc_name
==
kRequestCheckpoint
)
{
b
=
new
RequestCheckpointNotify
(
&
service_
,
cq
.
get
(),
handler
,
req_id
);
b
=
new
RequestCheckpointNotify
(
&
service_
,
cq
.
get
(),
handler
,
req_id
);
}
else
if
(
rpc_name
==
kRequestNotify
)
{
b
=
new
RequestNotify
(
&
service_
,
cq
.
get
(),
handler
,
req_id
);
}
else
{
}
else
{
PADDLE_ENFORCE
(
false
,
"not supported rpc"
);
PADDLE_ENFORCE
(
false
,
"not supported rpc"
);
}
}
...
...
paddle/fluid/operators/distributed/grpc/grpc_service.h
浏览文件 @
b4a3b750
...
@@ -84,10 +84,11 @@ enum class GrpcMethod {
...
@@ -84,10 +84,11 @@ enum class GrpcMethod {
kGetVariableNoBarrier
,
kGetVariableNoBarrier
,
kGetMonomerVariable
,
kGetMonomerVariable
,
kGetMonomerBarrier
,
kGetMonomerBarrier
,
kRequestNotify
,
};
};
static
const
int
kGrpcNumMethods
=
static
const
int
kGrpcNumMethods
=
static_cast
<
int
>
(
GrpcMethod
::
k
GetMonomerBarrier
)
+
1
;
static_cast
<
int
>
(
GrpcMethod
::
k
RequestNotify
)
+
1
;
inline
const
char
*
GrpcMethodName
(
GrpcMethod
id
)
{
inline
const
char
*
GrpcMethodName
(
GrpcMethod
id
)
{
switch
(
id
)
{
switch
(
id
)
{
...
@@ -105,6 +106,8 @@ inline const char* GrpcMethodName(GrpcMethod id) {
...
@@ -105,6 +106,8 @@ inline const char* GrpcMethodName(GrpcMethod id) {
return
"/sendrecv.SendRecvService/PrefetchVariable"
;
return
"/sendrecv.SendRecvService/PrefetchVariable"
;
case
GrpcMethod
::
kCheckpointNotify
:
case
GrpcMethod
::
kCheckpointNotify
:
return
"/sendrecv.SendRecvService/CheckpointNotify"
;
return
"/sendrecv.SendRecvService/CheckpointNotify"
;
case
GrpcMethod
::
kRequestNotify
:
return
"/sendrecv.SendRecvService/DistributeNotify"
;
}
}
// Shouldn't be reached.
// Shouldn't be reached.
...
...
paddle/fluid/operators/distributed/request_handler.h
浏览文件 @
b4a3b750
...
@@ -45,6 +45,7 @@ constexpr char kRequestPrefetch[] = "RequestPrefetch";
...
@@ -45,6 +45,7 @@ constexpr char kRequestPrefetch[] = "RequestPrefetch";
constexpr
char
kRequestCheckpoint
[]
=
"RequestCheckpoint"
;
constexpr
char
kRequestCheckpoint
[]
=
"RequestCheckpoint"
;
constexpr
char
kRequestPassBarrier
[]
=
"RequestPassBarrier"
;
constexpr
char
kRequestPassBarrier
[]
=
"RequestPassBarrier"
;
constexpr
char
kRequestGetNoBarrier
[]
=
"GetVariableNoBarrier"
;
constexpr
char
kRequestGetNoBarrier
[]
=
"GetVariableNoBarrier"
;
constexpr
char
kRequestNotify
[]
=
"RequestNotify"
;
constexpr
char
kSendRPC
[]
=
"SendRPC"
;
constexpr
char
kSendRPC
[]
=
"SendRPC"
;
constexpr
char
kGetRPC
[]
=
"GetRPC"
;
constexpr
char
kGetRPC
[]
=
"GetRPC"
;
...
@@ -62,6 +63,7 @@ constexpr char kCheckPointNotifyRPC[] = "CheckPointNotifyRPC";
...
@@ -62,6 +63,7 @@ constexpr char kCheckPointNotifyRPC[] = "CheckPointNotifyRPC";
#define FETCH_BARRIER_MESSAGE "FETCH_BARRIER@RECV"
#define FETCH_BARRIER_MESSAGE "FETCH_BARRIER@RECV"
#define COMPLETE_MESSAGE "COMPLETE@RECV"
#define COMPLETE_MESSAGE "COMPLETE@RECV"
#define WITHOUT_BARRIER_MESSAGE "@WITHOUT_BARRIER@RECV"
#define WITHOUT_BARRIER_MESSAGE "@WITHOUT_BARRIER@RECV"
#define LEARNING_RATE_DECAY_MESSAGE "LRDECAY@RECV"
#define CHECKPOINT_SAVE_MESSAGE "SAVE@CHECKPOINTNOTIFY"
#define CHECKPOINT_SAVE_MESSAGE "SAVE@CHECKPOINTNOTIFY"
#define CHECKPOINT_LOAD_MESSAGE "LOAD@CHECKPOINTNOTIFY"
#define CHECKPOINT_LOAD_MESSAGE "LOAD@CHECKPOINTNOTIFY"
...
@@ -188,6 +190,11 @@ class RequestHandler {
...
@@ -188,6 +190,11 @@ class RequestHandler {
sparse_grad_to_param_
=
g
;
sparse_grad_to_param_
=
g
;
}
}
void
SetLrDecayPreparedCtx
(
std
::
shared_ptr
<
framework
::
ExecutorPrepareContext
>
g
)
{
lr_decay_prepared_ctx_
=
g
;
}
void
SetRPCServer
(
RPCServer
*
rpc_server
)
{
rpc_server_
=
rpc_server
;
}
void
SetRPCServer
(
RPCServer
*
rpc_server
)
{
rpc_server_
=
rpc_server
;
}
// Get attributes.
// Get attributes.
...
@@ -238,6 +245,8 @@ class RequestHandler {
...
@@ -238,6 +245,8 @@ class RequestHandler {
grad_to_prepared_ctx_
;
grad_to_prepared_ctx_
;
std
::
unordered_map
<
std
::
string
,
std
::
string
>*
sparse_grad_to_param_
;
std
::
unordered_map
<
std
::
string
,
std
::
string
>*
sparse_grad_to_param_
;
// used for lr decay
std
::
shared_ptr
<
framework
::
ExecutorPrepareContext
>
lr_decay_prepared_ctx_
;
RPCServer
*
rpc_server_
;
RPCServer
*
rpc_server_
;
};
};
...
...
paddle/fluid/operators/distributed/request_handler_impl.cc
浏览文件 @
b4a3b750
...
@@ -251,6 +251,23 @@ bool RequestCheckpointHandler::Handle(const std::string& varname,
...
@@ -251,6 +251,23 @@ bool RequestCheckpointHandler::Handle(const std::string& varname,
return
true
;
return
true
;
}
}
bool
RequestNotifyHandler
::
Handle
(
const
std
::
string
&
varname
,
framework
::
Scope
*
scope
,
framework
::
Variable
*
invar
,
framework
::
Variable
**
outvar
,
const
int
trainer_id
,
const
std
::
string
&
out_var_name
,
const
std
::
string
&
table_name
)
{
VLOG
(
4
)
<<
"RequestNotifyHandler"
<<
varname
;
if
(
varname
==
LEARNING_RATE_DECAY_MESSAGE
)
{
PADDLE_ENFORCE_NE
(
lr_decay_block_id
,
-
1
,
"when lr_decay_block_id = -1, there should be no RPC invoke."
);
executor_
->
RunPreparedContext
(
lr_decay_prepared_ctx_
.
get
(),
scope_
);
}
return
true
;
}
}
// namespace distributed
}
// namespace distributed
}
// namespace operators
}
// namespace operators
}
// namespace paddle
}
// namespace paddle
paddle/fluid/operators/distributed/request_handler_impl.h
浏览文件 @
b4a3b750
...
@@ -17,6 +17,7 @@
...
@@ -17,6 +17,7 @@
#include <time.h>
#include <time.h>
#include <functional>
#include <functional>
#include <memory>
#include <string>
#include <string>
#include <utility>
#include <utility>
#include <vector>
#include <vector>
...
@@ -126,6 +127,22 @@ class RequestCheckpointHandler final : public RequestHandler {
...
@@ -126,6 +127,22 @@ class RequestCheckpointHandler final : public RequestHandler {
int
checkpoint_notify_id
;
int
checkpoint_notify_id
;
};
};
class
RequestNotifyHandler
final
:
public
RequestHandler
{
public:
explicit
RequestNotifyHandler
(
bool
sync_mode
,
int
lr_decay_block_id
)
:
RequestHandler
(
sync_mode
)
{
this
->
lr_decay_block_id
=
lr_decay_block_id
;
}
virtual
~
RequestNotifyHandler
()
{}
bool
Handle
(
const
std
::
string
&
varname
,
framework
::
Scope
*
scope
,
framework
::
Variable
*
var
,
framework
::
Variable
**
outvar
,
const
int
trainer_id
,
const
std
::
string
&
out_var_name
=
""
,
const
std
::
string
&
table_name
=
""
)
override
;
private:
int
lr_decay_block_id
;
};
}
// namespace distributed
}
// namespace distributed
}
// namespace operators
}
// namespace operators
}
// namespace paddle
}
// namespace paddle
paddle/fluid/operators/distributed/rpc_client.h
浏览文件 @
b4a3b750
...
@@ -80,6 +80,10 @@ class RPCClient {
...
@@ -80,6 +80,10 @@ class RPCClient {
const
std
::
string
&
ep
,
const
std
::
string
&
dir
,
const
std
::
string
&
ep
,
const
std
::
string
&
dir
,
int64_t
time_out
=
FLAGS_rpc_deadline
)
=
0
;
int64_t
time_out
=
FLAGS_rpc_deadline
)
=
0
;
virtual
VarHandlePtr
AsyncDistributeNotify
(
const
std
::
string
&
ep
,
const
std
::
string
&
type
,
int64_t
time_out
=
FLAGS_rpc_deadline
)
=
0
;
virtual
VarHandlePtr
AsyncSendComplete
(
virtual
VarHandlePtr
AsyncSendComplete
(
const
std
::
string
&
ep
,
int64_t
time_out
=
FLAGS_rpc_deadline
)
=
0
;
const
std
::
string
&
ep
,
int64_t
time_out
=
FLAGS_rpc_deadline
)
=
0
;
...
...
paddle/fluid/operators/distributed/send_recv.proto.in
浏览文件 @
b4a3b750
...
@@ -28,6 +28,7 @@ service SendRecvService {
...
@@ -28,6 +28,7 @@ service SendRecvService {
rpc
PrefetchVariable
(
VariableMessage
)
returns
(
VariableMessage
)
{}
rpc
PrefetchVariable
(
VariableMessage
)
returns
(
VariableMessage
)
{}
rpc
CheckpointNotify
(
VariableMessage
)
returns
(
VoidMessage
)
{}
rpc
CheckpointNotify
(
VariableMessage
)
returns
(
VoidMessage
)
{}
rpc
DistributeNotify
(
VariableMessage
)
returns
(
VoidMessage
)
{}
rpc
GetMonomerVariable
(
VariableMessage
)
returns
(
VariableMessage
)
{}
rpc
GetMonomerVariable
(
VariableMessage
)
returns
(
VariableMessage
)
{}
rpc
GetMonomerBarrier
(
VariableMessage
)
returns
(
VoidMessage
)
{}
rpc
GetMonomerBarrier
(
VariableMessage
)
returns
(
VoidMessage
)
{}
...
...
paddle/fluid/operators/distributed_ops/distributed_notify_op.cc
0 → 100644
浏览文件 @
b4a3b750
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <future> // NOLINT
#include <ostream>
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/distributed/distributed.h"
#include "paddle/fluid/operators/distributed_ops/send_recv_util.h"
#include "paddle/fluid/string/printf.h"
namespace
paddle
{
namespace
operators
{
class
DistributedNotifyOp
:
public
framework
::
OperatorBase
{
public:
DistributedNotifyOp
(
const
std
::
string
&
type
,
const
framework
::
VariableNameMap
&
inputs
,
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
AttributeMap
&
attrs
)
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{
std
::
vector
<
std
::
string
>
epmap
=
Attr
<
std
::
vector
<
std
::
string
>>
(
"epmap"
);
std
::
string
type
=
Attr
<
std
::
string
>
(
"type"
);
int
trainer_id
=
Attr
<
int
>
(
"trainer_id"
);
distributed
::
RPCClient
*
rpc_client
=
distributed
::
RPCClient
::
GetInstance
<
RPCCLIENT_T
>
(
trainer_id
);
for
(
size_t
i
=
0
;
i
<
epmap
.
size
();
i
++
)
{
rpc_client
->
AsyncDistributeNotify
(
epmap
[
i
],
type
);
VLOG
(
4
)
<<
"distribute notify sending : "
<<
type
<<
" to "
<<
epmap
[
i
];
}
PADDLE_ENFORCE_EQ
(
rpc_client
->
Wait
(),
true
,
"internal error in RPCClient"
);
}
};
class
DistributedNotifyOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
public:
void
Make
()
{
AddAttr
<
std
::
vector
<
std
::
string
>>
(
"epmap"
,
"(string vector, default 127.0.0.1:6164)"
"Parameter Server endpoints in the order"
)
.
SetDefault
({
"127.0.0.1:6164"
});
AddAttr
<
std
::
string
>
(
"type"
,
"(string, default '') indicate the action type"
);
AddAttr
<
int
>
(
"trainer_id"
,
"trainer id from 0 ~ worker_num."
).
SetDefault
(
0
);
AddComment
(
R"DOC(
DistributeNotify operator
This operator will send a signal to listen_and_serve op at
the parameter server.
)DOC"
);
}
};
class
DistributedNotifyOpShapeInference
:
public
framework
::
InferShapeBase
{
public:
void
operator
()(
framework
::
InferShapeContext
*
ctx
)
const
override
{}
};
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
REGISTER_OPERATOR
(
distributed_notify
,
ops
::
DistributedNotifyOp
,
paddle
::
framework
::
EmptyGradOpMaker
,
ops
::
DistributedNotifyOpMaker
,
ops
::
DistributedNotifyOpShapeInference
);
paddle/fluid/operators/distributed_ops/listen_and_serv_op.cc
浏览文件 @
b4a3b750
...
@@ -298,6 +298,7 @@ static void FillRequestCtx(
...
@@ -298,6 +298,7 @@ static void FillRequestCtx(
std
::
unordered_map
<
std
::
string
,
std
::
string
>
std
::
unordered_map
<
std
::
string
,
std
::
string
>
*
sparse_grad_name_to_param_name
,
*
sparse_grad_name_to_param_name
,
std
::
shared_ptr
<
framework
::
ExecutorPrepareContext
>
checkpoint_ctx
,
std
::
shared_ptr
<
framework
::
ExecutorPrepareContext
>
checkpoint_ctx
,
std
::
shared_ptr
<
framework
::
ExecutorPrepareContext
>
lr_decay_ctx
,
distributed
::
RPCServer
*
rpc_server
)
{
distributed
::
RPCServer
*
rpc_server
)
{
h
->
SetScope
(
scope
);
h
->
SetScope
(
scope
);
h
->
SetDevCtx
(
dev_ctx
);
h
->
SetDevCtx
(
dev_ctx
);
...
@@ -307,6 +308,7 @@ static void FillRequestCtx(
...
@@ -307,6 +308,7 @@ static void FillRequestCtx(
h
->
SetSparseGradToParam
(
sparse_grad_name_to_param_name
);
h
->
SetSparseGradToParam
(
sparse_grad_name_to_param_name
);
h
->
SetRPCServer
(
rpc_server
);
h
->
SetRPCServer
(
rpc_server
);
h
->
SetCheckpointNotifyPreparedCtx
(
checkpoint_ctx
);
h
->
SetCheckpointNotifyPreparedCtx
(
checkpoint_ctx
);
h
->
SetLrDecayPreparedCtx
(
lr_decay_ctx
);
}
}
void
ListenAndServOp
::
CacheVarsType
(
const
std
::
vector
<
std
::
string
>
&
varnames
,
void
ListenAndServOp
::
CacheVarsType
(
const
std
::
vector
<
std
::
string
>
&
varnames
,
...
@@ -345,10 +347,12 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
...
@@ -345,10 +347,12 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
PADDLE_ENFORCE
(
!
rpc_service_
);
PADDLE_ENFORCE
(
!
rpc_service_
);
std
::
string
endpoint
=
Attr
<
std
::
string
>
(
"endpoint"
);
std
::
string
endpoint
=
Attr
<
std
::
string
>
(
"endpoint"
);
int
checkpoint_block_id
=
Attr
<
int
>
(
kCheckpointBlockId
);
int
checkpoint_block_id
=
Attr
<
int
>
(
kCheckpointBlockId
);
int
lr_decay_block_id
=
Attr
<
int
>
(
kLRDecayBlockId
);
VLOG
(
4
)
<<
"pserver_id: "
<<
pserver_id
<<
", sync_mode:"
<<
sync_mode
VLOG
(
4
)
<<
"pserver_id: "
<<
pserver_id
<<
", sync_mode:"
<<
sync_mode
<<
", fan_in:"
<<
fan_in
<<
", end_point:"
<<
endpoint
<<
", fan_in:"
<<
fan_in
<<
", end_point:"
<<
endpoint
<<
", checkpoint_block_id: "
<<
checkpoint_block_id
;
<<
", checkpoint_block_id: "
<<
checkpoint_block_id
<<
", lr_decay_block_id: "
<<
lr_decay_block_id
;
rpc_service_
.
reset
(
new
RPCSERVER_T
(
endpoint
,
fan_in
));
rpc_service_
.
reset
(
new
RPCSERVER_T
(
endpoint
,
fan_in
));
...
@@ -362,6 +366,8 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
...
@@ -362,6 +366,8 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
sync_mode
,
checkpoint_block_id
));
sync_mode
,
checkpoint_block_id
));
request_get_no_barrier_handler_
.
reset
(
request_get_no_barrier_handler_
.
reset
(
new
distributed
::
RequestGetNoBarrierHandler
());
new
distributed
::
RequestGetNoBarrierHandler
());
request_notify_handler_
.
reset
(
new
distributed
::
RequestNotifyHandler
(
sync_mode
,
lr_decay_block_id
));
rpc_service_
->
RegisterRPC
(
distributed
::
kRequestSend
,
rpc_service_
->
RegisterRPC
(
distributed
::
kRequestSend
,
request_send_handler_
.
get
(),
request_send_handler_
.
get
(),
...
@@ -376,6 +382,8 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
...
@@ -376,6 +382,8 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
request_checkpoint_handler_
.
get
());
request_checkpoint_handler_
.
get
());
rpc_service_
->
RegisterRPC
(
distributed
::
kRequestGetNoBarrier
,
rpc_service_
->
RegisterRPC
(
distributed
::
kRequestGetNoBarrier
,
request_get_no_barrier_handler_
.
get
());
request_get_no_barrier_handler_
.
get
());
rpc_service_
->
RegisterRPC
(
distributed
::
kRequestNotify
,
request_notify_handler_
.
get
(),
1
);
auto
optimize_blocks
=
auto
optimize_blocks
=
Attr
<
std
::
vector
<
framework
::
BlockDesc
*>>
(
kOptimizeBlocks
);
Attr
<
std
::
vector
<
framework
::
BlockDesc
*>>
(
kOptimizeBlocks
);
...
@@ -391,6 +399,13 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
...
@@ -391,6 +399,13 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
ckpt_pre_context
=
std
::
move
(
ctx
);
ckpt_pre_context
=
std
::
move
(
ctx
);
}
}
std
::
shared_ptr
<
framework
::
ExecutorPrepareContext
>
lr_decay_context
=
nullptr
;
if
(
lr_decay_block_id
!=
-
1
)
{
auto
ctx
=
executor
.
Prepare
(
*
program
,
lr_decay_block_id
);
// see: https://stackoverflow.com/a/14856553
lr_decay_context
=
std
::
move
(
ctx
);
}
// prepare for prefetch
// prepare for prefetch
std
::
vector
<
int
>
prefetch_block_id_list
;
std
::
vector
<
int
>
prefetch_block_id_list
;
std
::
unordered_map
<
int
,
std
::
string
>
block_id_to_prefetch_var_name
;
std
::
unordered_map
<
int
,
std
::
string
>
block_id_to_prefetch_var_name
;
...
@@ -435,16 +450,18 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
...
@@ -435,16 +450,18 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
sparse_grad_name_to_param_name
[
pieces
[
0
]]
=
pieces
[
1
];
sparse_grad_name_to_param_name
[
pieces
[
0
]]
=
pieces
[
1
];
}
}
auto
f
=
std
::
bind
(
auto
f
=
FillRequestCtx
,
std
::
placeholders
::
_1
,
&
recv_scope
,
&
dev_ctx
,
&
executor
,
std
::
bind
(
FillRequestCtx
,
std
::
placeholders
::
_1
,
&
recv_scope
,
&
dev_ctx
,
program
,
&
prefetch_var_name_to_prepared_ctx
,
&
executor
,
program
,
&
prefetch_var_name_to_prepared_ctx
,
&
sparse_grad_name_to_param_name
,
ckpt_pre_context
,
rpc_service_
.
get
());
&
sparse_grad_name_to_param_name
,
ckpt_pre_context
,
lr_decay_context
,
rpc_service_
.
get
());
f
(
request_send_handler_
.
get
());
f
(
request_send_handler_
.
get
());
f
(
request_get_handler_
.
get
());
f
(
request_get_handler_
.
get
());
f
(
request_prefetch_handler_
.
get
());
f
(
request_prefetch_handler_
.
get
());
f
(
request_checkpoint_handler_
.
get
());
f
(
request_checkpoint_handler_
.
get
());
f
(
request_get_no_barrier_handler_
.
get
());
f
(
request_get_no_barrier_handler_
.
get
());
f
(
request_notify_handler_
.
get
());
// start the server listening after all member initialized.
// start the server listening after all member initialized.
server_thread_
.
reset
(
new
std
::
thread
(
RunServer
,
rpc_service_
));
server_thread_
.
reset
(
new
std
::
thread
(
RunServer
,
rpc_service_
));
...
@@ -522,6 +539,8 @@ class ListenAndServOpMaker : public framework::OpProtoAndCheckerMaker {
...
@@ -522,6 +539,8 @@ class ListenAndServOpMaker : public framework::OpProtoAndCheckerMaker {
AddAttr
<
int
>
(
kCheckpointBlockId
,
AddAttr
<
int
>
(
kCheckpointBlockId
,
"BolckID to run save checkpoint on pserer."
)
"BolckID to run save checkpoint on pserer."
)
.
SetDefault
(
-
1
);
.
SetDefault
(
-
1
);
AddAttr
<
int
>
(
kLRDecayBlockId
,
"BolckID to run lr decay on pserer."
)
.
SetDefault
(
-
1
);
}
}
};
};
...
...
paddle/fluid/operators/distributed_ops/listen_and_serv_op.h
浏览文件 @
b4a3b750
...
@@ -37,6 +37,7 @@ namespace operators {
...
@@ -37,6 +37,7 @@ namespace operators {
constexpr
char
kOptimizeBlocks
[]
=
"optimize_blocks"
;
constexpr
char
kOptimizeBlocks
[]
=
"optimize_blocks"
;
constexpr
char
kPrefetchVarNameToBlockId
[]
=
"prefetch_var_name_to_block_id"
;
constexpr
char
kPrefetchVarNameToBlockId
[]
=
"prefetch_var_name_to_block_id"
;
constexpr
char
kCheckpointBlockId
[]
=
"checkpint_block_id"
;
constexpr
char
kCheckpointBlockId
[]
=
"checkpint_block_id"
;
constexpr
char
kLRDecayBlockId
[]
=
"lr_decay_block_id"
;
constexpr
char
kSparseGradToParam
[]
=
"sparse_grad_to_param"
;
constexpr
char
kSparseGradToParam
[]
=
"sparse_grad_to_param"
;
void
RunServer
(
std
::
shared_ptr
<
distributed
::
RPCServer
>
service
);
void
RunServer
(
std
::
shared_ptr
<
distributed
::
RPCServer
>
service
);
...
@@ -97,6 +98,7 @@ class ListenAndServOp : public framework::OperatorBase {
...
@@ -97,6 +98,7 @@ class ListenAndServOp : public framework::OperatorBase {
request_prefetch_handler_
;
request_prefetch_handler_
;
mutable
std
::
shared_ptr
<
distributed
::
RequestHandler
>
mutable
std
::
shared_ptr
<
distributed
::
RequestHandler
>
request_checkpoint_handler_
;
request_checkpoint_handler_
;
mutable
std
::
shared_ptr
<
distributed
::
RequestHandler
>
request_notify_handler_
;
mutable
std
::
shared_ptr
<
std
::
thread
>
server_thread_
;
mutable
std
::
shared_ptr
<
std
::
thread
>
server_thread_
;
mutable
std
::
vector
<
std
::
string
>
sparse_vars_
;
mutable
std
::
vector
<
std
::
string
>
sparse_vars_
;
...
...
python/paddle/fluid/tests/unittests/dist_ctr.py
浏览文件 @
b4a3b750
...
@@ -103,8 +103,16 @@ class TestDistCTR2x2(TestDistRunnerBase):
...
@@ -103,8 +103,16 @@ class TestDistCTR2x2(TestDistRunnerBase):
if
use_l2_decay
:
if
use_l2_decay
:
regularization
=
fluid
.
regularizer
.
L2DecayRegularizer
(
regularization
=
fluid
.
regularizer
.
L2DecayRegularizer
(
regularization_coeff
=
1e-1
)
regularization_coeff
=
1e-1
)
use_lr_decay
=
bool
(
os
.
getenv
(
'LR_DECAY'
,
0
))
sgd_optimizer
=
fluid
.
optimizer
.
SGD
(
learning_rate
=
0.0001
,
lr
=
0.0001
if
use_lr_decay
:
lr
=
fluid
.
layers
.
exponential_decay
(
learning_rate
=
0.0001
,
decay_steps
=
10000
,
decay_rate
=
0.999
,
staircase
=
True
)
sgd_optimizer
=
fluid
.
optimizer
.
SGD
(
learning_rate
=
lr
,
regularization
=
regularization
)
regularization
=
regularization
)
sgd_optimizer
.
minimize
(
avg_cost
)
sgd_optimizer
.
minimize
(
avg_cost
)
...
...
python/paddle/fluid/tests/unittests/test_dist_ctr.py
浏览文件 @
b4a3b750
...
@@ -80,6 +80,28 @@ class TestDistCTR2x2_ASYNC(TestDistBase):
...
@@ -80,6 +80,28 @@ class TestDistCTR2x2_ASYNC(TestDistBase):
log_name
=
flag_name
)
log_name
=
flag_name
)
class
TestDistCTR2x2_ASYNCWithLRDecay2x2
(
TestDistBase
):
def
_setup_config
(
self
):
self
.
_sync_mode
=
False
self
.
_hogwild_mode
=
True
self
.
_enforce_place
=
"CPU"
def
test_dist_ctr
(
self
):
need_envs
=
{
"FLAGS_communicator_send_queue_size"
:
"2"
,
"FLAGS_communicator_max_merge_var_num"
:
"2"
,
"FLAGS_communicator_max_send_grad_num_before_recv"
:
"2"
,
"LR_DECAY"
:
"1"
}
self
.
check_with_place
(
"dist_ctr.py"
,
delta
=
100
,
check_error_log
=
True
,
need_envs
=
need_envs
,
log_name
=
flag_name
)
class
TestDistCTR2x2_ASYNC2
(
TestDistBase
):
class
TestDistCTR2x2_ASYNC2
(
TestDistBase
):
def
_setup_config
(
self
):
def
_setup_config
(
self
):
self
.
_sync_mode
=
False
self
.
_sync_mode
=
False
...
...
python/paddle/fluid/transpiler/distribute_transpiler.py
浏览文件 @
b4a3b750
...
@@ -818,6 +818,18 @@ class DistributeTranspiler(object):
...
@@ -818,6 +818,18 @@ class DistributeTranspiler(object):
self
.
_update_remote_sparse_update_op
(
program
,
self
.
_update_remote_sparse_update_op
(
program
,
need_sparse_update_params
)
need_sparse_update_params
)
if
not
self
.
sync_mode
:
lr_ops
=
self
.
_get_lr_ops
()
if
len
(
lr_ops
)
>
0
:
program
.
global_block
().
append_op
(
type
=
"distributed_notify"
,
inputs
=
{},
outputs
=
{},
attrs
=
{
"epmap"
:
pserver_endpoints
,
"trainer_id"
:
self
.
trainer_id
,
"type"
:
"LRDECAY@RECV"
})
self
.
_get_trainer_startup_program
(
recv_vars
=
recv_vars
,
eplist
=
eplist
)
self
.
_get_trainer_startup_program
(
recv_vars
=
recv_vars
,
eplist
=
eplist
)
...
@@ -1125,6 +1137,8 @@ class DistributeTranspiler(object):
...
@@ -1125,6 +1137,8 @@ class DistributeTranspiler(object):
lr_ops
=
self
.
_get_lr_ops
()
lr_ops
=
self
.
_get_lr_ops
()
# record optimize blocks and we can run them on pserver parallel
# record optimize blocks and we can run them on pserver parallel
optimize_blocks
=
[]
optimize_blocks
=
[]
lr_decay_block_id
=
-
1
if
len
(
lr_ops
)
>
0
:
if
len
(
lr_ops
)
>
0
:
lr_decay_block
=
pserver_program
.
_create_block
(
lr_decay_block
=
pserver_program
.
_create_block
(
pserver_program
.
num_blocks
-
1
)
pserver_program
.
num_blocks
-
1
)
...
@@ -1134,6 +1148,7 @@ class DistributeTranspiler(object):
...
@@ -1134,6 +1148,7 @@ class DistributeTranspiler(object):
# append sub blocks to pserver_program in lr_decay_op
# append sub blocks to pserver_program in lr_decay_op
__clone_lr_op_sub_block__
(
cloned_op
,
pserver_program
,
__clone_lr_op_sub_block__
(
cloned_op
,
pserver_program
,
lr_decay_block
)
lr_decay_block
)
lr_decay_block_id
=
lr_decay_block
.
idx
# append op to the current block
# append op to the current block
grad_to_block_id
=
[]
grad_to_block_id
=
[]
...
@@ -1211,6 +1226,7 @@ class DistributeTranspiler(object):
...
@@ -1211,6 +1226,7 @@ class DistributeTranspiler(object):
"sync_mode"
:
self
.
sync_mode
,
"sync_mode"
:
self
.
sync_mode
,
"grad_to_block_id"
:
grad_to_block_id
,
"grad_to_block_id"
:
grad_to_block_id
,
"sparse_grad_to_param"
:
sparse_grad_to_param
,
"sparse_grad_to_param"
:
sparse_grad_to_param
,
"lr_decay_block_id"
:
lr_decay_block_id
,
}
}
if
self
.
has_distributed_lookup_table
:
if
self
.
has_distributed_lookup_table
:
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录