From baa19feadfabc9a15eb1b02edb38d91983db87d0 Mon Sep 17 00:00:00 2001 From: sneaxiy Date: Sun, 30 Sep 2018 10:56:20 +0000 Subject: [PATCH] fix_eager_deletion --- .../framework/details/reference_count_pass.cc | 18 +++++++++--------- paddle/fluid/framework/parallel_executor.cc | 7 +++++++ 2 files changed, 16 insertions(+), 9 deletions(-) diff --git a/paddle/fluid/framework/details/reference_count_pass.cc b/paddle/fluid/framework/details/reference_count_pass.cc index b1ce551ce73..2d1f688d64e 100644 --- a/paddle/fluid/framework/details/reference_count_pass.cc +++ b/paddle/fluid/framework/details/reference_count_pass.cc @@ -80,15 +80,15 @@ std::unique_ptr ReferenceCountPass::ApplyImpl( // This is weird but there is really some variables without var_desc // in computation_op if (var_desc == nullptr) { - if (compute_op->Node()->Op()->Block()->FindVar(var_name) == nullptr) - continue; - } else { - if (var_desc->Persistable()) continue; - auto var_type = var_desc->Proto()->type().type(); - if (var_type != proto::VarType::LOD_TENSOR && - var_type != proto::VarType::SELECTED_ROWS) { - continue; - } + var_desc = compute_op->Node()->Op()->Block()->FindVar(var_name); + if (var_desc == nullptr) continue; + } + + if (var_desc->Persistable()) continue; + auto var_type = var_desc->Proto()->type().type(); + if (var_type != proto::VarType::LOD_TENSOR && + var_type != proto::VarType::SELECTED_ROWS) { + continue; } // compute op only runs in one device diff --git a/paddle/fluid/framework/parallel_executor.cc b/paddle/fluid/framework/parallel_executor.cc index f5a54c0f48c..274a3e686c6 100644 --- a/paddle/fluid/framework/parallel_executor.cc +++ b/paddle/fluid/framework/parallel_executor.cc @@ -319,6 +319,13 @@ void ParallelExecutor::Run(const std::vector &fetch_tensors, #ifdef PADDLE_WITH_CUDA if (!gcs_.empty()) { ResetReferenceCount(); + for (auto &pair : cur_ref_cnts_) { + auto &name_map = *(pair.second); + for (auto &fetch_name : fetch_tensors) { + name_map.erase(fetch_name); + } + name_map.erase(fetched_var_name); + } } #endif auto fetch_data = member_->executor_->Run(fetch_tensors); -- GitLab