From 9941ec1298ce5b18b2858d9a8b5b785b97d1f9b8 Mon Sep 17 00:00:00 2001 From: WangZhen <23097963+0x45f@users.noreply.github.com> Date: Tue, 20 Sep 2022 14:24:28 +0800 Subject: [PATCH] [JitLayer]Erase out vars in scope to avoid data rewritinig (#46249) * [JitLayer]Erase out vars to avoid data rewrittinig * Fix code comments --- paddle/fluid/jit/engine/executor_engine.cc | 7 +++++-- paddle/fluid/jit/engine/pe_engine.cc | 7 +++++-- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/paddle/fluid/jit/engine/executor_engine.cc b/paddle/fluid/jit/engine/executor_engine.cc index 58d80426e5f..1cde715b8f0 100644 --- a/paddle/fluid/jit/engine/executor_engine.cc +++ b/paddle/fluid/jit/engine/executor_engine.cc @@ -44,14 +44,17 @@ std::vector ExecutorEngine::operator()( std::vector ExecutorEngine::operator()( const std::vector &inputs) { utils::ShareIntoScope(info_->InputArgNames(), inputs, &scope_); + const auto out_names = info_->OutputArgNames(); inner_exe_.Run(info_->ProgramDesc(), &scope_, /*blockID=*/0, false, true, - info_->OutputArgNames()); + out_names); std::vector outputs; - utils::FetchOuts(info_->OutputArgNames(), scope_, &outputs); + utils::FetchOuts(out_names, scope_, &outputs); + // Erase output vars to avoid data rewriting. + scope_.EraseVars(out_names); return outputs; } diff --git a/paddle/fluid/jit/engine/pe_engine.cc b/paddle/fluid/jit/engine/pe_engine.cc index 78e48667547..576687c0efa 100644 --- a/paddle/fluid/jit/engine/pe_engine.cc +++ b/paddle/fluid/jit/engine/pe_engine.cc @@ -96,12 +96,15 @@ std::vector PEEngine::operator()(const std::vector &inputs) { std::vector PEEngine::operator()( const std::vector &inputs) { utils::ShareIntoScope(info_->InputArgNames(), inputs, &scope_); + const auto out_names = info_->OutputArgNames(); // need to recreate tmp variables in new scope inner_pe_->PrepareVariables(&scope_); - inner_pe_->RunWithoutFetch(info_->OutputArgNames()); + inner_pe_->RunWithoutFetch(out_names); std::vector outputs; - utils::FetchOuts(info_->OutputArgNames(), scope_, &outputs); + utils::FetchOuts(out_names, scope_, &outputs); + // Erase output vars to avoid data rewriting. + scope_.EraseVars(out_names); scope_.DropKids(); return outputs; } -- GitLab