From 55ce696986d3f8ba8c4236ea180090cc2a1e0911 Mon Sep 17 00:00:00 2001 From: chengduo <30176695+chengduoZH@users.noreply.github.com> Date: Tue, 24 Sep 2019 10:04:40 +0800 Subject: [PATCH] clean tensor array (#19930) test=develop --- .../scope_buffered_ssa_graph_executor.cc | 19 ++++++++++++++++++- .../scope_buffered_ssa_graph_executor.h | 5 +++++ 2 files changed, 23 insertions(+), 1 deletion(-) diff --git a/paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.cc b/paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.cc index da26f82008..3640e9f7db 100644 --- a/paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.cc +++ b/paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.cc @@ -74,7 +74,8 @@ FeedFetchList ScopeBufferedSSAGraphExecutor::Run( } ++drop_scope_counter_; - if (drop_scope_counter_ == strategy_.num_iteration_per_drop_scope_) { + if (drop_scope_counter_ == strategy_.num_iteration_per_drop_scope_ || + DropScopeOrNot()) { DropLocalExeScopes(); } @@ -93,6 +94,19 @@ FeedFetchList ScopeBufferedSSAGraphExecutor::Run( } } +bool ScopeBufferedSSAGraphExecutor::DropScopeOrNot() const { + for (auto &var : tensor_array_vars_) { + auto tensor_array = var->GetMutable(); + for (LoDTensor &tensor : *tensor_array) { + if (tensor.IsInitialized()) { + return true; + } + } + tensor_array->clear(); + } + return false; +} + void ScopeBufferedSSAGraphExecutor::InitVariables() { for (auto &info : tmp_var_infos_) { for (auto &pair : info) { @@ -165,6 +179,9 @@ void ScopeBufferedSSAGraphExecutor::PrepareLocalExeScopes() { Variable *tmp_var = local_scope->Var(info.name_); preserve_vars_[idx].emplace(tmp_var); tmp_var_infos_[idx].emplace_back(tmp_var, info.type_); + if (info.type_ == proto::VarType::LOD_TENSOR_ARRAY) { + tensor_array_vars_.emplace_back(tmp_var); + } } } } diff --git a/paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.h b/paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.h index 1e1d663a43..17493a89a6 100644 --- a/paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.h +++ b/paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.h @@ -61,6 +61,8 @@ class ScopeBufferedSSAGraphExecutor : public SSAGraphExecutor { private: void InitVariables(); + bool DropScopeOrNot() const; + size_t drop_scope_counter_{0}; ExecutionStrategy strategy_; std::unique_ptr underlying_executor_; @@ -71,8 +73,11 @@ class ScopeBufferedSSAGraphExecutor : public SSAGraphExecutor { std::vector>> tmp_var_infos_; + std::vector tensor_array_vars_; + std::vector var_infos_; std::vector places_; + ScopeBufferedMonitor scope_monitor_; }; } // namespace details -- GitLab