提交 36d1aadf 编写于 作者: H huanghui

fix when Batchnorm's output is 0,1,2,4, fission doesn't work

上级 09318086
......@@ -99,6 +99,7 @@ namespace {
void AddAscendBackendOptionalIRFusion(PassManager *ir_fusion_pm) {
MS_EXCEPTION_IF_NULL(ir_fusion_pm);
ir_fusion_pm->AddPass(std::make_shared<BatchNormBertFission>());
ir_fusion_pm->AddPass(std::make_shared<SingleBatchNormFission>());
ir_fusion_pm->AddPass(std::make_shared<SquareSumFusion>());
ir_fusion_pm->AddPass(std::make_shared<ClipByNormNoDivSquareSumFusion>());
ir_fusion_pm->AddPass(std::make_shared<LambUpdateWithLRRuleFusion>());
......@@ -225,7 +226,6 @@ void AscendBackendIRFusionOptimization(const std::shared_ptr<session::KernelGrap
ir_fusion_pm->AddPass(std::make_shared<FusedBatchNormFusion>());
ir_fusion_pm->AddPass(std::make_shared<FusedBatchNormMixPrecisionFusion0>());
ir_fusion_pm->AddPass(std::make_shared<FusedBatchNormMixPrecisionFusion1>());
ir_fusion_pm->AddPass(std::make_shared<SingleBatchNormFission>());
}
ir_fusion_pm->AddPass(std::make_shared<AddMemcpyAsync>());
ir_fusion_pm->AddPass(std::make_shared<InsertPadForNMSWithMask>());
......
......@@ -24,7 +24,7 @@ namespace mindspore {
namespace opt {
namespace {
const std::vector<int> kOutputIndex{0, 1, 2, 3, 4};
constexpr size_t kBatchNormRealOutputNum = 5;
constexpr size_t kBatchNormLeastOutputNum = 1;
constexpr size_t kBatchNormRealInputNum = 3;
bool GetBatchNormOutputs(const FuncGraphPtr &func_graph, const AnfNodePtr &bn, std::vector<AnfNodePtr> *bn_outputs) {
......@@ -56,7 +56,7 @@ bool GetBatchNormOutputs(const FuncGraphPtr &func_graph, const AnfNodePtr &bn, s
bn_outputs->push_back(output);
output_num++;
}
return output_num == kBatchNormRealOutputNum;
return output_num > kBatchNormLeastOutputNum;
}
AnfNodePtr CreateBNTrainingReduce(const FuncGraphPtr &func_graph, const AnfNodePtr &bn) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册