提交 580340ee 编写于 作者: Q Qiyang Min 提交者: Wu Yi

Shutdown pserver gracefully when SIGINT and SIGTERM was sent (#10984)

* 1. implement StopAll in ListenAndServOp
2. make pserver receive the SIGINT and SIGTERM from outside
3. add unittests for listen_and_serv_op in python

* 1. add blocking queue set to record queue
2. aware all blocking queue when exit and exit gracefully

* 1. Remove comment lines from blocking_queue.h
2. Implement SignalHandler and move all global vars and funcs into it

* 1. Make code follows the style check
2. Move the SignalHandler out of the unnamed namespace

* 1. Make yapf happy

* 1. Call Stop() in destructor to release the resource allocated by ListendAndServOp
2. Change exit status to EXIT_SUCCESS after handling the signal from outside
3. Remove the mis-usage of REMOVE_ITEM in unittests

* 1. use DISABLE_COPY_AND_ASSIGN
2. use program once macro only
上级 3d6934e3
...@@ -13,8 +13,9 @@ See the License for the specific language governing permissions and ...@@ -13,8 +13,9 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include <stdio.h> // for removing the port file #include <stdio.h> // for removing the port file
#include <csignal>
#include <cstdlib>
#include <fstream> #include <fstream>
#include <ostream>
#include <thread> // NOLINT #include <thread> // NOLINT
#include <vector> #include <vector>
...@@ -28,7 +29,6 @@ void RunServer(std::shared_ptr<detail::AsyncGRPCServer> service) { ...@@ -28,7 +29,6 @@ void RunServer(std::shared_ptr<detail::AsyncGRPCServer> service) {
service->RunSyncUpdate(); service->RunSyncUpdate();
VLOG(4) << "RunServer thread end"; VLOG(4) << "RunServer thread end";
} }
static void split(const std::string &str, char sep, static void split(const std::string &str, char sep,
std::vector<std::string> *pieces) { std::vector<std::string> *pieces) {
pieces->clear(); pieces->clear();
...@@ -59,7 +59,7 @@ static void ParallelExecuteBlocks( ...@@ -59,7 +59,7 @@ static void ParallelExecuteBlocks(
int run_block = idx; // thread local int run_block = idx; // thread local
try { try {
executor->RunPreparedContext(prepared[run_block].get(), scope); executor->RunPreparedContext(prepared[run_block].get(), scope);
} catch (std::exception &e) { } catch (const std::exception &e) {
LOG(ERROR) << "run sub program error " << e.what(); LOG(ERROR) << "run sub program error " << e.what();
} }
})); }));
...@@ -75,8 +75,11 @@ ListenAndServOp::ListenAndServOp(const std::string &type, ...@@ -75,8 +75,11 @@ ListenAndServOp::ListenAndServOp(const std::string &type,
const framework::AttributeMap &attrs) const framework::AttributeMap &attrs)
: OperatorBase(type, inputs, outputs, attrs) {} : OperatorBase(type, inputs, outputs, attrs) {}
ListenAndServOp::~ListenAndServOp() { Stop(); }
void ListenAndServOp::Stop() { void ListenAndServOp::Stop() {
rpc_service_->Push(LISTEN_TERMINATE_MESSAGE); rpc_service_->Push(LISTEN_TERMINATE_MESSAGE);
rpc_service_->ShutDown();
server_thread_->join(); server_thread_->join();
auto file_path = string::Sprintf("/tmp/paddle.%d.port", ::getpid()); auto file_path = string::Sprintf("/tmp/paddle.%d.port", ::getpid());
remove(file_path.c_str()); remove(file_path.c_str());
...@@ -122,7 +125,7 @@ void ListenAndServOp::RunSyncLoop(framework::Executor *executor, ...@@ -122,7 +125,7 @@ void ListenAndServOp::RunSyncLoop(framework::Executor *executor,
// Record received sparse variables, so that // Record received sparse variables, so that
// we could reset those after execute optimize program // we could reset those after execute optimize program
std::vector<framework::Variable *> sparse_vars; std::vector<framework::Variable *> sparse_vars;
while (!exit_flag) { while (!exit_flag && !SignalHandler::IsProgramExit()) {
// Get from multiple trainers, we don't care about the order in which // Get from multiple trainers, we don't care about the order in which
// the gradients arrives, just add suffix 0~n and merge the gradient. // the gradients arrives, just add suffix 0~n and merge the gradient.
rpc_service_->SetCond(0); rpc_service_->SetCond(0);
...@@ -187,7 +190,7 @@ void ListenAndServOp::RunSyncLoop(framework::Executor *executor, ...@@ -187,7 +190,7 @@ void ListenAndServOp::RunSyncLoop(framework::Executor *executor,
// mini-batch. // mini-batch.
// TODO(Yancey1989): move the reset action into an operator, we couldn't // TODO(Yancey1989): move the reset action into an operator, we couldn't
// have any hide logic in the operator. // have any hide logic in the operator.
for (auto &var : sparse_vars) { for (framework::Variable *var : sparse_vars) {
var->GetMutable<framework::SelectedRows>()->mutable_rows()->clear(); var->GetMutable<framework::SelectedRows>()->mutable_rows()->clear();
} }
...@@ -204,8 +207,12 @@ static void AsyncUpdateThread( ...@@ -204,8 +207,12 @@ static void AsyncUpdateThread(
framework::Executor *executor, framework::Executor *executor,
framework::ExecutorPrepareContext *prepared) { framework::ExecutorPrepareContext *prepared) {
VLOG(3) << "update thread for " << var_name << " started"; VLOG(3) << "update thread for " << var_name << " started";
while (!exit_flag) { while (!exit_flag && !SignalHandler::IsProgramExit()) {
const detail::ReceivedMessage v = queue->Pop(); const detail::ReceivedMessage v = queue->Pop();
if (SignalHandler::IsProgramExit()) {
VLOG(3) << "update thread for " << var_name << " exit";
break;
}
auto recv_var_name = v.first; auto recv_var_name = v.first;
VLOG(4) << "async update " << recv_var_name; VLOG(4) << "async update " << recv_var_name;
auto var = v.second->GetVar(); auto var = v.second->GetVar();
...@@ -217,7 +224,7 @@ static void AsyncUpdateThread( ...@@ -217,7 +224,7 @@ static void AsyncUpdateThread(
try { try {
executor->RunPreparedContext(prepared, executor->RunPreparedContext(prepared,
v.second->GetMutableLocalScope()); v.second->GetMutableLocalScope());
} catch (std::exception &e) { } catch (const std::exception &e) {
LOG(ERROR) << "run sub program error " << e.what(); LOG(ERROR) << "run sub program error " << e.what();
} }
}); });
...@@ -236,7 +243,7 @@ void ListenAndServOp::RunAsyncLoop(framework::Executor *executor, ...@@ -236,7 +243,7 @@ void ListenAndServOp::RunAsyncLoop(framework::Executor *executor,
auto grad_to_block_id_str = auto grad_to_block_id_str =
Attr<std::vector<std::string>>("grad_to_block_id"); Attr<std::vector<std::string>>("grad_to_block_id");
for (auto &grad_and_id : grad_to_block_id_str) { for (const auto &grad_and_id : grad_to_block_id_str) {
std::vector<std::string> pieces; std::vector<std::string> pieces;
split(grad_and_id, ':', &pieces); split(grad_and_id, ':', &pieces);
VLOG(3) << "after split, grad = " << pieces[0] << ", id=" << pieces[1]; VLOG(3) << "after split, grad = " << pieces[0] << ", id=" << pieces[1];
...@@ -244,7 +251,11 @@ void ListenAndServOp::RunAsyncLoop(framework::Executor *executor, ...@@ -244,7 +251,11 @@ void ListenAndServOp::RunAsyncLoop(framework::Executor *executor,
PADDLE_ENFORCE_EQ(grad_to_block_id.count(pieces[0]), 0); PADDLE_ENFORCE_EQ(grad_to_block_id.count(pieces[0]), 0);
int block_id = std::stoi(pieces[1]); int block_id = std::stoi(pieces[1]);
grad_to_block_id[pieces[0]] = block_id; grad_to_block_id[pieces[0]] = block_id;
grad_to_queue[pieces[0]] = std::make_shared<detail::ReceivedQueue>(); std::shared_ptr<detail::ReceivedQueue> queue =
std::make_shared<detail::ReceivedQueue>();
grad_to_queue[pieces[0]] = queue;
// record blocking queue in SignalHandler
SignalHandler::RegisterBlockingQueue(queue);
id_to_grad[block_id] = pieces[0]; id_to_grad[block_id] = pieces[0];
} }
size_t num_blocks = program->Size(); size_t num_blocks = program->Size();
...@@ -276,9 +287,8 @@ void ListenAndServOp::RunAsyncLoop(framework::Executor *executor, ...@@ -276,9 +287,8 @@ void ListenAndServOp::RunAsyncLoop(framework::Executor *executor,
executor, grad_to_prepared_ctx[grad_name].get()); executor, grad_to_prepared_ctx[grad_name].get());
})); }));
} }
VLOG(3) << "RunAsyncLoop into while"; VLOG(3) << "RunAsyncLoop into while";
while (!exit_flag) { while (!exit_flag && !SignalHandler::IsProgramExit()) {
const detail::ReceivedMessage v = rpc_service_->Get(); const detail::ReceivedMessage v = rpc_service_->Get();
auto recv_var_name = v.first; auto recv_var_name = v.first;
if (recv_var_name == LISTEN_TERMINATE_MESSAGE) { if (recv_var_name == LISTEN_TERMINATE_MESSAGE) {
...@@ -333,6 +343,10 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope, ...@@ -333,6 +343,10 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
VLOG(3) << "wait server thread to become ready..."; VLOG(3) << "wait server thread to become ready...";
rpc_service_->WaitServerReady(); rpc_service_->WaitServerReady();
// register SIGINT(from ctrl+C) and SIGTERM(from kill) signal handlers
signal(SIGINT, SignalHandler::StopAndExit);
signal(SIGTERM, SignalHandler::StopAndExit);
// Write to a file of server selected port for python use. // Write to a file of server selected port for python use.
std::string file_path = string::Sprintf("/tmp/paddle.%d.selected_port", std::string file_path = string::Sprintf("/tmp/paddle.%d.selected_port",
static_cast<int>(::getpid())); static_cast<int>(::getpid()));
...@@ -348,12 +362,9 @@ class ListenAndServOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -348,12 +362,9 @@ class ListenAndServOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
void Make() { void Make() {
AddInput("X", "(Tensor) Variables that server recv.").AsDuplicable(); AddInput("X", "(Tensor) Variables that server recv.").AsDuplicable();
AddComment(R"DOC( AddComment(R"DOC(" + "ListenAndServ operator" + "\n" + "This operator" +
ListenAndServ operator " will start a RPC server which can receive variables from send_op and send" +
"back variables to recv_op.)DOC");
This operator will start a RPC server which can receive variables
from send_op and send back variables to recv_op.
)DOC");
AddAttr<std::string>("endpoint", AddAttr<std::string>("endpoint",
"(string, default 127.0.0.1:6164)" "(string, default 127.0.0.1:6164)"
"IP address to listen on.") "IP address to listen on.")
...@@ -374,6 +385,29 @@ from send_op and send back variables to recv_op. ...@@ -374,6 +385,29 @@ from send_op and send back variables to recv_op.
} }
}; };
bool SignalHandler::program_exit_flag_ = false;
SignalHandler::BlockingQueueSet SignalHandler::blocking_queue_set_{};
void SignalHandler::StopAndExit(int signal_num) {
VLOG(3) << "Catch interrupt signal: " << signal_num << ", program will exit";
program_exit_flag_ = true;
// awake all blocking queues
for (BlockingQueueSet::iterator iter = blocking_queue_set_.begin();
iter != blocking_queue_set_.end(); iter++) {
iter->get()->Push(
std::make_pair(std::string(LISTEN_TERMINATE_MESSAGE), nullptr));
}
exit(EXIT_SUCCESS);
}
void SignalHandler::RegisterBlockingQueue(BlockingQueue &queue) {
blocking_queue_set_.insert(queue);
}
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
......
...@@ -16,7 +16,7 @@ limitations under the License. */ ...@@ -16,7 +16,7 @@ limitations under the License. */
#include <stdint.h> #include <stdint.h>
#include <atomic> #include <atomic>
#include <ostream> #include <set>
#include <string> #include <string>
#include "paddle/fluid/framework/executor.h" #include "paddle/fluid/framework/executor.h"
...@@ -40,6 +40,8 @@ class ListenAndServOp : public framework::OperatorBase { ...@@ -40,6 +40,8 @@ class ListenAndServOp : public framework::OperatorBase {
const framework::VariableNameMap& outputs, const framework::VariableNameMap& outputs,
const framework::AttributeMap& attrs); const framework::AttributeMap& attrs);
virtual ~ListenAndServOp();
void RunSyncLoop(framework::Executor* executor, void RunSyncLoop(framework::Executor* executor,
framework::ProgramDesc* program, framework::ProgramDesc* program,
framework::Scope* recv_scope, framework::Scope* recv_scope,
...@@ -68,5 +70,25 @@ class ListenAndServOp : public framework::OperatorBase { ...@@ -68,5 +70,25 @@ class ListenAndServOp : public framework::OperatorBase {
static std::atomic_int selected_port_; static std::atomic_int selected_port_;
}; };
class SignalHandler {
public:
typedef std::shared_ptr<detail::ReceivedQueue> BlockingQueue;
typedef std::unordered_set<BlockingQueue> BlockingQueueSet;
public:
static void StopAndExit(int signal_num);
static void RegisterBlockingQueue(BlockingQueue&);
static inline bool IsProgramExit() { return program_exit_flag_; }
private:
static bool program_exit_flag_;
static BlockingQueueSet blocking_queue_set_;
DISABLE_COPY_AND_ASSIGN(SignalHandler);
};
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -48,3 +48,5 @@ foreach(TEST_OP ${TEST_OPS}) ...@@ -48,3 +48,5 @@ foreach(TEST_OP ${TEST_OPS})
endforeach(TEST_OP) endforeach(TEST_OP)
py_test_modules(test_warpctc_op MODULES test_warpctc_op ENVS FLAGS_warpctc_dir=${WARPCTC_LIB_DIR} SERIAL) py_test_modules(test_warpctc_op MODULES test_warpctc_op ENVS FLAGS_warpctc_dir=${WARPCTC_LIB_DIR} SERIAL)
py_test_modules(test_dist_train MODULES test_dist_train SERIAL) py_test_modules(test_dist_train MODULES test_dist_train SERIAL)
# tests that need to be done in fixed timeout
set_tests_properties(test_listen_and_serv_op PROPERTIES TIMEOUT 20)
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import paddle.fluid as fluid
import os
import signal
import subprocess
import time
import unittest
from multiprocessing import Process
from op_test import OpTest
def run_pserver(use_cuda, sync_mode, ip, port, trainer_count, trainer_id):
x = fluid.layers.data(name='x', shape=[1], dtype='float32')
y_predict = fluid.layers.fc(input=x, size=1, act=None)
y = fluid.layers.data(name='y', shape=[1], dtype='float32')
# loss function
cost = fluid.layers.square_error_cost(input=y_predict, label=y)
avg_cost = fluid.layers.mean(cost)
# optimizer
sgd_optimizer = fluid.optimizer.SGD(learning_rate=0.001)
sgd_optimizer.minimize(avg_cost)
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
exe = fluid.Executor(place)
port = os.getenv("PADDLE_INIT_PORT", port)
pserver_ips = os.getenv("PADDLE_INIT_PSERVERS", ip) # ip,ip...
eplist = []
for ip in pserver_ips.split(","):
eplist.append(':'.join([ip, port]))
pserver_endpoints = ",".join(eplist) # ip:port,ip:port...
trainers = int(os.getenv("TRAINERS", trainer_count))
current_endpoint = os.getenv("POD_IP", ip) + ":" + port
trainer_id = int(os.getenv("PADDLE_INIT_TRAINER_ID", trainer_id))
t = fluid.DistributeTranspiler()
t.transpile(
trainer_id,
pservers=pserver_endpoints,
trainers=trainers,
sync_mode=sync_mode)
pserver_prog = t.get_pserver_program(current_endpoint)
pserver_startup = t.get_startup_program(current_endpoint, pserver_prog)
exe.run(pserver_startup)
exe.run(pserver_prog)
class TestListenAndServOp(OpTest):
def setUp(self):
self.sleep_time = 5
self.ip = "127.0.0.1"
self.port = "6173"
self.trainer_count = 1
self.trainer_id = 1
def _raise_signal(self, parent_pid, raised_signal):
time.sleep(self.sleep_time)
ps_command = subprocess.Popen(
"ps -o pid --ppid %d --noheaders" % parent_pid,
shell=True,
stdout=subprocess.PIPE)
ps_output = ps_command.stdout.read()
retcode = ps_command.wait()
assert retcode == 0, "ps command returned %d" % retcode
for pid_str in ps_output.split("\n")[:-1]:
try:
os.kill(int(pid_str), raised_signal)
except Exception:
continue
def _start_pserver(self, use_cuda, sync_mode):
p = Process(
target=run_pserver,
args=(use_cuda, sync_mode, self.ip, self.port, self.trainer_count,
self.trainer_id))
p.start()
def test_handle_signal_in_serv_op(self):
# run pserver on CPU in sync mode
self._start_pserver(False, True)
# raise SIGINT to pserver
self._raise_signal(os.getpid(), signal.SIGINT)
# run pserver on CPU in async mode
self._start_pserver(False, False)
# raise SIGTERM to pserver
self._raise_signal(os.getpid(), signal.SIGTERM)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册