未验证 提交 6d13992e 编写于 作者: J jiangcheng 提交者: GitHub

[CINN] BuildCinnPass collect inplace var from all cluster instead op (#50057)

上级 f8557cd9
......@@ -163,25 +163,23 @@ std::unordered_set<std::string> OpTransInfo::GetDenyVarNames(
}
std::unordered_set<std::string> OpTransInfo::GetInplaceVarNames(
const GraphNodeSet& cluster) {
std::unordered_set<std::string> inplace_var_set;
const GraphNodeSet& cluster_inputs, const GraphNodeSet& cluster_outputs) {
std::unordered_set<std::string> all_inputs, all_outputs;
for (auto* op : cluster) {
// skip if not op
if (!op->IsOp() || !op->Op()) {
continue;
}
const auto& op_desc = *op->Op();
// check whether input and output have same argument
auto inputs = op_desc.InputArgumentNames();
std::unordered_set<std::string> input_set(inputs.begin(), inputs.end());
for (auto& name : op_desc.OutputArgumentNames()) {
if (input_set.count(name)) {
inplace_var_set.insert(name);
}
for (auto* var : cluster_inputs) {
all_inputs.insert(var->Name());
}
for (auto* var : cluster_outputs) {
all_outputs.insert(var->Name());
}
std::unordered_set<std::string> inplace_var_set;
for (const auto& var_name : all_inputs) {
if (all_outputs.count(var_name)) {
inplace_var_set.insert(var_name);
}
}
return inplace_var_set;
}
......@@ -480,7 +478,7 @@ std::unique_ptr<Graph> CreateNewSubGraph(const GraphNodeSet& cluster,
subgraph->GetOrInit<Name2VarInfoMap>(kMemOptVarInfoFromMainGraph);
auto inplace_var_names = std::make_unique<std::unordered_set<std::string>>(
OpTransInfo::GetInplaceVarNames(cluster));
OpTransInfo::GetInplaceVarNames(cluster_inputs, cluster_outputs));
VLOG_IF(4, !inplace_var_names->empty())
<< "Inplace var in cluster are: " << GetDebugInfo(*inplace_var_names);
subgraph->Set<std::unordered_set<std::string>>(kInplaceVarNames,
......
......@@ -69,7 +69,7 @@ class OpTransInfo {
const GraphNodeSet& cluster) const;
static std::unordered_set<std::string> GetInplaceVarNames(
const GraphNodeSet& cluster);
const GraphNodeSet& cluster_inputs, const GraphNodeSet& cluster_outputs);
private:
DyOpCondT dynamic_op_cond_;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册