From d55919c656e19cd9600e1d009e6cdff878b5e28e Mon Sep 17 00:00:00 2001 From: fengjiayi Date: Mon, 9 Jul 2018 16:53:42 +0800 Subject: [PATCH] Impl ResetAll and fix errors --- paddle/fluid/framework/reader.cc | 2 +- paddle/fluid/framework/reader.h | 8 +++++++- paddle/fluid/framework/reader_test.cc | 5 ++--- python/paddle/fluid/layers/io.py | 3 --- 4 files changed, 10 insertions(+), 8 deletions(-) diff --git a/paddle/fluid/framework/reader.cc b/paddle/fluid/framework/reader.cc index f8877e5cb..5897d320a 100644 --- a/paddle/fluid/framework/reader.cc +++ b/paddle/fluid/framework/reader.cc @@ -26,7 +26,7 @@ void ReaderBase::ReadNext(std::vector *out) { void ReaderBase::InsertDecoratedReader( const std::shared_ptr &decorated_reader) { - std::lock_guard guard(mu_)); + std::lock_guard guard(mu_); decorated_readers_.emplace_back(decorated_reader); } diff --git a/paddle/fluid/framework/reader.h b/paddle/fluid/framework/reader.h index 93cd6243f..6c4432cb7 100644 --- a/paddle/fluid/framework/reader.h +++ b/paddle/fluid/framework/reader.h @@ -104,7 +104,13 @@ class ReaderHolder { } void ResetAll() { - // TODO(fengjiayi): The interface of reseting all. + auto end_readers = reader_->GetEndPoints(); + for (auto* reader : end_readers) { + reader->Shutdown(); + } + for (auto* reader : end_readers) { + reader->Start(); + } } void Shutdown() { diff --git a/paddle/fluid/framework/reader_test.cc b/paddle/fluid/framework/reader_test.cc index c05be8670..f0d07cb7c 100644 --- a/paddle/fluid/framework/reader_test.cc +++ b/paddle/fluid/framework/reader_test.cc @@ -21,13 +21,12 @@ class StubDecoratedReader : public paddle::framework::DecoratedReader { explicit StubDecoratedReader(const std::shared_ptr &reader) : DecoratedReader(reader) {} - void ReadNext(std::vector *out) override {} + void ReadNextImpl(std::vector *out) override {} }; class StubRootReader : public paddle::framework::ReaderBase { public: - void ReadNext(std::vector *out) override {} - void ReInit() override {} + void ReadNextImpl(std::vector *out) override {} }; TEST(READER, decorate_chain) { diff --git a/python/paddle/fluid/layers/io.py b/python/paddle/fluid/layers/io.py index 234625265..977abde21 100644 --- a/python/paddle/fluid/layers/io.py +++ b/python/paddle/fluid/layers/io.py @@ -375,9 +375,6 @@ def open_recordio_file(filename, if pass_num > 1: main_prog_var = multi_pass(reader=main_prog_var, pass_num=pass_num) - if for_parallel: - main_prog_var = parallel(reader=main_prog_var) - return monkey_patch_reader_methods(main_prog_var) -- GitLab