diff --git a/paddle/fluid/framework/ir/skip_layernorm_fuse_pass.cc b/paddle/fluid/framework/ir/skip_layernorm_fuse_pass.cc index 232e1d8da4ded39df732912bc86edb9a1fb54317..3c851f13b4d4d5447918945f3adb39b4b9c6c77f 100644 --- a/paddle/fluid/framework/ir/skip_layernorm_fuse_pass.cc +++ b/paddle/fluid/framework/ir/skip_layernorm_fuse_pass.cc @@ -129,6 +129,11 @@ void SkipLayerNormFusePass::ApplyImpl(ir::Graph *graph) const { return; } + if (!IsCompat(subgraph, graph)) { + LOG(WARNING) << "skip_layernorm pass in op compat failed."; + return; + } + VLOG(4) << "handle SkipLayerNorm fuse"; GET_IR_NODE_FROM_SUBGRAPH(elementwise, elementwise, fused_pattern); GET_IR_NODE_FROM_SUBGRAPH(elementwise_out, elementwise_out, fused_pattern); diff --git a/paddle/fluid/framework/ir/skip_layernorm_fuse_pass.h b/paddle/fluid/framework/ir/skip_layernorm_fuse_pass.h index 3a3e50052396a538aebb9027cb444b819129af95..804d0abdd6f06c7c1fbb995907409f0b7fbd3ae2 100644 --- a/paddle/fluid/framework/ir/skip_layernorm_fuse_pass.h +++ b/paddle/fluid/framework/ir/skip_layernorm_fuse_pass.h @@ -33,6 +33,49 @@ class Graph; class SkipLayerNormFusePass : public FusePassBase { public: + SkipLayerNormFusePass() { + AddOpCompat(OpCompat("elementwise_add")) + .AddInput("X") + .IsTensor() + .End() + .AddInput("Y") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddAttr("axis") + .IsIntIn({0, -1}) + .End(); + + AddOpCompat(OpCompat("layer_norm")) + .AddInput("X") + .IsTensor() + .End() + .AddInput("Scale") + .IsTensor() + .End() + .AddInput("Bias") + .IsTensor() + .End() + .AddOutput("Y") + .IsTensor() + .End() + .AddOutput("Mean") + .IsTensor() + .End() + .AddOutput("Variance") + .IsTensor() + .End() + .AddAttr("epsilon") + .IsNumGE(0.0f) + .IsNumLE(0.001f) + .End() + .AddAttr("begin_norm_axis") + .IsNumGT(0) + .End(); + } + virtual ~SkipLayerNormFusePass() {} protected: