From 31d8db9fc6c69482814272f48f49b843e24f97eb Mon Sep 17 00:00:00 2001 From: yujianfeng Date: Wed, 13 May 2020 15:30:27 +0800 Subject: [PATCH] Add broadcast fusion pass --- .../pre_activate/ascend/ascend_backend_optimization.cc | 1 + mindspore/ccsrc/pre_activate/pass/communication_op_fusion.h | 6 ++++++ .../{common/ir_fusion => pass}/allreduce_fusion_test.cc | 0 3 files changed, 7 insertions(+) rename tests/ut/cpp/pre_activate/{common/ir_fusion => pass}/allreduce_fusion_test.cc (100%) diff --git a/mindspore/ccsrc/pre_activate/ascend/ascend_backend_optimization.cc b/mindspore/ccsrc/pre_activate/ascend/ascend_backend_optimization.cc index 356926e2f..16a82500b 100644 --- a/mindspore/ccsrc/pre_activate/ascend/ascend_backend_optimization.cc +++ b/mindspore/ccsrc/pre_activate/ascend/ascend_backend_optimization.cc @@ -276,6 +276,7 @@ void AscendBackendOptimization(const std::shared_ptr &kern auto other_pm = std::make_shared("other_pm"); other_pm->AddPass(std::make_shared()); other_pm->AddPass(std::make_shared()); + other_pm->AddPass(std::make_shared()); other_pm->AddPass(std::make_shared()); other_pm->AddPass(std::make_shared()); other_pm->AddPass(std::make_shared()); diff --git a/mindspore/ccsrc/pre_activate/pass/communication_op_fusion.h b/mindspore/ccsrc/pre_activate/pass/communication_op_fusion.h index af8b557d5..e98da1f0c 100644 --- a/mindspore/ccsrc/pre_activate/pass/communication_op_fusion.h +++ b/mindspore/ccsrc/pre_activate/pass/communication_op_fusion.h @@ -62,6 +62,12 @@ class AllGatherFusion : public CommunicationOpFusion { explicit AllGatherFusion(size_t groups = 1) : CommunicationOpFusion("all_gather_fusion", kAllGatherOpName, groups) {} ~AllGatherFusion() override = default; }; + +class BroadcastFusion : public CommunicationOpFusion { + public: + explicit BroadcastFusion(size_t groups = 1) : CommunicationOpFusion("broadcast_fusion", kBroadcastOpName, groups) {} + ~BroadcastFusion() override = default; +}; } // namespace opt } // namespace mindspore #endif // MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_COMMUNICATION_OP_FUSION_H_ diff --git a/tests/ut/cpp/pre_activate/common/ir_fusion/allreduce_fusion_test.cc b/tests/ut/cpp/pre_activate/pass/allreduce_fusion_test.cc similarity index 100% rename from tests/ut/cpp/pre_activate/common/ir_fusion/allreduce_fusion_test.cc rename to tests/ut/cpp/pre_activate/pass/allreduce_fusion_test.cc -- GitLab