未验证 提交 b94b3ac0 编写于 作者: W wz1qqx 提交者: GitHub

[XPU]add layer_norm fuse pass (#54930)

上级 a215c46a
...@@ -403,12 +403,10 @@ void LayerNormFusePass::ApplyImpl(Graph* graph) const { ...@@ -403,12 +403,10 @@ void LayerNormFusePass::ApplyImpl(Graph* graph) const {
x_mean_out, x_mean_out,
x_sub_mean, x_sub_mean,
x_sub_mean_out, x_sub_mean_out,
sqr_pow,
x_sub_mean_sqr, x_sub_mean_sqr,
x_sub_mean_sqr_out, x_sub_mean_sqr_out,
std_dev, std_dev,
std_dev_out, std_dev_out,
eps,
std_dev_eps, std_dev_eps,
std_dev_eps_out, std_dev_eps_out,
std_dev_eps_sqrt, std_dev_eps_sqrt,
...@@ -417,9 +415,7 @@ void LayerNormFusePass::ApplyImpl(Graph* graph) const { ...@@ -417,9 +415,7 @@ void LayerNormFusePass::ApplyImpl(Graph* graph) const {
division_out, division_out,
scale, scale,
scale_out, scale_out,
shift, shift});
gamma,
beta});
found_layer_norm_count++; found_layer_norm_count++;
}; };
......
...@@ -24,16 +24,6 @@ ...@@ -24,16 +24,6 @@
#include "paddle/fluid/framework/op_version_registry.h" #include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
namespace phi {
class DenseTensor;
} // namespace phi
namespace paddle {
namespace framework {
class Scope;
} // namespace framework
} // namespace paddle
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace ir { namespace ir {
...@@ -163,9 +153,8 @@ Squeeze2MatmulPattern::Squeeze2MatmulPattern(PDPattern* pattern, ...@@ -163,9 +153,8 @@ Squeeze2MatmulPattern::Squeeze2MatmulPattern(PDPattern* pattern,
->assert_more([](Node* node) { ->assert_more([](Node* node) {
auto squeeze2_in_x_shape = node->Var()->GetShape(); auto squeeze2_in_x_shape = node->Var()->GetShape();
size_t squeeze2_in_rank = squeeze2_in_x_shape.size(); size_t squeeze2_in_rank = squeeze2_in_x_shape.size();
bool nice_shape = return squeeze2_in_rank == 4 &&
squeeze2_in_x_shape[2] == 1 && squeeze2_in_x_shape[3] == 1; (squeeze2_in_x_shape[2] == 1 && squeeze2_in_x_shape[3] == 1);
return squeeze2_in_rank == 4 && nice_shape;
}); });
auto* squeeze2 = pattern->NewNode(squeeze2_repr()) auto* squeeze2 = pattern->NewNode(squeeze2_repr())
->assert_is_op("squeeze2") ->assert_is_op("squeeze2")
......
...@@ -532,6 +532,7 @@ XpuPassStrategy::XpuPassStrategy() : PassStrategy({}) { ...@@ -532,6 +532,7 @@ XpuPassStrategy::XpuPassStrategy() : PassStrategy({}) {
"fused_multi_transformer_xpu_pass", "fused_multi_transformer_xpu_pass",
"relu6_fuse_pass", "relu6_fuse_pass",
"sigmoid_elementmul_fuse_pass", "sigmoid_elementmul_fuse_pass",
"layer_norm_fuse_pass",
"matmul_weight_trans_pass", "matmul_weight_trans_pass",
"map_matmulv2_to_matmul_xpu_pass", "map_matmulv2_to_matmul_xpu_pass",
"reshape2_matmul_xpu_fuse_pass", "reshape2_matmul_xpu_fuse_pass",
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册