未验证 提交 8e02f290 编写于 作者: Y Yuang Liu 提交者: GitHub

Fused attention pass backward pattern (#49855)

上级 65bce2b3
......@@ -327,8 +327,441 @@ PDNode* FusedAttentionGradPattern::operator()(PDNode* x,
bool has_attn_mask,
bool do_dropout,
bool add_residual) {
// TODO(Yuang Liu): finish the backward pattern
return nullptr;
// post layer norm
PDNode* post_layer_norm_grad_out_node{nullptr};
if (post_layer_norm) {
auto* post_layer_norm_grad_node =
pattern->NewNode(post_layer_norm_grad_op_repr())
->assert_is_op("layer_norm_grad");
auto* post_layer_norm_grad_bias_node =
pattern->NewNode(post_layer_norm_grad_bias_repr())
->assert_is_op_input("layer_norm_grad", "Bias");
auto* post_layer_norm_grad_scale_node =
pattern->NewNode(post_layer_norm_grad_scale_repr())
->assert_is_op_input("layer_norm_grad", "Scale");
auto* post_layer_norm_grad_mean_node =
pattern->NewNode(post_layer_norm_grad_mean_repr())
->assert_is_op_input("layer_norm_grad", "Mean");
auto* post_layer_norm_grad_variance_node =
pattern->NewNode(post_layer_norm_grad_variance_repr())
->assert_is_op_input("layer_norm_grad", "Variance");
auto* post_layer_norm_grad_x_node =
pattern->NewNode(post_layer_norm_grad_x_repr())
->assert_is_op_input("layer_norm_grad", "X");
post_layer_norm_grad_out_node =
pattern->NewNode(post_layer_norm_grad_x_grad_repr())
->assert_is_op_output("layer_norm_grad", "X@GRAD");
auto* post_layer_norm_grad_bias_grad_node =
pattern->NewNode(post_layer_norm_grad_bias_grad_repr())
->assert_is_op_output("layer_norm_grad", "Bias@GRAD");
auto* post_layer_norm_grad_scale_grad_node =
pattern->NewNode(post_layer_norm_grad_scale_grad_repr())
->assert_is_op_output("layer_norm_grad", "Scale@GRAD");
post_layer_norm_grad_node
->LinksFrom({x,
post_layer_norm_grad_bias_node,
post_layer_norm_grad_scale_node,
post_layer_norm_grad_mean_node,
post_layer_norm_grad_variance_node,
post_layer_norm_grad_x_node})
.LinksTo({post_layer_norm_grad_out_node,
post_layer_norm_grad_bias_grad_node,
post_layer_norm_grad_scale_grad_node});
}
// add residual
PDNode* residual_ele_add_grad_out_node{nullptr};
PDNode* residual_ele_add_grad_x_node{nullptr};
PDNode* residual_ele_add_grad_x_grad_node{nullptr};
if (add_residual) {
PDNode* ele_add_grad_input = x;
if (post_layer_norm) {
ele_add_grad_input = post_layer_norm_grad_out_node;
}
auto* residual_ele_add_grad_node =
pattern->NewNode(residual_ele_add_grad_op_repr())
->assert_is_op("elementwise_add_grad");
residual_ele_add_grad_x_node =
pattern->NewNode(residual_ele_add_grad_x_repr())
->assert_is_op_input("elementwise_add_grad", "X");
auto* residual_ele_add_grad_bias_node =
pattern->NewNode(residual_ele_add_grad_bias_repr())
->assert_is_op_input("elementwise_add_grad", "Y");
residual_ele_add_grad_out_node =
pattern->NewNode(residual_ele_add_grad_bias_grad_repr())
->assert_is_op_output("elementwise_add_grad", "Y@GRAD");
residual_ele_add_grad_x_grad_node =
pattern->NewNode(residual_ele_add_grad_x_grad_repr())
->assert_is_op_output("elementwise_add_grad", "X@GRAD");
ele_add_grad_input->assert_is_op_input("elementwise_add_grad", "Out@GRAD");
residual_ele_add_grad_node
->LinksFrom({ele_add_grad_input,
residual_ele_add_grad_x_node,
residual_ele_add_grad_bias_node})
.LinksTo({residual_ele_add_grad_x_grad_node,
residual_ele_add_grad_out_node});
}
// get the real input x for dropout grad
PDNode* out_linear_grad_input_node = x;
if (post_layer_norm && !add_residual) {
out_linear_grad_input_node = post_layer_norm_grad_out_node;
} else if (add_residual) {
out_linear_grad_input_node = residual_ele_add_grad_out_node;
}
// out linear part
auto* out_linear_dropout_grad_node =
pattern->NewNode(out_linear_dropout_grad_op_repr())
->assert_is_op("dropout_grad");
auto* out_linear_dropout_grad_mask_node =
pattern->NewNode(out_linear_dropout_grad_mask_repr())
->assert_is_op_input("dropout_grad", "Mask");
auto* out_linear_dropout_grad_out_node =
pattern->NewNode(out_linear_dropout_grad_out_repr())
->assert_is_op_output("dropout_grad", "X@GRAD");
out_linear_grad_input_node->assert_is_op_input("dropout_grad", "Out@GRAD");
out_linear_dropout_grad_node
->LinksFrom(
{out_linear_grad_input_node, out_linear_dropout_grad_mask_node})
.LinksTo({out_linear_dropout_grad_out_node});
auto* out_linear_ele_add_grad_node =
pattern->NewNode(out_linear_ele_add_grad_op_repr())
->assert_is_op("elementwise_add_grad");
auto* out_linear_ele_add_grad_x_node =
pattern->NewNode(out_linear_ele_add_grad_x_repr())
->assert_is_op_input("elementwise_add_grad", "X");
auto* out_linear_ele_add_grad_bias_node =
pattern->NewNode(out_linear_ele_add_grad_bias_repr())
->assert_is_op_input("elementwise_add_grad", "Y");
auto* out_linear_ele_add_grad_x_grad_node =
pattern->NewNode(out_linear_ele_add_grad_x_grad_repr())
->assert_is_op_output("elementwise_add_grad", "X@GRAD");
auto* out_linear_ele_add_grad_bias_grad_node =
pattern->NewNode(out_linear_ele_add_grad_bias_grad_repr())
->assert_is_op_output("elementwise_add_grad", "Y@GRAD");
out_linear_dropout_grad_out_node->assert_is_op_input("elementwise_add_grad",
"Out@GRAD");
out_linear_ele_add_grad_node
->LinksFrom({out_linear_dropout_grad_out_node,
out_linear_ele_add_grad_x_node,
out_linear_ele_add_grad_bias_node})
.LinksTo({out_linear_ele_add_grad_x_grad_node,
out_linear_ele_add_grad_bias_grad_node});
auto* out_linear_matmul_grad_node =
pattern->NewNode(out_linear_matmul_grad_op_repr())
->assert_is_op("matmul_v2_grad");
auto* out_linear_matmul_grad_x_node =
pattern->NewNode(out_linear_matmul_grad_x_repr())
->assert_is_op_input("matmul_v2_grad", "X");
auto* out_linear_matmul_grad_w_node =
pattern->NewNode(out_linear_matmul_grad_w_repr())
->assert_is_op_input("matmul_v2_grad", "Y");
auto* out_linear_matmul_grad_x_grad_node =
pattern->NewNode(out_linear_matmul_grad_x_grad_repr())
->assert_is_op_output("matmul_v2_grad", "X@GRAD");
auto* out_linear_matmul_grad_w_grad_node =
pattern->NewNode(out_linear_matmul_grad_w_grad_repr())
->assert_is_op_output("matmul_v2_grad", "Y@GRAD");
out_linear_ele_add_grad_x_grad_node->assert_is_op_input("matmul_v2_grad",
"Out@GRAD");
out_linear_matmul_grad_node
->LinksFrom({out_linear_ele_add_grad_x_grad_node,
out_linear_matmul_grad_x_node,
out_linear_matmul_grad_w_node})
.LinksTo({out_linear_matmul_grad_x_grad_node,
out_linear_matmul_grad_w_grad_node});
// core attention part
auto* qkv_reshape_grad_node = pattern->NewNode(qkv_reshape_grad_op_repr())
->assert_is_op("reshape2_grad");
auto* qkv_reshape_grad_x_shape_node =
pattern->NewNode(qkv_reshape_grad_x_shape_repr())
->assert_is_op_input("reshape2_grad", "XShape");
auto* qkv_reshape_grad_out_node =
pattern->NewNode(qkv_reshape_grad_out_repr())
->assert_is_op_output("reshape2_grad", "X@GRAD");
out_linear_matmul_grad_x_grad_node->assert_is_op_input("reshape2_grad",
"Out@GRAD");
qkv_reshape_grad_node
->LinksFrom(
{out_linear_matmul_grad_x_grad_node, qkv_reshape_grad_x_shape_node})
.LinksTo({qkv_reshape_grad_out_node});
auto* qkv_transpose_grad_node = pattern->NewNode(qkv_transpose_grad_op_repr())
->assert_is_op("transpose2_grad");
auto* qkv_transpose_grad_x_shape_node =
pattern->NewNode(qkv_transpose_grad_x_shape_repr())
->assert_is_op_input("transpose2_grad", "XShape");
auto* qkv_transpose_grad_out_node =
pattern->NewNode(qkv_transpose_grad_out_repr())
->assert_is_op_output("transpose2_grad", "X@GRAD");
qkv_reshape_grad_out_node->assert_is_op_input("transpose2_grad", "Out@GRAD");
qkv_transpose_grad_node
->LinksFrom({qkv_reshape_grad_out_node, qkv_transpose_grad_x_shape_node})
.LinksTo({qkv_transpose_grad_out_node});
auto* qkv_matmul_grad_node = pattern->NewNode(qkv_matmul_grad_op_repr())
->assert_is_op("matmul_v2_grad");
auto* qkv_matmul_grad_x_node =
pattern->NewNode(qkv_matmul_grad_x_repr())
->assert_is_op_input("matmul_v2_grad", "X");
auto* qkv_matmul_grad_w_node =
pattern->NewNode(qkv_matmul_grad_w_repr())
->assert_is_op_input("matmul_v2_grad", "Y");
auto* qkv_matmul_grad_x_grad_node =
pattern->NewNode(qkv_matmul_grad_x_grad_repr())
->assert_is_op_output("matmul_v2_grad", "X@GRAD");
auto* qkv_matmul_grad_w_grad_node =
pattern->NewNode(qkv_matmul_grad_w_grad_repr())
->assert_is_op_output("matmul_v2_grad", "Y@GRAD");
qkv_transpose_grad_out_node->assert_is_op_input("matmul_v2_grad", "Out@GRAD");
qkv_matmul_grad_node
->LinksFrom({qkv_transpose_grad_out_node,
qkv_matmul_grad_x_node,
qkv_matmul_grad_w_node})
.LinksTo({qkv_matmul_grad_x_grad_node, qkv_matmul_grad_w_grad_node});
PDNode* attn_dropout_grad_out_node{nullptr};
if (do_dropout) {
auto* attn_dropout_grad_node = pattern->NewNode(attn_dropout_grad_op_repr())
->assert_is_op("dropout_grad");
auto* attn_dropout_grad_mask_node =
pattern->NewNode(attn_dropout_grad_mask_repr())
->assert_is_op_input("dropout_grad", "Mask");
attn_dropout_grad_out_node =
pattern->NewNode(attn_dropout_grad_out_repr())
->assert_is_op_output("dropout_grad", "X@GRAD");
qkv_matmul_grad_x_grad_node->assert_is_op_input("dropout_grad", "Out@GRAD");
attn_dropout_grad_node
->LinksFrom({qkv_matmul_grad_x_grad_node, attn_dropout_grad_mask_node})
.LinksTo({attn_dropout_grad_out_node});
}
PDNode* qk_softmax_grad_input_node =
do_dropout ? attn_dropout_grad_out_node : qkv_matmul_grad_x_grad_node;
auto* qk_softmax_grad_node =
pattern->NewNode(qk_softmax_grad_op_repr())->assert_is_op("softmax_grad");
auto* qk_softmax_grad_fwd_out_node =
pattern->NewNode(qk_softmax_grad_fwd_out_repr())
->assert_is_op_input("softmax_grad", "Out");
auto* qk_softmax_grad_out =
pattern->NewNode(qk_softmax_grad_out_repr())
->assert_is_op_output("softmax_grad", "X@GRAD");
qk_softmax_grad_input_node->assert_is_op_input("softmax_grad", "Out@GRAD");
qk_softmax_grad_node
->LinksFrom({qk_softmax_grad_input_node, qk_softmax_grad_fwd_out_node})
.LinksTo({qk_softmax_grad_out});
PDNode* add_mask_ele_add_grad_x_grad_node{nullptr};
if (has_attn_mask) {
auto* add_mask_ele_add_grad_node =
pattern->NewNode(add_mask_ele_add_grad_op_repr())
->assert_is_op("elementwise_add_grad");
auto* add_mask_ele_add_grad_x_node =
pattern->NewNode(add_mask_ele_add_grad_x_repr())
->assert_is_op_input("elementwise_add_grad", "X");
auto* add_mask_ele_add_grad_bias_node =
pattern->NewNode(add_mask_ele_add_grad_bias_repr())
->assert_is_op_input("elementwise_add_grad", "Y");
add_mask_ele_add_grad_x_grad_node =
pattern->NewNode(add_mask_ele_add_grad_x_grad_repr())
->assert_is_op_output("elementwise_add_grad", "X@GRAD");
qk_softmax_grad_out->assert_is_op_input("elementwise_add_grad", "Out@GRAD");
add_mask_ele_add_grad_node
->LinksFrom({add_mask_ele_add_grad_x_node,
add_mask_ele_add_grad_bias_node,
qk_softmax_grad_out})
.LinksTo({add_mask_ele_add_grad_x_grad_node});
}
PDNode* qk_scale_grad_input_node =
has_attn_mask ? add_mask_ele_add_grad_x_grad_node : qk_softmax_grad_out;
auto* qk_scale_grad_node =
pattern->NewNode(qk_scale_grad_op_repr())->assert_is_op("scale");
auto* qk_scale_grad_out_node =
pattern->NewNode(qk_scale_grad_out_repr())->assert_is_op_output("scale");
qk_scale_grad_input_node->assert_is_op_input("scale", "X");
qk_scale_grad_node->LinksFrom({qk_scale_grad_input_node})
.LinksTo({qk_scale_grad_out_node});
auto* qk_matmul_grad_node = pattern->NewNode(qk_matmul_grad_op_repr())
->assert_is_op("matmul_v2_grad");
auto* qk_matmul_grad_x_node = pattern->NewNode(qk_matmul_grad_x_repr())
->assert_is_op_input("matmul_v2_grad", "X");
auto* qk_matmul_grad_w_node = pattern->NewNode(qk_matmul_grad_w_repr())
->assert_is_op_input("matmul_v2_grad", "Y");
auto* qk_matmul_grad_x_grad_node =
pattern->NewNode(qk_matmul_grad_x_grad_repr())
->assert_is_op_output("matmul_v2_grad", "X@GRAD");
auto* qk_matmul_grad_w_grad_node =
pattern->NewNode(qk_matmul_grad_w_grad_repr())
->assert_is_op_output("matmul_v2_grad", "Y@GRAD");
qk_scale_grad_out_node->assert_is_op_input("matmul_v2_grad", "Out@GRAD");
qk_matmul_grad_node
->LinksFrom({qk_scale_grad_out_node,
qk_matmul_grad_x_node,
qk_matmul_grad_w_node})
.LinksTo({qk_matmul_grad_x_grad_node, qk_matmul_grad_w_grad_node});
// fuse qkv projection
auto* fuse_qkv_split_grad_node =
pattern->NewNode(fuse_qkv_split_grad_op_repr())->assert_is_op("concat");
auto* fuse_qkv_split_grad_out_node =
pattern->NewNode(fuse_qkv_split_grad_out_repr())
->assert_is_op_output("concat");
qk_matmul_grad_x_grad_node->assert_is_op_input("concat"); // q grad
qk_matmul_grad_w_grad_node->assert_is_op_input("concat"); // k grad
qkv_matmul_grad_w_grad_node->assert_is_op_input("concat"); // v grad
fuse_qkv_split_grad_node
->LinksFrom({qk_matmul_grad_x_grad_node,
qk_matmul_grad_w_grad_node,
qkv_matmul_grad_w_grad_node})
.LinksTo({fuse_qkv_split_grad_out_node});
auto* fuse_qkv_transpose_grad_node =
pattern->NewNode(fuse_qkv_transpose_grad_op_repr())
->assert_is_op("transpose2_grad");
auto* fuse_qkv_transpose_grad_x_shape_node =
pattern->NewNode(fuse_qkv_transpose_grad_x_shape_repr())
->assert_is_op_input("transpose2_grad", "XShape");
auto* fuse_qkv_transpose_grad_out_node =
pattern->NewNode(fuse_qkv_transpose_grad_out_repr())
->assert_is_op_output("transpose2_grad", "X@GRAD");
fuse_qkv_split_grad_out_node->assert_is_op_input("transpose2_grad",
"Out@GRAD");
fuse_qkv_transpose_grad_node
->LinksFrom(
{fuse_qkv_split_grad_out_node, fuse_qkv_transpose_grad_x_shape_node})
.LinksTo({fuse_qkv_transpose_grad_out_node});
auto* fuse_qkv_reshape_grad_node =
pattern->NewNode(fuse_qkv_reshape_grad_op_repr())
->assert_is_op("reshape2_grad");
auto* fuse_qkv_reshape_grad_x_shape_node =
pattern->NewNode(fuse_qkv_reshape_grad_x_shape_repr())
->assert_is_op_input("reshape2_grad", "XShape");
auto* fuse_qkv_reshape_grad_out_node =
pattern->NewNode(fuse_qkv_reshape_grad_out_repr())
->assert_is_op_output("reshape2_grad", "X@GRAD");
fuse_qkv_transpose_grad_out_node->assert_is_op_input("reshape2_grad",
"Out@GRAD");
fuse_qkv_reshape_grad_node
->LinksFrom({fuse_qkv_transpose_grad_out_node,
fuse_qkv_reshape_grad_x_shape_node})
.LinksTo({fuse_qkv_reshape_grad_out_node});
auto* fuse_qkv_ele_add_grad_node =
pattern->NewNode(fuse_qkv_ele_add_grad_op_repr())
->assert_is_op("elementwise_add_grad");
auto* fuse_qkv_ele_add_grad_x_node =
pattern->NewNode(fuse_qkv_ele_add_grad_x_repr())
->assert_is_op_input("elementwise_add_grad", "X");
auto* fuse_qkv_ele_add_grad_bias_node =
pattern->NewNode(fuse_qkv_ele_add_grad_bias_repr())
->assert_is_op_input("elementwise_add_grad", "Y");
auto* fuse_qkv_ele_add_grad_x_grad_node =
pattern->NewNode(fuse_qkv_ele_add_grad_x_grad_repr())
->assert_is_op_output("elementwise_add_grad", "X@GRAD");
auto* fuse_qkv_ele_add_grad_bias_grad_node =
pattern->NewNode(fuse_qkv_ele_add_grad_bias_grad_repr())
->assert_is_op_output("elementwise_add_grad", "Y@GRAD");
fuse_qkv_reshape_grad_out_node->assert_is_op_input("elementwise_add_grad",
"Out@GRAD");
fuse_qkv_ele_add_grad_node
->LinksFrom({fuse_qkv_reshape_grad_out_node,
fuse_qkv_ele_add_grad_x_node,
fuse_qkv_ele_add_grad_bias_node})
.LinksTo({fuse_qkv_ele_add_grad_x_grad_node,
fuse_qkv_ele_add_grad_bias_grad_node});
auto* fuse_qkv_matmul_grad_node =
pattern->NewNode(fuse_qkv_matmul_grad_op_repr())
->assert_is_op("matmul_v2_grad");
auto* fuse_qkv_matmul_grad_x_node =
pattern->NewNode(fuse_qkv_matmul_grad_x_repr())
->assert_is_op_input("matmul_v2_grad", "X");
auto* fuse_qkv_matmul_grad_w_node =
pattern->NewNode(fuse_qkv_matmul_grad_w_repr())
->assert_is_op_input("matmul_v2_grad", "Y");
auto* fuse_qkv_matmul_grad_x_grad_node =
pattern->NewNode(fuse_qkv_matmul_grad_x_grad_repr())
->assert_is_op_output("matmul_v2_grad", "X@GRAD");
auto* fuse_qkv_matmul_grad_w_grad_node =
pattern->NewNode(fuse_qkv_matmul_grad_w_grad_repr())
->assert_is_op_output("matmul_v2_grad", "Y@GRAD");
fuse_qkv_ele_add_grad_x_grad_node->assert_is_op_input("matmul_v2_grad",
"Out@GRAD");
fuse_qkv_matmul_grad_node
->LinksFrom({fuse_qkv_ele_add_grad_x_grad_node,
fuse_qkv_matmul_grad_x_node,
fuse_qkv_matmul_grad_w_node})
.LinksTo(
{fuse_qkv_matmul_grad_x_grad_node, fuse_qkv_matmul_grad_w_grad_node});
if (!pre_layer_norm) {
return fuse_qkv_matmul_grad_x_grad_node;
}
// pre layer norm
auto* pre_layer_norm_grad_node =
pattern->NewNode(pre_layer_norm_grad_op_repr())
->assert_is_op("layer_norm_grad");
auto* pre_layer_norm_grad_scale_node =
pattern->NewNode(pre_layer_norm_grad_scale_repr())
->assert_is_op_input("layer_norm_grad", "Scale");
auto* pre_layer_norm_grad_bias_node =
pattern->NewNode(pre_layer_norm_grad_bias_repr())
->assert_is_op_input("layer_norm_grad", "Bias");
auto* pre_layer_norm_grad_mean_node =
pattern->NewNode(pre_layer_norm_grad_mean_repr())
->assert_is_op_input("layer_norm_grad", "Mean");
auto* pre_layer_norm_grad_variance_node =
pattern->NewNode(pre_layer_norm_grad_variance_repr())
->assert_is_op_input("layer_norm_grad", "Variance");
auto* pre_layer_norm_grad_x_node =
add_residual ? residual_ele_add_grad_x_node
: pattern->NewNode(pre_layer_norm_grad_x_repr())
->assert_is_op_input("layer_norm_grad", "X");
auto* pre_layer_norm_grad_scale_grad_node =
pattern->NewNode(pre_layer_norm_grad_scale_grad_repr())
->assert_is_op_output("layer_norm_grad", "Scale@GRAD");
auto* pre_layer_norm_grad_bias_grad_node =
pattern->NewNode(pre_layer_norm_grad_bias_grad_repr())
->assert_is_op_output("layer_norm_grad", "Bias@GRAD");
auto* pre_layer_norm_grad_x_grad_node =
pattern->NewNode(pre_layer_norm_grad_x_grad_repr())
->assert_is_op_output("layer_norm_grad", "X@GRAD");
fuse_qkv_matmul_grad_x_grad_node->assert_is_op_input("layer_norm_grad",
"Y@GRAD");
pre_layer_norm_grad_node
->LinksFrom({fuse_qkv_matmul_grad_x_grad_node,
pre_layer_norm_grad_scale_node,
pre_layer_norm_grad_bias_node,
pre_layer_norm_grad_mean_node,
pre_layer_norm_grad_variance_node,
pre_layer_norm_grad_x_node})
.LinksTo({pre_layer_norm_grad_scale_grad_node,
pre_layer_norm_grad_bias_grad_node,
pre_layer_norm_grad_x_grad_node});
if (!add_residual) {
return pre_layer_norm_grad_x_grad_node;
}
auto* grad_accumulation_sum_node =
pattern->NewNode(grad_accumulation_sum_op_repr())->assert_is_op("sum");
auto* grad_accumulation_sum_out_node =
pattern->NewNode(grad_accumulation_out_repr())
->assert_is_op_output("sum");
grad_accumulation_sum_node
->LinksFrom(
{pre_layer_norm_grad_x_grad_node, residual_ele_add_grad_x_grad_node})
.LinksTo({grad_accumulation_sum_out_node});
return grad_accumulation_sum_out_node;
}
} // namespace patterns
......@@ -437,7 +870,107 @@ ir::Graph* FusedAttentionsPass::PreMaskDropResPostFwd(Graph* graph) const {
}
ir::Graph* FusedAttentionsPass::PreMaskDropResPostBwd(Graph* graph) const {
// TODO(Yuang Liu): finish the pass
GraphPatternDetector gpd;
auto* x = gpd.mutable_pattern()
->NewNode(patterns::PDNodeName(name_scope_, "x"))
->AsInput()
->assert_is_op_input("layer_norm_grad", "Y@GRAD");
patterns::FusedAttentionGradPattern fused_attention_grad_pattern(
gpd.mutable_pattern(), "fused_attention_grad_pattern");
fused_attention_grad_pattern(x,
/* pre_layer_norm */ true,
/* post_layer_norm */ true,
/* has_attn_mask */ true,
/* do_dropout */ true,
/* add_residual */ true);
int found_fused_attention = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
VLOG(3) << "handle FusedMultiHeadAttention backward pass's fusion";
GET_IR_NODE_FROM_SUBGRAPH(post_layer_norm_grad_op_node,
post_layer_norm_grad_op,
fused_attention_grad_pattern);
GET_IR_NODE_FROM_SUBGRAPH(residual_ele_add_grad_op_node,
residual_ele_add_grad_op,
fused_attention_grad_pattern);
GET_IR_NODE_FROM_SUBGRAPH(out_linear_dropout_grad_op_node,
out_linear_dropout_grad_op,
fused_attention_grad_pattern);
GET_IR_NODE_FROM_SUBGRAPH(out_linear_ele_add_grad_op_node,
out_linear_ele_add_grad_op,
fused_attention_grad_pattern);
GET_IR_NODE_FROM_SUBGRAPH(out_linear_matmul_grad_op_node,
out_linear_matmul_grad_op,
fused_attention_grad_pattern);
GET_IR_NODE_FROM_SUBGRAPH(qkv_reshape_grad_op_node,
qkv_reshape_grad_op,
fused_attention_grad_pattern);
GET_IR_NODE_FROM_SUBGRAPH(qkv_transpose_grad_op_node,
qkv_transpose_grad_op,
fused_attention_grad_pattern);
GET_IR_NODE_FROM_SUBGRAPH(qkv_matmul_grad_op_node,
qkv_matmul_grad_op,
fused_attention_grad_pattern);
GET_IR_NODE_FROM_SUBGRAPH(attn_dropout_grad_op_node,
attn_dropout_grad_op,
fused_attention_grad_pattern);
GET_IR_NODE_FROM_SUBGRAPH(qk_softmax_grad_op_node,
qk_softmax_grad_op,
fused_attention_grad_pattern);
GET_IR_NODE_FROM_SUBGRAPH(add_mask_ele_add_grad_op_node,
add_mask_ele_add_grad_op,
fused_attention_grad_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
qk_scale_grad_op_node, qk_scale_grad_op, fused_attention_grad_pattern);
GET_IR_NODE_FROM_SUBGRAPH(qk_matmul_grad_op_node,
qk_matmul_grad_op,
fused_attention_grad_pattern);
GET_IR_NODE_FROM_SUBGRAPH(fuse_qkv_split_grad_op_node,
fuse_qkv_split_grad_op,
fused_attention_grad_pattern);
GET_IR_NODE_FROM_SUBGRAPH(fuse_qkv_transpose_grad_op_node,
fuse_qkv_transpose_grad_op,
fused_attention_grad_pattern);
GET_IR_NODE_FROM_SUBGRAPH(fuse_qkv_reshape_grad_op_node,
fuse_qkv_reshape_grad_op,
fused_attention_grad_pattern);
GET_IR_NODE_FROM_SUBGRAPH(fuse_qkv_ele_add_grad_op_node,
fuse_qkv_ele_add_grad_op,
fused_attention_grad_pattern);
GET_IR_NODE_FROM_SUBGRAPH(fuse_qkv_matmul_grad_op_node,
fuse_qkv_matmul_grad_op,
fused_attention_grad_pattern);
GET_IR_NODE_FROM_SUBGRAPH(pre_layer_norm_grad_op_node,
pre_layer_norm_grad_op,
fused_attention_grad_pattern);
GET_IR_NODE_FROM_SUBGRAPH(grad_accumulation_sum_op_node,
grad_accumulation_sum_op,
fused_attention_grad_pattern);
// TODO(Yuang Liu): finish the handler
GraphSafeRemoveNodes(
g, {post_layer_norm_grad_op_node, residual_ele_add_grad_op_node,
out_linear_dropout_grad_op_node, out_linear_ele_add_grad_op_node,
out_linear_matmul_grad_op_node, qkv_reshape_grad_op_node,
qkv_transpose_grad_op_node, qkv_matmul_grad_op_node,
attn_dropout_grad_op_node, qk_softmax_grad_op_node,
add_mask_ele_add_grad_op_node, qk_scale_grad_op_node,
qk_matmul_grad_op_node, fuse_qkv_split_grad_op_node,
fuse_qkv_transpose_grad_op_node, fuse_qkv_reshape_grad_op_node,
fuse_qkv_ele_add_grad_op_node, fuse_qkv_matmul_grad_op_node,
pre_layer_norm_grad_op_node, grad_accumulation_sum_op_node});
found_fused_attention++;
};
gpd(graph, handler);
AddStatis(found_fused_attention);
return graph;
}
......
......@@ -140,7 +140,116 @@ struct FusedAttentionGradPattern : public PatternBase {
bool do_dropout, // dropout the softmax(qk) or not
bool add_residual); // add residual to out linear or not
// TODO(Yuang Liu): add backward pattern
// post layer norm grad
PATTERN_DECL_NODE(post_layer_norm_grad_op);
PATTERN_DECL_NODE(post_layer_norm_grad_scale);
PATTERN_DECL_NODE(post_layer_norm_grad_bias);
PATTERN_DECL_NODE(post_layer_norm_grad_mean);
PATTERN_DECL_NODE(post_layer_norm_grad_variance);
PATTERN_DECL_NODE(post_layer_norm_grad_x);
PATTERN_DECL_NODE(post_layer_norm_grad_scale_grad);
PATTERN_DECL_NODE(post_layer_norm_grad_bias_grad);
PATTERN_DECL_NODE(post_layer_norm_grad_x_grad);
// residual grad
PATTERN_DECL_NODE(residual_ele_add_grad_op);
PATTERN_DECL_NODE(residual_ele_add_grad_x);
PATTERN_DECL_NODE(residual_ele_add_grad_bias);
PATTERN_DECL_NODE(residual_ele_add_grad_bias_grad);
PATTERN_DECL_NODE(residual_ele_add_grad_x_grad);
// out linear grad
PATTERN_DECL_NODE(out_linear_dropout_grad_op);
PATTERN_DECL_NODE(out_linear_dropout_grad_mask);
PATTERN_DECL_NODE(out_linear_dropout_grad_out);
PATTERN_DECL_NODE(out_linear_ele_add_grad_op);
PATTERN_DECL_NODE(out_linear_ele_add_grad_x);
PATTERN_DECL_NODE(out_linear_ele_add_grad_bias);
PATTERN_DECL_NODE(out_linear_ele_add_grad_x_grad);
PATTERN_DECL_NODE(out_linear_ele_add_grad_bias_grad);
PATTERN_DECL_NODE(out_linear_matmul_grad_op);
PATTERN_DECL_NODE(out_linear_matmul_grad_x);
PATTERN_DECL_NODE(out_linear_matmul_grad_w);
PATTERN_DECL_NODE(out_linear_matmul_grad_x_grad);
PATTERN_DECL_NODE(out_linear_matmul_grad_w_grad);
// core attention grad
PATTERN_DECL_NODE(qkv_reshape_grad_op);
PATTERN_DECL_NODE(qkv_reshape_grad_x_shape);
PATTERN_DECL_NODE(qkv_reshape_grad_out);
PATTERN_DECL_NODE(qkv_transpose_grad_op);
PATTERN_DECL_NODE(qkv_transpose_grad_x_shape);
PATTERN_DECL_NODE(qkv_transpose_grad_out);
PATTERN_DECL_NODE(qkv_matmul_grad_op);
PATTERN_DECL_NODE(qkv_matmul_grad_x);
PATTERN_DECL_NODE(qkv_matmul_grad_w);
PATTERN_DECL_NODE(qkv_matmul_grad_x_grad);
PATTERN_DECL_NODE(qkv_matmul_grad_w_grad);
PATTERN_DECL_NODE(attn_dropout_grad_op);
PATTERN_DECL_NODE(attn_dropout_grad_mask);
PATTERN_DECL_NODE(attn_dropout_grad_out);
PATTERN_DECL_NODE(qk_softmax_grad_op);
PATTERN_DECL_NODE(qk_softmax_grad_fwd_out);
PATTERN_DECL_NODE(qk_softmax_grad_out);
PATTERN_DECL_NODE(add_mask_ele_add_grad_op);
PATTERN_DECL_NODE(add_mask_ele_add_grad_x);
PATTERN_DECL_NODE(add_mask_ele_add_grad_bias);
PATTERN_DECL_NODE(add_mask_ele_add_grad_x_grad);
PATTERN_DECL_NODE(qk_scale_grad_op);
PATTERN_DECL_NODE(qk_scale_grad_out);
PATTERN_DECL_NODE(qk_matmul_grad_op);
PATTERN_DECL_NODE(qk_matmul_grad_x);
PATTERN_DECL_NODE(qk_matmul_grad_w);
PATTERN_DECL_NODE(qk_matmul_grad_x_grad);
PATTERN_DECL_NODE(qk_matmul_grad_w_grad);
// fuse qkv projection grad
PATTERN_DECL_NODE(fuse_qkv_split_grad_op); // concat op
PATTERN_DECL_NODE(fuse_qkv_split_grad_out);
PATTERN_DECL_NODE(fuse_qkv_transpose_grad_op);
PATTERN_DECL_NODE(fuse_qkv_transpose_grad_x_shape);
PATTERN_DECL_NODE(fuse_qkv_transpose_grad_out);
PATTERN_DECL_NODE(fuse_qkv_reshape_grad_op);
PATTERN_DECL_NODE(fuse_qkv_reshape_grad_x_shape);
PATTERN_DECL_NODE(fuse_qkv_reshape_grad_out);
PATTERN_DECL_NODE(fuse_qkv_ele_add_grad_op);
PATTERN_DECL_NODE(fuse_qkv_ele_add_grad_x);
PATTERN_DECL_NODE(fuse_qkv_ele_add_grad_bias);
PATTERN_DECL_NODE(fuse_qkv_ele_add_grad_x_grad);
PATTERN_DECL_NODE(fuse_qkv_ele_add_grad_bias_grad);
PATTERN_DECL_NODE(fuse_qkv_matmul_grad_op);
PATTERN_DECL_NODE(fuse_qkv_matmul_grad_x);
PATTERN_DECL_NODE(fuse_qkv_matmul_grad_w);
PATTERN_DECL_NODE(fuse_qkv_matmul_grad_x_grad);
PATTERN_DECL_NODE(fuse_qkv_matmul_grad_w_grad);
// pre layer norm grad
PATTERN_DECL_NODE(pre_layer_norm_grad_op);
PATTERN_DECL_NODE(pre_layer_norm_grad_scale);
PATTERN_DECL_NODE(pre_layer_norm_grad_bias);
PATTERN_DECL_NODE(pre_layer_norm_grad_mean);
PATTERN_DECL_NODE(pre_layer_norm_grad_variance);
PATTERN_DECL_NODE(pre_layer_norm_grad_x);
PATTERN_DECL_NODE(pre_layer_norm_grad_scale_grad);
PATTERN_DECL_NODE(pre_layer_norm_grad_bias_grad);
PATTERN_DECL_NODE(pre_layer_norm_grad_x_grad);
// grad accumulation
PATTERN_DECL_NODE(grad_accumulation_sum_op);
PATTERN_DECL_NODE(grad_accumulation_out);
};
} // namespace patterns
......
......@@ -114,9 +114,7 @@ class TestFusedAttentionPass(unittest.TestCase):
hidden_size = 768
num_heads = 12
x_data = np.random.rand(batch_size, seq_len, hidden_size).astype(
'float32'
)
x_data = np.random.rand(batch_size, seq_len, seq_len).astype('float32')
mask_data = np.random.rand(
batch_size, num_heads, seq_len, seq_len
).astype('float32')
......@@ -127,7 +125,7 @@ class TestFusedAttentionPass(unittest.TestCase):
with paddle.static.program_guard(main_prog, startup_prog):
data = paddle.static.data(
name="x",
shape=[-1, seq_len, hidden_size],
shape=[-1, seq_len, seq_len],
dtype='float32',
)
if self.add_mask:
......@@ -138,6 +136,7 @@ class TestFusedAttentionPass(unittest.TestCase):
)
else:
attn_mask = None
data_linear = paddle.nn.Linear(seq_len, hidden_size)
multi_head_attn = MultiHeadAttention(
hidden_size,
num_heads,
......@@ -146,7 +145,9 @@ class TestFusedAttentionPass(unittest.TestCase):
post_ln=self.post_ln,
attn_dropout=self.attn_dropout,
)
out = multi_head_attn(data, attn_mask)
attn_input = data_linear(data)
out = multi_head_attn(attn_input, attn_mask)
loss = paddle.mean(out)
sgd_optimizer = paddle.fluid.optimizer.SGD(learning_rate=0.001)
......@@ -156,7 +157,13 @@ class TestFusedAttentionPass(unittest.TestCase):
pass_manager.apply([main_prog], [startup_prog])
ops = main_prog.global_block().ops
assert ops[0].type == 'reduce_mean'
assert ops[2].type == 'reduce_mean'
assert ops[4].type == 'reduce_mean_grad'
# two ops for linear, one op for reduce mean
# one fill constant
# one op for reduce mean grad, two ops for linear bwd
# the eighth op should be the optimizer
assert ops[7].type == 'sgd'
if __name__ == "__main__":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册