diff --git a/mindspore/ccsrc/pre_activate/ascend/ascend_backend_optimization.cc b/mindspore/ccsrc/pre_activate/ascend/ascend_backend_optimization.cc index 8f3e5a5737e151e4f7dc7aceba6113c3d664dd10..220b309200d26f3fc2d6a327fc530e6dd09664b3 100644 --- a/mindspore/ccsrc/pre_activate/ascend/ascend_backend_optimization.cc +++ b/mindspore/ccsrc/pre_activate/ascend/ascend_backend_optimization.cc @@ -287,6 +287,7 @@ void AscendBackendOptimization(const std::shared_ptr &kern auto other_pm = std::make_shared("other_pm"); other_pm->AddPass(std::make_shared()); other_pm->AddPass(std::make_shared()); + other_pm->AddPass(std::make_shared()); other_pm->AddPass(std::make_shared()); other_pm->AddPass(std::make_shared()); other_pm->AddPass(std::make_shared()); diff --git a/mindspore/ccsrc/pre_activate/pass/communication_op_fusion.h b/mindspore/ccsrc/pre_activate/pass/communication_op_fusion.h index e01d18161646265eace61e83a5c28d69112934b5..d00180f97f023f81c96d92a063d5556221cb79b3 100644 --- a/mindspore/ccsrc/pre_activate/pass/communication_op_fusion.h +++ b/mindspore/ccsrc/pre_activate/pass/communication_op_fusion.h @@ -68,6 +68,13 @@ class BroadcastFusion : public CommunicationOpFusion { explicit BroadcastFusion(size_t groups = 1) : CommunicationOpFusion("broadcast_fusion", kBroadcastOpName, groups) {} ~BroadcastFusion() override = default; }; + +class ReduceScatterFusion : public CommunicationOpFusion { + public: + explicit ReduceScatterFusion(size_t groups = 1) + : CommunicationOpFusion("reduce_scatter_fusion", kReduceScatterOpName, groups) {} + ~ReduceScatterFusion() override = default; +}; } // namespace opt } // namespace mindspore #endif // MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_COMMUNICATION_OP_FUSION_H_