未验证 提交 a585b585 编写于 作者: Y Yancey 提交者: GitHub

Batch barrier in send/recv op (#7847)

* initialize batch barrier

* add some comments

* update

* fix batch barrier

* use sendvariable rpc interface to send batch barrier

* fix comment

* fix method

* fix by comment

* fix by comment
上级 f224b90d
...@@ -97,6 +97,21 @@ bool RPCClient::AsyncGetVariable(const std::string& ep, ...@@ -97,6 +97,21 @@ bool RPCClient::AsyncGetVariable(const std::string& ep,
return true; return true;
} }
bool RPCClient::AsyncSendBatchBarrier(const std::string& ep, int64_t time_out) {
const auto ch = GetChannel(ep);
BatchBarrierProcessor* s = new BatchBarrierProcessor(ch);
s->Prepare(time_out);
sendrecv::VariableMessage req;
req.set_varname(BATCH_BARRIER_MESSAGE);
auto rpc = s->stub_->AsyncSendVariable(s->context_.get(), req, &cq_);
rpc->Finish(&s->reply_, &s->status_, (void*)s);
req_count_++;
return true;
}
bool RPCClient::Wait() { bool RPCClient::Wait() {
if (req_count_ <= 0) { if (req_count_ <= 0) {
return true; return true;
......
...@@ -71,6 +71,15 @@ class ClientBase { ...@@ -71,6 +71,15 @@ class ClientBase {
context_->set_deadline(deadline); context_->set_deadline(deadline);
} }
virtual void Prepare(int64_t time_out) {
context_.reset(new grpc::ClientContext());
std::chrono::system_clock::time_point deadline =
std::chrono::system_clock::now() + std::chrono::milliseconds(time_out);
context_->set_deadline(deadline);
}
virtual void Process() = 0; virtual void Process() = 0;
std::unique_ptr<sendrecv::SendRecvService::Stub> stub_; std::unique_ptr<sendrecv::SendRecvService::Stub> stub_;
...@@ -117,6 +126,17 @@ class GetProcessor : public ClientBase { ...@@ -117,6 +126,17 @@ class GetProcessor : public ClientBase {
RequestGetCallBack response_call_back_ = ProcGetResponse; RequestGetCallBack response_call_back_ = ProcGetResponse;
}; };
class BatchBarrierProcessor : public ClientBase {
public:
explicit BatchBarrierProcessor(std::shared_ptr<grpc::Channel> ch)
: ClientBase(ch) {}
virtual ~BatchBarrierProcessor() {}
virtual void Process() {}
sendrecv::VoidMessage reply_;
};
class RPCClient { class RPCClient {
public: public:
bool AsyncSendVariable(const std::string& ep, bool AsyncSendVariable(const std::string& ep,
...@@ -130,6 +150,10 @@ class RPCClient { ...@@ -130,6 +150,10 @@ class RPCClient {
const framework::Scope& scope, const framework::Scope& scope,
const std::string& var_name, const std::string& var_name,
int64_t time_out = 600 * 1000); int64_t time_out = 600 * 1000);
bool AsyncSendBatchBarrier(const std::string& ep,
int64_t time_out = 600 * 1000);
bool Wait(); bool Wait();
private: private:
......
...@@ -132,6 +132,7 @@ void AsyncGRPCServer::RunSyncUpdate() { ...@@ -132,6 +132,7 @@ void AsyncGRPCServer::RunSyncUpdate() {
cq_send_ = builder.AddCompletionQueue(); cq_send_ = builder.AddCompletionQueue();
cq_get_ = builder.AddCompletionQueue(); cq_get_ = builder.AddCompletionQueue();
server_ = builder.BuildAndStart(); server_ = builder.BuildAndStart();
LOG(INFO) << "Server listening on " << address_ << std::endl; LOG(INFO) << "Server listening on " << address_ << std::endl;
...@@ -141,11 +142,11 @@ void AsyncGRPCServer::RunSyncUpdate() { ...@@ -141,11 +142,11 @@ void AsyncGRPCServer::RunSyncUpdate() {
std::bind(&AsyncGRPCServer::TryToRegisterNewGetOne, this); std::bind(&AsyncGRPCServer::TryToRegisterNewGetOne, this);
t_send_.reset( t_send_.reset(
new std::thread(std::bind(&AsyncGRPCServer::HandleRequest, this, false, new std::thread(std::bind(&AsyncGRPCServer::HandleRequest, this,
cq_send_.get(), "cq_send", send_register))); cq_send_.get(), "cq_send", send_register)));
t_get_.reset( t_get_.reset(
new std::thread(std::bind(&AsyncGRPCServer::HandleRequest, this, true, new std::thread(std::bind(&AsyncGRPCServer::HandleRequest, this,
cq_get_.get(), "cq_get", get_register))); cq_get_.get(), "cq_get", get_register)));
// wait server // wait server
...@@ -174,7 +175,7 @@ void AsyncGRPCServer::TryToRegisterNewSendOne() { ...@@ -174,7 +175,7 @@ void AsyncGRPCServer::TryToRegisterNewSendOne() {
} }
RequestSend* send = RequestSend* send =
new RequestSend(&service_, cq_send_.get(), &var_recv_queue_); new RequestSend(&service_, cq_send_.get(), &var_recv_queue_);
VLOG(4) << "create RequestSend status:" << send->Status(); VLOG(4) << "Create RequestSend status:" << send->Status();
} }
void AsyncGRPCServer::TryToRegisterNewGetOne() { void AsyncGRPCServer::TryToRegisterNewGetOne() {
...@@ -184,11 +185,11 @@ void AsyncGRPCServer::TryToRegisterNewGetOne() { ...@@ -184,11 +185,11 @@ void AsyncGRPCServer::TryToRegisterNewGetOne() {
} }
RequestGet* get = new RequestGet(&service_, cq_get_.get(), scope_, dev_ctx_, RequestGet* get = new RequestGet(&service_, cq_get_.get(), scope_, dev_ctx_,
&var_get_queue_); &var_get_queue_);
VLOG(4) << "create Requestget status:" << get->Status(); VLOG(4) << "Create RequestGet status:" << get->Status();
} }
// FIXME(typhoonzero): remove wait argument and change cq_name to enum. // FIXME(typhoonzero): change cq_name to enum.
void AsyncGRPCServer::HandleRequest(bool wait, grpc::ServerCompletionQueue* cq, void AsyncGRPCServer::HandleRequest(grpc::ServerCompletionQueue* cq,
std::string cq_name, std::string cq_name,
std::function<void()> TryToRegisterNewOne) { std::function<void()> TryToRegisterNewOne) {
TryToRegisterNewOne(); TryToRegisterNewOne();
......
...@@ -57,8 +57,7 @@ class AsyncGRPCServer final : public sendrecv::SendRecvService::Service { ...@@ -57,8 +57,7 @@ class AsyncGRPCServer final : public sendrecv::SendRecvService::Service {
void ShutDown(); void ShutDown();
protected: protected:
void HandleRequest(bool wait, grpc::ServerCompletionQueue *cq, void HandleRequest(grpc::ServerCompletionQueue *cq, std::string cq_name,
std::string cq_name,
std::function<void()> TryToRegisterNewOne); std::function<void()> TryToRegisterNewOne);
void TryToRegisterNewSendOne(); void TryToRegisterNewSendOne();
void TryToRegisterNewGetOne(); void TryToRegisterNewGetOne();
......
...@@ -30,6 +30,9 @@ namespace paddle { ...@@ -30,6 +30,9 @@ namespace paddle {
namespace operators { namespace operators {
namespace detail { namespace detail {
#define LISTEN_TERMINATE_MESSAGE "TERMINATE@RECV"
#define BATCH_BARRIER_MESSAGE "BATCH_BARRIER@RECV"
void SerializeToMessage(const std::string& name, const framework::Variable* var, void SerializeToMessage(const std::string& name, const framework::Variable* var,
const platform::DeviceContext& ctx, const platform::DeviceContext& ctx,
sendrecv::VariableMessage* msg); sendrecv::VariableMessage* msg);
......
...@@ -29,8 +29,6 @@ limitations under the License. */ ...@@ -29,8 +29,6 @@ limitations under the License. */
#include "paddle/operators/detail/simple_block_queue.h" #include "paddle/operators/detail/simple_block_queue.h"
#include "paddle/string/printf.h" #include "paddle/string/printf.h"
#define LISTEN_TERMINATE_MESSAGE "TERMINATE@RECV"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -95,7 +93,6 @@ class RecvOp : public framework::OperatorBase { ...@@ -95,7 +93,6 @@ class RecvOp : public framework::OperatorBase {
auto param_list = Attr<std::vector<std::string>>("ParamList"); auto param_list = Attr<std::vector<std::string>>("ParamList");
auto grad_list = Attr<std::vector<std::string>>("GradList"); auto grad_list = Attr<std::vector<std::string>>("GradList");
auto fan_in = Attr<int>("Fanin"); auto fan_in = Attr<int>("Fanin");
size_t param_count = param_list.size();
auto *block = Attr<framework::BlockDesc *>(kOptimizeBlock); auto *block = Attr<framework::BlockDesc *>(kOptimizeBlock);
auto *program = block->Program(); auto *program = block->Program();
...@@ -103,38 +100,50 @@ class RecvOp : public framework::OperatorBase { ...@@ -103,38 +100,50 @@ class RecvOp : public framework::OperatorBase {
// TODO(typhoonzero): change this to a while_op for every cluster-batch. // TODO(typhoonzero): change this to a while_op for every cluster-batch.
bool exit_flag = false; bool exit_flag = false;
size_t barrier_size = param_count * fan_in;
while (!exit_flag) { while (!exit_flag) {
// Get from multiple trainers, we don't care about the order in which // Get from multiple trainers, we don't care about the order in which
// the gradients arrives, just add suffix 0~n and merge the gradient. // the gradients arrives, just add suffix 0~n and merge the gradient.
rpc_service_->SetCond(0); rpc_service_->SetCond(0);
for (size_t i = 0; i < barrier_size; ++i) { size_t recv_var_cnt = 0;
int batch_barrier = 0;
while (batch_barrier != fan_in) {
const detail::MessageWithName &v = rpc_service_->Get(); const detail::MessageWithName &v = rpc_service_->Get();
auto grad_var_name = v.first; auto grad_var_name = v.first;
if (grad_var_name == LISTEN_TERMINATE_MESSAGE) { if (grad_var_name == LISTEN_TERMINATE_MESSAGE) {
LOG(INFO) << "received terminate message and exit"; LOG(INFO) << "received terminate message and exit";
exit_flag = true; exit_flag = true;
break; break;
} } else if (grad_var_name == BATCH_BARRIER_MESSAGE) {
auto it = std::find(grad_list.begin(), grad_list.end(), grad_var_name); VLOG(3) << "recv batch barrier message";
std::string param_var_name; batch_barrier++;
if (it != grad_list.end()) { continue;
param_var_name = param_list[it - grad_list.begin()];
} else { } else {
LOG(ERROR) << "grad has no paired param:" << grad_var_name; // receive a variable
} recv_var_cnt++;
VLOG(3) << "received grad: " << grad_var_name auto it =
<< " updating param: " << param_var_name; std::find(grad_list.begin(), grad_list.end(), grad_var_name);
if (fan_in > 1) { std::string param_var_name;
grad_var_name = this->GetGradVarNameForTrainer(grad_var_name); if (it != grad_list.end()) {
} param_var_name = param_list[it - grad_list.begin()];
auto *var = recv_scope.FindVar(grad_var_name); } else {
if (var == nullptr) { LOG(ERROR) << "grad has no paired param:" << grad_var_name;
LOG(ERROR) << "Can not find server side var: " << grad_var_name; }
PADDLE_THROW("Can not find server side var"); VLOG(3) << "received grad: " << grad_var_name
<< " updating param: " << param_var_name;
if (fan_in > 1) {
grad_var_name = this->GetGradVarNameForTrainer(grad_var_name);
}
auto *var = recv_scope.FindVar(grad_var_name);
if (var == nullptr) {
LOG(ERROR) << "Can not find server side var: " << grad_var_name;
PADDLE_THROW("Can not find server side var");
}
detail::DeserializeFromMessage(v.second, dev_ctx, var);
} }
detail::DeserializeFromMessage(v.second, dev_ctx, var);
} }
VLOG(3) << "recv " << recv_var_cnt << " parmeters for one barrier.";
// TODO(Yancey1989): merge SelectedRows variables here
if (exit_flag) { if (exit_flag) {
break; break;
} }
...@@ -146,7 +155,7 @@ class RecvOp : public framework::OperatorBase { ...@@ -146,7 +155,7 @@ class RecvOp : public framework::OperatorBase {
LOG(ERROR) << "run sub program error " << e.what(); LOG(ERROR) << "run sub program error " << e.what();
} }
rpc_service_->SetCond(1); rpc_service_->SetCond(1);
rpc_service_->WaitClientGet(barrier_size); rpc_service_->WaitClientGet(recv_var_cnt);
grads_counter_.clear(); grads_counter_.clear();
} // while(true) } // while(true)
} }
......
...@@ -37,17 +37,25 @@ class SendOp : public framework::OperatorBase { ...@@ -37,17 +37,25 @@ class SendOp : public framework::OperatorBase {
auto ins = Inputs("X"); auto ins = Inputs("X");
auto outs = Outputs("Out"); auto outs = Outputs("Out");
std::vector<std::string> epmap = Attr<std::vector<std::string>>("epmap"); std::vector<std::string> epmap = Attr<std::vector<std::string>>("epmap");
std::vector<std::string> endpoints =
Attr<std::vector<std::string>>("endpoints");
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto& ctx = *pool.Get(place); auto& ctx = *pool.Get(place);
for (size_t i = 0; i < ins.size(); i++) { for (size_t i = 0; i < ins.size(); i++) {
VLOG(3) << "sending " << ins[i]; VLOG(3) << "sending " << ins[i] << " to " << epmap[i];
client_.AsyncSendVariable(epmap[i], ctx, scope, ins[i]); client_.AsyncSendVariable(epmap[i], ctx, scope, ins[i]);
} }
PADDLE_ENFORCE(client_.Wait()); PADDLE_ENFORCE(client_.Wait());
for (auto& ep : endpoints) {
VLOG(3) << "batch barrier, ep: " << ep;
client_.AsyncSendBatchBarrier(ep);
}
PADDLE_ENFORCE(client_.Wait());
for (size_t i = 0; i < outs.size(); i++) { for (size_t i = 0; i < outs.size(); i++) {
VLOG(3) << "getting " << outs[i]; VLOG(3) << "getting " << outs[i] << " from " << epmap[i];
client_.AsyncGetVariable(epmap[i], ctx, scope, outs[i]); client_.AsyncGetVariable(epmap[i], ctx, scope, outs[i]);
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册