提交 efd5a849 编写于 作者: Y Yancey1989

update executor interface

上级 800702cc
......@@ -45,13 +45,13 @@ ExecutorPrepareContext::~ExecutorPrepareContext() {
Executor::Executor(const platform::Place& place) : place_(place) {}
void Executor::Close() {
#ifdef PADDLE_WITH_DISTRIBUTE
void Executor::Complete() {
::paddle::operators::distributed::RPCClient::GetInstance<
::paddle::operators::distributed::GRPCClient>()
->SendComplete();
}
#endif
}
void InitializeVariable(Variable* var, proto::VarType::Type var_type) {
if (var_type == proto::VarType::LOD_TENSOR) {
......
......@@ -48,7 +48,7 @@ class Executor {
/*
* Sending signal to pserver to mark current trainer completed.
*/
void Complete();
void Close();
#endif
......
......@@ -36,11 +36,15 @@ void GRPCClient::InitEventLoop() {
}
void GRPCClient::SendComplete() {
for (auto& it : channels_) {
VLOG(3) << "send complete message to " << it.first;
this->AsyncSendComplete(it.first);
std::unique_lock<std::mutex> lk(completed_mutex_);
if (!completed_) {
for (auto& it : channels_) {
VLOG(3) << "send complete message to " << it.first;
this->AsyncSendComplete(it.first);
}
PADDLE_ENFORCE(this->Wait(), "internal grpc error");
completed_ = true;
}
this->Wait();
}
GRPCClient::~GRPCClient() {
......
......@@ -188,7 +188,7 @@ class CheckpointNotifyProcessor : public BaseProcessor {
class GRPCClient : public RPCClient {
public:
GRPCClient() : ok_(true) {}
GRPCClient() : ok_(true), completed_(false) {}
virtual ~GRPCClient();
bool AsyncSendVar(const std::string& ep, const platform::DeviceContext& ctx,
......@@ -247,6 +247,10 @@ class GRPCClient : public RPCClient {
// mutex for GetChannel thread safety
std::mutex chan_mutex_;
DISABLE_COPY_AND_ASSIGN(GRPCClient);
// mutex for sending complete message only once
std::mutex completed_mutex_;
bool completed_;
};
} // namespace distributed
......
......@@ -502,9 +502,7 @@ All parameter, weight, gradient are variables in Paddle.
py::class_<framework::Executor>(m, "Executor")
.def(py::init<const platform::Place &>())
#ifdef PADDLE_WITH_DISTRIBUTE
.def("complete", &Executor::Complete)
#endif
.def("close", &Executor::Close)
.def("run", [](Executor &self, const ProgramDesc &prog, Scope *scope,
int block_id, bool create_local_scope, bool create_vars) {
pybind11::gil_scoped_release release;
......
......@@ -247,6 +247,7 @@ class Executor(object):
p.set_place(place)
self.executor = core.Executor(p)
self.program_caches = dict()
self._closed = False
def as_lodtensor(self, data):
"""
......@@ -348,8 +349,23 @@ class Executor(object):
]
return outs
def complete(self):
self.executor.complete()
def close(self):
"""
Close this executor.
You can no long use this executor after calling this method.
For the distributed training, this method would free the resource on PServers related to
the current Trainer.
Example:
>>> cpu = core.CPUPlace()
>>> exe = Executor(cpu)
>>> ...
>>> exe.close()
"""
if not self._closed:
self.executor.close()
self._closed = True
def run(self,
program=None,
......@@ -402,6 +418,10 @@ class Executor(object):
>>> feed={'X': x},
>>> fetch_list=[loss.name])
"""
if self._closed:
raise RuntimeError("Attempted to use a closed Executor")
if feed is None:
feed = {}
if not isinstance(feed, dict):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册