提交 114eb175 编写于 作者: S sneaxiy

fix executor bug

上级 612e1a31
...@@ -337,6 +337,7 @@ void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope, ...@@ -337,6 +337,7 @@ void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope,
std::unique_ptr<GarbageCollector<Tensor>> gc; std::unique_ptr<GarbageCollector<Tensor>> gc;
if (max_memory_size >= 0) { if (max_memory_size >= 0) {
ctx->ResetReferenceCount();
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
if (platform::is_gpu_place(place_)) { if (platform::is_gpu_place(place_)) {
gc.reset(new DefaultStreamGarbageCollector<Tensor>( gc.reset(new DefaultStreamGarbageCollector<Tensor>(
...@@ -357,11 +358,11 @@ void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope, ...@@ -357,11 +358,11 @@ void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope,
std::vector<std::string> erase_vars; std::vector<std::string> erase_vars;
for (auto& input : op->Inputs()) { for (auto& input : op->Inputs()) {
for (auto& input_name : input.second) { for (auto& input_name : input.second) {
auto it = ctx->ref_cnts_.find(input_name); auto it = ctx->cur_ref_cnts_.find(input_name);
if (it == ctx->ref_cnts_.end()) continue; if (it == ctx->cur_ref_cnts_.end()) continue;
if (it->second == 1) { // should delete it if (it->second == 1) { // should delete it
erase_vars.emplace_back(input_name); erase_vars.emplace_back(input_name);
ctx->ref_cnts_.erase(input_name); ctx->cur_ref_cnts_.erase(input_name);
} else { } else {
--(it->second); --(it->second);
} }
...@@ -370,11 +371,11 @@ void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope, ...@@ -370,11 +371,11 @@ void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope,
for (auto& output : op->Outputs()) { for (auto& output : op->Outputs()) {
for (auto& output_name : output.second) { for (auto& output_name : output.second) {
auto it = ctx->ref_cnts_.find(output_name); auto it = ctx->cur_ref_cnts_.find(output_name);
if (it == ctx->ref_cnts_.end()) continue; if (it == ctx->cur_ref_cnts_.end()) continue;
if (it->second == 1) { if (it->second == 1) {
erase_vars.emplace_back(output_name); erase_vars.emplace_back(output_name);
ctx->ref_cnts_.erase(output_name); ctx->cur_ref_cnts_.erase(output_name);
} else { } else {
--(it->second); --(it->second);
} }
......
...@@ -72,11 +72,14 @@ struct ExecutorPrepareContext { ...@@ -72,11 +72,14 @@ struct ExecutorPrepareContext {
ExecutorPrepareContext(const framework::ProgramDesc& prog, size_t block_id); ExecutorPrepareContext(const framework::ProgramDesc& prog, size_t block_id);
~ExecutorPrepareContext(); ~ExecutorPrepareContext();
void ResetReferenceCount() { cur_ref_cnts_ = ref_cnts_; }
const framework::ProgramDesc& prog_; const framework::ProgramDesc& prog_;
size_t block_id_; size_t block_id_;
std::vector<std::unique_ptr<OperatorBase>> ops_; std::vector<std::unique_ptr<OperatorBase>> ops_;
std::unordered_map<std::string, int> ref_cnts_; std::unordered_map<std::string, int> ref_cnts_;
std::unordered_map<std::string, int> cur_ref_cnts_;
}; };
class Executor { class Executor {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册