diff --git a/paddle/fluid/framework/paddle2cinn/build_cinn_pass.cc b/paddle/fluid/framework/paddle2cinn/build_cinn_pass.cc index 2d7a95da4201a67ddab677f8704668c1839f0fd6..d7be74b6f8a176d94f26649471deed7e02e360a9 100644 --- a/paddle/fluid/framework/paddle2cinn/build_cinn_pass.cc +++ b/paddle/fluid/framework/paddle2cinn/build_cinn_pass.cc @@ -163,9 +163,14 @@ std::unordered_set OpTransInfo::GetDenyVarNames( } std::unordered_set OpTransInfo::GetInplaceVarNames( - const GraphNodeSet& cluster_inputs, const GraphNodeSet& cluster_outputs) { + const GraphNodeSet& cluster_internals, + const GraphNodeSet& cluster_inputs, + const GraphNodeSet& cluster_outputs) { std::unordered_set all_inputs, all_outputs; + for (auto* var : cluster_internals) { + all_inputs.insert(var->Name()); + } for (auto* var : cluster_inputs) { all_inputs.insert(var->Name()); } @@ -478,7 +483,8 @@ std::unique_ptr CreateNewSubGraph(const GraphNodeSet& cluster, subgraph->GetOrInit(kMemOptVarInfoFromMainGraph); auto inplace_var_names = std::make_unique>( - OpTransInfo::GetInplaceVarNames(cluster_inputs, cluster_outputs)); + OpTransInfo::GetInplaceVarNames( + cluster_internals, cluster_inputs, cluster_outputs)); VLOG_IF(4, !inplace_var_names->empty()) << "Inplace var in cluster are: " << GetDebugInfo(*inplace_var_names); subgraph->Set>(kInplaceVarNames, diff --git a/paddle/fluid/framework/paddle2cinn/build_cinn_pass.h b/paddle/fluid/framework/paddle2cinn/build_cinn_pass.h index 1797d07faf5c72d4e784c37e018fae6eabd92063..4d7b784e32cedd3a1fa765bc891ce32e6c0a168e 100644 --- a/paddle/fluid/framework/paddle2cinn/build_cinn_pass.h +++ b/paddle/fluid/framework/paddle2cinn/build_cinn_pass.h @@ -69,7 +69,9 @@ class OpTransInfo { const GraphNodeSet& cluster) const; static std::unordered_set GetInplaceVarNames( - const GraphNodeSet& cluster_inputs, const GraphNodeSet& cluster_outputs); + const GraphNodeSet& cluster_internals, + const GraphNodeSet& cluster_inputs, + const GraphNodeSet& cluster_outputs); private: DyOpCondT dynamic_op_cond_;