From e05f4ff26700dd34aa6d3c6da7061c62c5fa39c9 Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Sun, 6 Nov 2016 23:04:02 -0600 Subject: [PATCH] Fix SRL hang when exit. (#291) * Fix SRL hang when exit. * Error occurred when enable Async Load in TestDataProvider. * It because DataProvider is calling getNextBatchInternal in one thread, and destructing DataProvider in other thread. * Add wait routine in DataProvider destructing. * Also fix another bug, when destructing TestDataProvider and do not read any test data. Fix #286 * Follow comments, Use mutex is cool! --- demo/semantic_role_labeling/.gitignore | 10 ++++++++++ paddle/gserver/dataproviders/DataProvider.cpp | 3 ++- .../gserver/dataproviders/PyDataProvider2.cpp | 19 +++++++++++++++++-- paddle/gserver/tests/test_PyDataProvider2.cpp | 17 +++++++++++++++++ 4 files changed, 46 insertions(+), 3 deletions(-) create mode 100644 demo/semantic_role_labeling/.gitignore diff --git a/demo/semantic_role_labeling/.gitignore b/demo/semantic_role_labeling/.gitignore new file mode 100644 index 0000000000..cd90ca7bbe --- /dev/null +++ b/demo/semantic_role_labeling/.gitignore @@ -0,0 +1,10 @@ +*.pyc +train.log +data/feature +data/conll05st-release/ +data/src.dict +data/test.wsj.props +data/test.wsj.seq_pair +data/test.wsj.words +data/tgt.dict +output diff --git a/paddle/gserver/dataproviders/DataProvider.cpp b/paddle/gserver/dataproviders/DataProvider.cpp index 8cefbb30ad..2cfb5a3a18 100644 --- a/paddle/gserver/dataproviders/DataProvider.cpp +++ b/paddle/gserver/dataproviders/DataProvider.cpp @@ -131,9 +131,10 @@ void DoubleBuffer::asyncLoadBatch() { taskReadySem_.wait(); if (stopping_) break; - while (batchSize_ == 0) { + while (batchSize_ == 0 && !stopping_) { usleep(5); } + if (stopping_) break; do { DataBatch newBatch; diff --git a/paddle/gserver/dataproviders/PyDataProvider2.cpp b/paddle/gserver/dataproviders/PyDataProvider2.cpp index ca8b07af49..90391a7c30 100644 --- a/paddle/gserver/dataproviders/PyDataProvider2.cpp +++ b/paddle/gserver/dataproviders/PyDataProvider2.cpp @@ -433,26 +433,34 @@ private: inline void resetImpl(bool startNewThread) { DBG << "Reseting " << startNewThread; + exit_.store(true); if (loadThread_) { // is loading. - exit_.store(true); loadThread_->join(); loadThread_.reset(); } { PyGuard g; callingContexts_.clear(); + this->pullCV_.notify_one(); + } + + std::lock_guard guard(mutexForReset_); + { + PyGuard g; dataPool_.clear(); } poolActualSize_ = 0; - exit_ = false; + if (startNewThread && cache_->reset()) { DBG << "Start new thread."; loadThread_.reset(new std::thread([this] { + exit_ = false; loadThread(); })); callingContextCreated_.wait(); } DBG << "Reset done"; + exit_ = false; } private: @@ -465,6 +473,8 @@ private: std::condition_variable pullCV_; std::mutex mtx_; + std::mutex mutexForReset_; + ThreadBarrier callingContextCreated_; std::unique_ptr cache_; @@ -529,6 +539,7 @@ public: * Loading a batch of data. */ int64_t getNextBatchInternal(int64_t size_, DataBatch *batch) { + std::lock_guard guard(mutexForReset_); REGISTER_TIMER("PyDP2.getNextBatchInternal") CHECK_GE(size_, 0); size_t size = (size_t) size_; @@ -554,6 +565,10 @@ public: } else { // loading from cache. poolPtr = this->cache_->load(); } + if (exit_) { + // PyDataProvider is destructing. + return 0; + } CHECK(poolPtr != nullptr); std::deque& pool = *poolPtr; diff --git a/paddle/gserver/tests/test_PyDataProvider2.cpp b/paddle/gserver/tests/test_PyDataProvider2.cpp index 6bf1e32925..b9867a728d 100644 --- a/paddle/gserver/tests/test_PyDataProvider2.cpp +++ b/paddle/gserver/tests/test_PyDataProvider2.cpp @@ -353,6 +353,23 @@ TEST(PyDataProvider2, test_check) { } } +TEST(PyDataProvider2, multiThread) { + paddle::DataConfig config; + config.set_type("py2"); + config.set_files(FLAGS_train_list.c_str()); + config.set_load_data_module("test_PyDataProvider2"); + config.set_load_data_object("test_dense_no_seq"); + config.set_async_load_data(true); + + std::unique_ptr provider( + paddle::DataProvider::create(config, false)); + provider->reset(); + paddle::DataBatch batch; + provider->getNextBatch(100, &batch); + provider->reset(); + provider.reset(); +} + int main(int argc, char** argv) { testing::InitGoogleTest(&argc, argv); paddle::initMain(argc, argv); -- GitLab