提交 31d8db9f 编写于 作者: Y yujianfeng

Add broadcast fusion pass

上级 905467be
......@@ -276,6 +276,7 @@ void AscendBackendOptimization(const std::shared_ptr<session::KernelGraph> &kern
auto other_pm = std::make_shared<PassManager>("other_pm");
other_pm->AddPass(std::make_shared<AllReduceFusion>());
other_pm->AddPass(std::make_shared<AllGatherFusion>());
other_pm->AddPass(std::make_shared<BroadcastFusion>());
other_pm->AddPass(std::make_shared<ParameterTransOpFusion>());
other_pm->AddPass(std::make_shared<RefreshParameterFormat>());
other_pm->AddPass(std::make_shared<BufferFusion>());
......
......@@ -62,6 +62,12 @@ class AllGatherFusion : public CommunicationOpFusion {
explicit AllGatherFusion(size_t groups = 1) : CommunicationOpFusion("all_gather_fusion", kAllGatherOpName, groups) {}
~AllGatherFusion() override = default;
};
class BroadcastFusion : public CommunicationOpFusion {
public:
explicit BroadcastFusion(size_t groups = 1) : CommunicationOpFusion("broadcast_fusion", kBroadcastOpName, groups) {}
~BroadcastFusion() override = default;
};
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_COMMUNICATION_OP_FUSION_H_
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册