未验证 提交 fcec564c 编写于 作者: Y Yuang Liu 提交者: GitHub

Fused attn pass single ut (#50227)

上级 8fb2dce9
......@@ -123,23 +123,23 @@ PDNode* FusedAttentionPattern::operator()(PDNode* x,
fuse_qkv_split_out_v_node});
// core attention pattern
auto* qk_scale_node =
pattern->NewNode(qk_scale_op_repr())->assert_is_op("scale");
auto* qk_scale_out_node =
pattern->NewNode(qk_scale_out_repr())->assert_is_op_output("scale");
fuse_qkv_split_out_q_node->assert_is_op_input("scale", "X");
qk_scale_node->LinksFrom({fuse_qkv_split_out_q_node})
.LinksTo({qk_scale_out_node});
auto* qk_matmul_node =
pattern->NewNode(qk_matmul_op_repr())->assert_is_op("matmul_v2");
auto* qk_matmul_out_node =
pattern->NewNode(qk_matmul_out_repr())->assert_is_op_output("matmul_v2");
fuse_qkv_split_out_q_node->assert_is_op_input("matmul_v2", "X");
qk_scale_out_node->assert_is_op_input("matmul_v2", "X");
fuse_qkv_split_out_k_node->assert_is_op_input("matmul_v2", "Y");
qk_matmul_node
->LinksFrom({fuse_qkv_split_out_q_node, fuse_qkv_split_out_k_node})
qk_matmul_node->LinksFrom({qk_scale_out_node, fuse_qkv_split_out_k_node})
.LinksTo({qk_matmul_out_node});
auto* qk_scale_node =
pattern->NewNode(qk_scale_op_repr())->assert_is_op("scale");
auto* qk_scale_out_node =
pattern->NewNode(qk_scale_out_repr())->assert_is_op_output("scale");
qk_matmul_out_node->assert_is_op_input("scale", "X");
qk_scale_node->LinksFrom({qk_matmul_out_node}).LinksTo({qk_scale_out_node});
PDNode* add_mask_ele_add_out_node{nullptr};
if (has_attn_mask) {
auto* add_mask_ele_add_node = pattern->NewNode(add_mask_ele_add_op_repr())
......@@ -149,9 +149,9 @@ PDNode* FusedAttentionPattern::operator()(PDNode* x,
->assert_is_op_input("elementwise_add", "Y");
add_mask_ele_add_out_node = pattern->NewNode(add_mask_ele_add_out_repr())
->assert_is_op_output("elementwise_add");
qk_scale_out_node->assert_is_op_input("elementwise_add", "X");
qk_matmul_out_node->assert_is_op_input("elementwise_add", "X");
add_mask_ele_add_node
->LinksFrom({qk_scale_out_node, add_mask_ele_add_mask_node})
->LinksFrom({qk_matmul_out_node, add_mask_ele_add_mask_node})
.LinksTo({add_mask_ele_add_out_node});
}
......@@ -164,8 +164,8 @@ PDNode* FusedAttentionPattern::operator()(PDNode* x,
qk_softmax_node->LinksFrom({add_mask_ele_add_out_node})
.LinksTo({qk_softmax_out_node});
} else {
qk_scale_out_node->assert_is_op_input("softmax", "X");
qk_softmax_node->LinksFrom({qk_scale_out_node})
qk_matmul_out_node->assert_is_op_input("softmax", "X");
qk_softmax_node->LinksFrom({qk_matmul_out_node})
.LinksTo({qk_softmax_out_node});
}
......@@ -575,16 +575,8 @@ PDNode* FusedAttentionGradPattern::operator()(PDNode* x,
.LinksTo({add_mask_ele_add_grad_x_grad_node});
}
PDNode* qk_scale_grad_input_node =
PDNode* qk_matmul_grad_input_node =
has_attn_mask ? add_mask_ele_add_grad_x_grad_node : qk_softmax_grad_out;
auto* qk_scale_grad_node =
pattern->NewNode(qk_scale_grad_op_repr())->assert_is_op("scale");
auto* qk_scale_grad_out_node =
pattern->NewNode(qk_scale_grad_out_repr())->assert_is_op_output("scale");
qk_scale_grad_input_node->assert_is_op_input("scale", "X");
qk_scale_grad_node->LinksFrom({qk_scale_grad_input_node})
.LinksTo({qk_scale_grad_out_node});
auto* qk_matmul_grad_node = pattern->NewNode(qk_matmul_grad_op_repr())
->assert_is_op("matmul_v2_grad");
auto* qk_matmul_grad_x_node = pattern->NewNode(qk_matmul_grad_x_repr())
......@@ -597,24 +589,32 @@ PDNode* FusedAttentionGradPattern::operator()(PDNode* x,
auto* qk_matmul_grad_w_grad_node =
pattern->NewNode(qk_matmul_grad_w_grad_repr())
->assert_is_op_output("matmul_v2_grad", "Y@GRAD");
qk_scale_grad_out_node->assert_is_op_input("matmul_v2_grad", "Out@GRAD");
qk_matmul_grad_input_node->assert_is_op_input("matmul_v2_grad", "Out@GRAD");
qk_matmul_grad_node
->LinksFrom({qk_scale_grad_out_node,
->LinksFrom({qk_matmul_grad_input_node,
qk_matmul_grad_x_node,
qk_matmul_grad_w_node})
.LinksTo({qk_matmul_grad_x_grad_node, qk_matmul_grad_w_grad_node});
auto* qk_scale_grad_node =
pattern->NewNode(qk_scale_grad_op_repr())->assert_is_op("scale");
auto* qk_scale_grad_out_node =
pattern->NewNode(qk_scale_grad_out_repr())->assert_is_op_output("scale");
qk_matmul_grad_x_grad_node->assert_is_op_input("scale", "X");
qk_scale_grad_node->LinksFrom({qk_matmul_grad_x_grad_node})
.LinksTo({qk_scale_grad_out_node});
// fuse qkv projection
auto* fuse_qkv_split_grad_node =
pattern->NewNode(fuse_qkv_split_grad_op_repr())->assert_is_op("concat");
auto* fuse_qkv_split_grad_out_node =
pattern->NewNode(fuse_qkv_split_grad_out_repr())
->assert_is_op_output("concat");
qk_matmul_grad_x_grad_node->assert_is_op_input("concat"); // q grad
qk_scale_grad_out_node->assert_is_op_input("concat"); // q grad
qk_matmul_grad_w_grad_node->assert_is_op_input("concat"); // k grad
qkv_matmul_grad_w_grad_node->assert_is_op_input("concat"); // v grad
fuse_qkv_split_grad_node
->LinksFrom({qk_matmul_grad_x_grad_node,
->LinksFrom({qk_scale_grad_out_node,
qk_matmul_grad_w_grad_node,
qkv_matmul_grad_w_grad_node})
.LinksTo({fuse_qkv_split_grad_out_node});
......@@ -894,7 +894,7 @@ ir::Graph* FusedAttentionsPass::PreMaskDropResFwd(
fused_attention_op_desc.SetAttr("transpose_qkv_wb", true);
std::vector<int> shape = PADDLE_GET_CONST(
std::vector<int>, fuse_qkv_reshape_op_node->Op()->GetAttr("shape"));
fused_attention_op_desc.SetAttr("num_heads", shape[2]);
fused_attention_op_desc.SetAttr("num_heads", shape[2] / 3);
GET_IR_NODE_FROM_SUBGRAPH(
fuse_qkv_matmul_out_node, fuse_qkv_matmul_out, fused_attention_pattern);
GET_IR_NODE_FROM_SUBGRAPH(fuse_qkv_ele_add_bias_node,
......@@ -1337,7 +1337,7 @@ ir::Graph* FusedAttentionsPass::PreMaskDropResBwd(
std::vector<int> shape =
PADDLE_GET_CONST(std::vector<int>,
fuse_qkv_reshape_grad_op_node->Op()->GetAttr("shape"));
fused_attention_grad_op_desc.SetAttr("num_heads", shape[2]);
fused_attention_grad_op_desc.SetAttr("num_heads", shape[2] / 3);
fused_attention_grad_op_desc.SetAttr("pre_layer_norm", true);
fused_attention_grad_op_desc.SetAttr("transpose_qkv_wb", true);
......
......@@ -53,7 +53,7 @@ class MultiHeadAttention(paddle.nn.Layer):
self.qkv_proj = paddle.nn.Linear(embed_dim, 3 * embed_dim)
self.out_proj = paddle.nn.Linear(embed_dim, embed_dim)
self.dropout = paddle.nn.Dropout(0.1, mode="upscale_in_train")
self.dropout = paddle.nn.Dropout(1e-10, mode="upscale_in_train")
def forward(self, x, attn_mask=None):
residual = x
......@@ -64,13 +64,13 @@ class MultiHeadAttention(paddle.nn.Layer):
# compute qkv
qkv = self.qkv_proj(x)
qkv = paddle.reshape(qkv, [0, 0, self.num_heads, 3 * self.head_dim])
qkv = paddle.reshape(qkv, [0, 0, 3 * self.num_heads, self.head_dim])
qkv = paddle.transpose(qkv, [0, 2, 1, 3])
q, k, v = paddle.split(qkv, num_or_sections=3, axis=-1)
q, k, v = paddle.split(qkv, num_or_sections=3, axis=1)
# compute core attention
q = paddle.scale(q, scale=self.head_dim**-0.5)
product = paddle.matmul(x=q, y=k, transpose_y=True)
product = paddle.scale(product, scale=self.head_dim**-0.5)
if attn_mask is not None:
product = product + attn_mask
weights = F.softmax(product)
......@@ -104,21 +104,28 @@ class TestFusedAttentionPass(unittest.TestCase):
self.pre_ln = True
self.attn_dropout = True
self.add_mask = True
self.x_data = None
self.mask_data = None
def test_pass(self):
def get_rst(self, use_pass=False):
batch_size = 2
seq_len = 1024
hidden_size = 768
num_heads = 12
x_data = np.random.rand(batch_size, seq_len, seq_len).astype('float32')
mask_data = np.random.rand(
batch_size, num_heads, seq_len, seq_len
).astype('float32')
np.random.seed(1234)
if self.x_data is None:
self.x_data = np.random.rand(batch_size, seq_len, seq_len).astype(
'float32'
)
self.mask_data = np.random.rand(
batch_size, num_heads, seq_len, seq_len
).astype('float32')
main_prog = paddle.static.Program()
main_prog.random_seed = 1234
startup_prog = paddle.static.Program()
startup_prog.random_seed = 1234
with paddle.static.program_guard(main_prog, startup_prog):
data = paddle.static.data(
......@@ -150,29 +157,36 @@ class TestFusedAttentionPass(unittest.TestCase):
sgd_optimizer = paddle.fluid.optimizer.SGD(learning_rate=0.001)
sgd_optimizer.minimize(loss)
pass_manager = PassManager([new_pass("fused_attention")])
pass_manager.apply([main_prog], [startup_prog])
ops = main_prog.global_block().ops
assert ops[2].type == 'fused_attention'
assert ops[3].type == 'reduce_mean'
assert ops[5].type == 'reduce_mean_grad'
assert ops[6].type == 'fused_attention_grad'
# two ops for linear, one op for reduce mean
# one fill constant
# one op for reduce mean grad, two ops for linear bwd
# the eighth op should be the optimizer
assert ops[9].type == 'sgd'
if use_pass:
pass_manager = PassManager([new_pass("fused_attention")])
pass_manager.apply([main_prog], [startup_prog])
ops = main_prog.global_block().ops
assert ops[2].type == 'fused_attention'
assert ops[3].type == 'reduce_mean'
assert ops[5].type == 'reduce_mean_grad'
assert ops[6].type == 'fused_attention_grad'
# two ops for linear, one op for reduce mean
# one fill constant
# one op for reduce mean grad, two ops for linear bwd
# the eighth op should be the optimizer
assert ops[9].type == 'sgd'
exe = paddle.static.Executor()
exe.run(startup_prog)
rst = exe.run(
main_prog,
feed={'x': x_data, 'attn_mask': mask_data},
fetch_list=[loss],
)
for i in range(2):
rst = exe.run(
main_prog,
feed={'x': self.x_data, 'attn_mask': self.mask_data},
fetch_list=[loss],
)
return rst
def test_pass(self):
fused_rst = self.get_rst(use_pass=True)
non_fused_rst = self.get_rst()
assert np.allclose(fused_rst, non_fused_rst)
if __name__ == "__main__":
np.random.seed(0)
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册