未验证 提交 9941ec12 编写于 作者: W WangZhen 提交者: GitHub

[JitLayer]Erase out vars in scope to avoid data rewritinig (#46249)

* [JitLayer]Erase out vars to avoid data rewrittinig

* Fix code comments
上级 82399bdf
...@@ -44,14 +44,17 @@ std::vector<Tensor> ExecutorEngine::operator()( ...@@ -44,14 +44,17 @@ std::vector<Tensor> ExecutorEngine::operator()(
std::vector<DenseTensor> ExecutorEngine::operator()( std::vector<DenseTensor> ExecutorEngine::operator()(
const std::vector<DenseTensor> &inputs) { const std::vector<DenseTensor> &inputs) {
utils::ShareIntoScope(info_->InputArgNames(), inputs, &scope_); utils::ShareIntoScope(info_->InputArgNames(), inputs, &scope_);
const auto out_names = info_->OutputArgNames();
inner_exe_.Run(info_->ProgramDesc(), inner_exe_.Run(info_->ProgramDesc(),
&scope_, &scope_,
/*blockID=*/0, /*blockID=*/0,
false, false,
true, true,
info_->OutputArgNames()); out_names);
std::vector<DenseTensor> outputs; std::vector<DenseTensor> outputs;
utils::FetchOuts(info_->OutputArgNames(), scope_, &outputs); utils::FetchOuts(out_names, scope_, &outputs);
// Erase output vars to avoid data rewriting.
scope_.EraseVars(out_names);
return outputs; return outputs;
} }
......
...@@ -96,12 +96,15 @@ std::vector<Tensor> PEEngine::operator()(const std::vector<Tensor> &inputs) { ...@@ -96,12 +96,15 @@ std::vector<Tensor> PEEngine::operator()(const std::vector<Tensor> &inputs) {
std::vector<DenseTensor> PEEngine::operator()( std::vector<DenseTensor> PEEngine::operator()(
const std::vector<DenseTensor> &inputs) { const std::vector<DenseTensor> &inputs) {
utils::ShareIntoScope(info_->InputArgNames(), inputs, &scope_); utils::ShareIntoScope(info_->InputArgNames(), inputs, &scope_);
const auto out_names = info_->OutputArgNames();
// need to recreate tmp variables in new scope // need to recreate tmp variables in new scope
inner_pe_->PrepareVariables(&scope_); inner_pe_->PrepareVariables(&scope_);
inner_pe_->RunWithoutFetch(info_->OutputArgNames()); inner_pe_->RunWithoutFetch(out_names);
std::vector<DenseTensor> outputs; std::vector<DenseTensor> outputs;
utils::FetchOuts(info_->OutputArgNames(), scope_, &outputs); utils::FetchOuts(out_names, scope_, &outputs);
// Erase output vars to avoid data rewriting.
scope_.EraseVars(out_names);
scope_.DropKids(); scope_.DropKids();
return outputs; return outputs;
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册