未验证 提交 b94b3ac0 编写于 作者: W wz1qqx 提交者: GitHub

[XPU]add layer_norm fuse pass (#54930)

上级 a215c46a
......@@ -403,12 +403,10 @@ void LayerNormFusePass::ApplyImpl(Graph* graph) const {
x_mean_out,
x_sub_mean,
x_sub_mean_out,
sqr_pow,
x_sub_mean_sqr,
x_sub_mean_sqr_out,
std_dev,
std_dev_out,
eps,
std_dev_eps,
std_dev_eps_out,
std_dev_eps_sqrt,
......@@ -417,9 +415,7 @@ void LayerNormFusePass::ApplyImpl(Graph* graph) const {
division_out,
scale,
scale_out,
shift,
gamma,
beta});
shift});
found_layer_norm_count++;
};
......
......@@ -24,16 +24,6 @@
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/platform/enforce.h"
namespace phi {
class DenseTensor;
} // namespace phi
namespace paddle {
namespace framework {
class Scope;
} // namespace framework
} // namespace paddle
namespace paddle {
namespace framework {
namespace ir {
......@@ -163,9 +153,8 @@ Squeeze2MatmulPattern::Squeeze2MatmulPattern(PDPattern* pattern,
->assert_more([](Node* node) {
auto squeeze2_in_x_shape = node->Var()->GetShape();
size_t squeeze2_in_rank = squeeze2_in_x_shape.size();
bool nice_shape =
squeeze2_in_x_shape[2] == 1 && squeeze2_in_x_shape[3] == 1;
return squeeze2_in_rank == 4 && nice_shape;
return squeeze2_in_rank == 4 &&
(squeeze2_in_x_shape[2] == 1 && squeeze2_in_x_shape[3] == 1);
});
auto* squeeze2 = pattern->NewNode(squeeze2_repr())
->assert_is_op("squeeze2")
......
......@@ -532,6 +532,7 @@ XpuPassStrategy::XpuPassStrategy() : PassStrategy({}) {
"fused_multi_transformer_xpu_pass",
"relu6_fuse_pass",
"sigmoid_elementmul_fuse_pass",
"layer_norm_fuse_pass",
"matmul_weight_trans_pass",
"map_matmulv2_to_matmul_xpu_pass",
"reshape2_matmul_xpu_fuse_pass",
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册