提交 d13ce358 编写于 作者: 武毅 提交者: gongweibao

Feature/send recv can now retry (#9027)

上级 14fe40aa
......@@ -97,7 +97,7 @@ bool RPCClient::AsyncGetVariable(const std::string& ep,
return true;
}
bool RPCClient::AsyncSendBatchBarrier(const std::string& ep, int64_t time_out) {
void RPCClient::AsyncSendBatchBarrier(const std::string& ep, int64_t time_out) {
const auto ch = GetChannel(ep);
BatchBarrierProcessor* s = new BatchBarrierProcessor(ch);
......@@ -108,8 +108,18 @@ bool RPCClient::AsyncSendBatchBarrier(const std::string& ep, int64_t time_out) {
auto rpc = s->stub_->AsyncSendVariable(s->context_.get(), req, &cq_);
rpc->Finish(&s->reply_, &s->status_, (void*)s);
req_count_++;
}
return true;
void RPCClient::AsyncSendFetchBarrier(const std::string& ep, int64_t time_out) {
const auto ch = GetChannel(ep);
FetchBarrierProcessor* s = new FetchBarrierProcessor(ch);
s->Prepare(time_out);
sendrecv::VariableMessage req;
req.set_varname(FETCH_BARRIER_MESSAGE);
auto rpc = s->stub_->AsyncGetVariable(s->context_.get(), req, &cq_);
rpc->Finish(&s->reply_, &s->status_, (void*)s);
req_count_++;
}
bool RPCClient::Wait() {
......@@ -154,7 +164,7 @@ bool RPCClient::Proceed() {
PADDLE_ENFORCE(tag);
// TODO(gongwb): add more retries.
ClientBase* c = static_cast<ClientBase*>(tag);
BaseProcessor* c = static_cast<BaseProcessor*>(tag);
if (!c->status_.ok()) {
LOG(ERROR) << "proc param error:" << c->var_h_.String()
<< " grpc error:" << c->status_.error_message();
......@@ -174,6 +184,8 @@ std::shared_ptr<grpc::Channel> RPCClient::GetChannel(const std::string& ep) {
}
grpc::ChannelArguments args;
args.SetInt("grpc.testing.fixed_reconnect_backoff_ms", 5000);
args.SetCompressionAlgorithm(GRPC_COMPRESS_NONE);
args.SetMaxSendMessageSize(std::numeric_limits<int>::max());
args.SetMaxReceiveMessageSize(std::numeric_limits<int>::max());
......
......@@ -52,14 +52,14 @@ struct VarHandle {
void ProcGetResponse(const VarHandle& var_h,
const sendrecv::VariableMessage& msg);
class ClientBase {
class BaseProcessor {
public:
explicit ClientBase(std::shared_ptr<grpc::Channel> ch) {
explicit BaseProcessor(std::shared_ptr<grpc::Channel> ch) {
stub_ = sendrecv::SendRecvService::NewStub(ch);
context_ = NULL;
}
virtual ~ClientBase() {}
virtual ~BaseProcessor() {}
virtual void Prepare(const VarHandle& var_info, int64_t time_out) {
context_.reset(new grpc::ClientContext());
......@@ -91,9 +91,10 @@ class ClientBase {
typedef std::function<void(const VarHandle&, const sendrecv::VoidMessage&)>
RequestSendCallBack;
class SendProcessor : public ClientBase {
class SendProcessor : public BaseProcessor {
public:
explicit SendProcessor(std::shared_ptr<grpc::Channel> ch) : ClientBase(ch) {}
explicit SendProcessor(std::shared_ptr<grpc::Channel> ch)
: BaseProcessor(ch) {}
virtual ~SendProcessor() {}
......@@ -110,9 +111,10 @@ class SendProcessor : public ClientBase {
typedef std::function<void(const VarHandle&, const sendrecv::VariableMessage&)>
RequestGetCallBack;
class GetProcessor : public ClientBase {
class GetProcessor : public BaseProcessor {
public:
explicit GetProcessor(std::shared_ptr<grpc::Channel> ch) : ClientBase(ch) {}
explicit GetProcessor(std::shared_ptr<grpc::Channel> ch)
: BaseProcessor(ch) {}
virtual ~GetProcessor() {}
......@@ -126,10 +128,10 @@ class GetProcessor : public ClientBase {
RequestGetCallBack response_call_back_ = ProcGetResponse;
};
class BatchBarrierProcessor : public ClientBase {
class BatchBarrierProcessor : public BaseProcessor {
public:
explicit BatchBarrierProcessor(std::shared_ptr<grpc::Channel> ch)
: ClientBase(ch) {}
: BaseProcessor(ch) {}
virtual ~BatchBarrierProcessor() {}
......@@ -137,6 +139,17 @@ class BatchBarrierProcessor : public ClientBase {
sendrecv::VoidMessage reply_;
};
class FetchBarrierProcessor : public BaseProcessor {
public:
explicit FetchBarrierProcessor(std::shared_ptr<grpc::Channel> ch)
: BaseProcessor(ch) {}
virtual ~FetchBarrierProcessor() {}
virtual void Process() {}
sendrecv::VariableMessage reply_;
};
class RPCClient {
public:
bool AsyncSendVariable(const std::string& ep,
......@@ -151,7 +164,10 @@ class RPCClient {
const std::string& var_name,
int64_t time_out = 600 * 1000);
bool AsyncSendBatchBarrier(const std::string& ep,
void AsyncSendBatchBarrier(const std::string& ep,
int64_t time_out = 600 * 1000);
void AsyncSendFetchBarrier(const std::string& ep,
int64_t time_out = 600 * 1000);
bool Wait();
......
......@@ -84,7 +84,7 @@ class RequestGet final : public RequestBase {
explicit RequestGet(sendrecv::SendRecvService::AsyncService* service,
grpc::ServerCompletionQueue* cq, framework::Scope* scope,
const platform::DeviceContext* dev_ctx,
SimpleBlockQueue<char>* queue)
SimpleBlockQueue<MessageWithName>* queue)
: RequestBase(service, cq),
responder_(&ctx_),
scope_(scope),
......@@ -101,11 +101,16 @@ class RequestGet final : public RequestBase {
// proc request.
std::string var_name = request_.varname();
auto* var = scope_->FindVar(var_name);
if (var_name != FETCH_BARRIER_MESSAGE) {
SerializeToMessage(var_name, var, *dev_ctx_, &reply_);
}
// TODO(gongwb): check var's info.
responder_.Finish(reply_, grpc::Status::OK, this);
status_ = FINISH;
queue_->Push('c');
MessageWithName msg_with_name =
// request name reply
std::make_pair(var_name, std::move(reply_));
queue_->Push(msg_with_name);
}
protected:
......@@ -114,12 +119,16 @@ class RequestGet final : public RequestBase {
ServerAsyncResponseWriter<sendrecv::VariableMessage> responder_;
framework::Scope* scope_;
const platform::DeviceContext* dev_ctx_;
SimpleBlockQueue<char>* queue_;
SimpleBlockQueue<MessageWithName>* queue_;
};
void AsyncGRPCServer::WaitClientGet(int count) {
for (int i = 0; i < count; ++i) {
var_get_queue_.Pop();
int fetch_barriers = 0;
while (fetch_barriers < count) {
auto msg = var_get_queue_.Pop();
if (msg.first == FETCH_BARRIER_MESSAGE) {
fetch_barriers++;
}
}
}
......
......@@ -77,7 +77,7 @@ class AsyncGRPCServer final : public sendrecv::SendRecvService::Service {
const platform::DeviceContext *dev_ctx_;
// received variable from RPC, operators fetch variable from this queue.
SimpleBlockQueue<MessageWithName> var_recv_queue_;
SimpleBlockQueue<char> var_get_queue_;
SimpleBlockQueue<MessageWithName> var_get_queue_;
// condition of the sub program
std::mutex barrier_mutex_;
......
......@@ -32,6 +32,7 @@ namespace detail {
#define LISTEN_TERMINATE_MESSAGE "TERMINATE@RECV"
#define BATCH_BARRIER_MESSAGE "BATCH_BARRIER@RECV"
#define FETCH_BARRIER_MESSAGE "FETCH_BARRIER@RECV"
typedef void (*DestroyCallback)(void*);
......
......@@ -128,8 +128,8 @@ class ListenAndServOp : public framework::OperatorBase {
}
}
if (exit_flag) {
rpc_service_->ShutDown();
rpc_service_->SetCond(1);
rpc_service_->ShutDown();
break;
}
try {
......@@ -148,7 +148,7 @@ class ListenAndServOp : public framework::OperatorBase {
}
rpc_service_->SetCond(1);
// FIXME(typhoonzero): use another condition to sync wait clients get.
rpc_service_->WaitClientGet(ins.size());
rpc_service_->WaitClientGet(fan_in);
sparse_vars.clear();
} // while(true)
}
......
......@@ -88,6 +88,12 @@ class SendOp : public framework::OperatorBase {
rpc_client->AsyncGetVariable(epmap[i], ctx, scope, outs[i]);
}
PADDLE_ENFORCE(rpc_client->Wait());
// tell pservers that current trainer have called fetch
for (auto& ep : endpoints) {
VLOG(3) << "send fetch barrier, ep: " << ep;
rpc_client->AsyncSendFetchBarrier(ep);
}
PADDLE_ENFORCE(rpc_client->Wait());
}
}
};
......
......@@ -250,6 +250,8 @@ class DistributeTranspiler:
def get_trainer_program(self):
# remove optimize ops and add a send op to main_program
self.program.global_block().delete_ops(self.optimize_ops)
# FIXME(typhoonzero): serialize once will fix error occurs when clone.
self.program.__str__()
return self.program
def get_pserver_program(self, endpoint):
......@@ -309,7 +311,8 @@ class DistributeTranspiler:
for _, opt_op in enumerate(opt_op_on_pserver):
if ufind.is_connected(op, opt_op):
if self._is_opt_op(op):
self._append_pserver_ops(optimize_block, op, endpoint)
self._append_pserver_ops(optimize_block, op, endpoint,
default_main_program())
else:
self._append_pserver_non_opt_ops(optimize_block, op)
break
......@@ -520,7 +523,8 @@ class DistributeTranspiler:
orig_var_name = varname[:suff_idx]
return orig_var_name
def _append_pserver_ops(self, optimize_block, opt_op, endpoint):
def _append_pserver_ops(self, optimize_block, opt_op, endpoint,
origin_program):
program = optimize_block.program
pserver_block = program.global_block()
new_inputs = dict()
......@@ -576,7 +580,17 @@ class DistributeTranspiler:
elif key == "LearningRate":
# leraning rate variable has already be created by non-optimize op,
# don't create it once again.
lr_varname = opt_op.input(key)[0]
if pserver_block.vars.has_key(lr_varname):
new_inputs[key] = pserver_block.vars[opt_op.input(key)[0]]
else:
origin_var = origin_program.global_block().vars[lr_varname]
tmpvar = pserver_block.create_var(
name=origin_var.name,
persistable=origin_var.persistable,
dtype=origin_var.dtype,
shape=origin_var.shape)
new_inputs[key] = tmpvar
for key in opt_op.input_names:
new_shape = None
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册