未验证 提交 2b848aef 编写于 作者: Y Yuang Liu 提交者: GitHub

Fused attention pass fwd, create the fused_attention op. (#50125)

上级 e6d29e00
...@@ -22,7 +22,6 @@ namespace patterns { ...@@ -22,7 +22,6 @@ namespace patterns {
PDNode* FusedAttentionPattern::operator()(PDNode* x, PDNode* FusedAttentionPattern::operator()(PDNode* x,
bool pre_layer_norm, bool pre_layer_norm,
bool post_layer_norm,
bool has_attn_mask, bool has_attn_mask,
bool do_dropout, bool do_dropout,
bool add_residual) { bool add_residual) {
...@@ -259,7 +258,7 @@ PDNode* FusedAttentionPattern::operator()(PDNode* x, ...@@ -259,7 +258,7 @@ PDNode* FusedAttentionPattern::operator()(PDNode* x,
out_linear_dropout_node->LinksFrom({out_linear_ele_add_out_node}) out_linear_dropout_node->LinksFrom({out_linear_ele_add_out_node})
.LinksTo({out_linear_dropout_mask_node, out_linear_dropout_out_node}); .LinksTo({out_linear_dropout_mask_node, out_linear_dropout_out_node});
if (!add_residual && !post_layer_norm) { if (!add_residual && pre_layer_norm) {
return out_linear_dropout_out_node; return out_linear_dropout_out_node;
} }
...@@ -276,7 +275,7 @@ PDNode* FusedAttentionPattern::operator()(PDNode* x, ...@@ -276,7 +275,7 @@ PDNode* FusedAttentionPattern::operator()(PDNode* x,
residual_ele_add_node->LinksFrom({x, out_linear_dropout_out_node}) residual_ele_add_node->LinksFrom({x, out_linear_dropout_out_node})
.LinksTo({residual_ele_add_out_node}); .LinksTo({residual_ele_add_out_node});
if (!post_layer_norm) { if (pre_layer_norm) {
return residual_ele_add_out_node; return residual_ele_add_out_node;
} }
} }
...@@ -323,13 +322,12 @@ PDNode* FusedAttentionPattern::operator()(PDNode* x, ...@@ -323,13 +322,12 @@ PDNode* FusedAttentionPattern::operator()(PDNode* x,
PDNode* FusedAttentionGradPattern::operator()(PDNode* x, PDNode* FusedAttentionGradPattern::operator()(PDNode* x,
bool pre_layer_norm, bool pre_layer_norm,
bool post_layer_norm,
bool has_attn_mask, bool has_attn_mask,
bool do_dropout, bool do_dropout,
bool add_residual) { bool add_residual) {
// post layer norm // post layer norm
PDNode* post_layer_norm_grad_out_node{nullptr}; PDNode* post_layer_norm_grad_out_node{nullptr};
if (post_layer_norm) { if (!pre_layer_norm) {
auto* post_layer_norm_grad_node = auto* post_layer_norm_grad_node =
pattern->NewNode(post_layer_norm_grad_op_repr()) pattern->NewNode(post_layer_norm_grad_op_repr())
->assert_is_op("layer_norm_grad"); ->assert_is_op("layer_norm_grad");
...@@ -375,7 +373,7 @@ PDNode* FusedAttentionGradPattern::operator()(PDNode* x, ...@@ -375,7 +373,7 @@ PDNode* FusedAttentionGradPattern::operator()(PDNode* x,
PDNode* residual_ele_add_grad_x_grad_node{nullptr}; PDNode* residual_ele_add_grad_x_grad_node{nullptr};
if (add_residual) { if (add_residual) {
PDNode* ele_add_grad_input = x; PDNode* ele_add_grad_input = x;
if (post_layer_norm) { if (!pre_layer_norm) {
ele_add_grad_input = post_layer_norm_grad_out_node; ele_add_grad_input = post_layer_norm_grad_out_node;
} }
auto* residual_ele_add_grad_node = auto* residual_ele_add_grad_node =
...@@ -404,7 +402,7 @@ PDNode* FusedAttentionGradPattern::operator()(PDNode* x, ...@@ -404,7 +402,7 @@ PDNode* FusedAttentionGradPattern::operator()(PDNode* x,
// get the real input x for dropout grad // get the real input x for dropout grad
PDNode* out_linear_grad_input_node = x; PDNode* out_linear_grad_input_node = x;
if (post_layer_norm && !add_residual) { if (!pre_layer_norm && !add_residual) {
out_linear_grad_input_node = post_layer_norm_grad_out_node; out_linear_grad_input_node = post_layer_norm_grad_out_node;
} else if (add_residual) { } else if (add_residual) {
out_linear_grad_input_node = residual_ele_add_grad_out_node; out_linear_grad_input_node = residual_ele_add_grad_out_node;
...@@ -769,11 +767,11 @@ PDNode* FusedAttentionGradPattern::operator()(PDNode* x, ...@@ -769,11 +767,11 @@ PDNode* FusedAttentionGradPattern::operator()(PDNode* x,
void FusedAttentionsPass::ApplyImpl(Graph* graph) const { void FusedAttentionsPass::ApplyImpl(Graph* graph) const {
FusePassBase::Init(name_scope_, graph); FusePassBase::Init(name_scope_, graph);
graph = PreMaskDropResPostFwd(graph); graph = PreMaskDropResFwd(graph);
graph = PreMaskDropResPostBwd(graph); graph = PreMaskDropResBwd(graph);
} }
ir::Graph* FusedAttentionsPass::PreMaskDropResPostFwd(Graph* graph) const { ir::Graph* FusedAttentionsPass::PreMaskDropResFwd(Graph* graph) const {
GraphPatternDetector gpd; GraphPatternDetector gpd;
auto* x = gpd.mutable_pattern() auto* x = gpd.mutable_pattern()
->NewNode(patterns::PDNodeName(name_scope_, "x")) ->NewNode(patterns::PDNodeName(name_scope_, "x"))
...@@ -784,7 +782,6 @@ ir::Graph* FusedAttentionsPass::PreMaskDropResPostFwd(Graph* graph) const { ...@@ -784,7 +782,6 @@ ir::Graph* FusedAttentionsPass::PreMaskDropResPostFwd(Graph* graph) const {
fused_attention_pattern(x, fused_attention_pattern(x,
/* pre_layer_norm */ true, /* pre_layer_norm */ true,
/* post_layer_norm */ true,
/* has_attn_mask */ true, /* has_attn_mask */ true,
/* do_dropout */ true, /* do_dropout */ true,
/* add_residual */ true); /* add_residual */ true);
...@@ -835,10 +832,191 @@ ir::Graph* FusedAttentionsPass::PreMaskDropResPostFwd(Graph* graph) const { ...@@ -835,10 +832,191 @@ ir::Graph* FusedAttentionsPass::PreMaskDropResPostFwd(Graph* graph) const {
fused_attention_pattern); fused_attention_pattern);
GET_IR_NODE_FROM_SUBGRAPH( GET_IR_NODE_FROM_SUBGRAPH(
residual_ele_add_op_node, residual_ele_add_op, fused_attention_pattern); residual_ele_add_op_node, residual_ele_add_op, fused_attention_pattern);
OpDesc fused_attention_op_desc(pre_layer_norm_op_node->Op()->Block());
fused_attention_op_desc.SetType("fused_attention");
fused_attention_op_desc.SetInput("X", {subgraph.at(x)->Name()});
fused_attention_op_desc.SetAttr("pre_layer_norm", true);
GET_IR_NODE_FROM_SUBGRAPH(pre_layer_norm_scale_node,
pre_layer_norm_scale,
fused_attention_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
pre_layer_norm_bias_node, pre_layer_norm_bias, fused_attention_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
pre_layer_norm_out_node, pre_layer_norm_out, fused_attention_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
pre_layer_norm_mean_node, pre_layer_norm_mean, fused_attention_pattern);
GET_IR_NODE_FROM_SUBGRAPH(pre_layer_norm_variance_node,
pre_layer_norm_variance,
fused_attention_pattern);
fused_attention_op_desc.SetInput("LnScale",
{pre_layer_norm_scale_node->Name()});
fused_attention_op_desc.SetInput("LnBias",
{pre_layer_norm_bias_node->Name()});
fused_attention_op_desc.SetOutput("LnOut",
{pre_layer_norm_out_node->Name()});
fused_attention_op_desc.SetOutput("LnMean",
{pre_layer_norm_mean_node->Name()});
fused_attention_op_desc.SetOutput("LnVariance",
{pre_layer_norm_variance_node->Name()});
fused_attention_op_desc.SetAttr(
"epsilon",
PADDLE_GET_CONST(float,
pre_layer_norm_op_node->Op()->GetAttr("epsilon")));
fused_attention_op_desc.SetAttr("transpose_qkv_wb", true);
std::vector<int> shape = PADDLE_GET_CONST(
std::vector<int>, fuse_qkv_reshape_op_node->Op()->GetAttr("shape"));
fused_attention_op_desc.SetAttr("num_heads", shape[2]);
GET_IR_NODE_FROM_SUBGRAPH( GET_IR_NODE_FROM_SUBGRAPH(
post_layer_norm_op_node, post_layer_norm_op, fused_attention_pattern); fuse_qkv_matmul_w_node, fuse_qkv_matmul_w, fused_attention_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
fuse_qkv_matmul_out_node, fuse_qkv_matmul_out, fused_attention_pattern);
GET_IR_NODE_FROM_SUBGRAPH(fuse_qkv_ele_add_bias_node,
fuse_qkv_ele_add_bias,
fused_attention_pattern);
GET_IR_NODE_FROM_SUBGRAPH(fuse_qkv_ele_add_out_node,
fuse_qkv_ele_add_out,
fused_attention_pattern);
GET_IR_NODE_FROM_SUBGRAPH(fuse_qkv_transpose_out_node,
fuse_qkv_transpose_out,
fused_attention_pattern);
fused_attention_op_desc.SetInput("QKVW", {fuse_qkv_matmul_w_node->Name()});
fused_attention_op_desc.SetInput("QKVBias",
{fuse_qkv_ele_add_bias_node->Name()});
fused_attention_op_desc.SetOutput("QKVOut",
{fuse_qkv_matmul_out_node->Name()});
fused_attention_op_desc.SetOutput("QKVBiasOut",
{fuse_qkv_ele_add_out_node->Name()});
fused_attention_op_desc.SetOutput("TransposeOut2",
{fuse_qkv_transpose_out_node->Name()});
// TODO(Yuang Liu): finish the handler GET_IR_NODE_FROM_SUBGRAPH(
qk_matmul_out_node, qk_matmul_out, fused_attention_pattern);
GET_IR_NODE_FROM_SUBGRAPH(add_mask_ele_add_mask_node,
add_mask_ele_add_mask,
fused_attention_pattern);
GET_IR_NODE_FROM_SUBGRAPH(add_mask_ele_add_out_node,
add_mask_ele_add_out,
fused_attention_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
qk_softmax_out_node, qk_softmax_out, fused_attention_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
attn_dropout_out_node, attn_dropout_out, fused_attention_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
attn_dropout_mask_node, attn_dropout_mask, fused_attention_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
qkv_matmul_out_node, qkv_matmul_out, fused_attention_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
qkv_reshape_out_node, qkv_reshape_out, fused_attention_pattern);
fused_attention_op_desc.SetOutput("QKOut", {qk_matmul_out_node->Name()});
fused_attention_op_desc.SetInput("SrcMask",
{add_mask_ele_add_mask_node->Name()});
fused_attention_op_desc.SetOutput("SrcMaskOut",
{add_mask_ele_add_out_node->Name()});
fused_attention_op_desc.SetOutput("SoftmaxOut",
{qk_softmax_out_node->Name()});
fused_attention_op_desc.SetAttr(
"attn_dropout_rate",
PADDLE_GET_CONST(float,
attn_dropout_op_node->Op()->GetAttr("dropout_prob")));
fused_attention_op_desc.SetAttr(
"is_test",
PADDLE_GET_CONST(bool, attn_dropout_op_node->Op()->GetAttr("is_test")));
fused_attention_op_desc.SetAttr(
"attn_dropout_fix_seed",
PADDLE_GET_CONST(bool,
attn_dropout_op_node->Op()->GetAttr("fix_seed")));
fused_attention_op_desc.SetAttr(
"attn_dropout_seed",
PADDLE_GET_CONST(int, attn_dropout_op_node->Op()->GetAttr("seed")));
fused_attention_op_desc.SetAttr(
"attn_dropout_implementation",
PADDLE_GET_CONST(
std::string,
attn_dropout_op_node->Op()->GetAttr("dropout_implementation")));
fused_attention_op_desc.SetOutput("AttnDropoutMaskOut",
{attn_dropout_mask_node->Name()});
fused_attention_op_desc.SetOutput("AttnDropoutOut",
{attn_dropout_out_node->Name()});
fused_attention_op_desc.SetOutput("QKTVOut", {qkv_matmul_out_node->Name()});
fused_attention_op_desc.SetOutput("FMHAOut",
{qkv_reshape_out_node->Name()});
GET_IR_NODE_FROM_SUBGRAPH(
out_linear_matmul_w_node, out_linear_matmul_w, fused_attention_pattern);
GET_IR_NODE_FROM_SUBGRAPH(out_linear_matmul_out_node,
out_linear_matmul_out,
fused_attention_pattern);
GET_IR_NODE_FROM_SUBGRAPH(out_linear_ele_add_bias_node,
out_linear_ele_add_bias,
fused_attention_pattern);
GET_IR_NODE_FROM_SUBGRAPH(out_linear_ele_add_out_node,
out_linear_ele_add_out,
fused_attention_pattern);
fused_attention_op_desc.SetInput("OutLinearW",
{out_linear_matmul_w_node->Name()});
fused_attention_op_desc.SetInput("OutLinearBias",
{out_linear_ele_add_bias_node->Name()});
fused_attention_op_desc.SetOutput("OutLinearOut",
{out_linear_matmul_out_node->Name()});
GET_IR_NODE_FROM_SUBGRAPH(out_linear_dropout_mask_node,
out_linear_dropout_mask,
fused_attention_pattern);
fused_attention_op_desc.SetAttr(
"dropout_rate",
PADDLE_GET_CONST(
float, out_linear_dropout_op_node->Op()->GetAttr("dropout_prob")));
fused_attention_op_desc.SetAttr(
"dropout_fix_seed",
PADDLE_GET_CONST(
bool, out_linear_dropout_op_node->Op()->GetAttr("fix_seed")));
fused_attention_op_desc.SetAttr(
"dropout_seed",
PADDLE_GET_CONST(int,
out_linear_dropout_op_node->Op()->GetAttr("seed")));
fused_attention_op_desc.SetAttr(
"dropout_implementation",
PADDLE_GET_CONST(std::string,
out_linear_dropout_op_node->Op()->GetAttr(
"dropout_implementation")));
fused_attention_op_desc.SetOutput("DropoutMaskOut",
{out_linear_dropout_mask_node->Name()});
GET_IR_NODE_FROM_SUBGRAPH(residual_ele_add_out_node,
residual_ele_add_out,
fused_attention_pattern);
fused_attention_op_desc.SetAttr("add_residual", true);
fused_attention_op_desc.SetOutput("Y", {residual_ele_add_out_node->Name()});
auto fused_attention_node = g->CreateOpNode(&fused_attention_op_desc);
IR_NODE_LINK_TO(subgraph.at(x), fused_attention_node);
IR_NODE_LINK_TO(pre_layer_norm_scale_node, fused_attention_node);
IR_NODE_LINK_TO(pre_layer_norm_bias_node, fused_attention_node);
IR_NODE_LINK_TO(fuse_qkv_matmul_w_node, fused_attention_node);
IR_NODE_LINK_TO(fuse_qkv_ele_add_bias_node, fused_attention_node);
IR_NODE_LINK_TO(add_mask_ele_add_mask_node, fused_attention_node);
IR_NODE_LINK_TO(out_linear_matmul_w_node, fused_attention_node);
IR_NODE_LINK_TO(out_linear_ele_add_bias_node, fused_attention_node);
IR_NODE_LINK_TO(fused_attention_node, pre_layer_norm_out_node);
IR_NODE_LINK_TO(fused_attention_node, pre_layer_norm_mean_node);
IR_NODE_LINK_TO(fused_attention_node, pre_layer_norm_variance_node);
IR_NODE_LINK_TO(fused_attention_node, fuse_qkv_matmul_out_node);
IR_NODE_LINK_TO(fused_attention_node, fuse_qkv_ele_add_out_node);
IR_NODE_LINK_TO(fused_attention_node, fuse_qkv_transpose_out_node);
IR_NODE_LINK_TO(fused_attention_node, qk_matmul_out_node);
IR_NODE_LINK_TO(fused_attention_node, add_mask_ele_add_out_node);
IR_NODE_LINK_TO(fused_attention_node, qk_softmax_out_node);
IR_NODE_LINK_TO(fused_attention_node, attn_dropout_mask_node);
IR_NODE_LINK_TO(fused_attention_node, attn_dropout_out_node);
IR_NODE_LINK_TO(fused_attention_node, qkv_matmul_out_node);
IR_NODE_LINK_TO(fused_attention_node, qkv_reshape_out_node);
IR_NODE_LINK_TO(fused_attention_node, out_linear_matmul_out_node);
IR_NODE_LINK_TO(fused_attention_node, out_linear_dropout_mask_node);
IR_NODE_LINK_TO(fused_attention_node, residual_ele_add_out_node);
GraphSafeRemoveNodes(g, GraphSafeRemoveNodes(g,
{pre_layer_norm_op_node, {pre_layer_norm_op_node,
...@@ -858,8 +1036,7 @@ ir::Graph* FusedAttentionsPass::PreMaskDropResPostFwd(Graph* graph) const { ...@@ -858,8 +1036,7 @@ ir::Graph* FusedAttentionsPass::PreMaskDropResPostFwd(Graph* graph) const {
out_linear_matmul_op_node, out_linear_matmul_op_node,
out_linear_ele_add_op_node, out_linear_ele_add_op_node,
out_linear_dropout_op_node, out_linear_dropout_op_node,
residual_ele_add_op_node, residual_ele_add_op_node});
post_layer_norm_op_node});
found_fused_attention++; found_fused_attention++;
}; };
...@@ -869,18 +1046,17 @@ ir::Graph* FusedAttentionsPass::PreMaskDropResPostFwd(Graph* graph) const { ...@@ -869,18 +1046,17 @@ ir::Graph* FusedAttentionsPass::PreMaskDropResPostFwd(Graph* graph) const {
return graph; return graph;
} }
ir::Graph* FusedAttentionsPass::PreMaskDropResPostBwd(Graph* graph) const { ir::Graph* FusedAttentionsPass::PreMaskDropResBwd(Graph* graph) const {
GraphPatternDetector gpd; GraphPatternDetector gpd;
auto* x = gpd.mutable_pattern() auto* x = gpd.mutable_pattern()
->NewNode(patterns::PDNodeName(name_scope_, "x")) ->NewNode(patterns::PDNodeName(name_scope_, "x"))
->AsInput() ->AsInput()
->assert_is_op_input("layer_norm_grad", "Y@GRAD"); ->assert_is_op_input("elementwise_add_grad", "Out@GRAD");
patterns::FusedAttentionGradPattern fused_attention_grad_pattern( patterns::FusedAttentionGradPattern fused_attention_grad_pattern(
gpd.mutable_pattern(), "fused_attention_grad_pattern"); gpd.mutable_pattern(), "fused_attention_grad_pattern");
fused_attention_grad_pattern(x, fused_attention_grad_pattern(x,
/* pre_layer_norm */ true, /* pre_layer_norm */ true,
/* post_layer_norm */ true,
/* has_attn_mask */ true, /* has_attn_mask */ true,
/* do_dropout */ true, /* do_dropout */ true,
/* add_residual */ true); /* add_residual */ true);
...@@ -891,9 +1067,6 @@ ir::Graph* FusedAttentionsPass::PreMaskDropResPostBwd(Graph* graph) const { ...@@ -891,9 +1067,6 @@ ir::Graph* FusedAttentionsPass::PreMaskDropResPostBwd(Graph* graph) const {
Graph* g) { Graph* g) {
VLOG(3) << "handle FusedMultiHeadAttention backward pass's fusion"; VLOG(3) << "handle FusedMultiHeadAttention backward pass's fusion";
GET_IR_NODE_FROM_SUBGRAPH(post_layer_norm_grad_op_node,
post_layer_norm_grad_op,
fused_attention_grad_pattern);
GET_IR_NODE_FROM_SUBGRAPH(residual_ele_add_grad_op_node, GET_IR_NODE_FROM_SUBGRAPH(residual_ele_add_grad_op_node,
residual_ele_add_grad_op, residual_ele_add_grad_op,
fused_attention_grad_pattern); fused_attention_grad_pattern);
...@@ -953,17 +1126,26 @@ ir::Graph* FusedAttentionsPass::PreMaskDropResPostBwd(Graph* graph) const { ...@@ -953,17 +1126,26 @@ ir::Graph* FusedAttentionsPass::PreMaskDropResPostBwd(Graph* graph) const {
// TODO(Yuang Liu): finish the handler // TODO(Yuang Liu): finish the handler
GraphSafeRemoveNodes( GraphSafeRemoveNodes(g,
g, {post_layer_norm_grad_op_node, residual_ele_add_grad_op_node, {residual_ele_add_grad_op_node,
out_linear_dropout_grad_op_node, out_linear_ele_add_grad_op_node, out_linear_dropout_grad_op_node,
out_linear_matmul_grad_op_node, qkv_reshape_grad_op_node, out_linear_ele_add_grad_op_node,
qkv_transpose_grad_op_node, qkv_matmul_grad_op_node, out_linear_matmul_grad_op_node,
attn_dropout_grad_op_node, qk_softmax_grad_op_node, qkv_reshape_grad_op_node,
add_mask_ele_add_grad_op_node, qk_scale_grad_op_node, qkv_transpose_grad_op_node,
qk_matmul_grad_op_node, fuse_qkv_split_grad_op_node, qkv_matmul_grad_op_node,
fuse_qkv_transpose_grad_op_node, fuse_qkv_reshape_grad_op_node, attn_dropout_grad_op_node,
fuse_qkv_ele_add_grad_op_node, fuse_qkv_matmul_grad_op_node, qk_softmax_grad_op_node,
pre_layer_norm_grad_op_node, grad_accumulation_sum_op_node}); add_mask_ele_add_grad_op_node,
qk_scale_grad_op_node,
qk_matmul_grad_op_node,
fuse_qkv_split_grad_op_node,
fuse_qkv_transpose_grad_op_node,
fuse_qkv_reshape_grad_op_node,
fuse_qkv_ele_add_grad_op_node,
fuse_qkv_matmul_grad_op_node,
pre_layer_norm_grad_op_node,
grad_accumulation_sum_op_node});
found_fused_attention++; found_fused_attention++;
}; };
......
...@@ -28,7 +28,7 @@ namespace patterns { ...@@ -28,7 +28,7 @@ namespace patterns {
// Declare patterns for multi head attention. // Declare patterns for multi head attention.
// Can detect: // Can detect:
// 1. Pre layer norm, post layer norm or sandwich layer norm. // 1. Pre layer norm or post layer norm.
// 2. Add attn mask for qk product before the softmax or not. // 2. Add attn mask for qk product before the softmax or not.
// 3. Do attn dropout or not. // 3. Do attn dropout or not.
// 4. Add residual to the out linear result or not. // 4. Add residual to the out linear result or not.
...@@ -37,11 +37,10 @@ struct FusedAttentionPattern : public PatternBase { ...@@ -37,11 +37,10 @@ struct FusedAttentionPattern : public PatternBase {
: PatternBase(pattern, name_scope, "fused_attention_pattern") {} : PatternBase(pattern, name_scope, "fused_attention_pattern") {}
PDNode* operator()(PDNode* x, PDNode* operator()(PDNode* x,
bool pre_layer_norm, // do pre ln or not bool pre_layer_norm, // do pre ln or not
bool post_layer_norm, // do post ln or not bool has_attn_mask, // add attn mask to qk or not
bool has_attn_mask, // add attn mask to qk or not bool do_dropout, // dropout the softmax(qk) or not
bool do_dropout, // dropout the softmax(qk) or not bool add_residual); // add residual to out linear or not
bool add_residual); // add residual to out linear or not
// pre layer norm // pre layer norm
PATTERN_DECL_NODE(pre_layer_norm_op); PATTERN_DECL_NODE(pre_layer_norm_op);
...@@ -134,11 +133,10 @@ struct FusedAttentionGradPattern : public PatternBase { ...@@ -134,11 +133,10 @@ struct FusedAttentionGradPattern : public PatternBase {
: PatternBase(pattern, name_scope, "fused_attention_pattern") {} : PatternBase(pattern, name_scope, "fused_attention_pattern") {}
PDNode* operator()(PDNode* x, PDNode* operator()(PDNode* x,
bool pre_layer_norm, // pre ln bool pre_layer_norm, // pre ln
bool post_layer_norm, // post ln bool has_attn_mask, // add attn mask to qk or not
bool has_attn_mask, // add attn mask to qk or not bool do_dropout, // dropout the softmax(qk) or not
bool do_dropout, // dropout the softmax(qk) or not bool add_residual); // add residual to out linear or not
bool add_residual); // add residual to out linear or not
// post layer norm grad // post layer norm grad
PATTERN_DECL_NODE(post_layer_norm_grad_op); PATTERN_DECL_NODE(post_layer_norm_grad_op);
...@@ -275,9 +273,9 @@ class FusedAttentionsPass : public FusePassBase { ...@@ -275,9 +273,9 @@ class FusedAttentionsPass : public FusePassBase {
// If true, the function name will have an abbreviation part. // If true, the function name will have an abbreviation part.
// If false, the function name won't contain an abbreviation for it. // If false, the function name won't contain an abbreviation for it.
ir::Graph* PreMaskDropResPostFwd(Graph* graph) const; ir::Graph* PreMaskDropResFwd(Graph* graph) const;
ir::Graph* PreMaskDropResPostBwd(Graph* graph) const; ir::Graph* PreMaskDropResBwd(Graph* graph) const;
}; };
} // namespace ir } // namespace ir
......
...@@ -31,7 +31,6 @@ class MultiHeadAttention(paddle.nn.Layer): ...@@ -31,7 +31,6 @@ class MultiHeadAttention(paddle.nn.Layer):
num_heads, num_heads,
add_residual=True, add_residual=True,
pre_ln=True, pre_ln=True,
post_ln=False,
attn_dropout=True, attn_dropout=True,
): ):
super(MultiHeadAttention, self).__init__() super(MultiHeadAttention, self).__init__()
...@@ -42,7 +41,6 @@ class MultiHeadAttention(paddle.nn.Layer): ...@@ -42,7 +41,6 @@ class MultiHeadAttention(paddle.nn.Layer):
self.add_residual = add_residual self.add_residual = add_residual
self.pre_ln = pre_ln self.pre_ln = pre_ln
self.post_ln = post_ln
self.attn_dropout = attn_dropout self.attn_dropout = attn_dropout
self.head_dim = embed_dim // num_heads self.head_dim = embed_dim // num_heads
...@@ -90,7 +88,7 @@ class MultiHeadAttention(paddle.nn.Layer): ...@@ -90,7 +88,7 @@ class MultiHeadAttention(paddle.nn.Layer):
if self.add_residual: if self.add_residual:
out = residual + out out = residual + out
if self.post_ln: if not self.pre_ln:
# post layer norm # post layer norm
out = self.norm2(out) out = self.norm2(out)
...@@ -104,7 +102,6 @@ class TestFusedAttentionPass(unittest.TestCase): ...@@ -104,7 +102,6 @@ class TestFusedAttentionPass(unittest.TestCase):
def setUp(self): def setUp(self):
self.add_residual = True self.add_residual = True
self.pre_ln = True self.pre_ln = True
self.post_ln = True
self.attn_dropout = True self.attn_dropout = True
self.add_mask = True self.add_mask = True
...@@ -120,6 +117,7 @@ class TestFusedAttentionPass(unittest.TestCase): ...@@ -120,6 +117,7 @@ class TestFusedAttentionPass(unittest.TestCase):
).astype('float32') ).astype('float32')
main_prog = paddle.static.Program() main_prog = paddle.static.Program()
main_prog.random_seed = 1234
startup_prog = paddle.static.Program() startup_prog = paddle.static.Program()
with paddle.static.program_guard(main_prog, startup_prog): with paddle.static.program_guard(main_prog, startup_prog):
...@@ -142,7 +140,6 @@ class TestFusedAttentionPass(unittest.TestCase): ...@@ -142,7 +140,6 @@ class TestFusedAttentionPass(unittest.TestCase):
num_heads, num_heads,
add_residual=self.add_residual, add_residual=self.add_residual,
pre_ln=self.pre_ln, pre_ln=self.pre_ln,
post_ln=self.post_ln,
attn_dropout=self.attn_dropout, attn_dropout=self.attn_dropout,
) )
...@@ -157,13 +154,14 @@ class TestFusedAttentionPass(unittest.TestCase): ...@@ -157,13 +154,14 @@ class TestFusedAttentionPass(unittest.TestCase):
pass_manager.apply([main_prog], [startup_prog]) pass_manager.apply([main_prog], [startup_prog])
ops = main_prog.global_block().ops ops = main_prog.global_block().ops
assert ops[2].type == 'reduce_mean' assert ops[2].type == 'fused_attention'
assert ops[4].type == 'reduce_mean_grad' assert ops[3].type == 'reduce_mean'
assert ops[5].type == 'reduce_mean_grad'
# two ops for linear, one op for reduce mean # two ops for linear, one op for reduce mean
# one fill constant # one fill constant
# one op for reduce mean grad, two ops for linear bwd # one op for reduce mean grad, two ops for linear bwd
# the eighth op should be the optimizer # the eighth op should be the optimizer
assert ops[7].type == 'sgd' assert ops[8].type == 'sgd'
if __name__ == "__main__": if __name__ == "__main__":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册