diff --git a/paddle/fluid/framework/ir/mkldnn/batch_norm_act_fuse_pass.cc b/paddle/fluid/framework/ir/mkldnn/batch_norm_act_fuse_pass.cc index 7e28ccd24a80da738ec69f00efb5053dcdf1cde4..3fdb87f254403652a99983c29f9ba283a45eed2b 100644 --- a/paddle/fluid/framework/ir/mkldnn/batch_norm_act_fuse_pass.cc +++ b/paddle/fluid/framework/ir/mkldnn/batch_norm_act_fuse_pass.cc @@ -29,6 +29,55 @@ void FuseBatchNormActOneDNNPass::ApplyImpl(Graph *graph) const { FuseBatchNormAct(graph, act_type); } +FuseBatchNormActOneDNNPass::FuseBatchNormActOneDNNPass() { + AddOpCompat(OpCompat("batch_norm")) + .AddInput("X") + .IsTensor() + .End() + .AddInput("Scale") + .IsTensor() + .End() + .AddInput("Bias") + .IsTensor() + .End() + .AddInput("Mean") + .IsTensor() + .End() + .AddInput("Variance") + .IsTensor() + .End() + .AddOutput("Y") + .IsTensor() + .End() + .AddOutput("MeanOut") + .IsOptional() + .End() + .AddOutput("VarianceOut") + .IsOptional() + .End() + .AddOutput("SavedMean") + .IsOptional() + .End() + .AddOutput("SavedVariance") + .IsOptional() + .End() + .AddOutput("ReserveSpace") + .IsOptional() + .End() + .AddAttr("epsilon") + .IsNumGE(0.0f) + .IsNumLE(0.001f) + .End(); + + AddOpCompat(OpCompat("relu")) + .AddInput("X") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End(); +} + void FuseBatchNormActOneDNNPass::FuseBatchNormAct( Graph *graph, const std::string &act_type) const { PADDLE_ENFORCE_NOT_NULL( @@ -45,6 +94,11 @@ void FuseBatchNormActOneDNNPass::FuseBatchNormAct( auto handler = [&](const GraphPatternDetector::subgraph_t &subgraph, Graph *g) { VLOG(4) << "Fuse BatchNorm with ReLU activation op."; + + if (!IsCompat(subgraph, g)) { + LOG(WARNING) << "Pass in op compat failed."; + return; + } // BN output GET_IR_NODE_FROM_SUBGRAPH(bn_out, bn_out, bn_act_pattern); // ACT output @@ -84,6 +138,11 @@ void FuseBatchNormActOneDNNPass::FuseBatchNormAct( bn_op->SetAttr("trainable_statistics", false); bn_op->SetOutput("Y", {act_out->Name()}); + if (!IsCompat(*bn_op)) { + LOG(WARNING) << "Fc fuse pass in out fc op compat failed."; + return; + } + IR_OP_VAR_LINK(batch_norm, act_out); GraphSafeRemoveNodes(g, {act, bn_out}); found_bn_act_count++; diff --git a/paddle/fluid/framework/ir/mkldnn/batch_norm_act_fuse_pass.h b/paddle/fluid/framework/ir/mkldnn/batch_norm_act_fuse_pass.h index 843e7e420b7be07f7fd63d8a9a7d39791b206333..ba6a65bce8a8cc0822df07ddbdf104ae7c645be9 100644 --- a/paddle/fluid/framework/ir/mkldnn/batch_norm_act_fuse_pass.h +++ b/paddle/fluid/framework/ir/mkldnn/batch_norm_act_fuse_pass.h @@ -31,6 +31,7 @@ namespace ir { */ class FuseBatchNormActOneDNNPass : public FusePassBase { public: + FuseBatchNormActOneDNNPass(); virtual ~FuseBatchNormActOneDNNPass() {} protected: diff --git a/paddle/fluid/framework/ir/mkldnn/batch_norm_act_fuse_pass_tester.cc b/paddle/fluid/framework/ir/mkldnn/batch_norm_act_fuse_pass_tester.cc index 38364721f651527da1da8839d574c1bee136fa4f..e13d44ac23222187a82753a027dd3585f423800b 100644 --- a/paddle/fluid/framework/ir/mkldnn/batch_norm_act_fuse_pass_tester.cc +++ b/paddle/fluid/framework/ir/mkldnn/batch_norm_act_fuse_pass_tester.cc @@ -32,6 +32,7 @@ void SetBatchNormAttrs(OpDesc* bn_op, bool is_test = true, bn_op->SetAttr("is_test", is_test); bn_op->SetAttr("trainable_statistics", trainable_stats); bn_op->SetAttr("fuse_with_relu", false); + bn_op->SetAttr("epsilon", 0.001f); } } diff --git a/paddle/fluid/operators/compat/batch_norm.pbtxt b/paddle/fluid/operators/compat/batch_norm.pbtxt index 772d66f00fcc9b4a985fc77520fbf191b333a802..ac2ccc6296c0ce9dc32a349d3c6914c5778d384a 100644 --- a/paddle/fluid/operators/compat/batch_norm.pbtxt +++ b/paddle/fluid/operators/compat/batch_norm.pbtxt @@ -42,6 +42,10 @@ extra { inputs { name: "MomentumTensor" } + attrs { + name: "@ENABLE_CACHE_RUNTIME_CONTEXT@" + type: BOOLEAN + } attrs { name: "is_test" type: BOOLEAN diff --git a/paddle/fluid/operators/compat/relu.pbtxt b/paddle/fluid/operators/compat/relu.pbtxt index 359bd70c2a310c0ea64da383c416482dfd28403e..bd0e9988010143df268711987a2612b9bbf6457f 100644 --- a/paddle/fluid/operators/compat/relu.pbtxt +++ b/paddle/fluid/operators/compat/relu.pbtxt @@ -8,6 +8,10 @@ def { } } extra { + attrs { + name: "@ENABLE_CACHE_RUNTIME_CONTEXT@" + type: BOOLEAN + } attrs { name: "use_mkldnn" type: BOOLEAN