未验证 提交 6d13992e 编写于 作者: J jiangcheng 提交者: GitHub

[CINN] BuildCinnPass collect inplace var from all cluster instead op (#50057)

上级 f8557cd9
...@@ -163,25 +163,23 @@ std::unordered_set<std::string> OpTransInfo::GetDenyVarNames( ...@@ -163,25 +163,23 @@ std::unordered_set<std::string> OpTransInfo::GetDenyVarNames(
} }
std::unordered_set<std::string> OpTransInfo::GetInplaceVarNames( std::unordered_set<std::string> OpTransInfo::GetInplaceVarNames(
const GraphNodeSet& cluster) { const GraphNodeSet& cluster_inputs, const GraphNodeSet& cluster_outputs) {
std::unordered_set<std::string> inplace_var_set; std::unordered_set<std::string> all_inputs, all_outputs;
for (auto* op : cluster) { for (auto* var : cluster_inputs) {
// skip if not op all_inputs.insert(var->Name());
if (!op->IsOp() || !op->Op()) {
continue;
} }
const auto& op_desc = *op->Op(); for (auto* var : cluster_outputs) {
all_outputs.insert(var->Name());
// check whether input and output have same argument
auto inputs = op_desc.InputArgumentNames();
std::unordered_set<std::string> input_set(inputs.begin(), inputs.end());
for (auto& name : op_desc.OutputArgumentNames()) {
if (input_set.count(name)) {
inplace_var_set.insert(name);
} }
std::unordered_set<std::string> inplace_var_set;
for (const auto& var_name : all_inputs) {
if (all_outputs.count(var_name)) {
inplace_var_set.insert(var_name);
} }
} }
return inplace_var_set; return inplace_var_set;
} }
...@@ -480,7 +478,7 @@ std::unique_ptr<Graph> CreateNewSubGraph(const GraphNodeSet& cluster, ...@@ -480,7 +478,7 @@ std::unique_ptr<Graph> CreateNewSubGraph(const GraphNodeSet& cluster,
subgraph->GetOrInit<Name2VarInfoMap>(kMemOptVarInfoFromMainGraph); subgraph->GetOrInit<Name2VarInfoMap>(kMemOptVarInfoFromMainGraph);
auto inplace_var_names = std::make_unique<std::unordered_set<std::string>>( auto inplace_var_names = std::make_unique<std::unordered_set<std::string>>(
OpTransInfo::GetInplaceVarNames(cluster)); OpTransInfo::GetInplaceVarNames(cluster_inputs, cluster_outputs));
VLOG_IF(4, !inplace_var_names->empty()) VLOG_IF(4, !inplace_var_names->empty())
<< "Inplace var in cluster are: " << GetDebugInfo(*inplace_var_names); << "Inplace var in cluster are: " << GetDebugInfo(*inplace_var_names);
subgraph->Set<std::unordered_set<std::string>>(kInplaceVarNames, subgraph->Set<std::unordered_set<std::string>>(kInplaceVarNames,
......
...@@ -69,7 +69,7 @@ class OpTransInfo { ...@@ -69,7 +69,7 @@ class OpTransInfo {
const GraphNodeSet& cluster) const; const GraphNodeSet& cluster) const;
static std::unordered_set<std::string> GetInplaceVarNames( static std::unordered_set<std::string> GetInplaceVarNames(
const GraphNodeSet& cluster); const GraphNodeSet& cluster_inputs, const GraphNodeSet& cluster_outputs);
private: private:
DyOpCondT dynamic_op_cond_; DyOpCondT dynamic_op_cond_;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册