diff --git a/paddle/fluid/inference/analysis/passes/memory_optimize_pass.cc b/paddle/fluid/inference/analysis/passes/memory_optimize_pass.cc index 2202b94bee727c4b4daaf9c76443507cd0804485..3fa417c2ea6311a1c1886c3cc887a32e45aad3d1 100644 --- a/paddle/fluid/inference/analysis/passes/memory_optimize_pass.cc +++ b/paddle/fluid/inference/analysis/passes/memory_optimize_pass.cc @@ -52,11 +52,11 @@ typedef struct { // The traversal order also affect the lifecycles, so different sort_kind is // used. void MemoryOptimizePass::CollectLifeCycle( - std::unordered_map* lifecycles, + Graph* graph, std::unordered_map* lifecycles, int sort_kind) const { - max_lifecycle_ = 0; + int max_lifecycle = 0; for (auto* op_node : framework::ir::TopologyVarientSort( - *graph_, static_cast(sort_kind))) { + *graph, static_cast(sort_kind))) { if (!op_node->IsOp()) continue; auto reads = op_node->inputs; auto writes = op_node->outputs; @@ -77,20 +77,20 @@ void MemoryOptimizePass::CollectLifeCycle( if (node->Var()->Persistable()) continue; std::string var = node->Name(); if (!lifecycles->count(var)) { - (*lifecycles)[var] = std::make_pair(max_lifecycle_, max_lifecycle_); + (*lifecycles)[var] = std::make_pair(max_lifecycle, max_lifecycle); } else { (*lifecycles)[var].second = - std::max(max_lifecycle_, lifecycles->at(var).second); // max() + std::max(max_lifecycle, lifecycles->at(var).second); // max() } } } - ++max_lifecycle_; + ++max_lifecycle; } } void MemoryOptimizePass::CollectVarMemorySize( - space_table_t* space_table) const { + Graph* graph, space_table_t* space_table) const { const int fake_batch_size = 1; auto valid_var = [&](framework::ir::Node* node) -> bool { @@ -130,7 +130,7 @@ void MemoryOptimizePass::CollectVarMemorySize( // although it's not always the case. so black list is the best compromise // between performance and underlying principle. std::unordered_set black_list; - for (auto* node : graph_->Nodes()) { + for (auto* node : graph->Nodes()) { if (node->IsVar() && node->Var()->GetType() == framework::proto::VarType::Type::VarType_Type_LOD_TENSOR) { @@ -141,7 +141,7 @@ void MemoryOptimizePass::CollectVarMemorySize( } // Collect tensors from graph. - for (auto* node : graph_->Nodes()) { + for (auto* node : graph->Nodes()) { if (node->IsVar() && node->Var()->GetType() == framework::proto::VarType::Type::VarType_Type_LOD_TENSOR && @@ -304,7 +304,10 @@ void MemoryOptimizePass::RunImpl(Argument* argument) { // 3. Perform reuse plan: Replace all var's name in the model according to the // mapping table. if (!argument->enable_memory_optim()) return; - graph_ = argument->main_graph_ptr(); + // Because of pass is a singleton, graph can not be member + // variables,otherwise,errors will be caused under multithreading + // conditions. + auto graph = argument->main_graph_ptr(); int sort_kind = 0; std::unordered_map lifecycles; @@ -312,10 +315,10 @@ void MemoryOptimizePass::RunImpl(Argument* argument) { std::unordered_map node2cluster; std::unordered_map cluster_size; - CollectLifeCycle(&lifecycles, sort_kind); - CollectVarMemorySize(&space_table); + CollectLifeCycle(graph, &lifecycles, sort_kind); + CollectVarMemorySize(graph, &space_table); MakeSimpleReusePlan(lifecycles, space_table, &node2cluster, &cluster_size); - UpdateOpDescsByReuse(graph_, node2cluster, sort_kind); + UpdateOpDescsByReuse(graph, node2cluster, sort_kind); return; } diff --git a/paddle/fluid/inference/analysis/passes/memory_optimize_pass.h b/paddle/fluid/inference/analysis/passes/memory_optimize_pass.h index 6d20aee295b7c1d2fe533bf8cd5195cb105afe2a..57052243d2f189ec6f722d5820cba223dd914e4a 100644 --- a/paddle/fluid/inference/analysis/passes/memory_optimize_pass.h +++ b/paddle/fluid/inference/analysis/passes/memory_optimize_pass.h @@ -57,17 +57,15 @@ class MemoryOptimizePass : public AnalysisPass { private: void CollectLifeCycle( + framework::ir::Graph *graph, std::unordered_map *lifecycles, int sort_kind) const; - void CollectVarMemorySize(space_table_t *space_table) const; + void CollectVarMemorySize(framework::ir::Graph *graph, + space_table_t *space_table) const; public: std::string repr() const override; - - private: - mutable framework::ir::Graph *graph_{nullptr}; - mutable int max_lifecycle_{-1}; }; } // namespace analysis