未验证 提交 6d934560 编写于 作者: Q Qiao Longfei 提交者: GitHub

Merge pull request #10042 from jacquesqiao/add-async-listen-and-serv-op

listen_and_serv_op support async update
...@@ -30,9 +30,13 @@ enum CallStatus { PROCESS = 0, FINISH }; ...@@ -30,9 +30,13 @@ enum CallStatus { PROCESS = 0, FINISH };
class RequestBase { class RequestBase {
public: public:
explicit RequestBase(GrpcService::AsyncService* service, explicit RequestBase(GrpcService::AsyncService* service,
::grpc::ServerCompletionQueue* cq, ::grpc::ServerCompletionQueue* cq, bool sync_mode,
const platform::DeviceContext* dev_ctx) const platform::DeviceContext* dev_ctx)
: service_(service), cq_(cq), status_(PROCESS), dev_ctx_(dev_ctx) { : service_(service),
cq_(cq),
sync_mode_(sync_mode),
status_(PROCESS),
dev_ctx_(dev_ctx) {
PADDLE_ENFORCE(cq_); PADDLE_ENFORCE(cq_);
} }
virtual ~RequestBase() {} virtual ~RequestBase() {}
...@@ -49,6 +53,7 @@ class RequestBase { ...@@ -49,6 +53,7 @@ class RequestBase {
::grpc::ServerContext ctx_; ::grpc::ServerContext ctx_;
GrpcService::AsyncService* service_; GrpcService::AsyncService* service_;
::grpc::ServerCompletionQueue* cq_; ::grpc::ServerCompletionQueue* cq_;
const bool sync_mode_;
CallStatus status_; CallStatus status_;
const platform::DeviceContext* dev_ctx_; const platform::DeviceContext* dev_ctx_;
}; };
...@@ -56,11 +61,17 @@ class RequestBase { ...@@ -56,11 +61,17 @@ class RequestBase {
class RequestSend final : public RequestBase { class RequestSend final : public RequestBase {
public: public:
explicit RequestSend(GrpcService::AsyncService* service, explicit RequestSend(GrpcService::AsyncService* service,
::grpc::ServerCompletionQueue* cq, ::grpc::ServerCompletionQueue* cq, bool sync_mode,
framework::Scope* scope, ReceivedQueue* queue, framework::Scope* scope, ReceivedQueue* queue,
const platform::DeviceContext* dev_ctx) const platform::DeviceContext* dev_ctx)
: RequestBase(service, cq, dev_ctx), queue_(queue), responder_(&ctx_) { : RequestBase(service, cq, sync_mode, dev_ctx),
request_.reset(new VariableResponse(scope, dev_ctx_)); queue_(queue),
responder_(&ctx_) {
if (sync_mode_) {
request_.reset(new VariableResponse(scope, dev_ctx_, false));
} else {
request_.reset(new VariableResponse(scope, dev_ctx_, true));
}
int method_id = static_cast<int>(detail::GrpcMethod::kSendVariable); int method_id = static_cast<int>(detail::GrpcMethod::kSendVariable);
service_->RequestAsyncUnary(method_id, &ctx_, request_.get(), &responder_, service_->RequestAsyncUnary(method_id, &ctx_, request_.get(), &responder_,
cq_, cq_, this); cq_, cq_, this);
...@@ -87,11 +98,11 @@ class RequestSend final : public RequestBase { ...@@ -87,11 +98,11 @@ class RequestSend final : public RequestBase {
class RequestGet final : public RequestBase { class RequestGet final : public RequestBase {
public: public:
explicit RequestGet(GrpcService::AsyncService* service, explicit RequestGet(GrpcService::AsyncService* service,
::grpc::ServerCompletionQueue* cq, ::grpc::ServerCompletionQueue* cq, bool sync_mode,
framework::Scope* scope, framework::Scope* scope,
const platform::DeviceContext* dev_ctx, const platform::DeviceContext* dev_ctx,
framework::BlockingQueue<MessageWithName>* queue) framework::BlockingQueue<MessageWithName>* queue)
: RequestBase(service, cq, dev_ctx), : RequestBase(service, cq, sync_mode, dev_ctx),
responder_(&ctx_), responder_(&ctx_),
scope_(scope), scope_(scope),
queue_(queue) { queue_(queue) {
...@@ -134,19 +145,23 @@ class RequestGet final : public RequestBase { ...@@ -134,19 +145,23 @@ class RequestGet final : public RequestBase {
class RequestPrefetch final : public RequestBase { class RequestPrefetch final : public RequestBase {
public: public:
explicit RequestPrefetch(GrpcService::AsyncService* service, explicit RequestPrefetch(GrpcService::AsyncService* service,
::grpc::ServerCompletionQueue* cq, ::grpc::ServerCompletionQueue* cq, bool sync_mode,
framework::Scope* scope, framework::Scope* scope,
const platform::DeviceContext* dev_ctx, const platform::DeviceContext* dev_ctx,
framework::Executor* executor, framework::Executor* executor,
framework::ProgramDesc* program, framework::ProgramDesc* program,
framework::ExecutorPrepareContext* prefetch_ctx) framework::ExecutorPrepareContext* prefetch_ctx)
: RequestBase(service, cq, dev_ctx), : RequestBase(service, cq, sync_mode, dev_ctx),
responder_(&ctx_), responder_(&ctx_),
scope_(scope), scope_(scope),
executor_(executor), executor_(executor),
program_(program), program_(program),
prefetch_ctx_(prefetch_ctx) { prefetch_ctx_(prefetch_ctx) {
request_.reset(new VariableResponse(scope, dev_ctx_)); if (sync_mode_) {
request_.reset(new VariableResponse(scope, dev_ctx_, false));
} else {
request_.reset(new VariableResponse(scope, dev_ctx_, true));
}
int method_id = static_cast<int>(detail::GrpcMethod::kPrefetchVariable); int method_id = static_cast<int>(detail::GrpcMethod::kPrefetchVariable);
service_->RequestAsyncUnary(method_id, &ctx_, request_.get(), &responder_, service_->RequestAsyncUnary(method_id, &ctx_, request_.get(), &responder_,
cq_, cq_, this); cq_, cq_, this);
...@@ -181,7 +196,6 @@ class RequestPrefetch final : public RequestBase { ...@@ -181,7 +196,6 @@ class RequestPrefetch final : public RequestBase {
framework::Executor* executor_; framework::Executor* executor_;
framework::ProgramDesc* program_; framework::ProgramDesc* program_;
framework::ExecutorPrepareContext* prefetch_ctx_; framework::ExecutorPrepareContext* prefetch_ctx_;
int blkid_;
}; };
void AsyncGRPCServer::WaitClientGet(int count) { void AsyncGRPCServer::WaitClientGet(int count) {
...@@ -254,8 +268,8 @@ void AsyncGRPCServer::TryToRegisterNewSendOne() { ...@@ -254,8 +268,8 @@ void AsyncGRPCServer::TryToRegisterNewSendOne() {
VLOG(3) << "shutdown, do not TryToRegisterNewSendOne"; VLOG(3) << "shutdown, do not TryToRegisterNewSendOne";
return; return;
} }
RequestSend* send = new RequestSend(&service_, cq_send_.get(), scope_, RequestSend* send = new RequestSend(&service_, cq_send_.get(), sync_mode_,
&var_recv_queue_, dev_ctx_); scope_, &var_recv_queue_, dev_ctx_);
VLOG(4) << "Create RequestSend status:" << send->Status(); VLOG(4) << "Create RequestSend status:" << send->Status();
} }
...@@ -265,8 +279,8 @@ void AsyncGRPCServer::TryToRegisterNewGetOne() { ...@@ -265,8 +279,8 @@ void AsyncGRPCServer::TryToRegisterNewGetOne() {
VLOG(3) << "shutdown, do not TryToRegisterNewGetOne"; VLOG(3) << "shutdown, do not TryToRegisterNewGetOne";
return; return;
} }
RequestGet* get = new RequestGet(&service_, cq_get_.get(), scope_, dev_ctx_, RequestGet* get = new RequestGet(&service_, cq_get_.get(), sync_mode_, scope_,
&var_get_queue_); dev_ctx_, &var_get_queue_);
VLOG(4) << "Create RequestGet status:" << get->Status(); VLOG(4) << "Create RequestGet status:" << get->Status();
} }
...@@ -277,8 +291,8 @@ void AsyncGRPCServer::TryToRegisterNewPrefetchOne() { ...@@ -277,8 +291,8 @@ void AsyncGRPCServer::TryToRegisterNewPrefetchOne() {
return; return;
} }
RequestPrefetch* prefetch = RequestPrefetch* prefetch =
new RequestPrefetch(&service_, cq_prefetch_.get(), scope_, dev_ctx_, new RequestPrefetch(&service_, cq_prefetch_.get(), sync_mode_, scope_,
executor_, program_, prefetch_ctx_); dev_ctx_, executor_, program_, prefetch_ctx_);
VLOG(4) << "Create RequestPrefetch status:" << prefetch->Status(); VLOG(4) << "Create RequestPrefetch status:" << prefetch->Status();
} }
...@@ -301,9 +315,11 @@ void AsyncGRPCServer::HandleRequest(::grpc::ServerCompletionQueue* cq, ...@@ -301,9 +315,11 @@ void AsyncGRPCServer::HandleRequest(::grpc::ServerCompletionQueue* cq,
VLOG(3) << "HandleRequest for " << cq_name << " while after Next"; VLOG(3) << "HandleRequest for " << cq_name << " while after Next";
PADDLE_ENFORCE(tag); PADDLE_ENFORCE(tag);
if (sync_mode_) {
// FIXME(typhoonzero): de-couple the barriers with recv_op // FIXME(typhoonzero): de-couple the barriers with recv_op
if (!is_shut_down_ && cq_name == "cq_get") WaitCond(1); if (!is_shut_down_ && cq_name == "cq_get") WaitCond(1);
if (!is_shut_down_ && cq_name == "cq_send") WaitCond(0); if (!is_shut_down_ && cq_name == "cq_send") WaitCond(0);
}
RequestBase* base = reinterpret_cast<RequestBase*>(tag); RequestBase* base = reinterpret_cast<RequestBase*>(tag);
// reference: // reference:
...@@ -320,13 +336,13 @@ void AsyncGRPCServer::HandleRequest(::grpc::ServerCompletionQueue* cq, ...@@ -320,13 +336,13 @@ void AsyncGRPCServer::HandleRequest(::grpc::ServerCompletionQueue* cq,
switch (base->Status()) { switch (base->Status()) {
case PROCESS: { case PROCESS: {
VLOG(4) << cq_name << " status:" << base->Status(); VLOG(4) << cq_name << " PROCESS status:" << base->Status();
TryToRegisterNewOne(); TryToRegisterNewOne();
base->Process(); base->Process();
break; break;
} }
case FINISH: { case FINISH: {
VLOG(4) << cq_name << " status:" << base->Status(); VLOG(4) << cq_name << " FINISH status:" << base->Status();
delete base; delete base;
break; break;
} }
......
...@@ -44,7 +44,8 @@ class RequestBase; ...@@ -44,7 +44,8 @@ class RequestBase;
class AsyncGRPCServer final { class AsyncGRPCServer final {
public: public:
explicit AsyncGRPCServer(const std::string &address) : address_(address) {} explicit AsyncGRPCServer(const std::string &address, bool sync_mode)
: address_(address), sync_mode_(sync_mode) {}
void RunSyncUpdate(); void RunSyncUpdate();
...@@ -95,6 +96,7 @@ class AsyncGRPCServer final { ...@@ -95,6 +96,7 @@ class AsyncGRPCServer final {
std::unique_ptr<::grpc::Server> server_; std::unique_ptr<::grpc::Server> server_;
std::string address_; std::string address_;
const bool sync_mode_;
framework::Scope *scope_; framework::Scope *scope_;
const platform::DeviceContext *dev_ctx_; const platform::DeviceContext *dev_ctx_;
......
...@@ -89,7 +89,7 @@ void InitTensorsOnServer(framework::Scope* scope, platform::CPUPlace* place, ...@@ -89,7 +89,7 @@ void InitTensorsOnServer(framework::Scope* scope, platform::CPUPlace* place,
} }
void StartServer(const std::string& endpoint) { void StartServer(const std::string& endpoint) {
rpc_service_.reset(new detail::AsyncGRPCServer(endpoint)); rpc_service_.reset(new detail::AsyncGRPCServer(endpoint, true));
framework::ProgramDesc program; framework::ProgramDesc program;
framework::Scope scope; framework::Scope scope;
platform::CPUPlace place; platform::CPUPlace place;
......
...@@ -46,7 +46,9 @@ class VariableResponse { ...@@ -46,7 +46,9 @@ class VariableResponse {
} }
virtual ~VariableResponse() { virtual ~VariableResponse() {
if (create_scope_) scope_->DeleteScope(local_scope_); if (create_scope_) {
scope_->DeleteScope(local_scope_);
}
} }
// return: // return:
...@@ -63,6 +65,8 @@ class VariableResponse { ...@@ -63,6 +65,8 @@ class VariableResponse {
const framework::Scope& GetLocalScope() const { return *local_scope_; } const framework::Scope& GetLocalScope() const { return *local_scope_; }
framework::Scope* GetMutableLocalScope() const { return local_scope_; }
inline std::string Varname() { return meta_.varname(); } inline std::string Varname() { return meta_.varname(); }
inline std::string OutVarname() { return meta_.out_varname(); } inline std::string OutVarname() { return meta_.out_varname(); }
......
...@@ -27,6 +27,38 @@ void RunServer(std::shared_ptr<detail::AsyncGRPCServer> service) { ...@@ -27,6 +27,38 @@ void RunServer(std::shared_ptr<detail::AsyncGRPCServer> service) {
VLOG(4) << "RunServer thread end"; VLOG(4) << "RunServer thread end";
} }
static void split(const std::string &str, char sep,
std::vector<std::string> *pieces) {
pieces->clear();
if (str.empty()) {
return;
}
size_t pos = 0;
size_t next = str.find(sep, pos);
while (next != std::string::npos) {
pieces->push_back(str.substr(pos, next - pos));
pos = next + 1;
next = str.find(sep, pos);
}
if (!str.substr(pos).empty()) {
pieces->push_back(str.substr(pos));
}
}
static void AsyncExecuteBlock(framework::Executor *executor,
framework::ExecutorPrepareContext *prepared,
framework::Scope *scope) {
std::future<void> future = framework::Async([&executor, &prepared, &scope]() {
try {
executor->RunPreparedContext(prepared, scope, false, false);
} catch (std::exception &e) {
LOG(ERROR) << "run sub program error " << e.what();
}
});
// TODO(qiao) maybe we can remove this
future.wait();
}
static void ParallelExecuteBlocks( static void ParallelExecuteBlocks(
const std::vector<size_t> &parallel_blkids, framework::Executor *executor, const std::vector<size_t> &parallel_blkids, framework::Executor *executor,
const std::vector<std::shared_ptr<framework::ExecutorPrepareContext>> const std::vector<std::shared_ptr<framework::ExecutorPrepareContext>>
...@@ -169,15 +201,82 @@ void ListenAndServOp::RunSyncLoop(framework::Executor *executor, ...@@ -169,15 +201,82 @@ void ListenAndServOp::RunSyncLoop(framework::Executor *executor,
} // while(true) } // while(true)
} }
void ListenAndServOp::RunAsyncLoop(framework::Executor *executor,
framework::ProgramDesc *program,
framework::Scope *recv_scope,
framework::BlockDesc *prefetch_block) const {
VLOG(3) << "RunAsyncLoop in";
// grad name to block id
std::unordered_map<std::string, int32_t> grad_to_block_id;
std::unordered_map<int32_t, std::string> id_to_grad;
auto grad_to_block_id_str =
Attr<std::vector<std::string>>("grad_to_block_id");
for (auto &grad_and_id : grad_to_block_id_str) {
std::vector<std::string> pieces;
split(grad_and_id, ':', &pieces);
VLOG(3) << "after split, grad = " << pieces[0] << ", id=" << pieces[1];
PADDLE_ENFORCE_EQ(pieces.size(), 2);
PADDLE_ENFORCE_EQ(grad_to_block_id.count(pieces[0]), 0);
int block_id = std::stoi(pieces[1]);
grad_to_block_id[pieces[0]] = block_id;
id_to_grad[block_id] = pieces[0];
}
size_t num_blocks = program->Size();
PADDLE_ENFORCE_GE(num_blocks, 2,
"server program should have at least 2 blocks");
std::vector<int> block_list;
for (size_t blkid = 1; blkid < num_blocks; ++blkid) {
block_list.push_back(blkid);
}
auto optimize_prepared = executor->Prepare(*program, block_list);
std::unordered_map<std::string,
std::shared_ptr<framework::ExecutorPrepareContext>>
grad_to_prepared_ctx;
for (size_t i = 0; i < block_list.size(); ++i) {
grad_to_prepared_ctx[id_to_grad[block_list[i]]] = optimize_prepared[i];
}
VLOG(3) << "RunAsyncLoop into while";
bool exit_flag = false;
while (!exit_flag) {
const detail::ReceivedMessage v = rpc_service_->Get();
auto recv_var_name = v.first;
if (recv_var_name == LISTEN_TERMINATE_MESSAGE) {
LOG(INFO) << "received terminate message and exit";
exit_flag = true;
break;
} else {
VLOG(3) << "received grad: " << recv_var_name;
auto var = v.second->GetVar();
if (var == nullptr) {
LOG(ERROR) << "Can not find server side var: " << recv_var_name;
PADDLE_THROW("Can not find server side var");
}
AsyncExecuteBlock(executor, grad_to_prepared_ctx[recv_var_name].get(),
v.second->GetMutableLocalScope());
}
if (exit_flag) {
rpc_service_->ShutDown();
break;
}
} // while(true)
}
void ListenAndServOp::RunImpl(const framework::Scope &scope, void ListenAndServOp::RunImpl(const framework::Scope &scope,
const platform::Place &dev_place) const { const platform::Place &dev_place) const {
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Get(dev_place); auto &dev_ctx = *pool.Get(dev_place);
framework::Scope &recv_scope = scope.NewScope(); framework::Scope &recv_scope = scope.NewScope();
bool sync_mode = Attr<bool>("sync_mode");
PADDLE_ENFORCE(!rpc_service_); PADDLE_ENFORCE(!rpc_service_);
std::string endpoint = Attr<std::string>("endpoint"); std::string endpoint = Attr<std::string>("endpoint");
rpc_service_.reset(new detail::AsyncGRPCServer(endpoint));
rpc_service_.reset(new detail::AsyncGRPCServer(endpoint, sync_mode));
auto *optimize_block = Attr<framework::BlockDesc *>(kOptimizeBlock); auto *optimize_block = Attr<framework::BlockDesc *>(kOptimizeBlock);
auto *prefetch_block = Attr<framework::BlockDesc *>(kPrefetchBlock); auto *prefetch_block = Attr<framework::BlockDesc *>(kPrefetchBlock);
...@@ -202,7 +301,11 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope, ...@@ -202,7 +301,11 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
sleep(5); sleep(5);
// Write to a file of server selected port for python use. // Write to a file of server selected port for python use.
SavePort(rpc_service_); SavePort(rpc_service_);
if (sync_mode) {
RunSyncLoop(&executor, program, &recv_scope, prefetch_block); RunSyncLoop(&executor, program, &recv_scope, prefetch_block);
} else {
RunAsyncLoop(&executor, program, &recv_scope, prefetch_block);
}
} }
class ListenAndServOpMaker : public framework::OpProtoAndCheckerMaker { class ListenAndServOpMaker : public framework::OpProtoAndCheckerMaker {
...@@ -221,6 +324,12 @@ from send_op and send back variables to recv_op. ...@@ -221,6 +324,12 @@ from send_op and send back variables to recv_op.
"IP address to listen on.") "IP address to listen on.")
.SetDefault("127.0.0.1:6164") .SetDefault("127.0.0.1:6164")
.AddCustomChecker([](const std::string &ip) { return !ip.empty(); }); .AddCustomChecker([](const std::string &ip) { return !ip.empty(); });
AddAttr<std::vector<std::string>>(
"grad_to_block_id",
"['param1@GRAD.block0:1', 'param2@GRAD.blockn:2'] "
"a map from grad name to it's optimize block id")
.SetDefault({});
AddAttr<bool>("sync_mode", "if works at sync_mode or not").SetDefault(true);
AddAttr<framework::BlockDesc *>(kOptimizeBlock, AddAttr<framework::BlockDesc *>(kOptimizeBlock,
"BlockID to run on server side."); "BlockID to run on server side.");
AddAttr<framework::BlockDesc *>(kPrefetchBlock, AddAttr<framework::BlockDesc *>(kPrefetchBlock,
......
...@@ -46,6 +46,11 @@ class ListenAndServOp : public framework::OperatorBase { ...@@ -46,6 +46,11 @@ class ListenAndServOp : public framework::OperatorBase {
framework::Scope* recv_scope, framework::Scope* recv_scope,
framework::BlockDesc* prefetch_block) const; framework::BlockDesc* prefetch_block) const;
void RunAsyncLoop(framework::Executor* executor,
framework::ProgramDesc* program,
framework::Scope* recv_scope,
framework::BlockDesc* prefetch_block) const;
void Stop() override; void Stop() override;
void RunImpl(const framework::Scope& scope, void RunImpl(const framework::Scope& scope,
......
...@@ -41,6 +41,8 @@ class SendOp : public framework::OperatorBase { ...@@ -41,6 +41,8 @@ class SendOp : public framework::OperatorBase {
std::vector<std::string> endpoints = std::vector<std::string> endpoints =
Attr<std::vector<std::string>>("endpoints"); Attr<std::vector<std::string>>("endpoints");
bool sync_mode = Attr<bool>("sync_mode");
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto& ctx = *pool.Get(place); auto& ctx = *pool.Get(place);
...@@ -64,11 +66,13 @@ class SendOp : public framework::OperatorBase { ...@@ -64,11 +66,13 @@ class SendOp : public framework::OperatorBase {
} }
PADDLE_ENFORCE(rpc_client->Wait()); PADDLE_ENFORCE(rpc_client->Wait());
if (sync_mode) {
for (auto& ep : endpoints) { for (auto& ep : endpoints) {
VLOG(3) << "batch barrier, ep: " << ep; VLOG(3) << "batch barrier, ep: " << ep;
rpc_client->AsyncSendBatchBarrier(ep); rpc_client->AsyncSendBatchBarrier(ep);
} }
PADDLE_ENFORCE(rpc_client->Wait()); PADDLE_ENFORCE(rpc_client->Wait());
}
if (outs.size() > 0) { if (outs.size() > 0) {
for (size_t i = 0; i < outs.size(); i++) { for (size_t i = 0; i < outs.size(); i++) {
...@@ -112,6 +116,7 @@ This operator will send tensor to recv_op at the parameter server. ...@@ -112,6 +116,7 @@ This operator will send tensor to recv_op at the parameter server.
"Server endpoints in the order of input " "Server endpoints in the order of input "
"variables for mapping") "variables for mapping")
.SetDefault({}); .SetDefault({});
AddAttr<bool>("sync_mode", "work in sync_mode or not").SetDefault(true);
} }
}; };
......
...@@ -137,6 +137,8 @@ void StartServerNet(bool is_sparse) { ...@@ -137,6 +137,8 @@ void StartServerNet(bool is_sparse) {
attrs.insert({"GradList", std::vector<std::string>({"x1"})}); attrs.insert({"GradList", std::vector<std::string>({"x1"})});
attrs.insert({"OptimizeBlock", optimize_block}); attrs.insert({"OptimizeBlock", optimize_block});
attrs.insert({"PrefetchBlock", prefetch_block}); attrs.insert({"PrefetchBlock", prefetch_block});
attrs.insert({"grad_to_block_id", std::vector<std::string>({""})});
attrs.insert({"sync_mode", true});
listen_and_serv_op = listen_and_serv_op =
f::OpRegistry::CreateOp("listen_and_serv", {{"X", {"x1"}}}, {}, attrs); f::OpRegistry::CreateOp("listen_and_serv", {{"X", {"x1"}}}, {}, attrs);
listen_and_serv_op->Run(scope, place); listen_and_serv_op->Run(scope, place);
......
...@@ -143,7 +143,8 @@ class DistributeTranspiler: ...@@ -143,7 +143,8 @@ class DistributeTranspiler:
program=None, program=None,
pservers="127.0.0.1:6174", pservers="127.0.0.1:6174",
trainers=1, trainers=1,
split_method=splitter.round_robin): split_method=splitter.round_robin,
sync_mode=True):
""" """
Transpile the program to distributed data-parallelism programs. Transpile the program to distributed data-parallelism programs.
The main_program will be transformed to use a remote parameter server The main_program will be transformed to use a remote parameter server
...@@ -184,6 +185,9 @@ class DistributeTranspiler: ...@@ -184,6 +185,9 @@ class DistributeTranspiler:
:param split_method: A function to determin how to split variables :param split_method: A function to determin how to split variables
to different servers equally. to different servers equally.
:type split_method: function :type split_method: function
:type sync_mode: boolean default True
:param sync_mode: if sync_mode is set True, it means that dist transpiler
will transpile the program into sync_mode pserver and trainer program.
""" """
assert (callable(split_method)) assert (callable(split_method))
if program is None: if program is None:
...@@ -191,6 +195,7 @@ class DistributeTranspiler: ...@@ -191,6 +195,7 @@ class DistributeTranspiler:
self.origin_program = program self.origin_program = program
self.trainer_num = trainers self.trainer_num = trainers
self.optimize_ops = optimize_ops self.optimize_ops = optimize_ops
self.sync_mode = sync_mode
# TODO(typhoonzero): currently trainer_id is fetched from cluster system # TODO(typhoonzero): currently trainer_id is fetched from cluster system
# like Kubernetes, we should port this to use etcd later when developing # like Kubernetes, we should port this to use etcd later when developing
# fluid distributed training with fault-tolerance. # fluid distributed training with fault-tolerance.
...@@ -295,8 +300,11 @@ class DistributeTranspiler: ...@@ -295,8 +300,11 @@ class DistributeTranspiler:
inputs={"X": send_inputs}, inputs={"X": send_inputs},
outputs={"Out": send_outputs, outputs={"Out": send_outputs,
"RPCClient": rpc_client_var}, "RPCClient": rpc_client_var},
attrs={"endpoints": pserver_endpoints, attrs={
"epmap": eplist}) "endpoints": pserver_endpoints,
"epmap": eplist,
"sync_mode": self.sync_mode
})
# step4: Concat the parameters splits together after recv. # step4: Concat the parameters splits together after recv.
for varname, splited_var in param_var_mapping.iteritems(): for varname, splited_var in param_var_mapping.iteritems():
if len(splited_var) <= 1: if len(splited_var) <= 1:
...@@ -356,7 +364,7 @@ class DistributeTranspiler: ...@@ -356,7 +364,7 @@ class DistributeTranspiler:
type=v.type, type=v.type,
dtype=v.dtype, dtype=v.dtype,
shape=v.shape) shape=v.shape)
if self.trainer_num > 1: if self.sync_mode and self.trainer_num > 1:
for trainer_id in xrange(self.trainer_num): for trainer_id in xrange(self.trainer_num):
var = pserver_program.global_block().create_var( var = pserver_program.global_block().create_var(
name="%s.trainer_%d" % (orig_var_name, trainer_id), name="%s.trainer_%d" % (orig_var_name, trainer_id),
...@@ -402,13 +410,13 @@ class DistributeTranspiler: ...@@ -402,13 +410,13 @@ class DistributeTranspiler:
for op in self.optimize_ops: for op in self.optimize_ops:
if op.type == "scale": if op.type == "scale":
for in_name in op.input_arg_names: for in_name in op.input_arg_names:
if in_name.startswith("beta1_pow_acc") or\ if in_name.startswith("beta1_pow_acc") or \
in_name.startswith("beta2_pow_acc"): in_name.startswith("beta2_pow_acc"):
global_ops.append(op) global_ops.append(op)
def __append_optimize_op__(op, block): def __append_optimize_op__(op, block, grad_to_block_id):
if self._is_opt_op(op): if self._is_opt_op(op):
self._append_pserver_ops(block, op, endpoint, self._append_pserver_ops(block, op, endpoint, grad_to_block_id,
default_main_program()) default_main_program())
else: else:
self._append_pserver_non_opt_ops(block, op) self._append_pserver_non_opt_ops(block, op)
...@@ -422,16 +430,16 @@ class DistributeTranspiler: ...@@ -422,16 +430,16 @@ class DistributeTranspiler:
self._append_pserver_non_opt_ops(lr_decay_block, op) self._append_pserver_non_opt_ops(lr_decay_block, op)
# append op to the current block # append op to the current block
grad_to_block_id = []
pre_block_idx = pserver_program.num_blocks - 1 pre_block_idx = pserver_program.num_blocks - 1
for idx, opt_op in enumerate(opt_op_on_pserver): for idx, opt_op in enumerate(opt_op_on_pserver):
per_opt_block = pserver_program.create_block(pre_block_idx) per_opt_block = pserver_program.create_block(pre_block_idx)
for _, op in enumerate(self.optimize_ops): for _, op in enumerate(self.optimize_ops):
# optimizer is connected to itself # optimizer is connected to itself
if ufind.is_connected(op, opt_op) and op not in global_ops: if ufind.is_connected(op, opt_op) and op not in global_ops:
__append_optimize_op__(op, per_opt_block) __append_optimize_op__(op, per_opt_block, grad_to_block_id)
# append global ops # append global ops
opt_state_block = None
if global_ops: if global_ops:
opt_state_block = pserver_program.create_block( opt_state_block = pserver_program.create_block(
pserver_program.num_blocks - 1) pserver_program.num_blocks - 1)
...@@ -472,7 +480,9 @@ class DistributeTranspiler: ...@@ -472,7 +480,9 @@ class DistributeTranspiler:
"OptimizeBlock": pserver_program.block(1), "OptimizeBlock": pserver_program.block(1),
"endpoint": endpoint, "endpoint": endpoint,
"Fanin": self.trainer_num, "Fanin": self.trainer_num,
"PrefetchBlock": prefetch_block "PrefetchBlock": prefetch_block,
"sync_mode": self.sync_mode,
"grad_to_block_id": grad_to_block_id
}) })
pserver_program.sync_with_cpp() pserver_program.sync_with_cpp()
...@@ -683,6 +693,16 @@ class DistributeTranspiler: ...@@ -683,6 +693,16 @@ class DistributeTranspiler:
self.table_name)], self.table_name)],
persistable=False) persistable=False)
# create table optimize block in pserver program
table_opt_op = [
op for op in self.optimize_ops
if op.input("Param")[0] == self.table_name
][0]
table_opt_block = pserver_program.create_block(pre_block_idx)
# only support sgd now
assert table_opt_op.type == "sgd"
if self.sync_mode:
# create grad vars in pserver program # create grad vars in pserver program
table_grad_var = self.table_param_grad[1] table_grad_var = self.table_param_grad[1]
table_grad_list = [ table_grad_list = [
...@@ -691,18 +711,10 @@ class DistributeTranspiler: ...@@ -691,18 +711,10 @@ class DistributeTranspiler:
(table_grad_var.name, index, pserver_index), (table_grad_var.name, index, pserver_index),
type=table_grad_var.type, type=table_grad_var.type,
shape=table_grad_var.shape, shape=table_grad_var.shape,
dtype=table_grad_var.dtype) for index in range(self.trainer_num) dtype=table_grad_var.dtype)
for index in range(self.trainer_num)
] ]
# create table optimize block in pserver program
table_opt_op = [
op for op in self.optimize_ops
if op.input("Param")[0] == self.table_name
][0]
table_opt_block = pserver_program.create_block(pre_block_idx)
# only support sgd now
assert table_opt_op.type == "sgd"
# append sum op for table_grad_list # append sum op for table_grad_list
table_opt_block.append_op( table_opt_block.append_op(
type="sum", type="sum",
...@@ -746,7 +758,7 @@ class DistributeTranspiler: ...@@ -746,7 +758,7 @@ class DistributeTranspiler:
for varname, splited in block_map.iteritems(): for varname, splited in block_map.iteritems():
orig_var = program.global_block().var(varname) orig_var = program.global_block().var(varname)
if len(splited) == 1: if len(splited) == 1:
if add_trainer_suffix: if self.sync_mode and add_trainer_suffix:
new_var_name = "%s.trainer_%d" % \ new_var_name = "%s.trainer_%d" % \
(orig_var.name, self.trainer_id) (orig_var.name, self.trainer_id)
program.global_block().rename_var(varname, new_var_name) program.global_block().rename_var(varname, new_var_name)
...@@ -770,7 +782,7 @@ class DistributeTranspiler: ...@@ -770,7 +782,7 @@ class DistributeTranspiler:
if len(orig_shape) >= 2: if len(orig_shape) >= 2:
splited_shape.extend(orig_shape[1:]) splited_shape.extend(orig_shape[1:])
new_var_name = "" new_var_name = ""
if add_trainer_suffix: if self.sync_mode and add_trainer_suffix:
new_var_name = "%s.block%d.trainer_%d" % \ new_var_name = "%s.block%d.trainer_%d" % \
(varname, i, self.trainer_id) (varname, i, self.trainer_id)
else: else:
...@@ -879,7 +891,7 @@ class DistributeTranspiler: ...@@ -879,7 +891,7 @@ class DistributeTranspiler:
return orig_var_name return orig_var_name
def _append_pserver_ops(self, optimize_block, opt_op, endpoint, def _append_pserver_ops(self, optimize_block, opt_op, endpoint,
origin_program): grad_to_block_id, origin_program):
program = optimize_block.program program = optimize_block.program
pserver_block = program.global_block() pserver_block = program.global_block()
new_inputs = dict() new_inputs = dict()
...@@ -900,7 +912,9 @@ class DistributeTranspiler: ...@@ -900,7 +912,9 @@ class DistributeTranspiler:
return return
merged_var = \ merged_var = \
pserver_block.vars[self._orig_varname(grad_block.name)] pserver_block.vars[self._orig_varname(grad_block.name)]
if self.trainer_num > 1: grad_to_block_id.append(merged_var.name + ":" + str(
optimize_block.idx))
if self.sync_mode and self.trainer_num > 1:
vars2merge = [] vars2merge = []
for i in xrange(self.trainer_num): for i in xrange(self.trainer_num):
per_trainer_name = "%s.trainer_%d" % \ per_trainer_name = "%s.trainer_%d" % \
...@@ -918,6 +932,7 @@ class DistributeTranspiler: ...@@ -918,6 +932,7 @@ class DistributeTranspiler:
inputs={"X": merged_var}, inputs={"X": merged_var},
outputs={"Out": merged_var}, outputs={"Out": merged_var},
attrs={"scale": 1.0 / float(self.trainer_num)}) attrs={"scale": 1.0 / float(self.trainer_num)})
new_inputs[key] = merged_var new_inputs[key] = merged_var
elif key == "Param": elif key == "Param":
# param is already created on global program # param is already created on global program
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册