提交 06242634 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!1940 Fix when batchnorm has just 4 outputs, it will be not fission

Merge pull request !1940 from huanghui/single-batchnorm-fission-pass
...@@ -99,6 +99,7 @@ namespace { ...@@ -99,6 +99,7 @@ namespace {
void AddAscendBackendOptionalIRFusion(PassManager *ir_fusion_pm) { void AddAscendBackendOptionalIRFusion(PassManager *ir_fusion_pm) {
MS_EXCEPTION_IF_NULL(ir_fusion_pm); MS_EXCEPTION_IF_NULL(ir_fusion_pm);
ir_fusion_pm->AddPass(std::make_shared<BatchNormBertFission>()); ir_fusion_pm->AddPass(std::make_shared<BatchNormBertFission>());
ir_fusion_pm->AddPass(std::make_shared<SingleBatchNormFission>());
ir_fusion_pm->AddPass(std::make_shared<SquareSumFusion>()); ir_fusion_pm->AddPass(std::make_shared<SquareSumFusion>());
ir_fusion_pm->AddPass(std::make_shared<ClipByNormNoDivSquareSumFusion>()); ir_fusion_pm->AddPass(std::make_shared<ClipByNormNoDivSquareSumFusion>());
ir_fusion_pm->AddPass(std::make_shared<LambUpdateWithLRRuleFusion>()); ir_fusion_pm->AddPass(std::make_shared<LambUpdateWithLRRuleFusion>());
...@@ -225,7 +226,6 @@ void AscendBackendIRFusionOptimization(const std::shared_ptr<session::KernelGrap ...@@ -225,7 +226,6 @@ void AscendBackendIRFusionOptimization(const std::shared_ptr<session::KernelGrap
ir_fusion_pm->AddPass(std::make_shared<FusedBatchNormFusion>()); ir_fusion_pm->AddPass(std::make_shared<FusedBatchNormFusion>());
ir_fusion_pm->AddPass(std::make_shared<FusedBatchNormMixPrecisionFusion0>()); ir_fusion_pm->AddPass(std::make_shared<FusedBatchNormMixPrecisionFusion0>());
ir_fusion_pm->AddPass(std::make_shared<FusedBatchNormMixPrecisionFusion1>()); ir_fusion_pm->AddPass(std::make_shared<FusedBatchNormMixPrecisionFusion1>());
ir_fusion_pm->AddPass(std::make_shared<SingleBatchNormFission>());
} }
ir_fusion_pm->AddPass(std::make_shared<AddMemcpyAsync>()); ir_fusion_pm->AddPass(std::make_shared<AddMemcpyAsync>());
ir_fusion_pm->AddPass(std::make_shared<InsertPadForNMSWithMask>()); ir_fusion_pm->AddPass(std::make_shared<InsertPadForNMSWithMask>());
......
...@@ -24,7 +24,7 @@ namespace mindspore { ...@@ -24,7 +24,7 @@ namespace mindspore {
namespace opt { namespace opt {
namespace { namespace {
const std::vector<int> kOutputIndex{0, 1, 2, 3, 4}; const std::vector<int> kOutputIndex{0, 1, 2, 3, 4};
constexpr size_t kBatchNormRealOutputNum = 5; constexpr size_t kBatchNormLeastOutputNum = 1;
constexpr size_t kBatchNormRealInputNum = 3; constexpr size_t kBatchNormRealInputNum = 3;
bool GetBatchNormOutputs(const FuncGraphPtr &func_graph, const AnfNodePtr &bn, std::vector<AnfNodePtr> *bn_outputs) { bool GetBatchNormOutputs(const FuncGraphPtr &func_graph, const AnfNodePtr &bn, std::vector<AnfNodePtr> *bn_outputs) {
...@@ -56,7 +56,7 @@ bool GetBatchNormOutputs(const FuncGraphPtr &func_graph, const AnfNodePtr &bn, s ...@@ -56,7 +56,7 @@ bool GetBatchNormOutputs(const FuncGraphPtr &func_graph, const AnfNodePtr &bn, s
bn_outputs->push_back(output); bn_outputs->push_back(output);
output_num++; output_num++;
} }
return output_num == kBatchNormRealOutputNum; return output_num > kBatchNormLeastOutputNum;
} }
AnfNodePtr CreateBNTrainingReduce(const FuncGraphPtr &func_graph, const AnfNodePtr &bn) { AnfNodePtr CreateBNTrainingReduce(const FuncGraphPtr &func_graph, const AnfNodePtr &bn) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册