提交 bc9d19c7 编写于 作者: F fengjiayi

fix a bug

上级 32478fe0
...@@ -159,23 +159,24 @@ void CustomReader::ReadNext(std::vector<framework::LoDTensor>* out) { ...@@ -159,23 +159,24 @@ void CustomReader::ReadNext(std::vector<framework::LoDTensor>* out) {
// The scope for CustomReader's sub-block should be independent and shouldn't // The scope for CustomReader's sub-block should be independent and shouldn't
// be any other computation scope's child. Otherwise, data preprocessing and // be any other computation scope's child. Otherwise, data preprocessing and
// compution cannot be concurrent. // compution cannot be concurrent.
framework::Scope& exe_scope = scope_.NewScope(); framework::Scope* exe_scope = &scope_.NewScope();
// 1. Copy LoDTensors from underlying reader's output to source variables. // 1. Copy LoDTensors from underlying reader's output to source variables.
for (size_t i = 0; i < source_var_names_.size(); ++i) { for (size_t i = 0; i < source_var_names_.size(); ++i) {
framework::Variable* var = exe_scope.Var(source_var_names_[i]); framework::Variable* var = exe_scope->Var(source_var_names_[i]);
framework::LoDTensor* tensor = var->GetMutable<framework::LoDTensor>(); framework::LoDTensor* tensor = var->GetMutable<framework::LoDTensor>();
tensor->ShareDataWith(underlying_outs[i]); tensor->ShareDataWith(underlying_outs[i]);
tensor->set_lod(underlying_outs[i].lod()); tensor->set_lod(underlying_outs[i].lod());
} }
// 2. Run the sub-block. // 2. Run the sub-block.
exe_.Run(program_, &exe_scope, sub_block_id_, false, true); exe_.Run(program_, exe_scope, sub_block_id_, false, true);
// 3. Copy LoDTensors from sink variables to out. // 3. Copy LoDTensors from sink variables to out.
out->resize(sink_var_names_.size()); out->resize(sink_var_names_.size());
for (size_t i = 0; i < sink_var_names_.size(); ++i) { for (size_t i = 0; i < sink_var_names_.size(); ++i) {
const auto& tensor = detail::Ref(exe_scope.FindVar(sink_var_names_[i])) const auto& tensor = detail::Ref(exe_scope->FindVar(sink_var_names_[i]))
.Get<framework::LoDTensor>(); .Get<framework::LoDTensor>();
framework::TensorCopySync(tensor, platform::CPUPlace(), &(*out)[i]); framework::TensorCopySync(tensor, platform::CPUPlace(), &(*out)[i]);
} }
scope_.DeleteScope(exe_scope);
} }
} // namespace reader } // namespace reader
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册