提交 ae19d2ea 编写于 作者: T typhoonzero

fix comm issues

上级 f233b936
......@@ -36,7 +36,10 @@ class RequestBase {
CallStatus Status() { return status_; }
void SetStatus(CallStatus status) { status_ = status; }
virtual std::string GetReqName() { assert(false); }
virtual std::string GetReqName() {
assert(false);
return "";
}
protected:
grpc::ServerContext ctx_;
......@@ -80,11 +83,13 @@ class RequestGet final : public RequestBase {
public:
explicit RequestGet(sendrecv::SendRecvService::AsyncService* service,
grpc::ServerCompletionQueue* cq, framework::Scope* scope,
const platform::DeviceContext* dev_ctx)
const platform::DeviceContext* dev_ctx,
SimpleBlockQueue<char>* queue)
: RequestBase(service, cq),
responder_(&ctx_),
scope_(scope),
dev_ctx_(dev_ctx) {
dev_ctx_(dev_ctx),
queue_(queue) {
service_->RequestGetVariable(&ctx_, &request_, &responder_, cq_, cq_, this);
}
......@@ -100,6 +105,7 @@ class RequestGet final : public RequestBase {
// TODO(gongwb): check var's info.
responder_.Finish(reply_, grpc::Status::OK, this);
status_ = FINISH;
queue_->Push('c');
}
protected:
......@@ -108,8 +114,15 @@ class RequestGet final : public RequestBase {
ServerAsyncResponseWriter<sendrecv::VariableMessage> responder_;
framework::Scope* scope_;
const platform::DeviceContext* dev_ctx_;
SimpleBlockQueue<char>* queue_;
};
void AsyncGRPCServer::WaitClientGet(int count) {
for (int i = 0; i < count; ++i) {
var_get_queue_.Pop();
}
}
void AsyncGRPCServer::RunSyncUpdate() {
grpc::ServerBuilder builder;
builder.AddListeningPort(address_, grpc::InsecureServerCredentials());
......@@ -170,7 +183,8 @@ void AsyncGRPCServer::TryToRegisterNewGetOne() {
if (is_shut_down_) {
return;
}
RequestGet* get = new RequestGet(&service_, cq_get_.get(), scope_, dev_ctx_);
RequestGet* get = new RequestGet(&service_, cq_get_.get(), scope_, dev_ctx_,
&var_get_queue_);
VLOG(4) << "create Requestget status:" << get->Status();
}
......@@ -188,9 +202,8 @@ void AsyncGRPCServer::HandleRequest(bool wait, grpc::ServerCompletionQueue* cq,
}
PADDLE_ENFORCE(tag);
if (wait && !done_) {
Wait();
}
if (cq_name == "cq_get") WaitCond(2);
if (cq_name == "cq_send") WaitCond(0);
RequestBase* base = (RequestBase*)tag;
// reference:
......@@ -222,22 +235,18 @@ void AsyncGRPCServer::HandleRequest(bool wait, grpc::ServerCompletionQueue* cq,
}
}
void AsyncGRPCServer::Wait() {
std::unique_lock<std::mutex> lock(this->mutex_);
condition_.wait(lock, [=] { return this->done_ == true; });
}
void AsyncGRPCServer::Reset() {
std::lock_guard<std::mutex> lock(this->mutex_);
done_ = false;
void AsyncGRPCServer::WaitCond(int cond) {
std::unique_lock<std::mutex> lock(this->barrier_mutex_);
barrier_condition_.wait(lock,
[=] { return this->barrier_cond_step_ == cond; });
}
void AsyncGRPCServer::Done() {
void AsyncGRPCServer::SetCond(int cond) {
{
std::lock_guard<std::mutex> lock(this->mutex_);
done_ = true;
std::lock_guard<std::mutex> lock(this->barrier_mutex_);
barrier_cond_step_ = cond;
}
condition_.notify_all();
barrier_condition_.notify_all();
}
} // namespace detail
......
......@@ -41,9 +41,12 @@ class AsyncGRPCServer final : public sendrecv::SendRecvService::Service {
void RunSyncUpdate();
void Reset();
// functions to sync server barrier status.
void WaitStart();
void WaitDone();
void Start();
void Done();
void WaitClientGet(int count);
void SetScope(framework::Scope *scope) { scope_ = scope; }
......@@ -56,7 +59,6 @@ class AsyncGRPCServer final : public sendrecv::SendRecvService::Service {
void ShutDown();
protected:
void Wait();
void HandleRequest(bool wait, grpc::ServerCompletionQueue *cq,
std::string cq_name,
std::function<void()> TryToRegisterNewOne);
......@@ -78,11 +80,12 @@ class AsyncGRPCServer final : public sendrecv::SendRecvService::Service {
const platform::DeviceContext *dev_ctx_;
// received variable from RPC, operators fetch variable from this queue.
SimpleBlockQueue<MessageWithName> var_recv_queue_;
SimpleBlockQueue<char> var_get_queue_;
// condition of the sub program
std::mutex mutex_;
volatile mutable bool done_;
std::condition_variable condition_;
std::mutex barrier_mutex_;
mutable int barrier_cond_step_;
std::condition_variable barrier_condition_;
std::unique_ptr<std::thread> t_send_;
std::unique_ptr<std::thread> t_get_;
......
......@@ -34,6 +34,10 @@ limitations under the License. */
namespace paddle {
namespace operators {
constexpr int kCondStart = 0;
constexpr int kCondRunning = 1;
constexpr int kCondDone = 2;
void RunServer(std::shared_ptr<detail::AsyncGRPCServer> service) {
service->RunSyncUpdate();
VLOG(4) << "RunServer thread end";
......@@ -101,12 +105,14 @@ class RecvOp : public framework::OperatorBase {
framework::ProgramDesc program(program_desc);
framework::Executor executor(dev_place);
rpc_service_->Reset();
// rpc_service_->Reset();
// TODO(typhoonzero): change this to a while_op for every cluster-batch.
bool exit_flag = false;
while (!exit_flag) {
// Get from multiple trainers, we don't care about the order in which
// the gradients arrives, just add suffix 0~n and merge the gradient.
rpc_service_->SetCond(kCondStart);
VLOG(3) << "================ start get from service ===========";
for (size_t i = 0; i < param_count * fan_in; ++i) {
const detail::MessageWithName &v = rpc_service_->Get();
auto grad_var_name = v.first;
......@@ -139,15 +145,16 @@ class RecvOp : public framework::OperatorBase {
if (exit_flag) {
break;
}
rpc_service_->Reset();
// rpc_service_->Reset();
try {
executor.Run(program, &recv_scope, 0, /*global_block*/
false /*create_local_scope*/, false /*create_vars*/);
} catch (std::exception &e) {
LOG(ERROR) << "run sub program error " << e.what();
}
rpc_service_->Done();
VLOG(3) << "================ run sub program end ===========";
rpc_service_->SetCond(kCondDone);
rpc_service_->WaitClientGet(param_count * fan_in);
grads_counter_.clear();
} // while(true)
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册