提交 d55919c6 编写于 作者: F fengjiayi

Impl ResetAll and fix errors

上级 6d6f49cd
...@@ -26,7 +26,7 @@ void ReaderBase::ReadNext(std::vector<LoDTensor> *out) { ...@@ -26,7 +26,7 @@ void ReaderBase::ReadNext(std::vector<LoDTensor> *out) {
void ReaderBase::InsertDecoratedReader( void ReaderBase::InsertDecoratedReader(
const std::shared_ptr<ReaderBase> &decorated_reader) { const std::shared_ptr<ReaderBase> &decorated_reader) {
std::lock_guard<std::mutex> guard(mu_)); std::lock_guard<std::mutex> guard(mu_);
decorated_readers_.emplace_back(decorated_reader); decorated_readers_.emplace_back(decorated_reader);
} }
......
...@@ -104,7 +104,13 @@ class ReaderHolder { ...@@ -104,7 +104,13 @@ class ReaderHolder {
} }
void ResetAll() { void ResetAll() {
// TODO(fengjiayi): The interface of reseting all. auto end_readers = reader_->GetEndPoints();
for (auto* reader : end_readers) {
reader->Shutdown();
}
for (auto* reader : end_readers) {
reader->Start();
}
} }
void Shutdown() { void Shutdown() {
......
...@@ -21,13 +21,12 @@ class StubDecoratedReader : public paddle::framework::DecoratedReader { ...@@ -21,13 +21,12 @@ class StubDecoratedReader : public paddle::framework::DecoratedReader {
explicit StubDecoratedReader(const std::shared_ptr<ReaderBase> &reader) explicit StubDecoratedReader(const std::shared_ptr<ReaderBase> &reader)
: DecoratedReader(reader) {} : DecoratedReader(reader) {}
void ReadNext(std::vector<paddle::framework::LoDTensor> *out) override {} void ReadNextImpl(std::vector<paddle::framework::LoDTensor> *out) override {}
}; };
class StubRootReader : public paddle::framework::ReaderBase { class StubRootReader : public paddle::framework::ReaderBase {
public: public:
void ReadNext(std::vector<paddle::framework::LoDTensor> *out) override {} void ReadNextImpl(std::vector<paddle::framework::LoDTensor> *out) override {}
void ReInit() override {}
}; };
TEST(READER, decorate_chain) { TEST(READER, decorate_chain) {
......
...@@ -375,9 +375,6 @@ def open_recordio_file(filename, ...@@ -375,9 +375,6 @@ def open_recordio_file(filename,
if pass_num > 1: if pass_num > 1:
main_prog_var = multi_pass(reader=main_prog_var, pass_num=pass_num) main_prog_var = multi_pass(reader=main_prog_var, pass_num=pass_num)
if for_parallel:
main_prog_var = parallel(reader=main_prog_var)
return monkey_patch_reader_methods(main_prog_var) return monkey_patch_reader_methods(main_prog_var)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册