From 225efa671fd1b234e67752ad9a1cd4aecdffe58b Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Mon, 12 Mar 2018 16:10:19 +0800 Subject: [PATCH] Remove dims in base class --- paddle/fluid/framework/operator.cc | 20 ++------------ paddle/fluid/framework/reader.cc | 10 +------ paddle/fluid/framework/reader.h | 26 ++----------------- .../reader/create_random_data_generator_op.cc | 5 ++-- .../reader/create_recordio_file_reader_op.cc | 10 +++---- 5 files changed, 12 insertions(+), 59 deletions(-) diff --git a/paddle/fluid/framework/operator.cc b/paddle/fluid/framework/operator.cc index ac6289c5a..49f8cd5f9 100644 --- a/paddle/fluid/framework/operator.cc +++ b/paddle/fluid/framework/operator.cc @@ -442,15 +442,7 @@ class RuntimeInferShapeContext : public InferShapeContext { } std::vector GetRepeatedDims(const std::string& name) const override { - Variable* var = scope_.FindVar(name); - if (var->IsType()) { - return var->Get().shapes(); - } else { - PADDLE_THROW( - "Only ReaderHolder support 'GetRepeatedDims', but Variable %s's " - "type_id is %s.", - name, var->Type().name()); - } + PADDLE_THROW("Only compile time support this method"); } void SetDim(const std::string& name, const DDim& dim) override { @@ -467,15 +459,7 @@ class RuntimeInferShapeContext : public InferShapeContext { void SetRepeatedDims(const std::string& name, const std::vector& dims) override { - Variable* var = scope_.FindVar(name); - if (var->IsType()) { - var->GetMutable()->set_shapes(dims); - } else { - PADDLE_THROW( - "Only ReaderHolder support 'SetRepeatedDims', but Variable %s's " - "type_id is %s.", - name, var->Type().name()); - } + PADDLE_THROW("Only compile time support this method"); } proto::VarType::Type GetVarType(const std::string& name) const override { diff --git a/paddle/fluid/framework/reader.cc b/paddle/fluid/framework/reader.cc index 91879d6d4..31f686151 100644 --- a/paddle/fluid/framework/reader.cc +++ b/paddle/fluid/framework/reader.cc @@ -16,14 +16,6 @@ namespace paddle { namespace framework { - -DDim ReaderBase::shape(size_t idx) const { - PADDLE_ENFORCE_LT( - idx, shapes_.size(), - "Cannot get the %d'th shape, 'shapes_' only has %d elements.", idx, - shapes_.size()); - return shapes_[idx]; -} - +ReaderBase::~ReaderBase() {} } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/reader.h b/paddle/fluid/framework/reader.h index e281c9b13..2d8d30fc6 100644 --- a/paddle/fluid/framework/reader.h +++ b/paddle/fluid/framework/reader.h @@ -22,34 +22,18 @@ namespace framework { class ReaderBase { public: - explicit ReaderBase(const std::vector& shapes) : shapes_(shapes) { - PADDLE_ENFORCE(!shapes_.empty()); - } virtual void ReadNext(std::vector* out) = 0; virtual void ReInit() = 0; - DDim shape(size_t idx) const; - std::vector shapes() const { return shapes_; } - void set_shapes(const std::vector& shapes) { shapes_ = shapes; } - virtual bool HasNext() const = 0; - virtual ~ReaderBase() {} - - protected: - std::vector shapes_; -}; - -class FileReader : public ReaderBase { - public: - explicit FileReader(const std::vector& shapes) : ReaderBase(shapes) {} + virtual ~ReaderBase(); }; class DecoratedReader : public ReaderBase { public: - explicit DecoratedReader(ReaderBase* reader) - : ReaderBase(reader->shapes()), reader_(reader) { + explicit DecoratedReader(ReaderBase* reader) : ReaderBase(), reader_(reader) { PADDLE_ENFORCE_NOT_NULL(reader_); } @@ -72,12 +56,6 @@ class ReaderHolder { void ReadNext(std::vector* out) { reader_->ReadNext(out); } void ReInit() { reader_->ReInit(); } - DDim shape(size_t idx) const { return reader_->shape(idx); } - std::vector shapes() const { return reader_->shapes(); } - void set_shapes(const std::vector& shapes) { - reader_->set_shapes(shapes); - } - bool HasNext() const { return reader_->HasNext(); } private: diff --git a/paddle/fluid/operators/reader/create_random_data_generator_op.cc b/paddle/fluid/operators/reader/create_random_data_generator_op.cc index e62f952d0..95d8674c0 100644 --- a/paddle/fluid/operators/reader/create_random_data_generator_op.cc +++ b/paddle/fluid/operators/reader/create_random_data_generator_op.cc @@ -19,11 +19,11 @@ namespace operators { namespace reader { template -class RandomDataGenerator : public framework::FileReader { +class RandomDataGenerator : public framework::ReaderBase { public: RandomDataGenerator(const std::vector& shapes, float min, float max) - : FileReader(shapes), min_(min), max_(max) { + : framework::ReaderBase(), min_(min), max_(max), shapes_(shapes) { PADDLE_ENFORCE_LE( min, max, "'min' shouldn't be greater than 'max'.(%f vs %f)", min, max); unsigned int seed = std::random_device()(); @@ -59,6 +59,7 @@ class RandomDataGenerator : public framework::FileReader { float max_; std::minstd_rand engine_; std::uniform_real_distribution dist_; + std::vector shapes_; }; template diff --git a/paddle/fluid/operators/reader/create_recordio_file_reader_op.cc b/paddle/fluid/operators/reader/create_recordio_file_reader_op.cc index c3eb247bb..4992eb861 100644 --- a/paddle/fluid/operators/reader/create_recordio_file_reader_op.cc +++ b/paddle/fluid/operators/reader/create_recordio_file_reader_op.cc @@ -18,11 +18,10 @@ namespace paddle { namespace operators { namespace reader { -class RecordIOFileReader : public framework::FileReader { +class RecordIOFileReader : public framework::ReaderBase { public: - RecordIOFileReader(const std::string& filename, - const std::vector& shapes) - : FileReader(shapes), + explicit RecordIOFileReader(const std::string& filename) + : ReaderBase(), scanner_(filename), dev_ctx_(*platform::DeviceContextPool::Instance().Get( platform::CPUPlace())) {} @@ -54,12 +53,11 @@ class CreateRecordIOReaderOp : public framework::OperatorBase { int(shape_concat.size()), "The accumulate of all ranks should be equal to the " "shape concat's length."); - std::vector shapes = RestoreShapes(shape_concat, ranks); std::string filename = Attr("filename"); auto* out = scope.FindVar(Output("Out")) ->template GetMutable(); - out->Reset(new RecordIOFileReader(filename, shapes)); + out->Reset(new RecordIOFileReader(filename)); } }; -- GitLab