From 164f2382afe6ded95c95f4fb731a1d932d578026 Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Tue, 13 Mar 2018 17:56:53 +0800 Subject: [PATCH] Polish code --- paddle/fluid/framework/reader.cc | 40 +------------------ paddle/fluid/framework/reader.h | 25 +----------- .../reader/create_double_buffer_reader_op.cc | 1 - .../reader/create_recordio_file_reader_op.cc | 4 +- 4 files changed, 6 insertions(+), 64 deletions(-) diff --git a/paddle/fluid/framework/reader.cc b/paddle/fluid/framework/reader.cc index c3fb657a3..fa00c08e0 100644 --- a/paddle/fluid/framework/reader.cc +++ b/paddle/fluid/framework/reader.cc @@ -18,45 +18,9 @@ namespace paddle { namespace framework { ReaderBase::~ReaderBase() {} -std::vector> ReaderBase::SplitReader( - const platform::PlaceList &places) { - std::vector> readers; +FileReader::FileReader(const std::vector &dims) : dims_(dims) {} - auto mutex = std::make_shared(); - for (size_t i = 0; i < places.size(); ++i) { - readers.emplace_back(new ThreadSafeReader(this, mutex)); - } - - return readers; -} - -void ThreadSafeReader::ReadNext(std::vector *out) { - std::lock_guard guard(*mutex_); - reader_->ReadNext(out); -} - -void ThreadSafeReader::ReInit() { - std::lock_guard guard(*mutex_); - reader_->ReInit(); -} - -bool ThreadSafeReader::HasNext() const { - std::lock_guard guard(*mutex_); - return reader_->HasNext(); -} - -std::vector> ThreadSafeReader::SplitReader( - const platform::PlaceList &places) { - std::vector> readers; - for (size_t i = 0; i < places.size(); ++i) { - readers.emplace_back(new ThreadSafeReader(reader_, mutex_)); - } - return readers; -} - -FileReaderBase::FileReaderBase(const std::vector &dims) : dims_(dims) {} - -void FileReaderBase::ReadNext(std::vector *out) { +void FileReader::ReadNext(std::vector *out) { ReadNextImpl(out); PADDLE_ENFORCE_EQ(out->size(), dims_.size()); for (size_t i = 0; i < dims_.size(); ++i) { diff --git a/paddle/fluid/framework/reader.h b/paddle/fluid/framework/reader.h index 8989bddd1..3573b99be 100644 --- a/paddle/fluid/framework/reader.h +++ b/paddle/fluid/framework/reader.h @@ -33,9 +33,6 @@ class ReaderBase { virtual bool HasNext() const = 0; - virtual std::vector> SplitReader( - const platform::PlaceList& places); - virtual ~ReaderBase(); }; @@ -53,27 +50,9 @@ class DecoratedReader : public ReaderBase { ReaderBase* reader_; }; -class ThreadSafeReader : public DecoratedReader { - public: - ThreadSafeReader(ReaderBase* reader, const std::shared_ptr& mutex) - : DecoratedReader(reader), mutex_(mutex) {} - - void ReadNext(std::vector* out) override; - - void ReInit() override; - - bool HasNext() const override; - - std::vector> SplitReader( - const platform::PlaceList& places) override; - - private: - std::shared_ptr mutex_; -}; - -class FileReaderBase : public ReaderBase { +class FileReader : public ReaderBase { public: - explicit FileReaderBase(const std::vector& dims); + explicit FileReader(const std::vector& dims); void ReadNext(std::vector* out) override; diff --git a/paddle/fluid/operators/reader/create_double_buffer_reader_op.cc b/paddle/fluid/operators/reader/create_double_buffer_reader_op.cc index 706f6fd59..d0de09294 100644 --- a/paddle/fluid/operators/reader/create_double_buffer_reader_op.cc +++ b/paddle/fluid/operators/reader/create_double_buffer_reader_op.cc @@ -39,7 +39,6 @@ class DoubleBufferReader : public framework::DecoratedReader { #ifdef PADDLE_WITH_CUDA ctxs_.emplace_back(new platform::CUDADeviceContext( boost::get(place_))); -#else #endif } } diff --git a/paddle/fluid/operators/reader/create_recordio_file_reader_op.cc b/paddle/fluid/operators/reader/create_recordio_file_reader_op.cc index 819e09a36..c4aa29c72 100644 --- a/paddle/fluid/operators/reader/create_recordio_file_reader_op.cc +++ b/paddle/fluid/operators/reader/create_recordio_file_reader_op.cc @@ -18,11 +18,11 @@ namespace paddle { namespace operators { namespace reader { -class RecordIOFileReader : public framework::FileReaderBase { +class RecordIOFileReader : public framework::FileReader { public: explicit RecordIOFileReader(const std::string& filename, const std::vector& dims) - : FileReaderBase(dims), + : FileReader(dims), scanner_(filename), dev_ctx_(*platform::DeviceContextPool::Instance().Get( platform::CPUPlace())) {} -- GitLab