提交 e4e9d362 编写于 作者: F fengjiayi

fix a potential bug

上级 239546a6
......@@ -21,25 +21,22 @@ namespace reader {
class CustomReader : public framework::DecoratedReader {
public:
CustomReader(ReaderBase* reader, const framework::BlockDesc* sub_block,
CustomReader(ReaderBase* reader, const framework::BlockDesc& sub_block,
const platform::Place& dev_place,
const std::vector<std::string>& source_var_names,
const std::vector<std::string>& sink_var_names)
: DecoratedReader(reader),
sub_block_(sub_block),
program_(*sub_block.Program()),
sub_block_id_(sub_block.ID()),
exe_(framework::Executor(dev_place)),
source_var_names_(source_var_names),
sink_var_names_(sink_var_names) {}
void ReadNext(std::vector<framework::LoDTensor>* out) override;
void UpdateBlockAndScope(const framework::BlockDesc* sub_block,
const framework::Scope* scope) {
sub_block_ = sub_block;
}
private:
const framework::BlockDesc* sub_block_;
const framework::ProgramDesc program_;
int sub_block_id_;
framework::Executor exe_;
std::vector<std::string> source_var_names_;
......@@ -57,14 +54,12 @@ class CreateCustomReaderOp : public framework::OperatorBase {
->template GetMutable<framework::ReaderHolder>();
auto* sub_block = Attr<framework::BlockDesc*>("sub_block");
if (out->Get() != nullptr) {
auto* custom_reader = reinterpret_cast<CustomReader*>(out->Get());
custom_reader->UpdateBlockAndScope(sub_block, &scope);
return;
}
const auto& underlying_reader = scope.FindVar(Input("UnderlyingReader"))
->Get<framework::ReaderHolder>();
out->Reset(
new CustomReader(underlying_reader.Get(), sub_block, dev_place,
new CustomReader(underlying_reader.Get(), *sub_block, dev_place,
Attr<std::vector<std::string>>("source_var_names"),
Attr<std::vector<std::string>>("sink_var_names")));
}
......@@ -159,8 +154,7 @@ void CustomReader::ReadNext(std::vector<framework::LoDTensor>* out) {
tensor->set_lod(underlying_outs[i].lod());
}
// 2. Run the sub-block.
framework::ProgramDesc* program = sub_block_->Program();
exe_.Run(*program, scope, sub_block_->ID(), false, true);
exe_.Run(program_, scope, sub_block_id_, false, true);
// 3. Copy LoDTensors from sink variables to out.
out->resize(sink_var_names_.size());
for (size_t i = 0; i < sink_var_names_.size(); ++i) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册