未验证 提交 58be41fa 编写于 作者: 武毅 提交者: GitHub

Merge pull request #7608 from typhoonzero/distributed_split_selectedrows

Enhance distributed train performance
...@@ -63,9 +63,6 @@ bool RPCClient::AsyncGetVariable(const std::string& ep, ...@@ -63,9 +63,6 @@ bool RPCClient::AsyncGetVariable(const std::string& ep,
sendrecv::VariableMessage req; sendrecv::VariableMessage req;
req.set_varname(var_name); req.set_varname(var_name);
auto* var = scope.FindVar(var_name);
SerializeToMessage(var_name, var, ctx, &req);
// varhandle // varhandle
VarHandle var_h; VarHandle var_h;
var_h.ep = ep; var_h.ep = ep;
......
...@@ -36,7 +36,10 @@ class RequestBase { ...@@ -36,7 +36,10 @@ class RequestBase {
CallStatus Status() { return status_; } CallStatus Status() { return status_; }
void SetStatus(CallStatus status) { status_ = status; } void SetStatus(CallStatus status) { status_ = status; }
virtual std::string GetReqName() { assert(false); } virtual std::string GetReqName() {
assert(false);
return "";
}
protected: protected:
grpc::ServerContext ctx_; grpc::ServerContext ctx_;
...@@ -80,11 +83,13 @@ class RequestGet final : public RequestBase { ...@@ -80,11 +83,13 @@ class RequestGet final : public RequestBase {
public: public:
explicit RequestGet(sendrecv::SendRecvService::AsyncService* service, explicit RequestGet(sendrecv::SendRecvService::AsyncService* service,
grpc::ServerCompletionQueue* cq, framework::Scope* scope, grpc::ServerCompletionQueue* cq, framework::Scope* scope,
const platform::DeviceContext* dev_ctx) const platform::DeviceContext* dev_ctx,
SimpleBlockQueue<char>* queue)
: RequestBase(service, cq), : RequestBase(service, cq),
responder_(&ctx_), responder_(&ctx_),
scope_(scope), scope_(scope),
dev_ctx_(dev_ctx) { dev_ctx_(dev_ctx),
queue_(queue) {
service_->RequestGetVariable(&ctx_, &request_, &responder_, cq_, cq_, this); service_->RequestGetVariable(&ctx_, &request_, &responder_, cq_, cq_, this);
} }
...@@ -100,6 +105,7 @@ class RequestGet final : public RequestBase { ...@@ -100,6 +105,7 @@ class RequestGet final : public RequestBase {
// TODO(gongwb): check var's info. // TODO(gongwb): check var's info.
responder_.Finish(reply_, grpc::Status::OK, this); responder_.Finish(reply_, grpc::Status::OK, this);
status_ = FINISH; status_ = FINISH;
queue_->Push('c');
} }
protected: protected:
...@@ -108,8 +114,15 @@ class RequestGet final : public RequestBase { ...@@ -108,8 +114,15 @@ class RequestGet final : public RequestBase {
ServerAsyncResponseWriter<sendrecv::VariableMessage> responder_; ServerAsyncResponseWriter<sendrecv::VariableMessage> responder_;
framework::Scope* scope_; framework::Scope* scope_;
const platform::DeviceContext* dev_ctx_; const platform::DeviceContext* dev_ctx_;
SimpleBlockQueue<char>* queue_;
}; };
void AsyncGRPCServer::WaitClientGet(int count) {
for (int i = 0; i < count; ++i) {
var_get_queue_.Pop();
}
}
void AsyncGRPCServer::RunSyncUpdate() { void AsyncGRPCServer::RunSyncUpdate() {
grpc::ServerBuilder builder; grpc::ServerBuilder builder;
builder.AddListeningPort(address_, grpc::InsecureServerCredentials()); builder.AddListeningPort(address_, grpc::InsecureServerCredentials());
...@@ -149,7 +162,6 @@ void AsyncGRPCServer::ShutdownQueue() { ...@@ -149,7 +162,6 @@ void AsyncGRPCServer::ShutdownQueue() {
} }
// This URL explains why shutdown is complicate: // This URL explains why shutdown is complicate:
// https://stackoverflow.com/questions/35708348/grpc-what-is-the-recommended-way-to-shut-down-an-asynchronous-server-in-c
void AsyncGRPCServer::ShutDown() { void AsyncGRPCServer::ShutDown() {
server_->Shutdown(); server_->Shutdown();
ShutdownQueue(); ShutdownQueue();
...@@ -170,10 +182,12 @@ void AsyncGRPCServer::TryToRegisterNewGetOne() { ...@@ -170,10 +182,12 @@ void AsyncGRPCServer::TryToRegisterNewGetOne() {
if (is_shut_down_) { if (is_shut_down_) {
return; return;
} }
RequestGet* get = new RequestGet(&service_, cq_get_.get(), scope_, dev_ctx_); RequestGet* get = new RequestGet(&service_, cq_get_.get(), scope_, dev_ctx_,
&var_get_queue_);
VLOG(4) << "create Requestget status:" << get->Status(); VLOG(4) << "create Requestget status:" << get->Status();
} }
// FIXME(typhoonzero): remove wait argument and change cq_name to enum.
void AsyncGRPCServer::HandleRequest(bool wait, grpc::ServerCompletionQueue* cq, void AsyncGRPCServer::HandleRequest(bool wait, grpc::ServerCompletionQueue* cq,
std::string cq_name, std::string cq_name,
std::function<void()> TryToRegisterNewOne) { std::function<void()> TryToRegisterNewOne) {
...@@ -188,9 +202,9 @@ void AsyncGRPCServer::HandleRequest(bool wait, grpc::ServerCompletionQueue* cq, ...@@ -188,9 +202,9 @@ void AsyncGRPCServer::HandleRequest(bool wait, grpc::ServerCompletionQueue* cq,
} }
PADDLE_ENFORCE(tag); PADDLE_ENFORCE(tag);
if (wait && !done_) { // FIXME(typhoonzero): de-couple the barriers with recv_op
Wait(); if (cq_name == "cq_get") WaitCond(1);
} if (cq_name == "cq_send") WaitCond(0);
RequestBase* base = (RequestBase*)tag; RequestBase* base = (RequestBase*)tag;
// reference: // reference:
...@@ -222,22 +236,18 @@ void AsyncGRPCServer::HandleRequest(bool wait, grpc::ServerCompletionQueue* cq, ...@@ -222,22 +236,18 @@ void AsyncGRPCServer::HandleRequest(bool wait, grpc::ServerCompletionQueue* cq,
} }
} }
void AsyncGRPCServer::Wait() { void AsyncGRPCServer::WaitCond(int cond) {
std::unique_lock<std::mutex> lock(this->mutex_); std::unique_lock<std::mutex> lock(this->barrier_mutex_);
condition_.wait(lock, [=] { return this->done_ == true; }); barrier_condition_.wait(lock,
} [=] { return this->barrier_cond_step_ == cond; });
void AsyncGRPCServer::Reset() {
std::lock_guard<std::mutex> lock(this->mutex_);
done_ = false;
} }
void AsyncGRPCServer::Done() { void AsyncGRPCServer::SetCond(int cond) {
{ {
std::lock_guard<std::mutex> lock(this->mutex_); std::lock_guard<std::mutex> lock(this->barrier_mutex_);
done_ = true; barrier_cond_step_ = cond;
} }
condition_.notify_all(); barrier_condition_.notify_all();
} }
} // namespace detail } // namespace detail
......
...@@ -41,9 +41,10 @@ class AsyncGRPCServer final : public sendrecv::SendRecvService::Service { ...@@ -41,9 +41,10 @@ class AsyncGRPCServer final : public sendrecv::SendRecvService::Service {
void RunSyncUpdate(); void RunSyncUpdate();
void Reset(); // functions to sync server barrier status.
void WaitCond(int cond);
void Done(); void SetCond(int cond);
void WaitClientGet(int count);
void SetScope(framework::Scope *scope) { scope_ = scope; } void SetScope(framework::Scope *scope) { scope_ = scope; }
...@@ -56,7 +57,6 @@ class AsyncGRPCServer final : public sendrecv::SendRecvService::Service { ...@@ -56,7 +57,6 @@ class AsyncGRPCServer final : public sendrecv::SendRecvService::Service {
void ShutDown(); void ShutDown();
protected: protected:
void Wait();
void HandleRequest(bool wait, grpc::ServerCompletionQueue *cq, void HandleRequest(bool wait, grpc::ServerCompletionQueue *cq,
std::string cq_name, std::string cq_name,
std::function<void()> TryToRegisterNewOne); std::function<void()> TryToRegisterNewOne);
...@@ -78,11 +78,12 @@ class AsyncGRPCServer final : public sendrecv::SendRecvService::Service { ...@@ -78,11 +78,12 @@ class AsyncGRPCServer final : public sendrecv::SendRecvService::Service {
const platform::DeviceContext *dev_ctx_; const platform::DeviceContext *dev_ctx_;
// received variable from RPC, operators fetch variable from this queue. // received variable from RPC, operators fetch variable from this queue.
SimpleBlockQueue<MessageWithName> var_recv_queue_; SimpleBlockQueue<MessageWithName> var_recv_queue_;
SimpleBlockQueue<char> var_get_queue_;
// condition of the sub program // condition of the sub program
std::mutex mutex_; std::mutex barrier_mutex_;
volatile mutable bool done_; mutable int barrier_cond_step_;
std::condition_variable condition_; std::condition_variable barrier_condition_;
std::unique_ptr<std::thread> t_send_; std::unique_ptr<std::thread> t_send_;
std::unique_ptr<std::thread> t_get_; std::unique_ptr<std::thread> t_get_;
......
...@@ -27,12 +27,17 @@ limitations under the License. */ ...@@ -27,12 +27,17 @@ limitations under the License. */
#include "paddle/operators/detail/grpc_server.h" #include "paddle/operators/detail/grpc_server.h"
#include "paddle/operators/detail/sendrecvop_utils.h" #include "paddle/operators/detail/sendrecvop_utils.h"
#include "paddle/operators/detail/simple_block_queue.h" #include "paddle/operators/detail/simple_block_queue.h"
#include "paddle/string/printf.h"
#define LISTEN_TERMINATE_MESSAGE "TERMINATE@RECV" #define LISTEN_TERMINATE_MESSAGE "TERMINATE@RECV"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
constexpr int kCondStart = 0;
constexpr int kCondRunning = 1;
constexpr int kCondDone = 2;
void RunServer(std::shared_ptr<detail::AsyncGRPCServer> service) { void RunServer(std::shared_ptr<detail::AsyncGRPCServer> service) {
service->RunSyncUpdate(); service->RunSyncUpdate();
VLOG(4) << "RunServer thread end"; VLOG(4) << "RunServer thread end";
...@@ -77,42 +82,41 @@ class RecvOp : public framework::OperatorBase { ...@@ -77,42 +82,41 @@ class RecvOp : public framework::OperatorBase {
if (grads_counter_.find(varname) == grads_counter_.end()) { if (grads_counter_.find(varname) == grads_counter_.end()) {
grads_counter_[varname] = 0; grads_counter_[varname] = 0;
} }
char ret[256]; return string::Sprintf("%s.trainer_%d", varname, grads_counter_[varname]++);
snprintf(ret, sizeof(ret), "%s.trainer_%d", varname.c_str(),
grads_counter_[varname]++);
return std::string(ret);
} }
void Run(const framework::Scope &scope, void Run(const framework::Scope &scope,
const platform::Place &dev_place) const override { const platform::Place &dev_place) const override {
// FIXME(typhoonzero): no new scopes for every run.
framework::Scope &recv_scope = scope.NewScope();
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Get(dev_place); auto &dev_ctx = *pool.Get(dev_place);
framework::Scope &recv_scope = scope.NewScope();
// FIXME(Yancey1989): initialize rpc server with laze mode. // FIXME(Yancey1989): initialize rpc server with laze mode.
rpc_service_->SetScope(&recv_scope); rpc_service_->SetScope(&recv_scope);
rpc_service_->SetDevCtx(&dev_ctx); rpc_service_->SetDevCtx(&dev_ctx);
auto param_list = Attr<std::vector<std::string>>("ParamList"); auto param_list = Attr<std::vector<std::string>>("ParamList");
auto grad_list = Attr<std::vector<std::string>>("GradList"); auto grad_list = Attr<std::vector<std::string>>("GradList");
auto trainer_count = Attr<int>("Trainers"); auto fan_in = Attr<int>("Fanin");
size_t param_count = param_list.size(); size_t param_count = param_list.size();
rpc_service_->Reset(); std::string program_str = Attr<std::string>("OptimizeProgram");
framework::proto::ProgramDesc program_desc;
program_desc.ParseFromString(program_str);
framework::ProgramDesc program(program_desc);
framework::Executor executor(dev_place);
// TODO(typhoonzero): change this to a while_op for every cluster-batch. // TODO(typhoonzero): change this to a while_op for every cluster-batch.
bool exit_flag = false; bool exit_flag = false;
VLOG(4) << "param_count:" << param_count int64_t barrier_size = param_count * fan_in;
<< " trainer_count:" << trainer_count;
while (!exit_flag) { while (!exit_flag) {
// TODO(gognwb): simply this loop. // Get from multiple trainers, we don't care about the order in which
// Get from multiple trainers, we don't care about order in which // the gradients arrives, just add suffix 0~n and merge the gradient.
// the gradient arrives, just add suffix 0~n then average the gradient. rpc_service_->SetCond(0);
for (size_t i = 0; i < param_count * trainer_count; ++i) { for (size_t i = 0; i < barrier_size; ++i) {
// blocking get one var from client.
const detail::MessageWithName &v = rpc_service_->Get(); const detail::MessageWithName &v = rpc_service_->Get();
auto grad_var_name = v.first; auto grad_var_name = v.first;
if (grad_var_name == LISTEN_TERMINATE_MESSAGE) { if (grad_var_name == LISTEN_TERMINATE_MESSAGE) {
VLOG(4) << "received LISTEN_TERMINATE_MESSAGE and RunOp.Run() exit"; LOG(INFO) << "received terminate message and exit";
exit_flag = true; exit_flag = true;
break; break;
} }
...@@ -121,49 +125,31 @@ class RecvOp : public framework::OperatorBase { ...@@ -121,49 +125,31 @@ class RecvOp : public framework::OperatorBase {
if (it != grad_list.end()) { if (it != grad_list.end()) {
param_var_name = param_list[it - grad_list.begin()]; param_var_name = param_list[it - grad_list.begin()];
} else { } else {
LOG(ERROR) << "grad have no paired param found!\"" << grad_var_name LOG(ERROR) << "grad have no paired param:" << grad_var_name;
<< "\"";
} }
VLOG(3) << "recved grad: " << grad_var_name VLOG(3) << "recved grad: " << grad_var_name
<< " updating param: " << param_var_name; << " updating param: " << param_var_name;
if (fan_in > 1) {
auto *merged_grad = recv_scope.FindVar(grad_var_name);
if (merged_grad == nullptr) {
auto *ptr = recv_scope.Var(grad_var_name);
CreateTensorFromMessageType(ptr, v.second.type());
VLOG(3) << "Create Variable " << grad_var_name
<< " on recv scope, which pointer is " << ptr << " type is "
<< v.second.type();
}
if (trainer_count > 1) {
grad_var_name = this->GetGradVarNameForTrainer(grad_var_name); grad_var_name = this->GetGradVarNameForTrainer(grad_var_name);
} }
auto *var = recv_scope.FindVar(grad_var_name);
auto *var = recv_scope.Var(grad_var_name); if (var == nullptr) {
LOG(ERROR) << "can not find server side var: " << grad_var_name;
PADDLE_THROW("can not find server side var");
}
detail::DeserializeFromMessage(v.second, dev_ctx, var); detail::DeserializeFromMessage(v.second, dev_ctx, var);
} }
if (exit_flag) { if (exit_flag) {
break; break;
} }
rpc_service_->Reset();
std::string program_str = Attr<std::string>("OptimizeProgram");
framework::proto::ProgramDesc program_desc;
program_desc.ParseFromString(program_str);
framework::ProgramDesc program(program_desc);
framework::Executor executor(dev_place);
// Run sub graph to get optimized tensor
try { try {
executor.Run(program, &recv_scope, 0, /*global_block*/ executor.Run(program, &recv_scope, 0, /*global_block*/
false /*create_local_scope*/, false /*create_vars*/); false /*create_local_scope*/, false /*create_vars*/);
} catch (std::exception &e) { } catch (std::exception &e) {
LOG(ERROR) << "run sub program error " << e.what(); LOG(ERROR) << "run sub program error " << e.what();
} }
rpc_service_->SetCond(1);
rpc_service_->Done(); rpc_service_->WaitClientGet(barrier_size);
grads_counter_.clear(); grads_counter_.clear();
} // while(true) } // while(true)
} }
...@@ -199,7 +185,7 @@ This operator will recv tensor from send_op ...@@ -199,7 +185,7 @@ This operator will recv tensor from send_op
"GradList", "type list of string", "GradList", "type list of string",
"grad->param name mapping to find which param to optimize.") "grad->param name mapping to find which param to optimize.")
.SetDefault({}); .SetDefault({});
AddAttr<int>("Trainers", "type int", AddAttr<int>("Fanin", "type int",
"Number of trainers in the current cluster job") "Number of trainers in the current cluster job")
.SetDefault(1); .SetDefault(1);
} }
......
...@@ -41,10 +41,13 @@ class SendOp : public framework::OperatorBase { ...@@ -41,10 +41,13 @@ class SendOp : public framework::OperatorBase {
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto& ctx = *pool.Get(place); auto& ctx = *pool.Get(place);
for (size_t i = 0; i < ins.size(); i++) { for (size_t i = 0; i < ins.size(); i++) {
VLOG(3) << "sending " << ins[i];
client_.AsyncSendVariable(epmap[i], ctx, scope, ins[i]); client_.AsyncSendVariable(epmap[i], ctx, scope, ins[i]);
} }
PADDLE_ENFORCE(client_.Wait());
for (size_t i = 0; i < outs.size(); i++) { for (size_t i = 0; i < outs.size(); i++) {
VLOG(3) << "getting " << outs[i];
client_.AsyncGetVariable(epmap[i], ctx, scope, outs[i]); client_.AsyncGetVariable(epmap[i], ctx, scope, outs[i]);
} }
......
...@@ -420,6 +420,19 @@ class DistributeTranspiler: ...@@ -420,6 +420,19 @@ class DistributeTranspiler:
pserver_program = Program() pserver_program = Program()
for v in self.param_grad_ep_mapping[endpoint]["params"]: for v in self.param_grad_ep_mapping[endpoint]["params"]:
self._clone_var(pserver_program.global_block(), v) self._clone_var(pserver_program.global_block(), v)
for v in self.param_grad_ep_mapping[endpoint]["grads"]:
# create vars for each trainer in global scope, so
# we don't need to create them when grad arrives.
pserver_program.global_block().create_var(
name=v.name, persistable=True, dtype=v.dtype, shape=v.shape)
for trainer_id in xrange(self.trainers):
print("create variable for program: %s.trainer_%d" %
(v.name, trainer_id))
pserver_program.global_block().create_var(
name="%s.trainer_%d" % (v.name, trainer_id),
persistable=True,
dtype=v.dtype,
shape=v.shape)
# step6 # step6
optimize_sub_program = Program() optimize_sub_program = Program()
for idx, opt_op in enumerate(self.optimize_ops): for idx, opt_op in enumerate(self.optimize_ops):
...@@ -449,7 +462,7 @@ class DistributeTranspiler: ...@@ -449,7 +462,7 @@ class DistributeTranspiler:
p.name p.name
for p in self.param_grad_ep_mapping[endpoint]["grads"] for p in self.param_grad_ep_mapping[endpoint]["grads"]
], ],
"Trainers": self.trainers "Fanin": self.trainers
}) })
pserver_program.sync_with_cpp() pserver_program.sync_with_cpp()
return pserver_program return pserver_program
......
...@@ -52,26 +52,27 @@ train_reader = paddle.batch( ...@@ -52,26 +52,27 @@ train_reader = paddle.batch(
place = fluid.CPUPlace() place = fluid.CPUPlace()
exe = fluid.Executor(place) exe = fluid.Executor(place)
t = fluid.DistributeTranspiler() pserver_endpoints = os.getenv("PSERVERS") # all pserver endpoints
# all parameter server endpoints list for spliting parameters trainers = int(os.getenv("TRAINERS")) # total trainer count
pserver_endpoints = os.getenv("PSERVERS") current_endpoint = os.getenv("SERVER_ENDPOINT") # current pserver endpoint
# server endpoint for current node
current_endpoint = os.getenv("SERVER_ENDPOINT")
# run as trainer or parameter server
training_role = os.getenv("TRAINING_ROLE", training_role = os.getenv("TRAINING_ROLE",
"TRAINER") # get the training role: trainer/pserver "TRAINER") # get the training role: trainer/pserver
t.transpile(optimize_ops, params_grads, pservers=pserver_endpoints, trainers=2) t = fluid.DistributeTranspiler()
t.transpile(
optimize_ops, params_grads, pservers=pserver_endpoints, trainers=trainers)
if training_role == "PSERVER": if training_role == "PSERVER":
if not current_endpoint: if not current_endpoint:
print("need env SERVER_ENDPOINT") print("need env SERVER_ENDPOINT")
exit(1) exit(1)
pserver_prog = t.get_pserver_program(current_endpoint) pserver_prog = t.get_pserver_program(current_endpoint)
exe.run(fluid.default_startup_program()) pserver_startup = t.get_startup_program(current_endpoint, pserver_prog)
exe.run(pserver_startup)
exe.run(pserver_prog) exe.run(pserver_prog)
elif training_role == "TRAINER": elif training_role == "TRAINER":
trainer_prog = t.get_trainer_program() trainer_prog = t.get_trainer_program()
feeder = fluid.DataFeeder(feed_list=[images, label], place=place) feeder = fluid.DataFeeder(feed_list=[images, label], place=place)
# TODO(typhoonzero): change trainer startup program to fetch parameters from pserver
exe.run(fluid.default_startup_program()) exe.run(fluid.default_startup_program())
for pass_id in range(PASS_NUM): for pass_id in range(PASS_NUM):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册