提交 164f2382 编写于 作者: Y Yu Yang

Polish code

上级 f9974a4a
......@@ -18,45 +18,9 @@ namespace paddle {
namespace framework {
ReaderBase::~ReaderBase() {}
std::vector<std::unique_ptr<ReaderBase>> ReaderBase::SplitReader(
const platform::PlaceList &places) {
std::vector<std::unique_ptr<ReaderBase>> readers;
FileReader::FileReader(const std::vector<DDim> &dims) : dims_(dims) {}
auto mutex = std::make_shared<std::mutex>();
for (size_t i = 0; i < places.size(); ++i) {
readers.emplace_back(new ThreadSafeReader(this, mutex));
}
return readers;
}
void ThreadSafeReader::ReadNext(std::vector<LoDTensor> *out) {
std::lock_guard<std::mutex> guard(*mutex_);
reader_->ReadNext(out);
}
void ThreadSafeReader::ReInit() {
std::lock_guard<std::mutex> guard(*mutex_);
reader_->ReInit();
}
bool ThreadSafeReader::HasNext() const {
std::lock_guard<std::mutex> guard(*mutex_);
return reader_->HasNext();
}
std::vector<std::unique_ptr<ReaderBase>> ThreadSafeReader::SplitReader(
const platform::PlaceList &places) {
std::vector<std::unique_ptr<ReaderBase>> readers;
for (size_t i = 0; i < places.size(); ++i) {
readers.emplace_back(new ThreadSafeReader(reader_, mutex_));
}
return readers;
}
FileReaderBase::FileReaderBase(const std::vector<DDim> &dims) : dims_(dims) {}
void FileReaderBase::ReadNext(std::vector<LoDTensor> *out) {
void FileReader::ReadNext(std::vector<LoDTensor> *out) {
ReadNextImpl(out);
PADDLE_ENFORCE_EQ(out->size(), dims_.size());
for (size_t i = 0; i < dims_.size(); ++i) {
......
......@@ -33,9 +33,6 @@ class ReaderBase {
virtual bool HasNext() const = 0;
virtual std::vector<std::unique_ptr<ReaderBase>> SplitReader(
const platform::PlaceList& places);
virtual ~ReaderBase();
};
......@@ -53,27 +50,9 @@ class DecoratedReader : public ReaderBase {
ReaderBase* reader_;
};
class ThreadSafeReader : public DecoratedReader {
public:
ThreadSafeReader(ReaderBase* reader, const std::shared_ptr<std::mutex>& mutex)
: DecoratedReader(reader), mutex_(mutex) {}
void ReadNext(std::vector<LoDTensor>* out) override;
void ReInit() override;
bool HasNext() const override;
std::vector<std::unique_ptr<ReaderBase>> SplitReader(
const platform::PlaceList& places) override;
private:
std::shared_ptr<std::mutex> mutex_;
};
class FileReaderBase : public ReaderBase {
class FileReader : public ReaderBase {
public:
explicit FileReaderBase(const std::vector<DDim>& dims);
explicit FileReader(const std::vector<DDim>& dims);
void ReadNext(std::vector<LoDTensor>* out) override;
......
......@@ -39,7 +39,6 @@ class DoubleBufferReader : public framework::DecoratedReader {
#ifdef PADDLE_WITH_CUDA
ctxs_.emplace_back(new platform::CUDADeviceContext(
boost::get<platform::CUDAPlace>(place_)));
#else
#endif
}
}
......
......@@ -18,11 +18,11 @@
namespace paddle {
namespace operators {
namespace reader {
class RecordIOFileReader : public framework::FileReaderBase {
class RecordIOFileReader : public framework::FileReader {
public:
explicit RecordIOFileReader(const std::string& filename,
const std::vector<framework::DDim>& dims)
: FileReaderBase(dims),
: FileReader(dims),
scanner_(filename),
dev_ctx_(*platform::DeviceContextPool::Instance().Get(
platform::CPUPlace())) {}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册