diff --git a/paddle/fluid/framework/operator.cc b/paddle/fluid/framework/operator.cc index ac6289c5abe8f40ae9ee32aa3d58cdef3ff0e836..49f8cd5f90a019543b71b684bb5972ff43fe0847 100644 --- a/paddle/fluid/framework/operator.cc +++ b/paddle/fluid/framework/operator.cc @@ -442,15 +442,7 @@ class RuntimeInferShapeContext : public InferShapeContext { } std::vector GetRepeatedDims(const std::string& name) const override { - Variable* var = scope_.FindVar(name); - if (var->IsType()) { - return var->Get().shapes(); - } else { - PADDLE_THROW( - "Only ReaderHolder support 'GetRepeatedDims', but Variable %s's " - "type_id is %s.", - name, var->Type().name()); - } + PADDLE_THROW("Only compile time support this method"); } void SetDim(const std::string& name, const DDim& dim) override { @@ -467,15 +459,7 @@ class RuntimeInferShapeContext : public InferShapeContext { void SetRepeatedDims(const std::string& name, const std::vector& dims) override { - Variable* var = scope_.FindVar(name); - if (var->IsType()) { - var->GetMutable()->set_shapes(dims); - } else { - PADDLE_THROW( - "Only ReaderHolder support 'SetRepeatedDims', but Variable %s's " - "type_id is %s.", - name, var->Type().name()); - } + PADDLE_THROW("Only compile time support this method"); } proto::VarType::Type GetVarType(const std::string& name) const override { diff --git a/paddle/fluid/framework/reader.cc b/paddle/fluid/framework/reader.cc index 91879d6d45868bb37ca44baafb8b0e8677cd6d1a..31f686151e3630f2787f2df335e2ed1d66198c78 100644 --- a/paddle/fluid/framework/reader.cc +++ b/paddle/fluid/framework/reader.cc @@ -16,14 +16,6 @@ namespace paddle { namespace framework { - -DDim ReaderBase::shape(size_t idx) const { - PADDLE_ENFORCE_LT( - idx, shapes_.size(), - "Cannot get the %d'th shape, 'shapes_' only has %d elements.", idx, - shapes_.size()); - return shapes_[idx]; -} - +ReaderBase::~ReaderBase() {} } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/reader.h b/paddle/fluid/framework/reader.h index e281c9b13fb50fba64185b6db2949fc76f5c4834..2d8d30fc66274f09291b9e9f00b82c4eae4737c9 100644 --- a/paddle/fluid/framework/reader.h +++ b/paddle/fluid/framework/reader.h @@ -22,34 +22,18 @@ namespace framework { class ReaderBase { public: - explicit ReaderBase(const std::vector& shapes) : shapes_(shapes) { - PADDLE_ENFORCE(!shapes_.empty()); - } virtual void ReadNext(std::vector* out) = 0; virtual void ReInit() = 0; - DDim shape(size_t idx) const; - std::vector shapes() const { return shapes_; } - void set_shapes(const std::vector& shapes) { shapes_ = shapes; } - virtual bool HasNext() const = 0; - virtual ~ReaderBase() {} - - protected: - std::vector shapes_; -}; - -class FileReader : public ReaderBase { - public: - explicit FileReader(const std::vector& shapes) : ReaderBase(shapes) {} + virtual ~ReaderBase(); }; class DecoratedReader : public ReaderBase { public: - explicit DecoratedReader(ReaderBase* reader) - : ReaderBase(reader->shapes()), reader_(reader) { + explicit DecoratedReader(ReaderBase* reader) : ReaderBase(), reader_(reader) { PADDLE_ENFORCE_NOT_NULL(reader_); } @@ -72,12 +56,6 @@ class ReaderHolder { void ReadNext(std::vector* out) { reader_->ReadNext(out); } void ReInit() { reader_->ReInit(); } - DDim shape(size_t idx) const { return reader_->shape(idx); } - std::vector shapes() const { return reader_->shapes(); } - void set_shapes(const std::vector& shapes) { - reader_->set_shapes(shapes); - } - bool HasNext() const { return reader_->HasNext(); } private: diff --git a/paddle/fluid/operators/reader/create_random_data_generator_op.cc b/paddle/fluid/operators/reader/create_random_data_generator_op.cc index e62f952d0e89561c3eed56112dc9d1d78801b59e..95d8674c08b63e872926ff8708d0c734da33684c 100644 --- a/paddle/fluid/operators/reader/create_random_data_generator_op.cc +++ b/paddle/fluid/operators/reader/create_random_data_generator_op.cc @@ -19,11 +19,11 @@ namespace operators { namespace reader { template -class RandomDataGenerator : public framework::FileReader { +class RandomDataGenerator : public framework::ReaderBase { public: RandomDataGenerator(const std::vector& shapes, float min, float max) - : FileReader(shapes), min_(min), max_(max) { + : framework::ReaderBase(), min_(min), max_(max), shapes_(shapes) { PADDLE_ENFORCE_LE( min, max, "'min' shouldn't be greater than 'max'.(%f vs %f)", min, max); unsigned int seed = std::random_device()(); @@ -59,6 +59,7 @@ class RandomDataGenerator : public framework::FileReader { float max_; std::minstd_rand engine_; std::uniform_real_distribution dist_; + std::vector shapes_; }; template diff --git a/paddle/fluid/operators/reader/create_recordio_file_reader_op.cc b/paddle/fluid/operators/reader/create_recordio_file_reader_op.cc index c3eb247bbe2041ae5a673c4fd3c1284c71276f91..4992eb861721648cd943c3c4a9e975b7bd2a0be9 100644 --- a/paddle/fluid/operators/reader/create_recordio_file_reader_op.cc +++ b/paddle/fluid/operators/reader/create_recordio_file_reader_op.cc @@ -18,11 +18,10 @@ namespace paddle { namespace operators { namespace reader { -class RecordIOFileReader : public framework::FileReader { +class RecordIOFileReader : public framework::ReaderBase { public: - RecordIOFileReader(const std::string& filename, - const std::vector& shapes) - : FileReader(shapes), + explicit RecordIOFileReader(const std::string& filename) + : ReaderBase(), scanner_(filename), dev_ctx_(*platform::DeviceContextPool::Instance().Get( platform::CPUPlace())) {} @@ -54,12 +53,11 @@ class CreateRecordIOReaderOp : public framework::OperatorBase { int(shape_concat.size()), "The accumulate of all ranks should be equal to the " "shape concat's length."); - std::vector shapes = RestoreShapes(shape_concat, ranks); std::string filename = Attr("filename"); auto* out = scope.FindVar(Output("Out")) ->template GetMutable(); - out->Reset(new RecordIOFileReader(filename, shapes)); + out->Reset(new RecordIOFileReader(filename)); } };