提交 164f2382 编写于 作者: Y Yu Yang

Polish code

上级 f9974a4a
...@@ -18,45 +18,9 @@ namespace paddle { ...@@ -18,45 +18,9 @@ namespace paddle {
namespace framework { namespace framework {
ReaderBase::~ReaderBase() {} ReaderBase::~ReaderBase() {}
std::vector<std::unique_ptr<ReaderBase>> ReaderBase::SplitReader( FileReader::FileReader(const std::vector<DDim> &dims) : dims_(dims) {}
const platform::PlaceList &places) {
std::vector<std::unique_ptr<ReaderBase>> readers;
auto mutex = std::make_shared<std::mutex>(); void FileReader::ReadNext(std::vector<LoDTensor> *out) {
for (size_t i = 0; i < places.size(); ++i) {
readers.emplace_back(new ThreadSafeReader(this, mutex));
}
return readers;
}
void ThreadSafeReader::ReadNext(std::vector<LoDTensor> *out) {
std::lock_guard<std::mutex> guard(*mutex_);
reader_->ReadNext(out);
}
void ThreadSafeReader::ReInit() {
std::lock_guard<std::mutex> guard(*mutex_);
reader_->ReInit();
}
bool ThreadSafeReader::HasNext() const {
std::lock_guard<std::mutex> guard(*mutex_);
return reader_->HasNext();
}
std::vector<std::unique_ptr<ReaderBase>> ThreadSafeReader::SplitReader(
const platform::PlaceList &places) {
std::vector<std::unique_ptr<ReaderBase>> readers;
for (size_t i = 0; i < places.size(); ++i) {
readers.emplace_back(new ThreadSafeReader(reader_, mutex_));
}
return readers;
}
FileReaderBase::FileReaderBase(const std::vector<DDim> &dims) : dims_(dims) {}
void FileReaderBase::ReadNext(std::vector<LoDTensor> *out) {
ReadNextImpl(out); ReadNextImpl(out);
PADDLE_ENFORCE_EQ(out->size(), dims_.size()); PADDLE_ENFORCE_EQ(out->size(), dims_.size());
for (size_t i = 0; i < dims_.size(); ++i) { for (size_t i = 0; i < dims_.size(); ++i) {
......
...@@ -33,9 +33,6 @@ class ReaderBase { ...@@ -33,9 +33,6 @@ class ReaderBase {
virtual bool HasNext() const = 0; virtual bool HasNext() const = 0;
virtual std::vector<std::unique_ptr<ReaderBase>> SplitReader(
const platform::PlaceList& places);
virtual ~ReaderBase(); virtual ~ReaderBase();
}; };
...@@ -53,27 +50,9 @@ class DecoratedReader : public ReaderBase { ...@@ -53,27 +50,9 @@ class DecoratedReader : public ReaderBase {
ReaderBase* reader_; ReaderBase* reader_;
}; };
class ThreadSafeReader : public DecoratedReader { class FileReader : public ReaderBase {
public:
ThreadSafeReader(ReaderBase* reader, const std::shared_ptr<std::mutex>& mutex)
: DecoratedReader(reader), mutex_(mutex) {}
void ReadNext(std::vector<LoDTensor>* out) override;
void ReInit() override;
bool HasNext() const override;
std::vector<std::unique_ptr<ReaderBase>> SplitReader(
const platform::PlaceList& places) override;
private:
std::shared_ptr<std::mutex> mutex_;
};
class FileReaderBase : public ReaderBase {
public: public:
explicit FileReaderBase(const std::vector<DDim>& dims); explicit FileReader(const std::vector<DDim>& dims);
void ReadNext(std::vector<LoDTensor>* out) override; void ReadNext(std::vector<LoDTensor>* out) override;
......
...@@ -39,7 +39,6 @@ class DoubleBufferReader : public framework::DecoratedReader { ...@@ -39,7 +39,6 @@ class DoubleBufferReader : public framework::DecoratedReader {
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
ctxs_.emplace_back(new platform::CUDADeviceContext( ctxs_.emplace_back(new platform::CUDADeviceContext(
boost::get<platform::CUDAPlace>(place_))); boost::get<platform::CUDAPlace>(place_)));
#else
#endif #endif
} }
} }
......
...@@ -18,11 +18,11 @@ ...@@ -18,11 +18,11 @@
namespace paddle { namespace paddle {
namespace operators { namespace operators {
namespace reader { namespace reader {
class RecordIOFileReader : public framework::FileReaderBase { class RecordIOFileReader : public framework::FileReader {
public: public:
explicit RecordIOFileReader(const std::string& filename, explicit RecordIOFileReader(const std::string& filename,
const std::vector<framework::DDim>& dims) const std::vector<framework::DDim>& dims)
: FileReaderBase(dims), : FileReader(dims),
scanner_(filename), scanner_(filename),
dev_ctx_(*platform::DeviceContextPool::Instance().Get( dev_ctx_(*platform::DeviceContextPool::Instance().Get(
platform::CPUPlace())) {} platform::CPUPlace())) {}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册