From 4ddd595f2f52dfe3c1296d72ba59f212bb129499 Mon Sep 17 00:00:00 2001 From: Wilber Date: Wed, 16 Jun 2021 21:35:13 +0800 Subject: [PATCH] add compat check for skip_layernorm (#33505) --- .../framework/ir/skip_layernorm_fuse_pass.cc | 5 +++ .../framework/ir/skip_layernorm_fuse_pass.h | 43 +++++++++++++++++++ 2 files changed, 48 insertions(+) diff --git a/paddle/fluid/framework/ir/skip_layernorm_fuse_pass.cc b/paddle/fluid/framework/ir/skip_layernorm_fuse_pass.cc index 232e1d8da4d..3c851f13b4d 100644 --- a/paddle/fluid/framework/ir/skip_layernorm_fuse_pass.cc +++ b/paddle/fluid/framework/ir/skip_layernorm_fuse_pass.cc @@ -129,6 +129,11 @@ void SkipLayerNormFusePass::ApplyImpl(ir::Graph *graph) const { return; } + if (!IsCompat(subgraph, graph)) { + LOG(WARNING) << "skip_layernorm pass in op compat failed."; + return; + } + VLOG(4) << "handle SkipLayerNorm fuse"; GET_IR_NODE_FROM_SUBGRAPH(elementwise, elementwise, fused_pattern); GET_IR_NODE_FROM_SUBGRAPH(elementwise_out, elementwise_out, fused_pattern); diff --git a/paddle/fluid/framework/ir/skip_layernorm_fuse_pass.h b/paddle/fluid/framework/ir/skip_layernorm_fuse_pass.h index 3a3e5005239..804d0abdd6f 100644 --- a/paddle/fluid/framework/ir/skip_layernorm_fuse_pass.h +++ b/paddle/fluid/framework/ir/skip_layernorm_fuse_pass.h @@ -33,6 +33,49 @@ class Graph; class SkipLayerNormFusePass : public FusePassBase { public: + SkipLayerNormFusePass() { + AddOpCompat(OpCompat("elementwise_add")) + .AddInput("X") + .IsTensor() + .End() + .AddInput("Y") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddAttr("axis") + .IsIntIn({0, -1}) + .End(); + + AddOpCompat(OpCompat("layer_norm")) + .AddInput("X") + .IsTensor() + .End() + .AddInput("Scale") + .IsTensor() + .End() + .AddInput("Bias") + .IsTensor() + .End() + .AddOutput("Y") + .IsTensor() + .End() + .AddOutput("Mean") + .IsTensor() + .End() + .AddOutput("Variance") + .IsTensor() + .End() + .AddAttr("epsilon") + .IsNumGE(0.0f) + .IsNumLE(0.001f) + .End() + .AddAttr("begin_norm_axis") + .IsNumGT(0) + .End(); + } + virtual ~SkipLayerNormFusePass() {} protected: -- GitLab