diff --git a/paddle/fluid/framework/reader.h b/paddle/fluid/framework/reader.h index 1be3f4ef1f46bd8d72a285afa69b52d6f519ccf5..e820c3d07e85fd1dea9080786b48ad031330ee00 100644 --- a/paddle/fluid/framework/reader.h +++ b/paddle/fluid/framework/reader.h @@ -65,12 +65,25 @@ class ReaderHolder { ReaderBase* Get() const { return reader_.get(); } - void ReadNext(std::vector* out) { reader_->ReadNext(out); } - void ReInit() { reader_->ReInit(); } + void ReadNext(std::vector* out) { + PADDLE_ENFORCE_NOT_NULL(reader_); + reader_->ReadNext(out); + } + void ReInit() { + PADDLE_ENFORCE_NOT_NULL(reader_); + reader_->ReInit(); + } - DDim shape(size_t idx) const { return reader_->shape(idx); } - std::vector shapes() const { return reader_->shapes(); } + DDim shape(size_t idx) const { + PADDLE_ENFORCE_NOT_NULL(reader_); + return reader_->shape(idx); + } + std::vector shapes() const { + PADDLE_ENFORCE_NOT_NULL(reader_); + return reader_->shapes(); + } void set_shapes(const std::vector& shapes) { + PADDLE_ENFORCE_NOT_NULL(reader_); reader_->set_shapes(shapes); } diff --git a/paddle/fluid/operators/reader/reader_op_registry.cc b/paddle/fluid/operators/reader/reader_op_registry.cc index 7ea4f4b8d9feecac5bc2d0338bbbe9ab7a532040..f80769d7cd2d35261cd55fc1d6c8c20197f5e88c 100644 --- a/paddle/fluid/operators/reader/reader_op_registry.cc +++ b/paddle/fluid/operators/reader/reader_op_registry.cc @@ -49,6 +49,10 @@ FileReaderMakerBase::FileReaderMakerBase( } void FileReaderInferShape::operator()(framework::InferShapeContext* ctx) const { + PADDLE_ENFORCE( + !ctx->IsRuntime(), + "'FileReaderInferShape' should only be invoked during compile time."); + PADDLE_ENFORCE(ctx->HasOutput("Out"), "The output file reader should not be null."); const auto shape_concat = ctx->Attrs().Get>("shape_concat"); @@ -56,16 +60,14 @@ void FileReaderInferShape::operator()(framework::InferShapeContext* ctx) const { std::vector shapes = RestoreShapes(shape_concat, ranks); ctx->SetReaderDims("Out", shapes); - if (ctx->IsRuntime()) { - const auto lod_levels = ctx->Attrs().Get>("lod_levels"); - PADDLE_ENFORCE_EQ(lod_levels.size(), shapes.size(), - "The number of 'lod_levels'(%d) doesn't match the number " - "of 'shapes'(%d).", - lod_levels.size(), shapes.size()); - framework::VarDesc* reader = - boost::get(ctx->GetOutputVarPtrs("Out")[0]); - reader->SetLoDLevels(lod_levels); - } + const auto lod_levels = ctx->Attrs().Get>("lod_levels"); + PADDLE_ENFORCE_EQ(lod_levels.size(), shapes.size(), + "The number of 'lod_levels'(%d) doesn't match the number " + "of 'shapes'(%d).", + lod_levels.size(), shapes.size()); + framework::VarDesc* reader = + boost::get(ctx->GetOutputVarPtrs("Out")[0]); + reader->SetLoDLevels(lod_levels); } void FileReaderInferVarType::operator()(const framework::OpDesc& op_desc, @@ -77,19 +79,21 @@ void FileReaderInferVarType::operator()(const framework::OpDesc& op_desc, void DecoratedReaderInferShape::operator()( framework::InferShapeContext* ctx) const { + PADDLE_ENFORCE(!ctx->IsRuntime(), + "'DecoratedReaderInferShape' should only be invoked during " + "compile time."); + PADDLE_ENFORCE(ctx->HasInput("UnderlyingReader"), "Input(UnderlyingReader) should not be null."); PADDLE_ENFORCE(ctx->HasOutput("Out"), "The output decorated reader should not be null."); ctx->SetReaderDims("Out", ctx->GetReaderDims("UnderlyingReader")); - if (ctx->IsRuntime()) { - framework::VarDesc* in_reader = boost::get( - ctx->GetInputVarPtrs("UnderlyingReader")[0]); - framework::VarDesc* out_reader = - boost::get(ctx->GetOutputVarPtrs("Out")[0]); - out_reader->SetLoDLevels(in_reader->GetLoDLevels()); - } + framework::VarDesc* in_reader = boost::get( + ctx->GetInputVarPtrs("UnderlyingReader")[0]); + framework::VarDesc* out_reader = + boost::get(ctx->GetOutputVarPtrs("Out")[0]); + out_reader->SetLoDLevels(in_reader->GetLoDLevels()); } void DecoratedReaderInferVarType::operator()( const framework::OpDesc& op_desc, framework::BlockDesc* block) const {