提交 580340ee 编写于 作者: Q Qiyang Min 提交者: Wu Yi

Shutdown pserver gracefully when SIGINT and SIGTERM was sent (#10984)

* 1. implement StopAll in ListenAndServOp
2. make pserver receive the SIGINT and SIGTERM from outside
3. add unittests for listen_and_serv_op in python

* 1. add blocking queue set to record queue
2. aware all blocking queue when exit and exit gracefully

* 1. Remove comment lines from blocking_queue.h
2. Implement SignalHandler and move all global vars and funcs into it

* 1. Make code follows the style check
2. Move the SignalHandler out of the unnamed namespace

* 1. Make yapf happy

* 1. Call Stop() in destructor to release the resource allocated by ListendAndServOp
2. Change exit status to EXIT_SUCCESS after handling the signal from outside
3. Remove the mis-usage of REMOVE_ITEM in unittests

* 1. use DISABLE_COPY_AND_ASSIGN
2. use program once macro only
上级 3d6934e3
......@@ -13,8 +13,9 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include <stdio.h> // for removing the port file
#include <csignal>
#include <cstdlib>
#include <fstream>
#include <ostream>
#include <thread> // NOLINT
#include <vector>
......@@ -28,7 +29,6 @@ void RunServer(std::shared_ptr<detail::AsyncGRPCServer> service) {
service->RunSyncUpdate();
VLOG(4) << "RunServer thread end";
}
static void split(const std::string &str, char sep,
std::vector<std::string> *pieces) {
pieces->clear();
......@@ -59,7 +59,7 @@ static void ParallelExecuteBlocks(
int run_block = idx; // thread local
try {
executor->RunPreparedContext(prepared[run_block].get(), scope);
} catch (std::exception &e) {
} catch (const std::exception &e) {
LOG(ERROR) << "run sub program error " << e.what();
}
}));
......@@ -75,8 +75,11 @@ ListenAndServOp::ListenAndServOp(const std::string &type,
const framework::AttributeMap &attrs)
: OperatorBase(type, inputs, outputs, attrs) {}
ListenAndServOp::~ListenAndServOp() { Stop(); }
void ListenAndServOp::Stop() {
rpc_service_->Push(LISTEN_TERMINATE_MESSAGE);
rpc_service_->ShutDown();
server_thread_->join();
auto file_path = string::Sprintf("/tmp/paddle.%d.port", ::getpid());
remove(file_path.c_str());
......@@ -122,7 +125,7 @@ void ListenAndServOp::RunSyncLoop(framework::Executor *executor,
// Record received sparse variables, so that
// we could reset those after execute optimize program
std::vector<framework::Variable *> sparse_vars;
while (!exit_flag) {
while (!exit_flag && !SignalHandler::IsProgramExit()) {
// Get from multiple trainers, we don't care about the order in which
// the gradients arrives, just add suffix 0~n and merge the gradient.
rpc_service_->SetCond(0);
......@@ -187,7 +190,7 @@ void ListenAndServOp::RunSyncLoop(framework::Executor *executor,
// mini-batch.
// TODO(Yancey1989): move the reset action into an operator, we couldn't
// have any hide logic in the operator.
for (auto &var : sparse_vars) {
for (framework::Variable *var : sparse_vars) {
var->GetMutable<framework::SelectedRows>()->mutable_rows()->clear();
}
......@@ -204,8 +207,12 @@ static void AsyncUpdateThread(
framework::Executor *executor,
framework::ExecutorPrepareContext *prepared) {
VLOG(3) << "update thread for " << var_name << " started";
while (!exit_flag) {
while (!exit_flag && !SignalHandler::IsProgramExit()) {
const detail::ReceivedMessage v = queue->Pop();
if (SignalHandler::IsProgramExit()) {
VLOG(3) << "update thread for " << var_name << " exit";
break;
}
auto recv_var_name = v.first;
VLOG(4) << "async update " << recv_var_name;
auto var = v.second->GetVar();
......@@ -217,7 +224,7 @@ static void AsyncUpdateThread(
try {
executor->RunPreparedContext(prepared,
v.second->GetMutableLocalScope());
} catch (std::exception &e) {
} catch (const std::exception &e) {
LOG(ERROR) << "run sub program error " << e.what();
}
});
......@@ -236,7 +243,7 @@ void ListenAndServOp::RunAsyncLoop(framework::Executor *executor,
auto grad_to_block_id_str =
Attr<std::vector<std::string>>("grad_to_block_id");
for (auto &grad_and_id : grad_to_block_id_str) {
for (const auto &grad_and_id : grad_to_block_id_str) {
std::vector<std::string> pieces;
split(grad_and_id, ':', &pieces);
VLOG(3) << "after split, grad = " << pieces[0] << ", id=" << pieces[1];
......@@ -244,7 +251,11 @@ void ListenAndServOp::RunAsyncLoop(framework::Executor *executor,
PADDLE_ENFORCE_EQ(grad_to_block_id.count(pieces[0]), 0);
int block_id = std::stoi(pieces[1]);
grad_to_block_id[pieces[0]] = block_id;
grad_to_queue[pieces[0]] = std::make_shared<detail::ReceivedQueue>();
std::shared_ptr<detail::ReceivedQueue> queue =
std::make_shared<detail::ReceivedQueue>();
grad_to_queue[pieces[0]] = queue;
// record blocking queue in SignalHandler
SignalHandler::RegisterBlockingQueue(queue);
id_to_grad[block_id] = pieces[0];
}
size_t num_blocks = program->Size();
......@@ -276,9 +287,8 @@ void ListenAndServOp::RunAsyncLoop(framework::Executor *executor,
executor, grad_to_prepared_ctx[grad_name].get());
}));
}
VLOG(3) << "RunAsyncLoop into while";
while (!exit_flag) {
while (!exit_flag && !SignalHandler::IsProgramExit()) {
const detail::ReceivedMessage v = rpc_service_->Get();
auto recv_var_name = v.first;
if (recv_var_name == LISTEN_TERMINATE_MESSAGE) {
......@@ -333,6 +343,10 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
VLOG(3) << "wait server thread to become ready...";
rpc_service_->WaitServerReady();
// register SIGINT(from ctrl+C) and SIGTERM(from kill) signal handlers
signal(SIGINT, SignalHandler::StopAndExit);
signal(SIGTERM, SignalHandler::StopAndExit);
// Write to a file of server selected port for python use.
std::string file_path = string::Sprintf("/tmp/paddle.%d.selected_port",
static_cast<int>(::getpid()));
......@@ -348,12 +362,9 @@ class ListenAndServOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() {
AddInput("X", "(Tensor) Variables that server recv.").AsDuplicable();
AddComment(R"DOC(
ListenAndServ operator
This operator will start a RPC server which can receive variables
from send_op and send back variables to recv_op.
)DOC");
AddComment(R"DOC(" + "ListenAndServ operator" + "\n" + "This operator" +
" will start a RPC server which can receive variables from send_op and send" +
"back variables to recv_op.)DOC");
AddAttr<std::string>("endpoint",
"(string, default 127.0.0.1:6164)"
"IP address to listen on.")
......@@ -374,6 +385,29 @@ from send_op and send back variables to recv_op.
}
};
bool SignalHandler::program_exit_flag_ = false;
SignalHandler::BlockingQueueSet SignalHandler::blocking_queue_set_{};
void SignalHandler::StopAndExit(int signal_num) {
VLOG(3) << "Catch interrupt signal: " << signal_num << ", program will exit";
program_exit_flag_ = true;
// awake all blocking queues
for (BlockingQueueSet::iterator iter = blocking_queue_set_.begin();
iter != blocking_queue_set_.end(); iter++) {
iter->get()->Push(
std::make_pair(std::string(LISTEN_TERMINATE_MESSAGE), nullptr));
}
exit(EXIT_SUCCESS);
}
void SignalHandler::RegisterBlockingQueue(BlockingQueue &queue) {
blocking_queue_set_.insert(queue);
}
} // namespace operators
} // namespace paddle
......
......@@ -16,7 +16,7 @@ limitations under the License. */
#include <stdint.h>
#include <atomic>
#include <ostream>
#include <set>
#include <string>
#include "paddle/fluid/framework/executor.h"
......@@ -40,6 +40,8 @@ class ListenAndServOp : public framework::OperatorBase {
const framework::VariableNameMap& outputs,
const framework::AttributeMap& attrs);
virtual ~ListenAndServOp();
void RunSyncLoop(framework::Executor* executor,
framework::ProgramDesc* program,
framework::Scope* recv_scope,
......@@ -68,5 +70,25 @@ class ListenAndServOp : public framework::OperatorBase {
static std::atomic_int selected_port_;
};
class SignalHandler {
public:
typedef std::shared_ptr<detail::ReceivedQueue> BlockingQueue;
typedef std::unordered_set<BlockingQueue> BlockingQueueSet;
public:
static void StopAndExit(int signal_num);
static void RegisterBlockingQueue(BlockingQueue&);
static inline bool IsProgramExit() { return program_exit_flag_; }
private:
static bool program_exit_flag_;
static BlockingQueueSet blocking_queue_set_;
DISABLE_COPY_AND_ASSIGN(SignalHandler);
};
} // namespace operators
} // namespace paddle
......@@ -48,3 +48,5 @@ foreach(TEST_OP ${TEST_OPS})
endforeach(TEST_OP)
py_test_modules(test_warpctc_op MODULES test_warpctc_op ENVS FLAGS_warpctc_dir=${WARPCTC_LIB_DIR} SERIAL)
py_test_modules(test_dist_train MODULES test_dist_train SERIAL)
# tests that need to be done in fixed timeout
set_tests_properties(test_listen_and_serv_op PROPERTIES TIMEOUT 20)
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import paddle.fluid as fluid
import os
import signal
import subprocess
import time
import unittest
from multiprocessing import Process
from op_test import OpTest
def run_pserver(use_cuda, sync_mode, ip, port, trainer_count, trainer_id):
x = fluid.layers.data(name='x', shape=[1], dtype='float32')
y_predict = fluid.layers.fc(input=x, size=1, act=None)
y = fluid.layers.data(name='y', shape=[1], dtype='float32')
# loss function
cost = fluid.layers.square_error_cost(input=y_predict, label=y)
avg_cost = fluid.layers.mean(cost)
# optimizer
sgd_optimizer = fluid.optimizer.SGD(learning_rate=0.001)
sgd_optimizer.minimize(avg_cost)
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
exe = fluid.Executor(place)
port = os.getenv("PADDLE_INIT_PORT", port)
pserver_ips = os.getenv("PADDLE_INIT_PSERVERS", ip) # ip,ip...
eplist = []
for ip in pserver_ips.split(","):
eplist.append(':'.join([ip, port]))
pserver_endpoints = ",".join(eplist) # ip:port,ip:port...
trainers = int(os.getenv("TRAINERS", trainer_count))
current_endpoint = os.getenv("POD_IP", ip) + ":" + port
trainer_id = int(os.getenv("PADDLE_INIT_TRAINER_ID", trainer_id))
t = fluid.DistributeTranspiler()
t.transpile(
trainer_id,
pservers=pserver_endpoints,
trainers=trainers,
sync_mode=sync_mode)
pserver_prog = t.get_pserver_program(current_endpoint)
pserver_startup = t.get_startup_program(current_endpoint, pserver_prog)
exe.run(pserver_startup)
exe.run(pserver_prog)
class TestListenAndServOp(OpTest):
def setUp(self):
self.sleep_time = 5
self.ip = "127.0.0.1"
self.port = "6173"
self.trainer_count = 1
self.trainer_id = 1
def _raise_signal(self, parent_pid, raised_signal):
time.sleep(self.sleep_time)
ps_command = subprocess.Popen(
"ps -o pid --ppid %d --noheaders" % parent_pid,
shell=True,
stdout=subprocess.PIPE)
ps_output = ps_command.stdout.read()
retcode = ps_command.wait()
assert retcode == 0, "ps command returned %d" % retcode
for pid_str in ps_output.split("\n")[:-1]:
try:
os.kill(int(pid_str), raised_signal)
except Exception:
continue
def _start_pserver(self, use_cuda, sync_mode):
p = Process(
target=run_pserver,
args=(use_cuda, sync_mode, self.ip, self.port, self.trainer_count,
self.trainer_id))
p.start()
def test_handle_signal_in_serv_op(self):
# run pserver on CPU in sync mode
self._start_pserver(False, True)
# raise SIGINT to pserver
self._raise_signal(os.getpid(), signal.SIGINT)
# run pserver on CPU in async mode
self._start_pserver(False, False)
# raise SIGTERM to pserver
self._raise_signal(os.getpid(), signal.SIGTERM)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册