From ac84dce9f12b37e066320a11c7a4aff16fc93e13 Mon Sep 17 00:00:00 2001 From: jiangcheng Date: Mon, 30 Jan 2023 20:48:37 +0800 Subject: [PATCH] [CINN] fix build_cinn_pass collect inplace var bug (#50072) --- paddle/fluid/framework/paddle2cinn/build_cinn_pass.cc | 10 ++++++++-- paddle/fluid/framework/paddle2cinn/build_cinn_pass.h | 4 +++- 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/paddle/fluid/framework/paddle2cinn/build_cinn_pass.cc b/paddle/fluid/framework/paddle2cinn/build_cinn_pass.cc index 2d7a95da42..d7be74b6f8 100644 --- a/paddle/fluid/framework/paddle2cinn/build_cinn_pass.cc +++ b/paddle/fluid/framework/paddle2cinn/build_cinn_pass.cc @@ -163,9 +163,14 @@ std::unordered_set OpTransInfo::GetDenyVarNames( } std::unordered_set OpTransInfo::GetInplaceVarNames( - const GraphNodeSet& cluster_inputs, const GraphNodeSet& cluster_outputs) { + const GraphNodeSet& cluster_internals, + const GraphNodeSet& cluster_inputs, + const GraphNodeSet& cluster_outputs) { std::unordered_set all_inputs, all_outputs; + for (auto* var : cluster_internals) { + all_inputs.insert(var->Name()); + } for (auto* var : cluster_inputs) { all_inputs.insert(var->Name()); } @@ -478,7 +483,8 @@ std::unique_ptr CreateNewSubGraph(const GraphNodeSet& cluster, subgraph->GetOrInit(kMemOptVarInfoFromMainGraph); auto inplace_var_names = std::make_unique>( - OpTransInfo::GetInplaceVarNames(cluster_inputs, cluster_outputs)); + OpTransInfo::GetInplaceVarNames( + cluster_internals, cluster_inputs, cluster_outputs)); VLOG_IF(4, !inplace_var_names->empty()) << "Inplace var in cluster are: " << GetDebugInfo(*inplace_var_names); subgraph->Set>(kInplaceVarNames, diff --git a/paddle/fluid/framework/paddle2cinn/build_cinn_pass.h b/paddle/fluid/framework/paddle2cinn/build_cinn_pass.h index 1797d07faf..4d7b784e32 100644 --- a/paddle/fluid/framework/paddle2cinn/build_cinn_pass.h +++ b/paddle/fluid/framework/paddle2cinn/build_cinn_pass.h @@ -69,7 +69,9 @@ class OpTransInfo { const GraphNodeSet& cluster) const; static std::unordered_set GetInplaceVarNames( - const GraphNodeSet& cluster_inputs, const GraphNodeSet& cluster_outputs); + const GraphNodeSet& cluster_internals, + const GraphNodeSet& cluster_inputs, + const GraphNodeSet& cluster_outputs); private: DyOpCondT dynamic_op_cond_; -- GitLab