From 5925b82e06dd1fc3ef9499353a89b91da6974df4 Mon Sep 17 00:00:00 2001 From: JingZhuangzhuang <75348594+JZZ-NOTE@users.noreply.github.com> Date: Wed, 12 Jan 2022 10:40:00 +0800 Subject: [PATCH] multithread memory optimize error fix (#37894) (#38737) * multithread_memory_optimize --- .../analysis/passes/memory_optimize_pass.cc | 29 ++++++++++--------- .../analysis/passes/memory_optimize_pass.h | 8 ++--- 2 files changed, 19 insertions(+), 18 deletions(-) diff --git a/paddle/fluid/inference/analysis/passes/memory_optimize_pass.cc b/paddle/fluid/inference/analysis/passes/memory_optimize_pass.cc index 2202b94bee..3fa417c2ea 100644 --- a/paddle/fluid/inference/analysis/passes/memory_optimize_pass.cc +++ b/paddle/fluid/inference/analysis/passes/memory_optimize_pass.cc @@ -52,11 +52,11 @@ typedef struct { // The traversal order also affect the lifecycles, so different sort_kind is // used. void MemoryOptimizePass::CollectLifeCycle( - std::unordered_map* lifecycles, + Graph* graph, std::unordered_map* lifecycles, int sort_kind) const { - max_lifecycle_ = 0; + int max_lifecycle = 0; for (auto* op_node : framework::ir::TopologyVarientSort( - *graph_, static_cast(sort_kind))) { + *graph, static_cast(sort_kind))) { if (!op_node->IsOp()) continue; auto reads = op_node->inputs; auto writes = op_node->outputs; @@ -77,20 +77,20 @@ void MemoryOptimizePass::CollectLifeCycle( if (node->Var()->Persistable()) continue; std::string var = node->Name(); if (!lifecycles->count(var)) { - (*lifecycles)[var] = std::make_pair(max_lifecycle_, max_lifecycle_); + (*lifecycles)[var] = std::make_pair(max_lifecycle, max_lifecycle); } else { (*lifecycles)[var].second = - std::max(max_lifecycle_, lifecycles->at(var).second); // max() + std::max(max_lifecycle, lifecycles->at(var).second); // max() } } } - ++max_lifecycle_; + ++max_lifecycle; } } void MemoryOptimizePass::CollectVarMemorySize( - space_table_t* space_table) const { + Graph* graph, space_table_t* space_table) const { const int fake_batch_size = 1; auto valid_var = [&](framework::ir::Node* node) -> bool { @@ -130,7 +130,7 @@ void MemoryOptimizePass::CollectVarMemorySize( // although it's not always the case. so black list is the best compromise // between performance and underlying principle. std::unordered_set black_list; - for (auto* node : graph_->Nodes()) { + for (auto* node : graph->Nodes()) { if (node->IsVar() && node->Var()->GetType() == framework::proto::VarType::Type::VarType_Type_LOD_TENSOR) { @@ -141,7 +141,7 @@ void MemoryOptimizePass::CollectVarMemorySize( } // Collect tensors from graph. - for (auto* node : graph_->Nodes()) { + for (auto* node : graph->Nodes()) { if (node->IsVar() && node->Var()->GetType() == framework::proto::VarType::Type::VarType_Type_LOD_TENSOR && @@ -304,7 +304,10 @@ void MemoryOptimizePass::RunImpl(Argument* argument) { // 3. Perform reuse plan: Replace all var's name in the model according to the // mapping table. if (!argument->enable_memory_optim()) return; - graph_ = argument->main_graph_ptr(); + // Because of pass is a singleton, graph can not be member + // variables,otherwise,errors will be caused under multithreading + // conditions. + auto graph = argument->main_graph_ptr(); int sort_kind = 0; std::unordered_map lifecycles; @@ -312,10 +315,10 @@ void MemoryOptimizePass::RunImpl(Argument* argument) { std::unordered_map node2cluster; std::unordered_map cluster_size; - CollectLifeCycle(&lifecycles, sort_kind); - CollectVarMemorySize(&space_table); + CollectLifeCycle(graph, &lifecycles, sort_kind); + CollectVarMemorySize(graph, &space_table); MakeSimpleReusePlan(lifecycles, space_table, &node2cluster, &cluster_size); - UpdateOpDescsByReuse(graph_, node2cluster, sort_kind); + UpdateOpDescsByReuse(graph, node2cluster, sort_kind); return; } diff --git a/paddle/fluid/inference/analysis/passes/memory_optimize_pass.h b/paddle/fluid/inference/analysis/passes/memory_optimize_pass.h index 6d20aee295..57052243d2 100644 --- a/paddle/fluid/inference/analysis/passes/memory_optimize_pass.h +++ b/paddle/fluid/inference/analysis/passes/memory_optimize_pass.h @@ -57,17 +57,15 @@ class MemoryOptimizePass : public AnalysisPass { private: void CollectLifeCycle( + framework::ir::Graph *graph, std::unordered_map *lifecycles, int sort_kind) const; - void CollectVarMemorySize(space_table_t *space_table) const; + void CollectVarMemorySize(framework::ir::Graph *graph, + space_table_t *space_table) const; public: std::string repr() const override; - - private: - mutable framework::ir::Graph *graph_{nullptr}; - mutable int max_lifecycle_{-1}; }; } // namespace analysis -- GitLab