提交 1b699234 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!3643 Throw exception if different communication ops which are divided to the...

!3643 Throw exception if different communication ops which are divided to the same segement share the same input
Merge pull request !3643 from huanghui/communication-op-fusion
......@@ -16,6 +16,7 @@
#include "backend/optimizer/pass/communication_op_fusion.h"
#include <vector>
#include <set>
#include <memory>
#include <unordered_map>
......@@ -89,6 +90,13 @@ std::string GetFusionGroupKey(const AnfNodePtr &node) {
}
return group + op + std::to_string(fusion);
}
void CheckInputs(const std::vector<AnfNodePtr> &fusion_inputs) {
std::set<AnfNodePtr> inputs_set(fusion_inputs.begin(), fusion_inputs.end());
if (inputs_set.size() < fusion_inputs.size()) {
MS_LOG(EXCEPTION) << "Different communication op in one segment cannot share the same input";
}
}
} // namespace
bool CommunicationOpFusion::GetSplitSegments(const CommunicationOpInfo &communication_op_info, size_t *segment_num,
......@@ -163,6 +171,7 @@ AnfNodePtr CommunicationOpFusion::CreateFusedCommunicationOp(const FuncGraphPtr
MS_EXCEPTION_IF_NULL(cnode);
fusion_inputs.insert(fusion_inputs.end(), cnode->inputs().begin() + 1, cnode->inputs().end());
}
CheckInputs(fusion_inputs);
AnfNodePtr fused_node = func_graph->NewCNode(fusion_inputs);
MS_EXCEPTION_IF_NULL(fused_node);
auto kernel_info = std::make_shared<device::KernelInfo>();
......@@ -172,9 +181,6 @@ AnfNodePtr CommunicationOpFusion::CreateFusedCommunicationOp(const FuncGraphPtr
for (size_t idx = start_index; idx <= end_index; ++idx) {
auto cnode = communication_op_info.communication_op_nodes[idx];
MS_EXCEPTION_IF_NULL(cnode);
AnfAlgo::CopyNodeAttr("fusion", cnode, fused_node);
AnfAlgo::CopyNodeAttr("op", cnode, fused_node);
AnfAlgo::CopyNodeAttr("group", cnode, fused_node);
abstract_list.push_back(cnode->abstract());
}
auto kernel_build_info = GenerateKernelBuildInfo(communication_op_info, start_index, end_index);
......@@ -182,6 +188,9 @@ AnfNodePtr CommunicationOpFusion::CreateFusedCommunicationOp(const FuncGraphPtr
auto abstract_tuple = std::make_shared<abstract::AbstractTuple>(abstract_list);
MS_EXCEPTION_IF_NULL(abstract_tuple);
fused_node->set_abstract(abstract_tuple);
AnfAlgo::CopyNodeAttr("fusion", communication_op_info.communication_op_nodes[end_index], fused_node);
AnfAlgo::CopyNodeAttr("op", communication_op_info.communication_op_nodes[end_index], fused_node);
AnfAlgo::CopyNodeAttr("group", communication_op_info.communication_op_nodes[end_index], fused_node);
return fused_node;
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册