提交 d13ce358 编写于 作者: 武毅 提交者: gongweibao

Feature/send recv can now retry (#9027)

上级 14fe40aa
...@@ -97,7 +97,7 @@ bool RPCClient::AsyncGetVariable(const std::string& ep, ...@@ -97,7 +97,7 @@ bool RPCClient::AsyncGetVariable(const std::string& ep,
return true; return true;
} }
bool RPCClient::AsyncSendBatchBarrier(const std::string& ep, int64_t time_out) { void RPCClient::AsyncSendBatchBarrier(const std::string& ep, int64_t time_out) {
const auto ch = GetChannel(ep); const auto ch = GetChannel(ep);
BatchBarrierProcessor* s = new BatchBarrierProcessor(ch); BatchBarrierProcessor* s = new BatchBarrierProcessor(ch);
...@@ -108,8 +108,18 @@ bool RPCClient::AsyncSendBatchBarrier(const std::string& ep, int64_t time_out) { ...@@ -108,8 +108,18 @@ bool RPCClient::AsyncSendBatchBarrier(const std::string& ep, int64_t time_out) {
auto rpc = s->stub_->AsyncSendVariable(s->context_.get(), req, &cq_); auto rpc = s->stub_->AsyncSendVariable(s->context_.get(), req, &cq_);
rpc->Finish(&s->reply_, &s->status_, (void*)s); rpc->Finish(&s->reply_, &s->status_, (void*)s);
req_count_++; req_count_++;
}
return true; void RPCClient::AsyncSendFetchBarrier(const std::string& ep, int64_t time_out) {
const auto ch = GetChannel(ep);
FetchBarrierProcessor* s = new FetchBarrierProcessor(ch);
s->Prepare(time_out);
sendrecv::VariableMessage req;
req.set_varname(FETCH_BARRIER_MESSAGE);
auto rpc = s->stub_->AsyncGetVariable(s->context_.get(), req, &cq_);
rpc->Finish(&s->reply_, &s->status_, (void*)s);
req_count_++;
} }
bool RPCClient::Wait() { bool RPCClient::Wait() {
...@@ -154,7 +164,7 @@ bool RPCClient::Proceed() { ...@@ -154,7 +164,7 @@ bool RPCClient::Proceed() {
PADDLE_ENFORCE(tag); PADDLE_ENFORCE(tag);
// TODO(gongwb): add more retries. // TODO(gongwb): add more retries.
ClientBase* c = static_cast<ClientBase*>(tag); BaseProcessor* c = static_cast<BaseProcessor*>(tag);
if (!c->status_.ok()) { if (!c->status_.ok()) {
LOG(ERROR) << "proc param error:" << c->var_h_.String() LOG(ERROR) << "proc param error:" << c->var_h_.String()
<< " grpc error:" << c->status_.error_message(); << " grpc error:" << c->status_.error_message();
...@@ -174,6 +184,8 @@ std::shared_ptr<grpc::Channel> RPCClient::GetChannel(const std::string& ep) { ...@@ -174,6 +184,8 @@ std::shared_ptr<grpc::Channel> RPCClient::GetChannel(const std::string& ep) {
} }
grpc::ChannelArguments args; grpc::ChannelArguments args;
args.SetInt("grpc.testing.fixed_reconnect_backoff_ms", 5000);
args.SetCompressionAlgorithm(GRPC_COMPRESS_NONE);
args.SetMaxSendMessageSize(std::numeric_limits<int>::max()); args.SetMaxSendMessageSize(std::numeric_limits<int>::max());
args.SetMaxReceiveMessageSize(std::numeric_limits<int>::max()); args.SetMaxReceiveMessageSize(std::numeric_limits<int>::max());
......
...@@ -52,14 +52,14 @@ struct VarHandle { ...@@ -52,14 +52,14 @@ struct VarHandle {
void ProcGetResponse(const VarHandle& var_h, void ProcGetResponse(const VarHandle& var_h,
const sendrecv::VariableMessage& msg); const sendrecv::VariableMessage& msg);
class ClientBase { class BaseProcessor {
public: public:
explicit ClientBase(std::shared_ptr<grpc::Channel> ch) { explicit BaseProcessor(std::shared_ptr<grpc::Channel> ch) {
stub_ = sendrecv::SendRecvService::NewStub(ch); stub_ = sendrecv::SendRecvService::NewStub(ch);
context_ = NULL; context_ = NULL;
} }
virtual ~ClientBase() {} virtual ~BaseProcessor() {}
virtual void Prepare(const VarHandle& var_info, int64_t time_out) { virtual void Prepare(const VarHandle& var_info, int64_t time_out) {
context_.reset(new grpc::ClientContext()); context_.reset(new grpc::ClientContext());
...@@ -91,9 +91,10 @@ class ClientBase { ...@@ -91,9 +91,10 @@ class ClientBase {
typedef std::function<void(const VarHandle&, const sendrecv::VoidMessage&)> typedef std::function<void(const VarHandle&, const sendrecv::VoidMessage&)>
RequestSendCallBack; RequestSendCallBack;
class SendProcessor : public ClientBase { class SendProcessor : public BaseProcessor {
public: public:
explicit SendProcessor(std::shared_ptr<grpc::Channel> ch) : ClientBase(ch) {} explicit SendProcessor(std::shared_ptr<grpc::Channel> ch)
: BaseProcessor(ch) {}
virtual ~SendProcessor() {} virtual ~SendProcessor() {}
...@@ -110,9 +111,10 @@ class SendProcessor : public ClientBase { ...@@ -110,9 +111,10 @@ class SendProcessor : public ClientBase {
typedef std::function<void(const VarHandle&, const sendrecv::VariableMessage&)> typedef std::function<void(const VarHandle&, const sendrecv::VariableMessage&)>
RequestGetCallBack; RequestGetCallBack;
class GetProcessor : public ClientBase { class GetProcessor : public BaseProcessor {
public: public:
explicit GetProcessor(std::shared_ptr<grpc::Channel> ch) : ClientBase(ch) {} explicit GetProcessor(std::shared_ptr<grpc::Channel> ch)
: BaseProcessor(ch) {}
virtual ~GetProcessor() {} virtual ~GetProcessor() {}
...@@ -126,10 +128,10 @@ class GetProcessor : public ClientBase { ...@@ -126,10 +128,10 @@ class GetProcessor : public ClientBase {
RequestGetCallBack response_call_back_ = ProcGetResponse; RequestGetCallBack response_call_back_ = ProcGetResponse;
}; };
class BatchBarrierProcessor : public ClientBase { class BatchBarrierProcessor : public BaseProcessor {
public: public:
explicit BatchBarrierProcessor(std::shared_ptr<grpc::Channel> ch) explicit BatchBarrierProcessor(std::shared_ptr<grpc::Channel> ch)
: ClientBase(ch) {} : BaseProcessor(ch) {}
virtual ~BatchBarrierProcessor() {} virtual ~BatchBarrierProcessor() {}
...@@ -137,6 +139,17 @@ class BatchBarrierProcessor : public ClientBase { ...@@ -137,6 +139,17 @@ class BatchBarrierProcessor : public ClientBase {
sendrecv::VoidMessage reply_; sendrecv::VoidMessage reply_;
}; };
class FetchBarrierProcessor : public BaseProcessor {
public:
explicit FetchBarrierProcessor(std::shared_ptr<grpc::Channel> ch)
: BaseProcessor(ch) {}
virtual ~FetchBarrierProcessor() {}
virtual void Process() {}
sendrecv::VariableMessage reply_;
};
class RPCClient { class RPCClient {
public: public:
bool AsyncSendVariable(const std::string& ep, bool AsyncSendVariable(const std::string& ep,
...@@ -151,7 +164,10 @@ class RPCClient { ...@@ -151,7 +164,10 @@ class RPCClient {
const std::string& var_name, const std::string& var_name,
int64_t time_out = 600 * 1000); int64_t time_out = 600 * 1000);
bool AsyncSendBatchBarrier(const std::string& ep, void AsyncSendBatchBarrier(const std::string& ep,
int64_t time_out = 600 * 1000);
void AsyncSendFetchBarrier(const std::string& ep,
int64_t time_out = 600 * 1000); int64_t time_out = 600 * 1000);
bool Wait(); bool Wait();
......
...@@ -84,7 +84,7 @@ class RequestGet final : public RequestBase { ...@@ -84,7 +84,7 @@ class RequestGet final : public RequestBase {
explicit RequestGet(sendrecv::SendRecvService::AsyncService* service, explicit RequestGet(sendrecv::SendRecvService::AsyncService* service,
grpc::ServerCompletionQueue* cq, framework::Scope* scope, grpc::ServerCompletionQueue* cq, framework::Scope* scope,
const platform::DeviceContext* dev_ctx, const platform::DeviceContext* dev_ctx,
SimpleBlockQueue<char>* queue) SimpleBlockQueue<MessageWithName>* queue)
: RequestBase(service, cq), : RequestBase(service, cq),
responder_(&ctx_), responder_(&ctx_),
scope_(scope), scope_(scope),
...@@ -101,11 +101,16 @@ class RequestGet final : public RequestBase { ...@@ -101,11 +101,16 @@ class RequestGet final : public RequestBase {
// proc request. // proc request.
std::string var_name = request_.varname(); std::string var_name = request_.varname();
auto* var = scope_->FindVar(var_name); auto* var = scope_->FindVar(var_name);
SerializeToMessage(var_name, var, *dev_ctx_, &reply_); if (var_name != FETCH_BARRIER_MESSAGE) {
SerializeToMessage(var_name, var, *dev_ctx_, &reply_);
}
// TODO(gongwb): check var's info. // TODO(gongwb): check var's info.
responder_.Finish(reply_, grpc::Status::OK, this); responder_.Finish(reply_, grpc::Status::OK, this);
status_ = FINISH; status_ = FINISH;
queue_->Push('c'); MessageWithName msg_with_name =
// request name reply
std::make_pair(var_name, std::move(reply_));
queue_->Push(msg_with_name);
} }
protected: protected:
...@@ -114,12 +119,16 @@ class RequestGet final : public RequestBase { ...@@ -114,12 +119,16 @@ class RequestGet final : public RequestBase {
ServerAsyncResponseWriter<sendrecv::VariableMessage> responder_; ServerAsyncResponseWriter<sendrecv::VariableMessage> responder_;
framework::Scope* scope_; framework::Scope* scope_;
const platform::DeviceContext* dev_ctx_; const platform::DeviceContext* dev_ctx_;
SimpleBlockQueue<char>* queue_; SimpleBlockQueue<MessageWithName>* queue_;
}; };
void AsyncGRPCServer::WaitClientGet(int count) { void AsyncGRPCServer::WaitClientGet(int count) {
for (int i = 0; i < count; ++i) { int fetch_barriers = 0;
var_get_queue_.Pop(); while (fetch_barriers < count) {
auto msg = var_get_queue_.Pop();
if (msg.first == FETCH_BARRIER_MESSAGE) {
fetch_barriers++;
}
} }
} }
......
...@@ -77,7 +77,7 @@ class AsyncGRPCServer final : public sendrecv::SendRecvService::Service { ...@@ -77,7 +77,7 @@ class AsyncGRPCServer final : public sendrecv::SendRecvService::Service {
const platform::DeviceContext *dev_ctx_; const platform::DeviceContext *dev_ctx_;
// received variable from RPC, operators fetch variable from this queue. // received variable from RPC, operators fetch variable from this queue.
SimpleBlockQueue<MessageWithName> var_recv_queue_; SimpleBlockQueue<MessageWithName> var_recv_queue_;
SimpleBlockQueue<char> var_get_queue_; SimpleBlockQueue<MessageWithName> var_get_queue_;
// condition of the sub program // condition of the sub program
std::mutex barrier_mutex_; std::mutex barrier_mutex_;
......
...@@ -32,6 +32,7 @@ namespace detail { ...@@ -32,6 +32,7 @@ namespace detail {
#define LISTEN_TERMINATE_MESSAGE "TERMINATE@RECV" #define LISTEN_TERMINATE_MESSAGE "TERMINATE@RECV"
#define BATCH_BARRIER_MESSAGE "BATCH_BARRIER@RECV" #define BATCH_BARRIER_MESSAGE "BATCH_BARRIER@RECV"
#define FETCH_BARRIER_MESSAGE "FETCH_BARRIER@RECV"
typedef void (*DestroyCallback)(void*); typedef void (*DestroyCallback)(void*);
......
...@@ -128,8 +128,8 @@ class ListenAndServOp : public framework::OperatorBase { ...@@ -128,8 +128,8 @@ class ListenAndServOp : public framework::OperatorBase {
} }
} }
if (exit_flag) { if (exit_flag) {
rpc_service_->ShutDown();
rpc_service_->SetCond(1); rpc_service_->SetCond(1);
rpc_service_->ShutDown();
break; break;
} }
try { try {
...@@ -148,7 +148,7 @@ class ListenAndServOp : public framework::OperatorBase { ...@@ -148,7 +148,7 @@ class ListenAndServOp : public framework::OperatorBase {
} }
rpc_service_->SetCond(1); rpc_service_->SetCond(1);
// FIXME(typhoonzero): use another condition to sync wait clients get. // FIXME(typhoonzero): use another condition to sync wait clients get.
rpc_service_->WaitClientGet(ins.size()); rpc_service_->WaitClientGet(fan_in);
sparse_vars.clear(); sparse_vars.clear();
} // while(true) } // while(true)
} }
......
...@@ -88,6 +88,12 @@ class SendOp : public framework::OperatorBase { ...@@ -88,6 +88,12 @@ class SendOp : public framework::OperatorBase {
rpc_client->AsyncGetVariable(epmap[i], ctx, scope, outs[i]); rpc_client->AsyncGetVariable(epmap[i], ctx, scope, outs[i]);
} }
PADDLE_ENFORCE(rpc_client->Wait()); PADDLE_ENFORCE(rpc_client->Wait());
// tell pservers that current trainer have called fetch
for (auto& ep : endpoints) {
VLOG(3) << "send fetch barrier, ep: " << ep;
rpc_client->AsyncSendFetchBarrier(ep);
}
PADDLE_ENFORCE(rpc_client->Wait());
} }
} }
}; };
......
...@@ -250,6 +250,8 @@ class DistributeTranspiler: ...@@ -250,6 +250,8 @@ class DistributeTranspiler:
def get_trainer_program(self): def get_trainer_program(self):
# remove optimize ops and add a send op to main_program # remove optimize ops and add a send op to main_program
self.program.global_block().delete_ops(self.optimize_ops) self.program.global_block().delete_ops(self.optimize_ops)
# FIXME(typhoonzero): serialize once will fix error occurs when clone.
self.program.__str__()
return self.program return self.program
def get_pserver_program(self, endpoint): def get_pserver_program(self, endpoint):
...@@ -309,7 +311,8 @@ class DistributeTranspiler: ...@@ -309,7 +311,8 @@ class DistributeTranspiler:
for _, opt_op in enumerate(opt_op_on_pserver): for _, opt_op in enumerate(opt_op_on_pserver):
if ufind.is_connected(op, opt_op): if ufind.is_connected(op, opt_op):
if self._is_opt_op(op): if self._is_opt_op(op):
self._append_pserver_ops(optimize_block, op, endpoint) self._append_pserver_ops(optimize_block, op, endpoint,
default_main_program())
else: else:
self._append_pserver_non_opt_ops(optimize_block, op) self._append_pserver_non_opt_ops(optimize_block, op)
break break
...@@ -520,7 +523,8 @@ class DistributeTranspiler: ...@@ -520,7 +523,8 @@ class DistributeTranspiler:
orig_var_name = varname[:suff_idx] orig_var_name = varname[:suff_idx]
return orig_var_name return orig_var_name
def _append_pserver_ops(self, optimize_block, opt_op, endpoint): def _append_pserver_ops(self, optimize_block, opt_op, endpoint,
origin_program):
program = optimize_block.program program = optimize_block.program
pserver_block = program.global_block() pserver_block = program.global_block()
new_inputs = dict() new_inputs = dict()
...@@ -576,7 +580,17 @@ class DistributeTranspiler: ...@@ -576,7 +580,17 @@ class DistributeTranspiler:
elif key == "LearningRate": elif key == "LearningRate":
# leraning rate variable has already be created by non-optimize op, # leraning rate variable has already be created by non-optimize op,
# don't create it once again. # don't create it once again.
new_inputs[key] = pserver_block.vars[opt_op.input(key)[0]] lr_varname = opt_op.input(key)[0]
if pserver_block.vars.has_key(lr_varname):
new_inputs[key] = pserver_block.vars[opt_op.input(key)[0]]
else:
origin_var = origin_program.global_block().vars[lr_varname]
tmpvar = pserver_block.create_var(
name=origin_var.name,
persistable=origin_var.persistable,
dtype=origin_var.dtype,
shape=origin_var.shape)
new_inputs[key] = tmpvar
for key in opt_op.input_names: for key in opt_op.input_names:
new_shape = None new_shape = None
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册