From 6d13992e8818e5ce18d984fd583b3322cbc76219 Mon Sep 17 00:00:00 2001 From: jiangcheng Date: Sun, 29 Jan 2023 19:28:04 +0800 Subject: [PATCH] [CINN] BuildCinnPass collect inplace var from all cluster instead op (#50057) --- .../framework/paddle2cinn/build_cinn_pass.cc | 32 +++++++++---------- .../framework/paddle2cinn/build_cinn_pass.h | 2 +- 2 files changed, 16 insertions(+), 18 deletions(-) diff --git a/paddle/fluid/framework/paddle2cinn/build_cinn_pass.cc b/paddle/fluid/framework/paddle2cinn/build_cinn_pass.cc index 4d438122d14..2d7a95da420 100644 --- a/paddle/fluid/framework/paddle2cinn/build_cinn_pass.cc +++ b/paddle/fluid/framework/paddle2cinn/build_cinn_pass.cc @@ -163,25 +163,23 @@ std::unordered_set OpTransInfo::GetDenyVarNames( } std::unordered_set OpTransInfo::GetInplaceVarNames( - const GraphNodeSet& cluster) { - std::unordered_set inplace_var_set; + const GraphNodeSet& cluster_inputs, const GraphNodeSet& cluster_outputs) { + std::unordered_set all_inputs, all_outputs; - for (auto* op : cluster) { - // skip if not op - if (!op->IsOp() || !op->Op()) { - continue; - } - const auto& op_desc = *op->Op(); - - // check whether input and output have same argument - auto inputs = op_desc.InputArgumentNames(); - std::unordered_set input_set(inputs.begin(), inputs.end()); - for (auto& name : op_desc.OutputArgumentNames()) { - if (input_set.count(name)) { - inplace_var_set.insert(name); - } + for (auto* var : cluster_inputs) { + all_inputs.insert(var->Name()); + } + for (auto* var : cluster_outputs) { + all_outputs.insert(var->Name()); + } + + std::unordered_set inplace_var_set; + for (const auto& var_name : all_inputs) { + if (all_outputs.count(var_name)) { + inplace_var_set.insert(var_name); } } + return inplace_var_set; } @@ -480,7 +478,7 @@ std::unique_ptr CreateNewSubGraph(const GraphNodeSet& cluster, subgraph->GetOrInit(kMemOptVarInfoFromMainGraph); auto inplace_var_names = std::make_unique>( - OpTransInfo::GetInplaceVarNames(cluster)); + OpTransInfo::GetInplaceVarNames(cluster_inputs, cluster_outputs)); VLOG_IF(4, !inplace_var_names->empty()) << "Inplace var in cluster are: " << GetDebugInfo(*inplace_var_names); subgraph->Set>(kInplaceVarNames, diff --git a/paddle/fluid/framework/paddle2cinn/build_cinn_pass.h b/paddle/fluid/framework/paddle2cinn/build_cinn_pass.h index 7e5152048d9..1797d07faf5 100644 --- a/paddle/fluid/framework/paddle2cinn/build_cinn_pass.h +++ b/paddle/fluid/framework/paddle2cinn/build_cinn_pass.h @@ -69,7 +69,7 @@ class OpTransInfo { const GraphNodeSet& cluster) const; static std::unordered_set GetInplaceVarNames( - const GraphNodeSet& cluster); + const GraphNodeSet& cluster_inputs, const GraphNodeSet& cluster_outputs); private: DyOpCondT dynamic_op_cond_; -- GitLab