From 9d8d5317d509170122f225637e6c9eee9627e888 Mon Sep 17 00:00:00 2001 From: feng_shuai Date: Fri, 11 Jun 2021 17:54:13 +0800 Subject: [PATCH] fc_elementwise_layer_fuse_pass (#33467) * fc_elementwise_layer_fuse_pass * fc_ele_layernorm_pass * fc_elementwise_layernorm_pass * fc_elementwise_layernorm_pass_amend --- .../ir/fc_elementwise_layernorm_fuse_pass.cc | 69 +++++++++++++++++++ .../ir/fc_elementwise_layernorm_fuse_pass.h | 1 + 2 files changed, 70 insertions(+) diff --git a/paddle/fluid/framework/ir/fc_elementwise_layernorm_fuse_pass.cc b/paddle/fluid/framework/ir/fc_elementwise_layernorm_fuse_pass.cc index ef5b3c3c96e..6f7a52fce59 100644 --- a/paddle/fluid/framework/ir/fc_elementwise_layernorm_fuse_pass.cc +++ b/paddle/fluid/framework/ir/fc_elementwise_layernorm_fuse_pass.cc @@ -136,6 +136,70 @@ static bool IsEqual(const std::vector &x, const std::vector &y) { return true; } +FCElementwiseLayerNormFusePass::FCElementwiseLayerNormFusePass() { + AddOpCompat(OpCompat("fc")) + .AddInput("Input") + .IsTensor() + .End() + .AddInput("W") + .IsTensor() + .End() + .AddInput("Bias") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddAttr("in_num_col_dims") + .IsNumGE(1) + .End() + .AddAttr("activation_type") + .IsStringIn({"relu", ""}) + .End(); + + AddOpCompat(OpCompat("layer_norm")) + .AddInput("X") + .IsTensor() + .End() + .AddInput("Scale") + .IsTensor() + .End() + .AddInput("Bias") + .IsTensor() + .End() + .AddOutput("Y") + .IsTensor() + .End() + .AddOutput("Mean") + .IsOptional() + .End() + .AddOutput("Variance") + .IsOptional() + .End() + + .AddAttr("epsilon") + .IsNumGE(0.0f) + .IsNumLE(0.001f) + .End() + .AddAttr("begin_norm_axis") + .IsNumGT(0) + .End(); + + AddOpCompat(OpCompat("elementwise_add")) + .AddInput("X") + .IsTensor() + .End() + .AddInput("Y") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddAttr("axis") + .IsNumEQ(-1) + .End(); +} + void FCElementwiseLayerNormFusePass::ApplyImpl(ir::Graph *graph) const { PADDLE_ENFORCE_NOT_NULL(graph, platform::errors::InvalidArgument( @@ -159,6 +223,11 @@ void FCElementwiseLayerNormFusePass::ApplyImpl(ir::Graph *graph) const { return; } + if (!IsCompat(subgraph, graph)) { + LOG(WARNING) << "Pass in op compat failed."; + return; + } + VLOG(4) << "handle FCElementwiseLayerNorm fuse"; GET_IR_NODE_FROM_SUBGRAPH(fc, fc, fused_pattern); GET_IR_NODE_FROM_SUBGRAPH(fc_w, fc_w, fused_pattern); diff --git a/paddle/fluid/framework/ir/fc_elementwise_layernorm_fuse_pass.h b/paddle/fluid/framework/ir/fc_elementwise_layernorm_fuse_pass.h index 12e4c44b84e..0e8f9866c76 100644 --- a/paddle/fluid/framework/ir/fc_elementwise_layernorm_fuse_pass.h +++ b/paddle/fluid/framework/ir/fc_elementwise_layernorm_fuse_pass.h @@ -24,6 +24,7 @@ class Graph; class FCElementwiseLayerNormFusePass : public FusePassBase { public: + FCElementwiseLayerNormFusePass(); virtual ~FCElementwiseLayerNormFusePass() {} protected: -- GitLab