未验证 提交 8e02f290 编写于 作者: Y Yuang Liu 提交者: GitHub

Fused attention pass backward pattern (#49855)

上级 65bce2b3
...@@ -140,7 +140,116 @@ struct FusedAttentionGradPattern : public PatternBase { ...@@ -140,7 +140,116 @@ struct FusedAttentionGradPattern : public PatternBase {
bool do_dropout, // dropout the softmax(qk) or not bool do_dropout, // dropout the softmax(qk) or not
bool add_residual); // add residual to out linear or not bool add_residual); // add residual to out linear or not
// TODO(Yuang Liu): add backward pattern // post layer norm grad
PATTERN_DECL_NODE(post_layer_norm_grad_op);
PATTERN_DECL_NODE(post_layer_norm_grad_scale);
PATTERN_DECL_NODE(post_layer_norm_grad_bias);
PATTERN_DECL_NODE(post_layer_norm_grad_mean);
PATTERN_DECL_NODE(post_layer_norm_grad_variance);
PATTERN_DECL_NODE(post_layer_norm_grad_x);
PATTERN_DECL_NODE(post_layer_norm_grad_scale_grad);
PATTERN_DECL_NODE(post_layer_norm_grad_bias_grad);
PATTERN_DECL_NODE(post_layer_norm_grad_x_grad);
// residual grad
PATTERN_DECL_NODE(residual_ele_add_grad_op);
PATTERN_DECL_NODE(residual_ele_add_grad_x);
PATTERN_DECL_NODE(residual_ele_add_grad_bias);
PATTERN_DECL_NODE(residual_ele_add_grad_bias_grad);
PATTERN_DECL_NODE(residual_ele_add_grad_x_grad);
// out linear grad
PATTERN_DECL_NODE(out_linear_dropout_grad_op);
PATTERN_DECL_NODE(out_linear_dropout_grad_mask);
PATTERN_DECL_NODE(out_linear_dropout_grad_out);
PATTERN_DECL_NODE(out_linear_ele_add_grad_op);
PATTERN_DECL_NODE(out_linear_ele_add_grad_x);
PATTERN_DECL_NODE(out_linear_ele_add_grad_bias);
PATTERN_DECL_NODE(out_linear_ele_add_grad_x_grad);
PATTERN_DECL_NODE(out_linear_ele_add_grad_bias_grad);
PATTERN_DECL_NODE(out_linear_matmul_grad_op);
PATTERN_DECL_NODE(out_linear_matmul_grad_x);
PATTERN_DECL_NODE(out_linear_matmul_grad_w);
PATTERN_DECL_NODE(out_linear_matmul_grad_x_grad);
PATTERN_DECL_NODE(out_linear_matmul_grad_w_grad);
// core attention grad
PATTERN_DECL_NODE(qkv_reshape_grad_op);
PATTERN_DECL_NODE(qkv_reshape_grad_x_shape);
PATTERN_DECL_NODE(qkv_reshape_grad_out);
PATTERN_DECL_NODE(qkv_transpose_grad_op);
PATTERN_DECL_NODE(qkv_transpose_grad_x_shape);
PATTERN_DECL_NODE(qkv_transpose_grad_out);
PATTERN_DECL_NODE(qkv_matmul_grad_op);
PATTERN_DECL_NODE(qkv_matmul_grad_x);
PATTERN_DECL_NODE(qkv_matmul_grad_w);
PATTERN_DECL_NODE(qkv_matmul_grad_x_grad);
PATTERN_DECL_NODE(qkv_matmul_grad_w_grad);
PATTERN_DECL_NODE(attn_dropout_grad_op);
PATTERN_DECL_NODE(attn_dropout_grad_mask);
PATTERN_DECL_NODE(attn_dropout_grad_out);
PATTERN_DECL_NODE(qk_softmax_grad_op);
PATTERN_DECL_NODE(qk_softmax_grad_fwd_out);
PATTERN_DECL_NODE(qk_softmax_grad_out);
PATTERN_DECL_NODE(add_mask_ele_add_grad_op);
PATTERN_DECL_NODE(add_mask_ele_add_grad_x);
PATTERN_DECL_NODE(add_mask_ele_add_grad_bias);
PATTERN_DECL_NODE(add_mask_ele_add_grad_x_grad);
PATTERN_DECL_NODE(qk_scale_grad_op);
PATTERN_DECL_NODE(qk_scale_grad_out);
PATTERN_DECL_NODE(qk_matmul_grad_op);
PATTERN_DECL_NODE(qk_matmul_grad_x);
PATTERN_DECL_NODE(qk_matmul_grad_w);
PATTERN_DECL_NODE(qk_matmul_grad_x_grad);
PATTERN_DECL_NODE(qk_matmul_grad_w_grad);
// fuse qkv projection grad
PATTERN_DECL_NODE(fuse_qkv_split_grad_op); // concat op
PATTERN_DECL_NODE(fuse_qkv_split_grad_out);
PATTERN_DECL_NODE(fuse_qkv_transpose_grad_op);
PATTERN_DECL_NODE(fuse_qkv_transpose_grad_x_shape);
PATTERN_DECL_NODE(fuse_qkv_transpose_grad_out);
PATTERN_DECL_NODE(fuse_qkv_reshape_grad_op);
PATTERN_DECL_NODE(fuse_qkv_reshape_grad_x_shape);
PATTERN_DECL_NODE(fuse_qkv_reshape_grad_out);
PATTERN_DECL_NODE(fuse_qkv_ele_add_grad_op);
PATTERN_DECL_NODE(fuse_qkv_ele_add_grad_x);
PATTERN_DECL_NODE(fuse_qkv_ele_add_grad_bias);
PATTERN_DECL_NODE(fuse_qkv_ele_add_grad_x_grad);
PATTERN_DECL_NODE(fuse_qkv_ele_add_grad_bias_grad);
PATTERN_DECL_NODE(fuse_qkv_matmul_grad_op);
PATTERN_DECL_NODE(fuse_qkv_matmul_grad_x);
PATTERN_DECL_NODE(fuse_qkv_matmul_grad_w);
PATTERN_DECL_NODE(fuse_qkv_matmul_grad_x_grad);
PATTERN_DECL_NODE(fuse_qkv_matmul_grad_w_grad);
// pre layer norm grad
PATTERN_DECL_NODE(pre_layer_norm_grad_op);
PATTERN_DECL_NODE(pre_layer_norm_grad_scale);
PATTERN_DECL_NODE(pre_layer_norm_grad_bias);
PATTERN_DECL_NODE(pre_layer_norm_grad_mean);
PATTERN_DECL_NODE(pre_layer_norm_grad_variance);
PATTERN_DECL_NODE(pre_layer_norm_grad_x);
PATTERN_DECL_NODE(pre_layer_norm_grad_scale_grad);
PATTERN_DECL_NODE(pre_layer_norm_grad_bias_grad);
PATTERN_DECL_NODE(pre_layer_norm_grad_x_grad);
// grad accumulation
PATTERN_DECL_NODE(grad_accumulation_sum_op);
PATTERN_DECL_NODE(grad_accumulation_out);
}; };
} // namespace patterns } // namespace patterns
......
...@@ -114,9 +114,7 @@ class TestFusedAttentionPass(unittest.TestCase): ...@@ -114,9 +114,7 @@ class TestFusedAttentionPass(unittest.TestCase):
hidden_size = 768 hidden_size = 768
num_heads = 12 num_heads = 12
x_data = np.random.rand(batch_size, seq_len, hidden_size).astype( x_data = np.random.rand(batch_size, seq_len, seq_len).astype('float32')
'float32'
)
mask_data = np.random.rand( mask_data = np.random.rand(
batch_size, num_heads, seq_len, seq_len batch_size, num_heads, seq_len, seq_len
).astype('float32') ).astype('float32')
...@@ -127,7 +125,7 @@ class TestFusedAttentionPass(unittest.TestCase): ...@@ -127,7 +125,7 @@ class TestFusedAttentionPass(unittest.TestCase):
with paddle.static.program_guard(main_prog, startup_prog): with paddle.static.program_guard(main_prog, startup_prog):
data = paddle.static.data( data = paddle.static.data(
name="x", name="x",
shape=[-1, seq_len, hidden_size], shape=[-1, seq_len, seq_len],
dtype='float32', dtype='float32',
) )
if self.add_mask: if self.add_mask:
...@@ -138,6 +136,7 @@ class TestFusedAttentionPass(unittest.TestCase): ...@@ -138,6 +136,7 @@ class TestFusedAttentionPass(unittest.TestCase):
) )
else: else:
attn_mask = None attn_mask = None
data_linear = paddle.nn.Linear(seq_len, hidden_size)
multi_head_attn = MultiHeadAttention( multi_head_attn = MultiHeadAttention(
hidden_size, hidden_size,
num_heads, num_heads,
...@@ -146,7 +145,9 @@ class TestFusedAttentionPass(unittest.TestCase): ...@@ -146,7 +145,9 @@ class TestFusedAttentionPass(unittest.TestCase):
post_ln=self.post_ln, post_ln=self.post_ln,
attn_dropout=self.attn_dropout, attn_dropout=self.attn_dropout,
) )
out = multi_head_attn(data, attn_mask)
attn_input = data_linear(data)
out = multi_head_attn(attn_input, attn_mask)
loss = paddle.mean(out) loss = paddle.mean(out)
sgd_optimizer = paddle.fluid.optimizer.SGD(learning_rate=0.001) sgd_optimizer = paddle.fluid.optimizer.SGD(learning_rate=0.001)
...@@ -156,7 +157,13 @@ class TestFusedAttentionPass(unittest.TestCase): ...@@ -156,7 +157,13 @@ class TestFusedAttentionPass(unittest.TestCase):
pass_manager.apply([main_prog], [startup_prog]) pass_manager.apply([main_prog], [startup_prog])
ops = main_prog.global_block().ops ops = main_prog.global_block().ops
assert ops[0].type == 'reduce_mean' assert ops[2].type == 'reduce_mean'
assert ops[4].type == 'reduce_mean_grad'
# two ops for linear, one op for reduce mean
# one fill constant
# one op for reduce mean grad, two ops for linear bwd
# the eighth op should be the optimizer
assert ops[7].type == 'sgd'
if __name__ == "__main__": if __name__ == "__main__":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册