未验证 提交 ac84dce9 编写于 作者: J jiangcheng 提交者: GitHub

[CINN] fix build_cinn_pass collect inplace var bug (#50072)

上级 b1d44bfc
...@@ -163,9 +163,14 @@ std::unordered_set<std::string> OpTransInfo::GetDenyVarNames( ...@@ -163,9 +163,14 @@ std::unordered_set<std::string> OpTransInfo::GetDenyVarNames(
} }
std::unordered_set<std::string> OpTransInfo::GetInplaceVarNames( std::unordered_set<std::string> OpTransInfo::GetInplaceVarNames(
const GraphNodeSet& cluster_inputs, const GraphNodeSet& cluster_outputs) { const GraphNodeSet& cluster_internals,
const GraphNodeSet& cluster_inputs,
const GraphNodeSet& cluster_outputs) {
std::unordered_set<std::string> all_inputs, all_outputs; std::unordered_set<std::string> all_inputs, all_outputs;
for (auto* var : cluster_internals) {
all_inputs.insert(var->Name());
}
for (auto* var : cluster_inputs) { for (auto* var : cluster_inputs) {
all_inputs.insert(var->Name()); all_inputs.insert(var->Name());
} }
...@@ -478,7 +483,8 @@ std::unique_ptr<Graph> CreateNewSubGraph(const GraphNodeSet& cluster, ...@@ -478,7 +483,8 @@ std::unique_ptr<Graph> CreateNewSubGraph(const GraphNodeSet& cluster,
subgraph->GetOrInit<Name2VarInfoMap>(kMemOptVarInfoFromMainGraph); subgraph->GetOrInit<Name2VarInfoMap>(kMemOptVarInfoFromMainGraph);
auto inplace_var_names = std::make_unique<std::unordered_set<std::string>>( auto inplace_var_names = std::make_unique<std::unordered_set<std::string>>(
OpTransInfo::GetInplaceVarNames(cluster_inputs, cluster_outputs)); OpTransInfo::GetInplaceVarNames(
cluster_internals, cluster_inputs, cluster_outputs));
VLOG_IF(4, !inplace_var_names->empty()) VLOG_IF(4, !inplace_var_names->empty())
<< "Inplace var in cluster are: " << GetDebugInfo(*inplace_var_names); << "Inplace var in cluster are: " << GetDebugInfo(*inplace_var_names);
subgraph->Set<std::unordered_set<std::string>>(kInplaceVarNames, subgraph->Set<std::unordered_set<std::string>>(kInplaceVarNames,
......
...@@ -69,7 +69,9 @@ class OpTransInfo { ...@@ -69,7 +69,9 @@ class OpTransInfo {
const GraphNodeSet& cluster) const; const GraphNodeSet& cluster) const;
static std::unordered_set<std::string> GetInplaceVarNames( static std::unordered_set<std::string> GetInplaceVarNames(
const GraphNodeSet& cluster_inputs, const GraphNodeSet& cluster_outputs); const GraphNodeSet& cluster_internals,
const GraphNodeSet& cluster_inputs,
const GraphNodeSet& cluster_outputs);
private: private:
DyOpCondT dynamic_op_cond_; DyOpCondT dynamic_op_cond_;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册