diff --git a/paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.cc b/paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.cc index da26f82008fb31123248d02e4694cf5b227ed474..3640e9f7dbfa5fac3c09b455ece6f98603a832b2 100644 --- a/paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.cc +++ b/paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.cc @@ -74,7 +74,8 @@ FeedFetchList ScopeBufferedSSAGraphExecutor::Run( } ++drop_scope_counter_; - if (drop_scope_counter_ == strategy_.num_iteration_per_drop_scope_) { + if (drop_scope_counter_ == strategy_.num_iteration_per_drop_scope_ || + DropScopeOrNot()) { DropLocalExeScopes(); } @@ -93,6 +94,19 @@ FeedFetchList ScopeBufferedSSAGraphExecutor::Run( } } +bool ScopeBufferedSSAGraphExecutor::DropScopeOrNot() const { + for (auto &var : tensor_array_vars_) { + auto tensor_array = var->GetMutable(); + for (LoDTensor &tensor : *tensor_array) { + if (tensor.IsInitialized()) { + return true; + } + } + tensor_array->clear(); + } + return false; +} + void ScopeBufferedSSAGraphExecutor::InitVariables() { for (auto &info : tmp_var_infos_) { for (auto &pair : info) { @@ -165,6 +179,9 @@ void ScopeBufferedSSAGraphExecutor::PrepareLocalExeScopes() { Variable *tmp_var = local_scope->Var(info.name_); preserve_vars_[idx].emplace(tmp_var); tmp_var_infos_[idx].emplace_back(tmp_var, info.type_); + if (info.type_ == proto::VarType::LOD_TENSOR_ARRAY) { + tensor_array_vars_.emplace_back(tmp_var); + } } } } diff --git a/paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.h b/paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.h index 1e1d663a4363e2e7c065b527c25b7dbf2d323333..17493a89a660588b0e0f8f8da42518961b008773 100644 --- a/paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.h +++ b/paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.h @@ -61,6 +61,8 @@ class ScopeBufferedSSAGraphExecutor : public SSAGraphExecutor { private: void InitVariables(); + bool DropScopeOrNot() const; + size_t drop_scope_counter_{0}; ExecutionStrategy strategy_; std::unique_ptr underlying_executor_; @@ -71,8 +73,11 @@ class ScopeBufferedSSAGraphExecutor : public SSAGraphExecutor { std::vector>> tmp_var_infos_; + std::vector tensor_array_vars_; + std::vector var_infos_; std::vector places_; + ScopeBufferedMonitor scope_monitor_; }; } // namespace details