提交 de9a411f 编写于 作者: F fengjiayi

adjust readers' inheritance relationships

1. Make PyReader and RandomDataGenerator inherited from FileReader.
2. Remove the memeber variable 'dims_' and realated checks in FileReader.
上级 c1ac9f61
...@@ -18,24 +18,6 @@ namespace paddle { ...@@ -18,24 +18,6 @@ namespace paddle {
namespace framework { namespace framework {
ReaderBase::~ReaderBase() {} ReaderBase::~ReaderBase() {}
FileReader::FileReader(const std::vector<DDim> &dims) : dims_(dims) {} void FileReader::ReadNext(std::vector<LoDTensor> *out) { ReadNextImpl(out); }
void FileReader::ReadNext(std::vector<LoDTensor> *out) {
ReadNextImpl(out);
if (out->empty()) {
return;
}
PADDLE_ENFORCE_EQ(out->size(), dims_.size());
for (size_t i = 0; i < dims_.size(); ++i) {
auto &actual = (*out)[i].dims();
auto &expect = dims_[i];
PADDLE_ENFORCE_EQ(actual.size(), expect.size());
for (int j = 0; j < actual.size(); ++j) {
// PADDLE_ENFORCE(actual[i] == expect[i] || expect[i] == -1);
}
}
}
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -48,15 +48,12 @@ class DecoratedReader : public ReaderBase { ...@@ -48,15 +48,12 @@ class DecoratedReader : public ReaderBase {
class FileReader : public ReaderBase { class FileReader : public ReaderBase {
public: public:
explicit FileReader(const std::vector<DDim>& dims); FileReader() : ReaderBase() {}
void ReadNext(std::vector<LoDTensor>* out) override; void ReadNext(std::vector<LoDTensor>* out) override;
protected: protected:
virtual void ReadNextImpl(std::vector<LoDTensor>* out) = 0; virtual void ReadNextImpl(std::vector<LoDTensor>* out) = 0;
private:
std::vector<DDim> dims_;
}; };
// The ReaderHolder is used as reader' unified wrapper, // The ReaderHolder is used as reader' unified wrapper,
......
...@@ -151,8 +151,8 @@ void DoubleBufferReader::ReadNext(std::vector<framework::LoDTensor>* out) { ...@@ -151,8 +151,8 @@ void DoubleBufferReader::ReadNext(std::vector<framework::LoDTensor>* out) {
} }
void DoubleBufferReader::ReInit() { void DoubleBufferReader::ReInit() {
reader_->ReInit();
EndPrefetcher(); EndPrefetcher();
reader_->ReInit();
StartPrefetcher(); StartPrefetcher();
} }
......
...@@ -19,14 +19,15 @@ namespace paddle { ...@@ -19,14 +19,15 @@ namespace paddle {
namespace operators { namespace operators {
namespace reader { namespace reader {
class PyReader : public framework::ReaderBase { class PyReader : public framework::FileReader {
public: public:
explicit PyReader(const std::shared_ptr<LoDTensorBlockingQueue>& queue) { explicit PyReader(const std::shared_ptr<LoDTensorBlockingQueue>& queue)
: framework::FileReader() {
PADDLE_ENFORCE(queue != nullptr, "LoDTensorBlockingQueue must not be null"); PADDLE_ENFORCE(queue != nullptr, "LoDTensorBlockingQueue must not be null");
queue_ = queue; queue_ = queue;
} }
void ReadNext(std::vector<framework::LoDTensor>* out) override { void ReadNextImpl(std::vector<framework::LoDTensor>* out) override {
bool success; bool success;
*out = queue_->Pop(&success); *out = queue_->Pop(&success);
if (!success) out->clear(); if (!success) out->clear();
......
...@@ -19,11 +19,11 @@ namespace operators { ...@@ -19,11 +19,11 @@ namespace operators {
namespace reader { namespace reader {
template <typename T> template <typename T>
class RandomDataGenerator : public framework::ReaderBase { class RandomDataGenerator : public framework::FileReader {
public: public:
RandomDataGenerator(const std::vector<framework::DDim>& shapes, float low, RandomDataGenerator(const std::vector<framework::DDim>& shapes, float low,
float high) float high)
: framework::ReaderBase(), low_(low), high_(high), shapes_(shapes) { : framework::FileReader(), low_(low), high_(high), shapes_(shapes) {
PADDLE_ENFORCE_LE(low, high, PADDLE_ENFORCE_LE(low, high,
"'low' shouldn't be greater than 'high'.(%f vs %f)", low, "'low' shouldn't be greater than 'high'.(%f vs %f)", low,
high); high);
...@@ -32,7 +32,7 @@ class RandomDataGenerator : public framework::ReaderBase { ...@@ -32,7 +32,7 @@ class RandomDataGenerator : public framework::ReaderBase {
dist_ = std::uniform_real_distribution<float>(low_, high_); dist_ = std::uniform_real_distribution<float>(low_, high_);
} }
void ReadNext(std::vector<framework::LoDTensor>* out) override { void ReadNextImpl(std::vector<framework::LoDTensor>* out) override {
out->clear(); out->clear();
out->reserve(shapes_.size()); out->reserve(shapes_.size());
for (const framework::DDim& shape : shapes_) { for (const framework::DDim& shape : shapes_) {
......
...@@ -21,9 +21,8 @@ namespace reader { ...@@ -21,9 +21,8 @@ namespace reader {
template <bool ThreadSafe> template <bool ThreadSafe>
class RecordIOFileReader : public framework::FileReader { class RecordIOFileReader : public framework::FileReader {
public: public:
explicit RecordIOFileReader(const std::string& filename, explicit RecordIOFileReader(const std::string& filename)
const std::vector<framework::DDim>& dims) : FileReader(),
: FileReader(dims),
scanner_(filename), scanner_(filename),
dev_ctx_(*platform::DeviceContextPool::Instance().Get( dev_ctx_(*platform::DeviceContextPool::Instance().Get(
platform::CPUPlace())) { platform::CPUPlace())) {
...@@ -58,20 +57,10 @@ class CreateRecordIOReaderOp : public framework::OperatorBase { ...@@ -58,20 +57,10 @@ class CreateRecordIOReaderOp : public framework::OperatorBase {
private: private:
void RunImpl(const framework::Scope& scope, void RunImpl(const framework::Scope& scope,
const platform::Place& dev_place) const override { const platform::Place& dev_place) const override {
const auto& shape_concat = Attr<std::vector<int>>("shape_concat");
const auto& ranks = Attr<std::vector<int>>("ranks");
PADDLE_ENFORCE(!shape_concat.empty() && !ranks.empty());
PADDLE_ENFORCE_EQ(std::accumulate(ranks.begin(), ranks.end(), 0),
static_cast<int>(shape_concat.size()),
"The accumulate of all ranks should be equal to the "
"shape concat's length.");
std::string filename = Attr<std::string>("filename"); std::string filename = Attr<std::string>("filename");
auto* out = scope.FindVar(Output("Out")) auto* out = scope.FindVar(Output("Out"))
->template GetMutable<framework::ReaderHolder>(); ->template GetMutable<framework::ReaderHolder>();
out->Reset(new RecordIOFileReader<true>(filename));
out->Reset(new RecordIOFileReader<true>(
filename, RestoreShapes(shape_concat, ranks)));
} }
}; };
......
...@@ -23,13 +23,12 @@ namespace reader { ...@@ -23,13 +23,12 @@ namespace reader {
class MultiFileReader : public framework::ReaderBase { class MultiFileReader : public framework::ReaderBase {
public: public:
MultiFileReader(const std::vector<std::string>& file_names, MultiFileReader(const std::vector<std::string>& file_names, size_t thread_num,
const std::vector<framework::DDim>& dims, size_t thread_num,
size_t buffer_size) size_t buffer_size)
: buffer_size_(buffer_size) { : buffer_size_(buffer_size) {
readers_.reserve(file_names.size()); readers_.reserve(file_names.size());
for (const std::string& f_name : file_names) { for (const std::string& f_name : file_names) {
readers_.emplace_back(CreateReaderByFileName(f_name, dims)); readers_.emplace_back(CreateReaderByFileName(f_name));
} }
prefetchers_.resize(thread_num); prefetchers_.resize(thread_num);
StartNewScheduler(); StartNewScheduler();
...@@ -180,9 +179,7 @@ class OpenFilesOp : public framework::OperatorBase { ...@@ -180,9 +179,7 @@ class OpenFilesOp : public framework::OperatorBase {
auto* out = scope.FindVar(Output("Out")) auto* out = scope.FindVar(Output("Out"))
->template GetMutable<framework::ReaderHolder>(); ->template GetMutable<framework::ReaderHolder>();
out->Reset(new MultiFileReader(file_names, out->Reset(new MultiFileReader(file_names, thread_num, buffer_size));
RestoreShapes(shape_concat, ranks),
thread_num, buffer_size));
} }
}; };
......
...@@ -39,7 +39,7 @@ std::unordered_map<std::string, FileReaderCreator>& FileReaderRegistry() { ...@@ -39,7 +39,7 @@ std::unordered_map<std::string, FileReaderCreator>& FileReaderRegistry() {
} }
std::unique_ptr<framework::ReaderBase> CreateReaderByFileName( std::unique_ptr<framework::ReaderBase> CreateReaderByFileName(
const std::string& file_name, const std::vector<framework::DDim>& dims) { const std::string& file_name) {
size_t separator_pos = file_name.find_last_of(kFileFormatSeparator); size_t separator_pos = file_name.find_last_of(kFileFormatSeparator);
PADDLE_ENFORCE_NE(separator_pos, std::string::npos, PADDLE_ENFORCE_NE(separator_pos, std::string::npos,
"File name illegal! A legal file name should be like: " "File name illegal! A legal file name should be like: "
...@@ -49,7 +49,7 @@ std::unique_ptr<framework::ReaderBase> CreateReaderByFileName( ...@@ -49,7 +49,7 @@ std::unique_ptr<framework::ReaderBase> CreateReaderByFileName(
auto itor = FileReaderRegistry().find(filetype); auto itor = FileReaderRegistry().find(filetype);
PADDLE_ENFORCE(itor != FileReaderRegistry().end(), PADDLE_ENFORCE(itor != FileReaderRegistry().end(),
"No file reader registered for '%s' format.", filetype); "No file reader registered for '%s' format.", filetype);
framework::ReaderBase* reader = (itor->second)(file_name, dims); framework::ReaderBase* reader = (itor->second)(file_name);
return std::unique_ptr<framework::ReaderBase>(reader); return std::unique_ptr<framework::ReaderBase>(reader);
} }
......
...@@ -25,22 +25,21 @@ namespace reader { ...@@ -25,22 +25,21 @@ namespace reader {
static constexpr char kFileFormatSeparator[] = "."; static constexpr char kFileFormatSeparator[] = ".";
using FileReaderCreator = std::function<framework::ReaderBase*( using FileReaderCreator =
const std::string&, const std::vector<framework::DDim>&)>; std::function<framework::ReaderBase*(const std::string&)>;
std::unordered_map<std::string, FileReaderCreator>& FileReaderRegistry(); std::unordered_map<std::string, FileReaderCreator>& FileReaderRegistry();
template <typename Reader> template <typename Reader>
int RegisterFileReader(const std::string& filetype) { int RegisterFileReader(const std::string& filetype) {
FileReaderRegistry()[filetype] = []( FileReaderRegistry()[filetype] = [](const std::string& fn) {
const std::string& fn, const std::vector<framework::DDim>& dims) { return new Reader(fn);
return new Reader(fn, dims);
}; };
return 0; return 0;
} }
std::unique_ptr<framework::ReaderBase> CreateReaderByFileName( std::unique_ptr<framework::ReaderBase> CreateReaderByFileName(
const std::string& file_name, const std::vector<framework::DDim>& dims); const std::string& file_name);
extern std::vector<framework::DDim> RestoreShapes( extern std::vector<framework::DDim> RestoreShapes(
const std::vector<int>& shape_concat, const std::vector<int>& ranks); const std::vector<int>& shape_concat, const std::vector<int>& ranks);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册