未验证 提交 58be41fa 编写于 作者: 武毅 提交者: GitHub

Merge pull request #7608 from typhoonzero/distributed_split_selectedrows

Enhance distributed train performance
......@@ -63,9 +63,6 @@ bool RPCClient::AsyncGetVariable(const std::string& ep,
sendrecv::VariableMessage req;
req.set_varname(var_name);
auto* var = scope.FindVar(var_name);
SerializeToMessage(var_name, var, ctx, &req);
// varhandle
VarHandle var_h;
var_h.ep = ep;
......
......@@ -36,7 +36,10 @@ class RequestBase {
CallStatus Status() { return status_; }
void SetStatus(CallStatus status) { status_ = status; }
virtual std::string GetReqName() { assert(false); }
virtual std::string GetReqName() {
assert(false);
return "";
}
protected:
grpc::ServerContext ctx_;
......@@ -80,11 +83,13 @@ class RequestGet final : public RequestBase {
public:
explicit RequestGet(sendrecv::SendRecvService::AsyncService* service,
grpc::ServerCompletionQueue* cq, framework::Scope* scope,
const platform::DeviceContext* dev_ctx)
const platform::DeviceContext* dev_ctx,
SimpleBlockQueue<char>* queue)
: RequestBase(service, cq),
responder_(&ctx_),
scope_(scope),
dev_ctx_(dev_ctx) {
dev_ctx_(dev_ctx),
queue_(queue) {
service_->RequestGetVariable(&ctx_, &request_, &responder_, cq_, cq_, this);
}
......@@ -100,6 +105,7 @@ class RequestGet final : public RequestBase {
// TODO(gongwb): check var's info.
responder_.Finish(reply_, grpc::Status::OK, this);
status_ = FINISH;
queue_->Push('c');
}
protected:
......@@ -108,8 +114,15 @@ class RequestGet final : public RequestBase {
ServerAsyncResponseWriter<sendrecv::VariableMessage> responder_;
framework::Scope* scope_;
const platform::DeviceContext* dev_ctx_;
SimpleBlockQueue<char>* queue_;
};
void AsyncGRPCServer::WaitClientGet(int count) {
for (int i = 0; i < count; ++i) {
var_get_queue_.Pop();
}
}
void AsyncGRPCServer::RunSyncUpdate() {
grpc::ServerBuilder builder;
builder.AddListeningPort(address_, grpc::InsecureServerCredentials());
......@@ -149,7 +162,6 @@ void AsyncGRPCServer::ShutdownQueue() {
}
// This URL explains why shutdown is complicate:
// https://stackoverflow.com/questions/35708348/grpc-what-is-the-recommended-way-to-shut-down-an-asynchronous-server-in-c
void AsyncGRPCServer::ShutDown() {
server_->Shutdown();
ShutdownQueue();
......@@ -170,10 +182,12 @@ void AsyncGRPCServer::TryToRegisterNewGetOne() {
if (is_shut_down_) {
return;
}
RequestGet* get = new RequestGet(&service_, cq_get_.get(), scope_, dev_ctx_);
RequestGet* get = new RequestGet(&service_, cq_get_.get(), scope_, dev_ctx_,
&var_get_queue_);
VLOG(4) << "create Requestget status:" << get->Status();
}
// FIXME(typhoonzero): remove wait argument and change cq_name to enum.
void AsyncGRPCServer::HandleRequest(bool wait, grpc::ServerCompletionQueue* cq,
std::string cq_name,
std::function<void()> TryToRegisterNewOne) {
......@@ -188,9 +202,9 @@ void AsyncGRPCServer::HandleRequest(bool wait, grpc::ServerCompletionQueue* cq,
}
PADDLE_ENFORCE(tag);
if (wait && !done_) {
Wait();
}
// FIXME(typhoonzero): de-couple the barriers with recv_op
if (cq_name == "cq_get") WaitCond(1);
if (cq_name == "cq_send") WaitCond(0);
RequestBase* base = (RequestBase*)tag;
// reference:
......@@ -222,22 +236,18 @@ void AsyncGRPCServer::HandleRequest(bool wait, grpc::ServerCompletionQueue* cq,
}
}
void AsyncGRPCServer::Wait() {
std::unique_lock<std::mutex> lock(this->mutex_);
condition_.wait(lock, [=] { return this->done_ == true; });
}
void AsyncGRPCServer::Reset() {
std::lock_guard<std::mutex> lock(this->mutex_);
done_ = false;
void AsyncGRPCServer::WaitCond(int cond) {
std::unique_lock<std::mutex> lock(this->barrier_mutex_);
barrier_condition_.wait(lock,
[=] { return this->barrier_cond_step_ == cond; });
}
void AsyncGRPCServer::Done() {
void AsyncGRPCServer::SetCond(int cond) {
{
std::lock_guard<std::mutex> lock(this->mutex_);
done_ = true;
std::lock_guard<std::mutex> lock(this->barrier_mutex_);
barrier_cond_step_ = cond;
}
condition_.notify_all();
barrier_condition_.notify_all();
}
} // namespace detail
......
......@@ -41,9 +41,10 @@ class AsyncGRPCServer final : public sendrecv::SendRecvService::Service {
void RunSyncUpdate();
void Reset();
void Done();
// functions to sync server barrier status.
void WaitCond(int cond);
void SetCond(int cond);
void WaitClientGet(int count);
void SetScope(framework::Scope *scope) { scope_ = scope; }
......@@ -56,7 +57,6 @@ class AsyncGRPCServer final : public sendrecv::SendRecvService::Service {
void ShutDown();
protected:
void Wait();
void HandleRequest(bool wait, grpc::ServerCompletionQueue *cq,
std::string cq_name,
std::function<void()> TryToRegisterNewOne);
......@@ -78,11 +78,12 @@ class AsyncGRPCServer final : public sendrecv::SendRecvService::Service {
const platform::DeviceContext *dev_ctx_;
// received variable from RPC, operators fetch variable from this queue.
SimpleBlockQueue<MessageWithName> var_recv_queue_;
SimpleBlockQueue<char> var_get_queue_;
// condition of the sub program
std::mutex mutex_;
volatile mutable bool done_;
std::condition_variable condition_;
std::mutex barrier_mutex_;
mutable int barrier_cond_step_;
std::condition_variable barrier_condition_;
std::unique_ptr<std::thread> t_send_;
std::unique_ptr<std::thread> t_get_;
......
......@@ -27,12 +27,17 @@ limitations under the License. */
#include "paddle/operators/detail/grpc_server.h"
#include "paddle/operators/detail/sendrecvop_utils.h"
#include "paddle/operators/detail/simple_block_queue.h"
#include "paddle/string/printf.h"
#define LISTEN_TERMINATE_MESSAGE "TERMINATE@RECV"
namespace paddle {
namespace operators {
constexpr int kCondStart = 0;
constexpr int kCondRunning = 1;
constexpr int kCondDone = 2;
void RunServer(std::shared_ptr<detail::AsyncGRPCServer> service) {
service->RunSyncUpdate();
VLOG(4) << "RunServer thread end";
......@@ -77,42 +82,41 @@ class RecvOp : public framework::OperatorBase {
if (grads_counter_.find(varname) == grads_counter_.end()) {
grads_counter_[varname] = 0;
}
char ret[256];
snprintf(ret, sizeof(ret), "%s.trainer_%d", varname.c_str(),
grads_counter_[varname]++);
return std::string(ret);
return string::Sprintf("%s.trainer_%d", varname, grads_counter_[varname]++);
}
void Run(const framework::Scope &scope,
const platform::Place &dev_place) const override {
// FIXME(typhoonzero): no new scopes for every run.
framework::Scope &recv_scope = scope.NewScope();
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Get(dev_place);
framework::Scope &recv_scope = scope.NewScope();
// FIXME(Yancey1989): initialize rpc server with laze mode.
rpc_service_->SetScope(&recv_scope);
rpc_service_->SetDevCtx(&dev_ctx);
auto param_list = Attr<std::vector<std::string>>("ParamList");
auto grad_list = Attr<std::vector<std::string>>("GradList");
auto trainer_count = Attr<int>("Trainers");
auto fan_in = Attr<int>("Fanin");
size_t param_count = param_list.size();
rpc_service_->Reset();
std::string program_str = Attr<std::string>("OptimizeProgram");
framework::proto::ProgramDesc program_desc;
program_desc.ParseFromString(program_str);
framework::ProgramDesc program(program_desc);
framework::Executor executor(dev_place);
// TODO(typhoonzero): change this to a while_op for every cluster-batch.
bool exit_flag = false;
VLOG(4) << "param_count:" << param_count
<< " trainer_count:" << trainer_count;
int64_t barrier_size = param_count * fan_in;
while (!exit_flag) {
// TODO(gognwb): simply this loop.
// Get from multiple trainers, we don't care about order in which
// the gradient arrives, just add suffix 0~n then average the gradient.
for (size_t i = 0; i < param_count * trainer_count; ++i) {
// blocking get one var from client.
// Get from multiple trainers, we don't care about the order in which
// the gradients arrives, just add suffix 0~n and merge the gradient.
rpc_service_->SetCond(0);
for (size_t i = 0; i < barrier_size; ++i) {
const detail::MessageWithName &v = rpc_service_->Get();
auto grad_var_name = v.first;
if (grad_var_name == LISTEN_TERMINATE_MESSAGE) {
VLOG(4) << "received LISTEN_TERMINATE_MESSAGE and RunOp.Run() exit";
LOG(INFO) << "received terminate message and exit";
exit_flag = true;
break;
}
......@@ -121,49 +125,31 @@ class RecvOp : public framework::OperatorBase {
if (it != grad_list.end()) {
param_var_name = param_list[it - grad_list.begin()];
} else {
LOG(ERROR) << "grad have no paired param found!\"" << grad_var_name
<< "\"";
LOG(ERROR) << "grad have no paired param:" << grad_var_name;
}
VLOG(3) << "recved grad: " << grad_var_name
<< " updating param: " << param_var_name;
auto *merged_grad = recv_scope.FindVar(grad_var_name);
if (merged_grad == nullptr) {
auto *ptr = recv_scope.Var(grad_var_name);
CreateTensorFromMessageType(ptr, v.second.type());
VLOG(3) << "Create Variable " << grad_var_name
<< " on recv scope, which pointer is " << ptr << " type is "
<< v.second.type();
}
if (trainer_count > 1) {
if (fan_in > 1) {
grad_var_name = this->GetGradVarNameForTrainer(grad_var_name);
}
auto *var = recv_scope.Var(grad_var_name);
auto *var = recv_scope.FindVar(grad_var_name);
if (var == nullptr) {
LOG(ERROR) << "can not find server side var: " << grad_var_name;
PADDLE_THROW("can not find server side var");
}
detail::DeserializeFromMessage(v.second, dev_ctx, var);
}
if (exit_flag) {
break;
}
rpc_service_->Reset();
std::string program_str = Attr<std::string>("OptimizeProgram");
framework::proto::ProgramDesc program_desc;
program_desc.ParseFromString(program_str);
framework::ProgramDesc program(program_desc);
framework::Executor executor(dev_place);
// Run sub graph to get optimized tensor
try {
executor.Run(program, &recv_scope, 0, /*global_block*/
false /*create_local_scope*/, false /*create_vars*/);
} catch (std::exception &e) {
LOG(ERROR) << "run sub program error " << e.what();
}
rpc_service_->Done();
rpc_service_->SetCond(1);
rpc_service_->WaitClientGet(barrier_size);
grads_counter_.clear();
} // while(true)
}
......@@ -199,7 +185,7 @@ This operator will recv tensor from send_op
"GradList", "type list of string",
"grad->param name mapping to find which param to optimize.")
.SetDefault({});
AddAttr<int>("Trainers", "type int",
AddAttr<int>("Fanin", "type int",
"Number of trainers in the current cluster job")
.SetDefault(1);
}
......
......@@ -41,10 +41,13 @@ class SendOp : public framework::OperatorBase {
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto& ctx = *pool.Get(place);
for (size_t i = 0; i < ins.size(); i++) {
VLOG(3) << "sending " << ins[i];
client_.AsyncSendVariable(epmap[i], ctx, scope, ins[i]);
}
PADDLE_ENFORCE(client_.Wait());
for (size_t i = 0; i < outs.size(); i++) {
VLOG(3) << "getting " << outs[i];
client_.AsyncGetVariable(epmap[i], ctx, scope, outs[i]);
}
......
......@@ -420,6 +420,19 @@ class DistributeTranspiler:
pserver_program = Program()
for v in self.param_grad_ep_mapping[endpoint]["params"]:
self._clone_var(pserver_program.global_block(), v)
for v in self.param_grad_ep_mapping[endpoint]["grads"]:
# create vars for each trainer in global scope, so
# we don't need to create them when grad arrives.
pserver_program.global_block().create_var(
name=v.name, persistable=True, dtype=v.dtype, shape=v.shape)
for trainer_id in xrange(self.trainers):
print("create variable for program: %s.trainer_%d" %
(v.name, trainer_id))
pserver_program.global_block().create_var(
name="%s.trainer_%d" % (v.name, trainer_id),
persistable=True,
dtype=v.dtype,
shape=v.shape)
# step6
optimize_sub_program = Program()
for idx, opt_op in enumerate(self.optimize_ops):
......@@ -449,7 +462,7 @@ class DistributeTranspiler:
p.name
for p in self.param_grad_ep_mapping[endpoint]["grads"]
],
"Trainers": self.trainers
"Fanin": self.trainers
})
pserver_program.sync_with_cpp()
return pserver_program
......
......@@ -52,26 +52,27 @@ train_reader = paddle.batch(
place = fluid.CPUPlace()
exe = fluid.Executor(place)
t = fluid.DistributeTranspiler()
# all parameter server endpoints list for spliting parameters
pserver_endpoints = os.getenv("PSERVERS")
# server endpoint for current node
current_endpoint = os.getenv("SERVER_ENDPOINT")
# run as trainer or parameter server
pserver_endpoints = os.getenv("PSERVERS") # all pserver endpoints
trainers = int(os.getenv("TRAINERS")) # total trainer count
current_endpoint = os.getenv("SERVER_ENDPOINT") # current pserver endpoint
training_role = os.getenv("TRAINING_ROLE",
"TRAINER") # get the training role: trainer/pserver
t.transpile(optimize_ops, params_grads, pservers=pserver_endpoints, trainers=2)
t = fluid.DistributeTranspiler()
t.transpile(
optimize_ops, params_grads, pservers=pserver_endpoints, trainers=trainers)
if training_role == "PSERVER":
if not current_endpoint:
print("need env SERVER_ENDPOINT")
exit(1)
pserver_prog = t.get_pserver_program(current_endpoint)
exe.run(fluid.default_startup_program())
pserver_startup = t.get_startup_program(current_endpoint, pserver_prog)
exe.run(pserver_startup)
exe.run(pserver_prog)
elif training_role == "TRAINER":
trainer_prog = t.get_trainer_program()
feeder = fluid.DataFeeder(feed_list=[images, label], place=place)
# TODO(typhoonzero): change trainer startup program to fetch parameters from pserver
exe.run(fluid.default_startup_program())
for pass_id in range(PASS_NUM):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册