未验证 提交 6d934560 编写于 作者: Q Qiao Longfei 提交者: GitHub

Merge pull request #10042 from jacquesqiao/add-async-listen-and-serv-op

listen_and_serv_op support async update
......@@ -30,9 +30,13 @@ enum CallStatus { PROCESS = 0, FINISH };
class RequestBase {
public:
explicit RequestBase(GrpcService::AsyncService* service,
::grpc::ServerCompletionQueue* cq,
::grpc::ServerCompletionQueue* cq, bool sync_mode,
const platform::DeviceContext* dev_ctx)
: service_(service), cq_(cq), status_(PROCESS), dev_ctx_(dev_ctx) {
: service_(service),
cq_(cq),
sync_mode_(sync_mode),
status_(PROCESS),
dev_ctx_(dev_ctx) {
PADDLE_ENFORCE(cq_);
}
virtual ~RequestBase() {}
......@@ -49,6 +53,7 @@ class RequestBase {
::grpc::ServerContext ctx_;
GrpcService::AsyncService* service_;
::grpc::ServerCompletionQueue* cq_;
const bool sync_mode_;
CallStatus status_;
const platform::DeviceContext* dev_ctx_;
};
......@@ -56,11 +61,17 @@ class RequestBase {
class RequestSend final : public RequestBase {
public:
explicit RequestSend(GrpcService::AsyncService* service,
::grpc::ServerCompletionQueue* cq,
::grpc::ServerCompletionQueue* cq, bool sync_mode,
framework::Scope* scope, ReceivedQueue* queue,
const platform::DeviceContext* dev_ctx)
: RequestBase(service, cq, dev_ctx), queue_(queue), responder_(&ctx_) {
request_.reset(new VariableResponse(scope, dev_ctx_));
: RequestBase(service, cq, sync_mode, dev_ctx),
queue_(queue),
responder_(&ctx_) {
if (sync_mode_) {
request_.reset(new VariableResponse(scope, dev_ctx_, false));
} else {
request_.reset(new VariableResponse(scope, dev_ctx_, true));
}
int method_id = static_cast<int>(detail::GrpcMethod::kSendVariable);
service_->RequestAsyncUnary(method_id, &ctx_, request_.get(), &responder_,
cq_, cq_, this);
......@@ -87,11 +98,11 @@ class RequestSend final : public RequestBase {
class RequestGet final : public RequestBase {
public:
explicit RequestGet(GrpcService::AsyncService* service,
::grpc::ServerCompletionQueue* cq,
::grpc::ServerCompletionQueue* cq, bool sync_mode,
framework::Scope* scope,
const platform::DeviceContext* dev_ctx,
framework::BlockingQueue<MessageWithName>* queue)
: RequestBase(service, cq, dev_ctx),
: RequestBase(service, cq, sync_mode, dev_ctx),
responder_(&ctx_),
scope_(scope),
queue_(queue) {
......@@ -134,19 +145,23 @@ class RequestGet final : public RequestBase {
class RequestPrefetch final : public RequestBase {
public:
explicit RequestPrefetch(GrpcService::AsyncService* service,
::grpc::ServerCompletionQueue* cq,
::grpc::ServerCompletionQueue* cq, bool sync_mode,
framework::Scope* scope,
const platform::DeviceContext* dev_ctx,
framework::Executor* executor,
framework::ProgramDesc* program,
framework::ExecutorPrepareContext* prefetch_ctx)
: RequestBase(service, cq, dev_ctx),
: RequestBase(service, cq, sync_mode, dev_ctx),
responder_(&ctx_),
scope_(scope),
executor_(executor),
program_(program),
prefetch_ctx_(prefetch_ctx) {
request_.reset(new VariableResponse(scope, dev_ctx_));
if (sync_mode_) {
request_.reset(new VariableResponse(scope, dev_ctx_, false));
} else {
request_.reset(new VariableResponse(scope, dev_ctx_, true));
}
int method_id = static_cast<int>(detail::GrpcMethod::kPrefetchVariable);
service_->RequestAsyncUnary(method_id, &ctx_, request_.get(), &responder_,
cq_, cq_, this);
......@@ -181,7 +196,6 @@ class RequestPrefetch final : public RequestBase {
framework::Executor* executor_;
framework::ProgramDesc* program_;
framework::ExecutorPrepareContext* prefetch_ctx_;
int blkid_;
};
void AsyncGRPCServer::WaitClientGet(int count) {
......@@ -254,8 +268,8 @@ void AsyncGRPCServer::TryToRegisterNewSendOne() {
VLOG(3) << "shutdown, do not TryToRegisterNewSendOne";
return;
}
RequestSend* send = new RequestSend(&service_, cq_send_.get(), scope_,
&var_recv_queue_, dev_ctx_);
RequestSend* send = new RequestSend(&service_, cq_send_.get(), sync_mode_,
scope_, &var_recv_queue_, dev_ctx_);
VLOG(4) << "Create RequestSend status:" << send->Status();
}
......@@ -265,8 +279,8 @@ void AsyncGRPCServer::TryToRegisterNewGetOne() {
VLOG(3) << "shutdown, do not TryToRegisterNewGetOne";
return;
}
RequestGet* get = new RequestGet(&service_, cq_get_.get(), scope_, dev_ctx_,
&var_get_queue_);
RequestGet* get = new RequestGet(&service_, cq_get_.get(), sync_mode_, scope_,
dev_ctx_, &var_get_queue_);
VLOG(4) << "Create RequestGet status:" << get->Status();
}
......@@ -277,8 +291,8 @@ void AsyncGRPCServer::TryToRegisterNewPrefetchOne() {
return;
}
RequestPrefetch* prefetch =
new RequestPrefetch(&service_, cq_prefetch_.get(), scope_, dev_ctx_,
executor_, program_, prefetch_ctx_);
new RequestPrefetch(&service_, cq_prefetch_.get(), sync_mode_, scope_,
dev_ctx_, executor_, program_, prefetch_ctx_);
VLOG(4) << "Create RequestPrefetch status:" << prefetch->Status();
}
......@@ -301,9 +315,11 @@ void AsyncGRPCServer::HandleRequest(::grpc::ServerCompletionQueue* cq,
VLOG(3) << "HandleRequest for " << cq_name << " while after Next";
PADDLE_ENFORCE(tag);
// FIXME(typhoonzero): de-couple the barriers with recv_op
if (!is_shut_down_ && cq_name == "cq_get") WaitCond(1);
if (!is_shut_down_ && cq_name == "cq_send") WaitCond(0);
if (sync_mode_) {
// FIXME(typhoonzero): de-couple the barriers with recv_op
if (!is_shut_down_ && cq_name == "cq_get") WaitCond(1);
if (!is_shut_down_ && cq_name == "cq_send") WaitCond(0);
}
RequestBase* base = reinterpret_cast<RequestBase*>(tag);
// reference:
......@@ -320,13 +336,13 @@ void AsyncGRPCServer::HandleRequest(::grpc::ServerCompletionQueue* cq,
switch (base->Status()) {
case PROCESS: {
VLOG(4) << cq_name << " status:" << base->Status();
VLOG(4) << cq_name << " PROCESS status:" << base->Status();
TryToRegisterNewOne();
base->Process();
break;
}
case FINISH: {
VLOG(4) << cq_name << " status:" << base->Status();
VLOG(4) << cq_name << " FINISH status:" << base->Status();
delete base;
break;
}
......
......@@ -44,7 +44,8 @@ class RequestBase;
class AsyncGRPCServer final {
public:
explicit AsyncGRPCServer(const std::string &address) : address_(address) {}
explicit AsyncGRPCServer(const std::string &address, bool sync_mode)
: address_(address), sync_mode_(sync_mode) {}
void RunSyncUpdate();
......@@ -95,6 +96,7 @@ class AsyncGRPCServer final {
std::unique_ptr<::grpc::Server> server_;
std::string address_;
const bool sync_mode_;
framework::Scope *scope_;
const platform::DeviceContext *dev_ctx_;
......
......@@ -89,7 +89,7 @@ void InitTensorsOnServer(framework::Scope* scope, platform::CPUPlace* place,
}
void StartServer(const std::string& endpoint) {
rpc_service_.reset(new detail::AsyncGRPCServer(endpoint));
rpc_service_.reset(new detail::AsyncGRPCServer(endpoint, true));
framework::ProgramDesc program;
framework::Scope scope;
platform::CPUPlace place;
......
......@@ -46,7 +46,9 @@ class VariableResponse {
}
virtual ~VariableResponse() {
if (create_scope_) scope_->DeleteScope(local_scope_);
if (create_scope_) {
scope_->DeleteScope(local_scope_);
}
}
// return:
......@@ -63,6 +65,8 @@ class VariableResponse {
const framework::Scope& GetLocalScope() const { return *local_scope_; }
framework::Scope* GetMutableLocalScope() const { return local_scope_; }
inline std::string Varname() { return meta_.varname(); }
inline std::string OutVarname() { return meta_.out_varname(); }
......
......@@ -27,6 +27,38 @@ void RunServer(std::shared_ptr<detail::AsyncGRPCServer> service) {
VLOG(4) << "RunServer thread end";
}
static void split(const std::string &str, char sep,
std::vector<std::string> *pieces) {
pieces->clear();
if (str.empty()) {
return;
}
size_t pos = 0;
size_t next = str.find(sep, pos);
while (next != std::string::npos) {
pieces->push_back(str.substr(pos, next - pos));
pos = next + 1;
next = str.find(sep, pos);
}
if (!str.substr(pos).empty()) {
pieces->push_back(str.substr(pos));
}
}
static void AsyncExecuteBlock(framework::Executor *executor,
framework::ExecutorPrepareContext *prepared,
framework::Scope *scope) {
std::future<void> future = framework::Async([&executor, &prepared, &scope]() {
try {
executor->RunPreparedContext(prepared, scope, false, false);
} catch (std::exception &e) {
LOG(ERROR) << "run sub program error " << e.what();
}
});
// TODO(qiao) maybe we can remove this
future.wait();
}
static void ParallelExecuteBlocks(
const std::vector<size_t> &parallel_blkids, framework::Executor *executor,
const std::vector<std::shared_ptr<framework::ExecutorPrepareContext>>
......@@ -169,15 +201,82 @@ void ListenAndServOp::RunSyncLoop(framework::Executor *executor,
} // while(true)
}
void ListenAndServOp::RunAsyncLoop(framework::Executor *executor,
framework::ProgramDesc *program,
framework::Scope *recv_scope,
framework::BlockDesc *prefetch_block) const {
VLOG(3) << "RunAsyncLoop in";
// grad name to block id
std::unordered_map<std::string, int32_t> grad_to_block_id;
std::unordered_map<int32_t, std::string> id_to_grad;
auto grad_to_block_id_str =
Attr<std::vector<std::string>>("grad_to_block_id");
for (auto &grad_and_id : grad_to_block_id_str) {
std::vector<std::string> pieces;
split(grad_and_id, ':', &pieces);
VLOG(3) << "after split, grad = " << pieces[0] << ", id=" << pieces[1];
PADDLE_ENFORCE_EQ(pieces.size(), 2);
PADDLE_ENFORCE_EQ(grad_to_block_id.count(pieces[0]), 0);
int block_id = std::stoi(pieces[1]);
grad_to_block_id[pieces[0]] = block_id;
id_to_grad[block_id] = pieces[0];
}
size_t num_blocks = program->Size();
PADDLE_ENFORCE_GE(num_blocks, 2,
"server program should have at least 2 blocks");
std::vector<int> block_list;
for (size_t blkid = 1; blkid < num_blocks; ++blkid) {
block_list.push_back(blkid);
}
auto optimize_prepared = executor->Prepare(*program, block_list);
std::unordered_map<std::string,
std::shared_ptr<framework::ExecutorPrepareContext>>
grad_to_prepared_ctx;
for (size_t i = 0; i < block_list.size(); ++i) {
grad_to_prepared_ctx[id_to_grad[block_list[i]]] = optimize_prepared[i];
}
VLOG(3) << "RunAsyncLoop into while";
bool exit_flag = false;
while (!exit_flag) {
const detail::ReceivedMessage v = rpc_service_->Get();
auto recv_var_name = v.first;
if (recv_var_name == LISTEN_TERMINATE_MESSAGE) {
LOG(INFO) << "received terminate message and exit";
exit_flag = true;
break;
} else {
VLOG(3) << "received grad: " << recv_var_name;
auto var = v.second->GetVar();
if (var == nullptr) {
LOG(ERROR) << "Can not find server side var: " << recv_var_name;
PADDLE_THROW("Can not find server side var");
}
AsyncExecuteBlock(executor, grad_to_prepared_ctx[recv_var_name].get(),
v.second->GetMutableLocalScope());
}
if (exit_flag) {
rpc_service_->ShutDown();
break;
}
} // while(true)
}
void ListenAndServOp::RunImpl(const framework::Scope &scope,
const platform::Place &dev_place) const {
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Get(dev_place);
framework::Scope &recv_scope = scope.NewScope();
bool sync_mode = Attr<bool>("sync_mode");
PADDLE_ENFORCE(!rpc_service_);
std::string endpoint = Attr<std::string>("endpoint");
rpc_service_.reset(new detail::AsyncGRPCServer(endpoint));
rpc_service_.reset(new detail::AsyncGRPCServer(endpoint, sync_mode));
auto *optimize_block = Attr<framework::BlockDesc *>(kOptimizeBlock);
auto *prefetch_block = Attr<framework::BlockDesc *>(kPrefetchBlock);
......@@ -202,7 +301,11 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
sleep(5);
// Write to a file of server selected port for python use.
SavePort(rpc_service_);
RunSyncLoop(&executor, program, &recv_scope, prefetch_block);
if (sync_mode) {
RunSyncLoop(&executor, program, &recv_scope, prefetch_block);
} else {
RunAsyncLoop(&executor, program, &recv_scope, prefetch_block);
}
}
class ListenAndServOpMaker : public framework::OpProtoAndCheckerMaker {
......@@ -221,6 +324,12 @@ from send_op and send back variables to recv_op.
"IP address to listen on.")
.SetDefault("127.0.0.1:6164")
.AddCustomChecker([](const std::string &ip) { return !ip.empty(); });
AddAttr<std::vector<std::string>>(
"grad_to_block_id",
"['param1@GRAD.block0:1', 'param2@GRAD.blockn:2'] "
"a map from grad name to it's optimize block id")
.SetDefault({});
AddAttr<bool>("sync_mode", "if works at sync_mode or not").SetDefault(true);
AddAttr<framework::BlockDesc *>(kOptimizeBlock,
"BlockID to run on server side.");
AddAttr<framework::BlockDesc *>(kPrefetchBlock,
......
......@@ -46,6 +46,11 @@ class ListenAndServOp : public framework::OperatorBase {
framework::Scope* recv_scope,
framework::BlockDesc* prefetch_block) const;
void RunAsyncLoop(framework::Executor* executor,
framework::ProgramDesc* program,
framework::Scope* recv_scope,
framework::BlockDesc* prefetch_block) const;
void Stop() override;
void RunImpl(const framework::Scope& scope,
......
......@@ -41,6 +41,8 @@ class SendOp : public framework::OperatorBase {
std::vector<std::string> endpoints =
Attr<std::vector<std::string>>("endpoints");
bool sync_mode = Attr<bool>("sync_mode");
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto& ctx = *pool.Get(place);
......@@ -64,11 +66,13 @@ class SendOp : public framework::OperatorBase {
}
PADDLE_ENFORCE(rpc_client->Wait());
for (auto& ep : endpoints) {
VLOG(3) << "batch barrier, ep: " << ep;
rpc_client->AsyncSendBatchBarrier(ep);
if (sync_mode) {
for (auto& ep : endpoints) {
VLOG(3) << "batch barrier, ep: " << ep;
rpc_client->AsyncSendBatchBarrier(ep);
}
PADDLE_ENFORCE(rpc_client->Wait());
}
PADDLE_ENFORCE(rpc_client->Wait());
if (outs.size() > 0) {
for (size_t i = 0; i < outs.size(); i++) {
......@@ -112,6 +116,7 @@ This operator will send tensor to recv_op at the parameter server.
"Server endpoints in the order of input "
"variables for mapping")
.SetDefault({});
AddAttr<bool>("sync_mode", "work in sync_mode or not").SetDefault(true);
}
};
......
......@@ -137,6 +137,8 @@ void StartServerNet(bool is_sparse) {
attrs.insert({"GradList", std::vector<std::string>({"x1"})});
attrs.insert({"OptimizeBlock", optimize_block});
attrs.insert({"PrefetchBlock", prefetch_block});
attrs.insert({"grad_to_block_id", std::vector<std::string>({""})});
attrs.insert({"sync_mode", true});
listen_and_serv_op =
f::OpRegistry::CreateOp("listen_and_serv", {{"X", {"x1"}}}, {}, attrs);
listen_and_serv_op->Run(scope, place);
......
......@@ -143,7 +143,8 @@ class DistributeTranspiler:
program=None,
pservers="127.0.0.1:6174",
trainers=1,
split_method=splitter.round_robin):
split_method=splitter.round_robin,
sync_mode=True):
"""
Transpile the program to distributed data-parallelism programs.
The main_program will be transformed to use a remote parameter server
......@@ -184,6 +185,9 @@ class DistributeTranspiler:
:param split_method: A function to determin how to split variables
to different servers equally.
:type split_method: function
:type sync_mode: boolean default True
:param sync_mode: if sync_mode is set True, it means that dist transpiler
will transpile the program into sync_mode pserver and trainer program.
"""
assert (callable(split_method))
if program is None:
......@@ -191,6 +195,7 @@ class DistributeTranspiler:
self.origin_program = program
self.trainer_num = trainers
self.optimize_ops = optimize_ops
self.sync_mode = sync_mode
# TODO(typhoonzero): currently trainer_id is fetched from cluster system
# like Kubernetes, we should port this to use etcd later when developing
# fluid distributed training with fault-tolerance.
......@@ -295,8 +300,11 @@ class DistributeTranspiler:
inputs={"X": send_inputs},
outputs={"Out": send_outputs,
"RPCClient": rpc_client_var},
attrs={"endpoints": pserver_endpoints,
"epmap": eplist})
attrs={
"endpoints": pserver_endpoints,
"epmap": eplist,
"sync_mode": self.sync_mode
})
# step4: Concat the parameters splits together after recv.
for varname, splited_var in param_var_mapping.iteritems():
if len(splited_var) <= 1:
......@@ -356,7 +364,7 @@ class DistributeTranspiler:
type=v.type,
dtype=v.dtype,
shape=v.shape)
if self.trainer_num > 1:
if self.sync_mode and self.trainer_num > 1:
for trainer_id in xrange(self.trainer_num):
var = pserver_program.global_block().create_var(
name="%s.trainer_%d" % (orig_var_name, trainer_id),
......@@ -402,13 +410,13 @@ class DistributeTranspiler:
for op in self.optimize_ops:
if op.type == "scale":
for in_name in op.input_arg_names:
if in_name.startswith("beta1_pow_acc") or\
in_name.startswith("beta2_pow_acc"):
if in_name.startswith("beta1_pow_acc") or \
in_name.startswith("beta2_pow_acc"):
global_ops.append(op)
def __append_optimize_op__(op, block):
def __append_optimize_op__(op, block, grad_to_block_id):
if self._is_opt_op(op):
self._append_pserver_ops(block, op, endpoint,
self._append_pserver_ops(block, op, endpoint, grad_to_block_id,
default_main_program())
else:
self._append_pserver_non_opt_ops(block, op)
......@@ -422,16 +430,16 @@ class DistributeTranspiler:
self._append_pserver_non_opt_ops(lr_decay_block, op)
# append op to the current block
grad_to_block_id = []
pre_block_idx = pserver_program.num_blocks - 1
for idx, opt_op in enumerate(opt_op_on_pserver):
per_opt_block = pserver_program.create_block(pre_block_idx)
for _, op in enumerate(self.optimize_ops):
# optimizer is connected to itself
if ufind.is_connected(op, opt_op) and op not in global_ops:
__append_optimize_op__(op, per_opt_block)
__append_optimize_op__(op, per_opt_block, grad_to_block_id)
# append global ops
opt_state_block = None
if global_ops:
opt_state_block = pserver_program.create_block(
pserver_program.num_blocks - 1)
......@@ -472,7 +480,9 @@ class DistributeTranspiler:
"OptimizeBlock": pserver_program.block(1),
"endpoint": endpoint,
"Fanin": self.trainer_num,
"PrefetchBlock": prefetch_block
"PrefetchBlock": prefetch_block,
"sync_mode": self.sync_mode,
"grad_to_block_id": grad_to_block_id
})
pserver_program.sync_with_cpp()
......@@ -683,17 +693,6 @@ class DistributeTranspiler:
self.table_name)],
persistable=False)
# create grad vars in pserver program
table_grad_var = self.table_param_grad[1]
table_grad_list = [
pserver_program.global_block().create_var(
name="%s.trainer_%d.pserver_%d" %
(table_grad_var.name, index, pserver_index),
type=table_grad_var.type,
shape=table_grad_var.shape,
dtype=table_grad_var.dtype) for index in range(self.trainer_num)
]
# create table optimize block in pserver program
table_opt_op = [
op for op in self.optimize_ops
......@@ -703,11 +702,24 @@ class DistributeTranspiler:
# only support sgd now
assert table_opt_op.type == "sgd"
# append sum op for table_grad_list
table_opt_block.append_op(
type="sum",
inputs={"X": table_grad_list},
outputs={"Out": [grad_var]})
if self.sync_mode:
# create grad vars in pserver program
table_grad_var = self.table_param_grad[1]
table_grad_list = [
pserver_program.global_block().create_var(
name="%s.trainer_%d.pserver_%d" %
(table_grad_var.name, index, pserver_index),
type=table_grad_var.type,
shape=table_grad_var.shape,
dtype=table_grad_var.dtype)
for index in range(self.trainer_num)
]
# append sum op for table_grad_list
table_opt_block.append_op(
type="sum",
inputs={"X": table_grad_list},
outputs={"Out": [grad_var]})
lr_var = pserver_program.global_block().vars[table_opt_op.input(
"LearningRate")[0]]
......@@ -746,7 +758,7 @@ class DistributeTranspiler:
for varname, splited in block_map.iteritems():
orig_var = program.global_block().var(varname)
if len(splited) == 1:
if add_trainer_suffix:
if self.sync_mode and add_trainer_suffix:
new_var_name = "%s.trainer_%d" % \
(orig_var.name, self.trainer_id)
program.global_block().rename_var(varname, new_var_name)
......@@ -770,7 +782,7 @@ class DistributeTranspiler:
if len(orig_shape) >= 2:
splited_shape.extend(orig_shape[1:])
new_var_name = ""
if add_trainer_suffix:
if self.sync_mode and add_trainer_suffix:
new_var_name = "%s.block%d.trainer_%d" % \
(varname, i, self.trainer_id)
else:
......@@ -879,7 +891,7 @@ class DistributeTranspiler:
return orig_var_name
def _append_pserver_ops(self, optimize_block, opt_op, endpoint,
origin_program):
grad_to_block_id, origin_program):
program = optimize_block.program
pserver_block = program.global_block()
new_inputs = dict()
......@@ -900,7 +912,9 @@ class DistributeTranspiler:
return
merged_var = \
pserver_block.vars[self._orig_varname(grad_block.name)]
if self.trainer_num > 1:
grad_to_block_id.append(merged_var.name + ":" + str(
optimize_block.idx))
if self.sync_mode and self.trainer_num > 1:
vars2merge = []
for i in xrange(self.trainer_num):
per_trainer_name = "%s.trainer_%d" % \
......@@ -918,6 +932,7 @@ class DistributeTranspiler:
inputs={"X": merged_var},
outputs={"Out": merged_var},
attrs={"scale": 1.0 / float(self.trainer_num)})
new_inputs[key] = merged_var
elif key == "Param":
# param is already created on global program
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册