提交 e05f4ff2 编写于 作者: Y Yu Yang 提交者: hedaoyuan

Fix SRL hang when exit. (#291)

* Fix SRL hang when exit.

* Error occurred when enable Async Load in TestDataProvider.
  * It because DataProvider is calling getNextBatchInternal in one thread, and destructing DataProvider in other thread.
  * Add wait routine in DataProvider destructing.
* Also fix another bug, when destructing TestDataProvider and do not read any test data.

Fix #286

* Follow comments, Use mutex is cool!
上级 c64cd6fe
*.pyc
train.log
data/feature
data/conll05st-release/
data/src.dict
data/test.wsj.props
data/test.wsj.seq_pair
data/test.wsj.words
data/tgt.dict
output
......@@ -131,9 +131,10 @@ void DoubleBuffer::asyncLoadBatch() {
taskReadySem_.wait();
if (stopping_) break;
while (batchSize_ == 0) {
while (batchSize_ == 0 && !stopping_) {
usleep(5);
}
if (stopping_) break;
do {
DataBatch newBatch;
......
......@@ -433,26 +433,34 @@ private:
inline void resetImpl(bool startNewThread) {
DBG << "Reseting " << startNewThread;
exit_.store(true);
if (loadThread_) { // is loading.
exit_.store(true);
loadThread_->join();
loadThread_.reset();
}
{
PyGuard g;
callingContexts_.clear();
this->pullCV_.notify_one();
}
std::lock_guard<std::mutex> guard(mutexForReset_);
{
PyGuard g;
dataPool_.clear();
}
poolActualSize_ = 0;
exit_ = false;
if (startNewThread && cache_->reset()) {
DBG << "Start new thread.";
loadThread_.reset(new std::thread([this] {
exit_ = false;
loadThread();
}));
callingContextCreated_.wait();
}
DBG << "Reset done";
exit_ = false;
}
private:
......@@ -465,6 +473,8 @@ private:
std::condition_variable pullCV_;
std::mutex mtx_;
std::mutex mutexForReset_;
ThreadBarrier callingContextCreated_;
std::unique_ptr<IPyDataProviderCache> cache_;
......@@ -529,6 +539,7 @@ public:
* Loading a batch of data.
*/
int64_t getNextBatchInternal(int64_t size_, DataBatch *batch) {
std::lock_guard<std::mutex> guard(mutexForReset_);
REGISTER_TIMER("PyDP2.getNextBatchInternal")
CHECK_GE(size_, 0);
size_t size = (size_t) size_;
......@@ -554,6 +565,10 @@ public:
} else { // loading from cache.
poolPtr = this->cache_->load();
}
if (exit_) {
// PyDataProvider is destructing.
return 0;
}
CHECK(poolPtr != nullptr);
std::deque<PyObjectPtr>& pool = *poolPtr;
......
......@@ -353,6 +353,23 @@ TEST(PyDataProvider2, test_check) {
}
}
TEST(PyDataProvider2, multiThread) {
paddle::DataConfig config;
config.set_type("py2");
config.set_files(FLAGS_train_list.c_str());
config.set_load_data_module("test_PyDataProvider2");
config.set_load_data_object("test_dense_no_seq");
config.set_async_load_data(true);
std::unique_ptr<paddle::DataProvider> provider(
paddle::DataProvider::create(config, false));
provider->reset();
paddle::DataBatch batch;
provider->getNextBatch(100, &batch);
provider->reset();
provider.reset();
}
int main(int argc, char** argv) {
testing::InitGoogleTest(&argc, argv);
paddle::initMain(argc, argv);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册