提交 614c33fb 编写于 作者: F fengjiayi

fix a potential bug in the c++ reader

上级 ccc54188
...@@ -65,12 +65,25 @@ class ReaderHolder { ...@@ -65,12 +65,25 @@ class ReaderHolder {
ReaderBase* Get() const { return reader_.get(); } ReaderBase* Get() const { return reader_.get(); }
void ReadNext(std::vector<LoDTensor>* out) { reader_->ReadNext(out); } void ReadNext(std::vector<LoDTensor>* out) {
void ReInit() { reader_->ReInit(); } PADDLE_ENFORCE_NOT_NULL(reader_);
reader_->ReadNext(out);
}
void ReInit() {
PADDLE_ENFORCE_NOT_NULL(reader_);
reader_->ReInit();
}
DDim shape(size_t idx) const { return reader_->shape(idx); } DDim shape(size_t idx) const {
std::vector<DDim> shapes() const { return reader_->shapes(); } PADDLE_ENFORCE_NOT_NULL(reader_);
return reader_->shape(idx);
}
std::vector<DDim> shapes() const {
PADDLE_ENFORCE_NOT_NULL(reader_);
return reader_->shapes();
}
void set_shapes(const std::vector<DDim>& shapes) { void set_shapes(const std::vector<DDim>& shapes) {
PADDLE_ENFORCE_NOT_NULL(reader_);
reader_->set_shapes(shapes); reader_->set_shapes(shapes);
} }
......
...@@ -49,6 +49,10 @@ FileReaderMakerBase::FileReaderMakerBase( ...@@ -49,6 +49,10 @@ FileReaderMakerBase::FileReaderMakerBase(
} }
void FileReaderInferShape::operator()(framework::InferShapeContext* ctx) const { void FileReaderInferShape::operator()(framework::InferShapeContext* ctx) const {
PADDLE_ENFORCE(
!ctx->IsRuntime(),
"'FileReaderInferShape' should only be invoked during compile time.");
PADDLE_ENFORCE(ctx->HasOutput("Out"), PADDLE_ENFORCE(ctx->HasOutput("Out"),
"The output file reader should not be null."); "The output file reader should not be null.");
const auto shape_concat = ctx->Attrs().Get<std::vector<int>>("shape_concat"); const auto shape_concat = ctx->Attrs().Get<std::vector<int>>("shape_concat");
...@@ -56,16 +60,14 @@ void FileReaderInferShape::operator()(framework::InferShapeContext* ctx) const { ...@@ -56,16 +60,14 @@ void FileReaderInferShape::operator()(framework::InferShapeContext* ctx) const {
std::vector<framework::DDim> shapes = RestoreShapes(shape_concat, ranks); std::vector<framework::DDim> shapes = RestoreShapes(shape_concat, ranks);
ctx->SetReaderDims("Out", shapes); ctx->SetReaderDims("Out", shapes);
if (ctx->IsRuntime()) { const auto lod_levels = ctx->Attrs().Get<std::vector<int>>("lod_levels");
const auto lod_levels = ctx->Attrs().Get<std::vector<int>>("lod_levels"); PADDLE_ENFORCE_EQ(lod_levels.size(), shapes.size(),
PADDLE_ENFORCE_EQ(lod_levels.size(), shapes.size(), "The number of 'lod_levels'(%d) doesn't match the number "
"The number of 'lod_levels'(%d) doesn't match the number " "of 'shapes'(%d).",
"of 'shapes'(%d).", lod_levels.size(), shapes.size());
lod_levels.size(), shapes.size()); framework::VarDesc* reader =
framework::VarDesc* reader = boost::get<framework::VarDesc*>(ctx->GetOutputVarPtrs("Out")[0]);
boost::get<framework::VarDesc*>(ctx->GetOutputVarPtrs("Out")[0]); reader->SetLoDLevels(lod_levels);
reader->SetLoDLevels(lod_levels);
}
} }
void FileReaderInferVarType::operator()(const framework::OpDesc& op_desc, void FileReaderInferVarType::operator()(const framework::OpDesc& op_desc,
...@@ -77,19 +79,21 @@ void FileReaderInferVarType::operator()(const framework::OpDesc& op_desc, ...@@ -77,19 +79,21 @@ void FileReaderInferVarType::operator()(const framework::OpDesc& op_desc,
void DecoratedReaderInferShape::operator()( void DecoratedReaderInferShape::operator()(
framework::InferShapeContext* ctx) const { framework::InferShapeContext* ctx) const {
PADDLE_ENFORCE(!ctx->IsRuntime(),
"'DecoratedReaderInferShape' should only be invoked during "
"compile time.");
PADDLE_ENFORCE(ctx->HasInput("UnderlyingReader"), PADDLE_ENFORCE(ctx->HasInput("UnderlyingReader"),
"Input(UnderlyingReader) should not be null."); "Input(UnderlyingReader) should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"), PADDLE_ENFORCE(ctx->HasOutput("Out"),
"The output decorated reader should not be null."); "The output decorated reader should not be null.");
ctx->SetReaderDims("Out", ctx->GetReaderDims("UnderlyingReader")); ctx->SetReaderDims("Out", ctx->GetReaderDims("UnderlyingReader"));
if (ctx->IsRuntime()) { framework::VarDesc* in_reader = boost::get<framework::VarDesc*>(
framework::VarDesc* in_reader = boost::get<framework::VarDesc*>( ctx->GetInputVarPtrs("UnderlyingReader")[0]);
ctx->GetInputVarPtrs("UnderlyingReader")[0]); framework::VarDesc* out_reader =
framework::VarDesc* out_reader = boost::get<framework::VarDesc*>(ctx->GetOutputVarPtrs("Out")[0]);
boost::get<framework::VarDesc*>(ctx->GetOutputVarPtrs("Out")[0]); out_reader->SetLoDLevels(in_reader->GetLoDLevels());
out_reader->SetLoDLevels(in_reader->GetLoDLevels());
}
} }
void DecoratedReaderInferVarType::operator()( void DecoratedReaderInferVarType::operator()(
const framework::OpDesc& op_desc, framework::BlockDesc* block) const { const framework::OpDesc& op_desc, framework::BlockDesc* block) const {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册