提交 efd5a849 编写于 作者: Y Yancey1989

update executor interface

上级 800702cc
...@@ -45,13 +45,13 @@ ExecutorPrepareContext::~ExecutorPrepareContext() { ...@@ -45,13 +45,13 @@ ExecutorPrepareContext::~ExecutorPrepareContext() {
Executor::Executor(const platform::Place& place) : place_(place) {} Executor::Executor(const platform::Place& place) : place_(place) {}
void Executor::Close() {
#ifdef PADDLE_WITH_DISTRIBUTE #ifdef PADDLE_WITH_DISTRIBUTE
void Executor::Complete() {
::paddle::operators::distributed::RPCClient::GetInstance< ::paddle::operators::distributed::RPCClient::GetInstance<
::paddle::operators::distributed::GRPCClient>() ::paddle::operators::distributed::GRPCClient>()
->SendComplete(); ->SendComplete();
}
#endif #endif
}
void InitializeVariable(Variable* var, proto::VarType::Type var_type) { void InitializeVariable(Variable* var, proto::VarType::Type var_type) {
if (var_type == proto::VarType::LOD_TENSOR) { if (var_type == proto::VarType::LOD_TENSOR) {
......
...@@ -48,7 +48,7 @@ class Executor { ...@@ -48,7 +48,7 @@ class Executor {
/* /*
* Sending signal to pserver to mark current trainer completed. * Sending signal to pserver to mark current trainer completed.
*/ */
void Complete(); void Close();
#endif #endif
......
...@@ -36,11 +36,15 @@ void GRPCClient::InitEventLoop() { ...@@ -36,11 +36,15 @@ void GRPCClient::InitEventLoop() {
} }
void GRPCClient::SendComplete() { void GRPCClient::SendComplete() {
std::unique_lock<std::mutex> lk(completed_mutex_);
if (!completed_) {
for (auto& it : channels_) { for (auto& it : channels_) {
VLOG(3) << "send complete message to " << it.first; VLOG(3) << "send complete message to " << it.first;
this->AsyncSendComplete(it.first); this->AsyncSendComplete(it.first);
} }
this->Wait(); PADDLE_ENFORCE(this->Wait(), "internal grpc error");
completed_ = true;
}
} }
GRPCClient::~GRPCClient() { GRPCClient::~GRPCClient() {
......
...@@ -188,7 +188,7 @@ class CheckpointNotifyProcessor : public BaseProcessor { ...@@ -188,7 +188,7 @@ class CheckpointNotifyProcessor : public BaseProcessor {
class GRPCClient : public RPCClient { class GRPCClient : public RPCClient {
public: public:
GRPCClient() : ok_(true) {} GRPCClient() : ok_(true), completed_(false) {}
virtual ~GRPCClient(); virtual ~GRPCClient();
bool AsyncSendVar(const std::string& ep, const platform::DeviceContext& ctx, bool AsyncSendVar(const std::string& ep, const platform::DeviceContext& ctx,
...@@ -247,6 +247,10 @@ class GRPCClient : public RPCClient { ...@@ -247,6 +247,10 @@ class GRPCClient : public RPCClient {
// mutex for GetChannel thread safety // mutex for GetChannel thread safety
std::mutex chan_mutex_; std::mutex chan_mutex_;
DISABLE_COPY_AND_ASSIGN(GRPCClient); DISABLE_COPY_AND_ASSIGN(GRPCClient);
// mutex for sending complete message only once
std::mutex completed_mutex_;
bool completed_;
}; };
} // namespace distributed } // namespace distributed
......
...@@ -502,9 +502,7 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -502,9 +502,7 @@ All parameter, weight, gradient are variables in Paddle.
py::class_<framework::Executor>(m, "Executor") py::class_<framework::Executor>(m, "Executor")
.def(py::init<const platform::Place &>()) .def(py::init<const platform::Place &>())
#ifdef PADDLE_WITH_DISTRIBUTE .def("close", &Executor::Close)
.def("complete", &Executor::Complete)
#endif
.def("run", [](Executor &self, const ProgramDesc &prog, Scope *scope, .def("run", [](Executor &self, const ProgramDesc &prog, Scope *scope,
int block_id, bool create_local_scope, bool create_vars) { int block_id, bool create_local_scope, bool create_vars) {
pybind11::gil_scoped_release release; pybind11::gil_scoped_release release;
......
...@@ -247,6 +247,7 @@ class Executor(object): ...@@ -247,6 +247,7 @@ class Executor(object):
p.set_place(place) p.set_place(place)
self.executor = core.Executor(p) self.executor = core.Executor(p)
self.program_caches = dict() self.program_caches = dict()
self._closed = False
def as_lodtensor(self, data): def as_lodtensor(self, data):
""" """
...@@ -348,8 +349,23 @@ class Executor(object): ...@@ -348,8 +349,23 @@ class Executor(object):
] ]
return outs return outs
def complete(self): def close(self):
self.executor.complete() """
Close this executor.
You can no long use this executor after calling this method.
For the distributed training, this method would free the resource on PServers related to
the current Trainer.
Example:
>>> cpu = core.CPUPlace()
>>> exe = Executor(cpu)
>>> ...
>>> exe.close()
"""
if not self._closed:
self.executor.close()
self._closed = True
def run(self, def run(self,
program=None, program=None,
...@@ -402,6 +418,10 @@ class Executor(object): ...@@ -402,6 +418,10 @@ class Executor(object):
>>> feed={'X': x}, >>> feed={'X': x},
>>> fetch_list=[loss.name]) >>> fetch_list=[loss.name])
""" """
if self._closed:
raise RuntimeError("Attempted to use a closed Executor")
if feed is None: if feed is None:
feed = {} feed = {}
if not isinstance(feed, dict): if not isinstance(feed, dict):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册