未验证 提交 fcec564c 编写于 作者: Y Yuang Liu 提交者: GitHub

Fused attn pass single ut (#50227)

上级 8fb2dce9
...@@ -123,23 +123,23 @@ PDNode* FusedAttentionPattern::operator()(PDNode* x, ...@@ -123,23 +123,23 @@ PDNode* FusedAttentionPattern::operator()(PDNode* x,
fuse_qkv_split_out_v_node}); fuse_qkv_split_out_v_node});
// core attention pattern // core attention pattern
auto* qk_scale_node =
pattern->NewNode(qk_scale_op_repr())->assert_is_op("scale");
auto* qk_scale_out_node =
pattern->NewNode(qk_scale_out_repr())->assert_is_op_output("scale");
fuse_qkv_split_out_q_node->assert_is_op_input("scale", "X");
qk_scale_node->LinksFrom({fuse_qkv_split_out_q_node})
.LinksTo({qk_scale_out_node});
auto* qk_matmul_node = auto* qk_matmul_node =
pattern->NewNode(qk_matmul_op_repr())->assert_is_op("matmul_v2"); pattern->NewNode(qk_matmul_op_repr())->assert_is_op("matmul_v2");
auto* qk_matmul_out_node = auto* qk_matmul_out_node =
pattern->NewNode(qk_matmul_out_repr())->assert_is_op_output("matmul_v2"); pattern->NewNode(qk_matmul_out_repr())->assert_is_op_output("matmul_v2");
fuse_qkv_split_out_q_node->assert_is_op_input("matmul_v2", "X"); qk_scale_out_node->assert_is_op_input("matmul_v2", "X");
fuse_qkv_split_out_k_node->assert_is_op_input("matmul_v2", "Y"); fuse_qkv_split_out_k_node->assert_is_op_input("matmul_v2", "Y");
qk_matmul_node qk_matmul_node->LinksFrom({qk_scale_out_node, fuse_qkv_split_out_k_node})
->LinksFrom({fuse_qkv_split_out_q_node, fuse_qkv_split_out_k_node})
.LinksTo({qk_matmul_out_node}); .LinksTo({qk_matmul_out_node});
auto* qk_scale_node =
pattern->NewNode(qk_scale_op_repr())->assert_is_op("scale");
auto* qk_scale_out_node =
pattern->NewNode(qk_scale_out_repr())->assert_is_op_output("scale");
qk_matmul_out_node->assert_is_op_input("scale", "X");
qk_scale_node->LinksFrom({qk_matmul_out_node}).LinksTo({qk_scale_out_node});
PDNode* add_mask_ele_add_out_node{nullptr}; PDNode* add_mask_ele_add_out_node{nullptr};
if (has_attn_mask) { if (has_attn_mask) {
auto* add_mask_ele_add_node = pattern->NewNode(add_mask_ele_add_op_repr()) auto* add_mask_ele_add_node = pattern->NewNode(add_mask_ele_add_op_repr())
...@@ -149,9 +149,9 @@ PDNode* FusedAttentionPattern::operator()(PDNode* x, ...@@ -149,9 +149,9 @@ PDNode* FusedAttentionPattern::operator()(PDNode* x,
->assert_is_op_input("elementwise_add", "Y"); ->assert_is_op_input("elementwise_add", "Y");
add_mask_ele_add_out_node = pattern->NewNode(add_mask_ele_add_out_repr()) add_mask_ele_add_out_node = pattern->NewNode(add_mask_ele_add_out_repr())
->assert_is_op_output("elementwise_add"); ->assert_is_op_output("elementwise_add");
qk_scale_out_node->assert_is_op_input("elementwise_add", "X"); qk_matmul_out_node->assert_is_op_input("elementwise_add", "X");
add_mask_ele_add_node add_mask_ele_add_node
->LinksFrom({qk_scale_out_node, add_mask_ele_add_mask_node}) ->LinksFrom({qk_matmul_out_node, add_mask_ele_add_mask_node})
.LinksTo({add_mask_ele_add_out_node}); .LinksTo({add_mask_ele_add_out_node});
} }
...@@ -164,8 +164,8 @@ PDNode* FusedAttentionPattern::operator()(PDNode* x, ...@@ -164,8 +164,8 @@ PDNode* FusedAttentionPattern::operator()(PDNode* x,
qk_softmax_node->LinksFrom({add_mask_ele_add_out_node}) qk_softmax_node->LinksFrom({add_mask_ele_add_out_node})
.LinksTo({qk_softmax_out_node}); .LinksTo({qk_softmax_out_node});
} else { } else {
qk_scale_out_node->assert_is_op_input("softmax", "X"); qk_matmul_out_node->assert_is_op_input("softmax", "X");
qk_softmax_node->LinksFrom({qk_scale_out_node}) qk_softmax_node->LinksFrom({qk_matmul_out_node})
.LinksTo({qk_softmax_out_node}); .LinksTo({qk_softmax_out_node});
} }
...@@ -575,16 +575,8 @@ PDNode* FusedAttentionGradPattern::operator()(PDNode* x, ...@@ -575,16 +575,8 @@ PDNode* FusedAttentionGradPattern::operator()(PDNode* x,
.LinksTo({add_mask_ele_add_grad_x_grad_node}); .LinksTo({add_mask_ele_add_grad_x_grad_node});
} }
PDNode* qk_scale_grad_input_node = PDNode* qk_matmul_grad_input_node =
has_attn_mask ? add_mask_ele_add_grad_x_grad_node : qk_softmax_grad_out; has_attn_mask ? add_mask_ele_add_grad_x_grad_node : qk_softmax_grad_out;
auto* qk_scale_grad_node =
pattern->NewNode(qk_scale_grad_op_repr())->assert_is_op("scale");
auto* qk_scale_grad_out_node =
pattern->NewNode(qk_scale_grad_out_repr())->assert_is_op_output("scale");
qk_scale_grad_input_node->assert_is_op_input("scale", "X");
qk_scale_grad_node->LinksFrom({qk_scale_grad_input_node})
.LinksTo({qk_scale_grad_out_node});
auto* qk_matmul_grad_node = pattern->NewNode(qk_matmul_grad_op_repr()) auto* qk_matmul_grad_node = pattern->NewNode(qk_matmul_grad_op_repr())
->assert_is_op("matmul_v2_grad"); ->assert_is_op("matmul_v2_grad");
auto* qk_matmul_grad_x_node = pattern->NewNode(qk_matmul_grad_x_repr()) auto* qk_matmul_grad_x_node = pattern->NewNode(qk_matmul_grad_x_repr())
...@@ -597,24 +589,32 @@ PDNode* FusedAttentionGradPattern::operator()(PDNode* x, ...@@ -597,24 +589,32 @@ PDNode* FusedAttentionGradPattern::operator()(PDNode* x,
auto* qk_matmul_grad_w_grad_node = auto* qk_matmul_grad_w_grad_node =
pattern->NewNode(qk_matmul_grad_w_grad_repr()) pattern->NewNode(qk_matmul_grad_w_grad_repr())
->assert_is_op_output("matmul_v2_grad", "Y@GRAD"); ->assert_is_op_output("matmul_v2_grad", "Y@GRAD");
qk_scale_grad_out_node->assert_is_op_input("matmul_v2_grad", "Out@GRAD"); qk_matmul_grad_input_node->assert_is_op_input("matmul_v2_grad", "Out@GRAD");
qk_matmul_grad_node qk_matmul_grad_node
->LinksFrom({qk_scale_grad_out_node, ->LinksFrom({qk_matmul_grad_input_node,
qk_matmul_grad_x_node, qk_matmul_grad_x_node,
qk_matmul_grad_w_node}) qk_matmul_grad_w_node})
.LinksTo({qk_matmul_grad_x_grad_node, qk_matmul_grad_w_grad_node}); .LinksTo({qk_matmul_grad_x_grad_node, qk_matmul_grad_w_grad_node});
auto* qk_scale_grad_node =
pattern->NewNode(qk_scale_grad_op_repr())->assert_is_op("scale");
auto* qk_scale_grad_out_node =
pattern->NewNode(qk_scale_grad_out_repr())->assert_is_op_output("scale");
qk_matmul_grad_x_grad_node->assert_is_op_input("scale", "X");
qk_scale_grad_node->LinksFrom({qk_matmul_grad_x_grad_node})
.LinksTo({qk_scale_grad_out_node});
// fuse qkv projection // fuse qkv projection
auto* fuse_qkv_split_grad_node = auto* fuse_qkv_split_grad_node =
pattern->NewNode(fuse_qkv_split_grad_op_repr())->assert_is_op("concat"); pattern->NewNode(fuse_qkv_split_grad_op_repr())->assert_is_op("concat");
auto* fuse_qkv_split_grad_out_node = auto* fuse_qkv_split_grad_out_node =
pattern->NewNode(fuse_qkv_split_grad_out_repr()) pattern->NewNode(fuse_qkv_split_grad_out_repr())
->assert_is_op_output("concat"); ->assert_is_op_output("concat");
qk_matmul_grad_x_grad_node->assert_is_op_input("concat"); // q grad qk_scale_grad_out_node->assert_is_op_input("concat"); // q grad
qk_matmul_grad_w_grad_node->assert_is_op_input("concat"); // k grad qk_matmul_grad_w_grad_node->assert_is_op_input("concat"); // k grad
qkv_matmul_grad_w_grad_node->assert_is_op_input("concat"); // v grad qkv_matmul_grad_w_grad_node->assert_is_op_input("concat"); // v grad
fuse_qkv_split_grad_node fuse_qkv_split_grad_node
->LinksFrom({qk_matmul_grad_x_grad_node, ->LinksFrom({qk_scale_grad_out_node,
qk_matmul_grad_w_grad_node, qk_matmul_grad_w_grad_node,
qkv_matmul_grad_w_grad_node}) qkv_matmul_grad_w_grad_node})
.LinksTo({fuse_qkv_split_grad_out_node}); .LinksTo({fuse_qkv_split_grad_out_node});
...@@ -894,7 +894,7 @@ ir::Graph* FusedAttentionsPass::PreMaskDropResFwd( ...@@ -894,7 +894,7 @@ ir::Graph* FusedAttentionsPass::PreMaskDropResFwd(
fused_attention_op_desc.SetAttr("transpose_qkv_wb", true); fused_attention_op_desc.SetAttr("transpose_qkv_wb", true);
std::vector<int> shape = PADDLE_GET_CONST( std::vector<int> shape = PADDLE_GET_CONST(
std::vector<int>, fuse_qkv_reshape_op_node->Op()->GetAttr("shape")); std::vector<int>, fuse_qkv_reshape_op_node->Op()->GetAttr("shape"));
fused_attention_op_desc.SetAttr("num_heads", shape[2]); fused_attention_op_desc.SetAttr("num_heads", shape[2] / 3);
GET_IR_NODE_FROM_SUBGRAPH( GET_IR_NODE_FROM_SUBGRAPH(
fuse_qkv_matmul_out_node, fuse_qkv_matmul_out, fused_attention_pattern); fuse_qkv_matmul_out_node, fuse_qkv_matmul_out, fused_attention_pattern);
GET_IR_NODE_FROM_SUBGRAPH(fuse_qkv_ele_add_bias_node, GET_IR_NODE_FROM_SUBGRAPH(fuse_qkv_ele_add_bias_node,
...@@ -1337,7 +1337,7 @@ ir::Graph* FusedAttentionsPass::PreMaskDropResBwd( ...@@ -1337,7 +1337,7 @@ ir::Graph* FusedAttentionsPass::PreMaskDropResBwd(
std::vector<int> shape = std::vector<int> shape =
PADDLE_GET_CONST(std::vector<int>, PADDLE_GET_CONST(std::vector<int>,
fuse_qkv_reshape_grad_op_node->Op()->GetAttr("shape")); fuse_qkv_reshape_grad_op_node->Op()->GetAttr("shape"));
fused_attention_grad_op_desc.SetAttr("num_heads", shape[2]); fused_attention_grad_op_desc.SetAttr("num_heads", shape[2] / 3);
fused_attention_grad_op_desc.SetAttr("pre_layer_norm", true); fused_attention_grad_op_desc.SetAttr("pre_layer_norm", true);
fused_attention_grad_op_desc.SetAttr("transpose_qkv_wb", true); fused_attention_grad_op_desc.SetAttr("transpose_qkv_wb", true);
......
...@@ -53,7 +53,7 @@ class MultiHeadAttention(paddle.nn.Layer): ...@@ -53,7 +53,7 @@ class MultiHeadAttention(paddle.nn.Layer):
self.qkv_proj = paddle.nn.Linear(embed_dim, 3 * embed_dim) self.qkv_proj = paddle.nn.Linear(embed_dim, 3 * embed_dim)
self.out_proj = paddle.nn.Linear(embed_dim, embed_dim) self.out_proj = paddle.nn.Linear(embed_dim, embed_dim)
self.dropout = paddle.nn.Dropout(0.1, mode="upscale_in_train") self.dropout = paddle.nn.Dropout(1e-10, mode="upscale_in_train")
def forward(self, x, attn_mask=None): def forward(self, x, attn_mask=None):
residual = x residual = x
...@@ -64,13 +64,13 @@ class MultiHeadAttention(paddle.nn.Layer): ...@@ -64,13 +64,13 @@ class MultiHeadAttention(paddle.nn.Layer):
# compute qkv # compute qkv
qkv = self.qkv_proj(x) qkv = self.qkv_proj(x)
qkv = paddle.reshape(qkv, [0, 0, self.num_heads, 3 * self.head_dim]) qkv = paddle.reshape(qkv, [0, 0, 3 * self.num_heads, self.head_dim])
qkv = paddle.transpose(qkv, [0, 2, 1, 3]) qkv = paddle.transpose(qkv, [0, 2, 1, 3])
q, k, v = paddle.split(qkv, num_or_sections=3, axis=-1) q, k, v = paddle.split(qkv, num_or_sections=3, axis=1)
# compute core attention # compute core attention
q = paddle.scale(q, scale=self.head_dim**-0.5)
product = paddle.matmul(x=q, y=k, transpose_y=True) product = paddle.matmul(x=q, y=k, transpose_y=True)
product = paddle.scale(product, scale=self.head_dim**-0.5)
if attn_mask is not None: if attn_mask is not None:
product = product + attn_mask product = product + attn_mask
weights = F.softmax(product) weights = F.softmax(product)
...@@ -104,21 +104,28 @@ class TestFusedAttentionPass(unittest.TestCase): ...@@ -104,21 +104,28 @@ class TestFusedAttentionPass(unittest.TestCase):
self.pre_ln = True self.pre_ln = True
self.attn_dropout = True self.attn_dropout = True
self.add_mask = True self.add_mask = True
self.x_data = None
self.mask_data = None
def test_pass(self): def get_rst(self, use_pass=False):
batch_size = 2 batch_size = 2
seq_len = 1024 seq_len = 1024
hidden_size = 768 hidden_size = 768
num_heads = 12 num_heads = 12
x_data = np.random.rand(batch_size, seq_len, seq_len).astype('float32') np.random.seed(1234)
mask_data = np.random.rand( if self.x_data is None:
batch_size, num_heads, seq_len, seq_len self.x_data = np.random.rand(batch_size, seq_len, seq_len).astype(
).astype('float32') 'float32'
)
self.mask_data = np.random.rand(
batch_size, num_heads, seq_len, seq_len
).astype('float32')
main_prog = paddle.static.Program() main_prog = paddle.static.Program()
main_prog.random_seed = 1234 main_prog.random_seed = 1234
startup_prog = paddle.static.Program() startup_prog = paddle.static.Program()
startup_prog.random_seed = 1234
with paddle.static.program_guard(main_prog, startup_prog): with paddle.static.program_guard(main_prog, startup_prog):
data = paddle.static.data( data = paddle.static.data(
...@@ -150,29 +157,36 @@ class TestFusedAttentionPass(unittest.TestCase): ...@@ -150,29 +157,36 @@ class TestFusedAttentionPass(unittest.TestCase):
sgd_optimizer = paddle.fluid.optimizer.SGD(learning_rate=0.001) sgd_optimizer = paddle.fluid.optimizer.SGD(learning_rate=0.001)
sgd_optimizer.minimize(loss) sgd_optimizer.minimize(loss)
pass_manager = PassManager([new_pass("fused_attention")]) if use_pass:
pass_manager.apply([main_prog], [startup_prog]) pass_manager = PassManager([new_pass("fused_attention")])
pass_manager.apply([main_prog], [startup_prog])
ops = main_prog.global_block().ops
assert ops[2].type == 'fused_attention' ops = main_prog.global_block().ops
assert ops[3].type == 'reduce_mean' assert ops[2].type == 'fused_attention'
assert ops[5].type == 'reduce_mean_grad' assert ops[3].type == 'reduce_mean'
assert ops[6].type == 'fused_attention_grad' assert ops[5].type == 'reduce_mean_grad'
# two ops for linear, one op for reduce mean assert ops[6].type == 'fused_attention_grad'
# one fill constant # two ops for linear, one op for reduce mean
# one op for reduce mean grad, two ops for linear bwd # one fill constant
# the eighth op should be the optimizer # one op for reduce mean grad, two ops for linear bwd
assert ops[9].type == 'sgd' # the eighth op should be the optimizer
assert ops[9].type == 'sgd'
exe = paddle.static.Executor() exe = paddle.static.Executor()
exe.run(startup_prog) exe.run(startup_prog)
rst = exe.run( for i in range(2):
main_prog, rst = exe.run(
feed={'x': x_data, 'attn_mask': mask_data}, main_prog,
fetch_list=[loss], feed={'x': self.x_data, 'attn_mask': self.mask_data},
) fetch_list=[loss],
)
return rst
def test_pass(self):
fused_rst = self.get_rst(use_pass=True)
non_fused_rst = self.get_rst()
assert np.allclose(fused_rst, non_fused_rst)
if __name__ == "__main__": if __name__ == "__main__":
np.random.seed(0)
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册