未验证 提交 ac84dce9 编写于 作者: J jiangcheng 提交者: GitHub

[CINN] fix build_cinn_pass collect inplace var bug (#50072)

上级 b1d44bfc
......@@ -163,9 +163,14 @@ std::unordered_set<std::string> OpTransInfo::GetDenyVarNames(
}
std::unordered_set<std::string> OpTransInfo::GetInplaceVarNames(
const GraphNodeSet& cluster_inputs, const GraphNodeSet& cluster_outputs) {
const GraphNodeSet& cluster_internals,
const GraphNodeSet& cluster_inputs,
const GraphNodeSet& cluster_outputs) {
std::unordered_set<std::string> all_inputs, all_outputs;
for (auto* var : cluster_internals) {
all_inputs.insert(var->Name());
}
for (auto* var : cluster_inputs) {
all_inputs.insert(var->Name());
}
......@@ -478,7 +483,8 @@ std::unique_ptr<Graph> CreateNewSubGraph(const GraphNodeSet& cluster,
subgraph->GetOrInit<Name2VarInfoMap>(kMemOptVarInfoFromMainGraph);
auto inplace_var_names = std::make_unique<std::unordered_set<std::string>>(
OpTransInfo::GetInplaceVarNames(cluster_inputs, cluster_outputs));
OpTransInfo::GetInplaceVarNames(
cluster_internals, cluster_inputs, cluster_outputs));
VLOG_IF(4, !inplace_var_names->empty())
<< "Inplace var in cluster are: " << GetDebugInfo(*inplace_var_names);
subgraph->Set<std::unordered_set<std::string>>(kInplaceVarNames,
......
......@@ -69,7 +69,9 @@ class OpTransInfo {
const GraphNodeSet& cluster) const;
static std::unordered_set<std::string> GetInplaceVarNames(
const GraphNodeSet& cluster_inputs, const GraphNodeSet& cluster_outputs);
const GraphNodeSet& cluster_internals,
const GraphNodeSet& cluster_inputs,
const GraphNodeSet& cluster_outputs);
private:
DyOpCondT dynamic_op_cond_;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册