diff --git a/paddle/fluid/framework/ir/fc_elementwise_layernorm_fuse_pass.cc b/paddle/fluid/framework/ir/fc_elementwise_layernorm_fuse_pass.cc index ef5b3c3c96e2374ef0cabc1ed8fc4bbab9577388..6f7a52fce5933042e8929c36da065442d1b02371 100644 --- a/paddle/fluid/framework/ir/fc_elementwise_layernorm_fuse_pass.cc +++ b/paddle/fluid/framework/ir/fc_elementwise_layernorm_fuse_pass.cc @@ -136,6 +136,70 @@ static bool IsEqual(const std::vector &x, const std::vector &y) { return true; } +FCElementwiseLayerNormFusePass::FCElementwiseLayerNormFusePass() { + AddOpCompat(OpCompat("fc")) + .AddInput("Input") + .IsTensor() + .End() + .AddInput("W") + .IsTensor() + .End() + .AddInput("Bias") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddAttr("in_num_col_dims") + .IsNumGE(1) + .End() + .AddAttr("activation_type") + .IsStringIn({"relu", ""}) + .End(); + + AddOpCompat(OpCompat("layer_norm")) + .AddInput("X") + .IsTensor() + .End() + .AddInput("Scale") + .IsTensor() + .End() + .AddInput("Bias") + .IsTensor() + .End() + .AddOutput("Y") + .IsTensor() + .End() + .AddOutput("Mean") + .IsOptional() + .End() + .AddOutput("Variance") + .IsOptional() + .End() + + .AddAttr("epsilon") + .IsNumGE(0.0f) + .IsNumLE(0.001f) + .End() + .AddAttr("begin_norm_axis") + .IsNumGT(0) + .End(); + + AddOpCompat(OpCompat("elementwise_add")) + .AddInput("X") + .IsTensor() + .End() + .AddInput("Y") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddAttr("axis") + .IsNumEQ(-1) + .End(); +} + void FCElementwiseLayerNormFusePass::ApplyImpl(ir::Graph *graph) const { PADDLE_ENFORCE_NOT_NULL(graph, platform::errors::InvalidArgument( @@ -159,6 +223,11 @@ void FCElementwiseLayerNormFusePass::ApplyImpl(ir::Graph *graph) const { return; } + if (!IsCompat(subgraph, graph)) { + LOG(WARNING) << "Pass in op compat failed."; + return; + } + VLOG(4) << "handle FCElementwiseLayerNorm fuse"; GET_IR_NODE_FROM_SUBGRAPH(fc, fc, fused_pattern); GET_IR_NODE_FROM_SUBGRAPH(fc_w, fc_w, fused_pattern); diff --git a/paddle/fluid/framework/ir/fc_elementwise_layernorm_fuse_pass.h b/paddle/fluid/framework/ir/fc_elementwise_layernorm_fuse_pass.h index 12e4c44b84e87bb710774ebba0ba2853d8b37f5e..0e8f9866c765c2fb9d8c0199a2a02fccee2c6c12 100644 --- a/paddle/fluid/framework/ir/fc_elementwise_layernorm_fuse_pass.h +++ b/paddle/fluid/framework/ir/fc_elementwise_layernorm_fuse_pass.h @@ -24,6 +24,7 @@ class Graph; class FCElementwiseLayerNormFusePass : public FusePassBase { public: + FCElementwiseLayerNormFusePass(); virtual ~FCElementwiseLayerNormFusePass() {} protected: