提交 ae19d2ea 编写于 作者: T typhoonzero

fix comm issues

上级 f233b936
...@@ -36,7 +36,10 @@ class RequestBase { ...@@ -36,7 +36,10 @@ class RequestBase {
CallStatus Status() { return status_; } CallStatus Status() { return status_; }
void SetStatus(CallStatus status) { status_ = status; } void SetStatus(CallStatus status) { status_ = status; }
virtual std::string GetReqName() { assert(false); } virtual std::string GetReqName() {
assert(false);
return "";
}
protected: protected:
grpc::ServerContext ctx_; grpc::ServerContext ctx_;
...@@ -80,11 +83,13 @@ class RequestGet final : public RequestBase { ...@@ -80,11 +83,13 @@ class RequestGet final : public RequestBase {
public: public:
explicit RequestGet(sendrecv::SendRecvService::AsyncService* service, explicit RequestGet(sendrecv::SendRecvService::AsyncService* service,
grpc::ServerCompletionQueue* cq, framework::Scope* scope, grpc::ServerCompletionQueue* cq, framework::Scope* scope,
const platform::DeviceContext* dev_ctx) const platform::DeviceContext* dev_ctx,
SimpleBlockQueue<char>* queue)
: RequestBase(service, cq), : RequestBase(service, cq),
responder_(&ctx_), responder_(&ctx_),
scope_(scope), scope_(scope),
dev_ctx_(dev_ctx) { dev_ctx_(dev_ctx),
queue_(queue) {
service_->RequestGetVariable(&ctx_, &request_, &responder_, cq_, cq_, this); service_->RequestGetVariable(&ctx_, &request_, &responder_, cq_, cq_, this);
} }
...@@ -100,6 +105,7 @@ class RequestGet final : public RequestBase { ...@@ -100,6 +105,7 @@ class RequestGet final : public RequestBase {
// TODO(gongwb): check var's info. // TODO(gongwb): check var's info.
responder_.Finish(reply_, grpc::Status::OK, this); responder_.Finish(reply_, grpc::Status::OK, this);
status_ = FINISH; status_ = FINISH;
queue_->Push('c');
} }
protected: protected:
...@@ -108,8 +114,15 @@ class RequestGet final : public RequestBase { ...@@ -108,8 +114,15 @@ class RequestGet final : public RequestBase {
ServerAsyncResponseWriter<sendrecv::VariableMessage> responder_; ServerAsyncResponseWriter<sendrecv::VariableMessage> responder_;
framework::Scope* scope_; framework::Scope* scope_;
const platform::DeviceContext* dev_ctx_; const platform::DeviceContext* dev_ctx_;
SimpleBlockQueue<char>* queue_;
}; };
void AsyncGRPCServer::WaitClientGet(int count) {
for (int i = 0; i < count; ++i) {
var_get_queue_.Pop();
}
}
void AsyncGRPCServer::RunSyncUpdate() { void AsyncGRPCServer::RunSyncUpdate() {
grpc::ServerBuilder builder; grpc::ServerBuilder builder;
builder.AddListeningPort(address_, grpc::InsecureServerCredentials()); builder.AddListeningPort(address_, grpc::InsecureServerCredentials());
...@@ -170,7 +183,8 @@ void AsyncGRPCServer::TryToRegisterNewGetOne() { ...@@ -170,7 +183,8 @@ void AsyncGRPCServer::TryToRegisterNewGetOne() {
if (is_shut_down_) { if (is_shut_down_) {
return; return;
} }
RequestGet* get = new RequestGet(&service_, cq_get_.get(), scope_, dev_ctx_); RequestGet* get = new RequestGet(&service_, cq_get_.get(), scope_, dev_ctx_,
&var_get_queue_);
VLOG(4) << "create Requestget status:" << get->Status(); VLOG(4) << "create Requestget status:" << get->Status();
} }
...@@ -188,9 +202,8 @@ void AsyncGRPCServer::HandleRequest(bool wait, grpc::ServerCompletionQueue* cq, ...@@ -188,9 +202,8 @@ void AsyncGRPCServer::HandleRequest(bool wait, grpc::ServerCompletionQueue* cq,
} }
PADDLE_ENFORCE(tag); PADDLE_ENFORCE(tag);
if (wait && !done_) { if (cq_name == "cq_get") WaitCond(2);
Wait(); if (cq_name == "cq_send") WaitCond(0);
}
RequestBase* base = (RequestBase*)tag; RequestBase* base = (RequestBase*)tag;
// reference: // reference:
...@@ -222,22 +235,18 @@ void AsyncGRPCServer::HandleRequest(bool wait, grpc::ServerCompletionQueue* cq, ...@@ -222,22 +235,18 @@ void AsyncGRPCServer::HandleRequest(bool wait, grpc::ServerCompletionQueue* cq,
} }
} }
void AsyncGRPCServer::Wait() { void AsyncGRPCServer::WaitCond(int cond) {
std::unique_lock<std::mutex> lock(this->mutex_); std::unique_lock<std::mutex> lock(this->barrier_mutex_);
condition_.wait(lock, [=] { return this->done_ == true; }); barrier_condition_.wait(lock,
} [=] { return this->barrier_cond_step_ == cond; });
void AsyncGRPCServer::Reset() {
std::lock_guard<std::mutex> lock(this->mutex_);
done_ = false;
} }
void AsyncGRPCServer::Done() { void AsyncGRPCServer::SetCond(int cond) {
{ {
std::lock_guard<std::mutex> lock(this->mutex_); std::lock_guard<std::mutex> lock(this->barrier_mutex_);
done_ = true; barrier_cond_step_ = cond;
} }
condition_.notify_all(); barrier_condition_.notify_all();
} }
} // namespace detail } // namespace detail
......
...@@ -41,9 +41,12 @@ class AsyncGRPCServer final : public sendrecv::SendRecvService::Service { ...@@ -41,9 +41,12 @@ class AsyncGRPCServer final : public sendrecv::SendRecvService::Service {
void RunSyncUpdate(); void RunSyncUpdate();
void Reset(); // functions to sync server barrier status.
void WaitStart();
void WaitDone();
void Start();
void Done(); void Done();
void WaitClientGet(int count);
void SetScope(framework::Scope *scope) { scope_ = scope; } void SetScope(framework::Scope *scope) { scope_ = scope; }
...@@ -56,7 +59,6 @@ class AsyncGRPCServer final : public sendrecv::SendRecvService::Service { ...@@ -56,7 +59,6 @@ class AsyncGRPCServer final : public sendrecv::SendRecvService::Service {
void ShutDown(); void ShutDown();
protected: protected:
void Wait();
void HandleRequest(bool wait, grpc::ServerCompletionQueue *cq, void HandleRequest(bool wait, grpc::ServerCompletionQueue *cq,
std::string cq_name, std::string cq_name,
std::function<void()> TryToRegisterNewOne); std::function<void()> TryToRegisterNewOne);
...@@ -78,11 +80,12 @@ class AsyncGRPCServer final : public sendrecv::SendRecvService::Service { ...@@ -78,11 +80,12 @@ class AsyncGRPCServer final : public sendrecv::SendRecvService::Service {
const platform::DeviceContext *dev_ctx_; const platform::DeviceContext *dev_ctx_;
// received variable from RPC, operators fetch variable from this queue. // received variable from RPC, operators fetch variable from this queue.
SimpleBlockQueue<MessageWithName> var_recv_queue_; SimpleBlockQueue<MessageWithName> var_recv_queue_;
SimpleBlockQueue<char> var_get_queue_;
// condition of the sub program // condition of the sub program
std::mutex mutex_; std::mutex barrier_mutex_;
volatile mutable bool done_; mutable int barrier_cond_step_;
std::condition_variable condition_; std::condition_variable barrier_condition_;
std::unique_ptr<std::thread> t_send_; std::unique_ptr<std::thread> t_send_;
std::unique_ptr<std::thread> t_get_; std::unique_ptr<std::thread> t_get_;
......
...@@ -34,6 +34,10 @@ limitations under the License. */ ...@@ -34,6 +34,10 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
constexpr int kCondStart = 0;
constexpr int kCondRunning = 1;
constexpr int kCondDone = 2;
void RunServer(std::shared_ptr<detail::AsyncGRPCServer> service) { void RunServer(std::shared_ptr<detail::AsyncGRPCServer> service) {
service->RunSyncUpdate(); service->RunSyncUpdate();
VLOG(4) << "RunServer thread end"; VLOG(4) << "RunServer thread end";
...@@ -101,12 +105,14 @@ class RecvOp : public framework::OperatorBase { ...@@ -101,12 +105,14 @@ class RecvOp : public framework::OperatorBase {
framework::ProgramDesc program(program_desc); framework::ProgramDesc program(program_desc);
framework::Executor executor(dev_place); framework::Executor executor(dev_place);
rpc_service_->Reset(); // rpc_service_->Reset();
// TODO(typhoonzero): change this to a while_op for every cluster-batch. // TODO(typhoonzero): change this to a while_op for every cluster-batch.
bool exit_flag = false; bool exit_flag = false;
while (!exit_flag) { while (!exit_flag) {
// Get from multiple trainers, we don't care about the order in which // Get from multiple trainers, we don't care about the order in which
// the gradients arrives, just add suffix 0~n and merge the gradient. // the gradients arrives, just add suffix 0~n and merge the gradient.
rpc_service_->SetCond(kCondStart);
VLOG(3) << "================ start get from service ===========";
for (size_t i = 0; i < param_count * fan_in; ++i) { for (size_t i = 0; i < param_count * fan_in; ++i) {
const detail::MessageWithName &v = rpc_service_->Get(); const detail::MessageWithName &v = rpc_service_->Get();
auto grad_var_name = v.first; auto grad_var_name = v.first;
...@@ -139,15 +145,16 @@ class RecvOp : public framework::OperatorBase { ...@@ -139,15 +145,16 @@ class RecvOp : public framework::OperatorBase {
if (exit_flag) { if (exit_flag) {
break; break;
} }
rpc_service_->Reset(); // rpc_service_->Reset();
try { try {
executor.Run(program, &recv_scope, 0, /*global_block*/ executor.Run(program, &recv_scope, 0, /*global_block*/
false /*create_local_scope*/, false /*create_vars*/); false /*create_local_scope*/, false /*create_vars*/);
} catch (std::exception &e) { } catch (std::exception &e) {
LOG(ERROR) << "run sub program error " << e.what(); LOG(ERROR) << "run sub program error " << e.what();
} }
VLOG(3) << "================ run sub program end ===========";
rpc_service_->Done(); rpc_service_->SetCond(kCondDone);
rpc_service_->WaitClientGet(param_count * fan_in);
grads_counter_.clear(); grads_counter_.clear();
} // while(true) } // while(true)
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册