diff --git a/paddle/fluid/framework/ir/fused_attention_pass.cc b/paddle/fluid/framework/ir/fused_attention_pass.cc index dcf5f05e643ebd267e956b3d5661c70674cab447..bebed80c469c1e6dfff3bd229b86e94b45945ceb 100644 --- a/paddle/fluid/framework/ir/fused_attention_pass.cc +++ b/paddle/fluid/framework/ir/fused_attention_pass.cc @@ -123,23 +123,23 @@ PDNode* FusedAttentionPattern::operator()(PDNode* x, fuse_qkv_split_out_v_node}); // core attention pattern + auto* qk_scale_node = + pattern->NewNode(qk_scale_op_repr())->assert_is_op("scale"); + auto* qk_scale_out_node = + pattern->NewNode(qk_scale_out_repr())->assert_is_op_output("scale"); + fuse_qkv_split_out_q_node->assert_is_op_input("scale", "X"); + qk_scale_node->LinksFrom({fuse_qkv_split_out_q_node}) + .LinksTo({qk_scale_out_node}); + auto* qk_matmul_node = pattern->NewNode(qk_matmul_op_repr())->assert_is_op("matmul_v2"); auto* qk_matmul_out_node = pattern->NewNode(qk_matmul_out_repr())->assert_is_op_output("matmul_v2"); - fuse_qkv_split_out_q_node->assert_is_op_input("matmul_v2", "X"); + qk_scale_out_node->assert_is_op_input("matmul_v2", "X"); fuse_qkv_split_out_k_node->assert_is_op_input("matmul_v2", "Y"); - qk_matmul_node - ->LinksFrom({fuse_qkv_split_out_q_node, fuse_qkv_split_out_k_node}) + qk_matmul_node->LinksFrom({qk_scale_out_node, fuse_qkv_split_out_k_node}) .LinksTo({qk_matmul_out_node}); - auto* qk_scale_node = - pattern->NewNode(qk_scale_op_repr())->assert_is_op("scale"); - auto* qk_scale_out_node = - pattern->NewNode(qk_scale_out_repr())->assert_is_op_output("scale"); - qk_matmul_out_node->assert_is_op_input("scale", "X"); - qk_scale_node->LinksFrom({qk_matmul_out_node}).LinksTo({qk_scale_out_node}); - PDNode* add_mask_ele_add_out_node{nullptr}; if (has_attn_mask) { auto* add_mask_ele_add_node = pattern->NewNode(add_mask_ele_add_op_repr()) @@ -149,9 +149,9 @@ PDNode* FusedAttentionPattern::operator()(PDNode* x, ->assert_is_op_input("elementwise_add", "Y"); add_mask_ele_add_out_node = pattern->NewNode(add_mask_ele_add_out_repr()) ->assert_is_op_output("elementwise_add"); - qk_scale_out_node->assert_is_op_input("elementwise_add", "X"); + qk_matmul_out_node->assert_is_op_input("elementwise_add", "X"); add_mask_ele_add_node - ->LinksFrom({qk_scale_out_node, add_mask_ele_add_mask_node}) + ->LinksFrom({qk_matmul_out_node, add_mask_ele_add_mask_node}) .LinksTo({add_mask_ele_add_out_node}); } @@ -164,8 +164,8 @@ PDNode* FusedAttentionPattern::operator()(PDNode* x, qk_softmax_node->LinksFrom({add_mask_ele_add_out_node}) .LinksTo({qk_softmax_out_node}); } else { - qk_scale_out_node->assert_is_op_input("softmax", "X"); - qk_softmax_node->LinksFrom({qk_scale_out_node}) + qk_matmul_out_node->assert_is_op_input("softmax", "X"); + qk_softmax_node->LinksFrom({qk_matmul_out_node}) .LinksTo({qk_softmax_out_node}); } @@ -575,16 +575,8 @@ PDNode* FusedAttentionGradPattern::operator()(PDNode* x, .LinksTo({add_mask_ele_add_grad_x_grad_node}); } - PDNode* qk_scale_grad_input_node = + PDNode* qk_matmul_grad_input_node = has_attn_mask ? add_mask_ele_add_grad_x_grad_node : qk_softmax_grad_out; - auto* qk_scale_grad_node = - pattern->NewNode(qk_scale_grad_op_repr())->assert_is_op("scale"); - auto* qk_scale_grad_out_node = - pattern->NewNode(qk_scale_grad_out_repr())->assert_is_op_output("scale"); - qk_scale_grad_input_node->assert_is_op_input("scale", "X"); - qk_scale_grad_node->LinksFrom({qk_scale_grad_input_node}) - .LinksTo({qk_scale_grad_out_node}); - auto* qk_matmul_grad_node = pattern->NewNode(qk_matmul_grad_op_repr()) ->assert_is_op("matmul_v2_grad"); auto* qk_matmul_grad_x_node = pattern->NewNode(qk_matmul_grad_x_repr()) @@ -597,24 +589,32 @@ PDNode* FusedAttentionGradPattern::operator()(PDNode* x, auto* qk_matmul_grad_w_grad_node = pattern->NewNode(qk_matmul_grad_w_grad_repr()) ->assert_is_op_output("matmul_v2_grad", "Y@GRAD"); - qk_scale_grad_out_node->assert_is_op_input("matmul_v2_grad", "Out@GRAD"); + qk_matmul_grad_input_node->assert_is_op_input("matmul_v2_grad", "Out@GRAD"); qk_matmul_grad_node - ->LinksFrom({qk_scale_grad_out_node, + ->LinksFrom({qk_matmul_grad_input_node, qk_matmul_grad_x_node, qk_matmul_grad_w_node}) .LinksTo({qk_matmul_grad_x_grad_node, qk_matmul_grad_w_grad_node}); + auto* qk_scale_grad_node = + pattern->NewNode(qk_scale_grad_op_repr())->assert_is_op("scale"); + auto* qk_scale_grad_out_node = + pattern->NewNode(qk_scale_grad_out_repr())->assert_is_op_output("scale"); + qk_matmul_grad_x_grad_node->assert_is_op_input("scale", "X"); + qk_scale_grad_node->LinksFrom({qk_matmul_grad_x_grad_node}) + .LinksTo({qk_scale_grad_out_node}); + // fuse qkv projection auto* fuse_qkv_split_grad_node = pattern->NewNode(fuse_qkv_split_grad_op_repr())->assert_is_op("concat"); auto* fuse_qkv_split_grad_out_node = pattern->NewNode(fuse_qkv_split_grad_out_repr()) ->assert_is_op_output("concat"); - qk_matmul_grad_x_grad_node->assert_is_op_input("concat"); // q grad + qk_scale_grad_out_node->assert_is_op_input("concat"); // q grad qk_matmul_grad_w_grad_node->assert_is_op_input("concat"); // k grad qkv_matmul_grad_w_grad_node->assert_is_op_input("concat"); // v grad fuse_qkv_split_grad_node - ->LinksFrom({qk_matmul_grad_x_grad_node, + ->LinksFrom({qk_scale_grad_out_node, qk_matmul_grad_w_grad_node, qkv_matmul_grad_w_grad_node}) .LinksTo({fuse_qkv_split_grad_out_node}); @@ -894,7 +894,7 @@ ir::Graph* FusedAttentionsPass::PreMaskDropResFwd( fused_attention_op_desc.SetAttr("transpose_qkv_wb", true); std::vector shape = PADDLE_GET_CONST( std::vector, fuse_qkv_reshape_op_node->Op()->GetAttr("shape")); - fused_attention_op_desc.SetAttr("num_heads", shape[2]); + fused_attention_op_desc.SetAttr("num_heads", shape[2] / 3); GET_IR_NODE_FROM_SUBGRAPH( fuse_qkv_matmul_out_node, fuse_qkv_matmul_out, fused_attention_pattern); GET_IR_NODE_FROM_SUBGRAPH(fuse_qkv_ele_add_bias_node, @@ -1337,7 +1337,7 @@ ir::Graph* FusedAttentionsPass::PreMaskDropResBwd( std::vector shape = PADDLE_GET_CONST(std::vector, fuse_qkv_reshape_grad_op_node->Op()->GetAttr("shape")); - fused_attention_grad_op_desc.SetAttr("num_heads", shape[2]); + fused_attention_grad_op_desc.SetAttr("num_heads", shape[2] / 3); fused_attention_grad_op_desc.SetAttr("pre_layer_norm", true); fused_attention_grad_op_desc.SetAttr("transpose_qkv_wb", true); diff --git a/python/paddle/fluid/tests/unittests/test_fused_attention_pass.py b/python/paddle/fluid/tests/unittests/test_fused_attention_pass.py index 98085c223a0cbf9ba0bab9625cba4882432bfca5..f27daa2d0c119e023471d2029a4bccc5f5742c71 100644 --- a/python/paddle/fluid/tests/unittests/test_fused_attention_pass.py +++ b/python/paddle/fluid/tests/unittests/test_fused_attention_pass.py @@ -53,7 +53,7 @@ class MultiHeadAttention(paddle.nn.Layer): self.qkv_proj = paddle.nn.Linear(embed_dim, 3 * embed_dim) self.out_proj = paddle.nn.Linear(embed_dim, embed_dim) - self.dropout = paddle.nn.Dropout(0.1, mode="upscale_in_train") + self.dropout = paddle.nn.Dropout(1e-10, mode="upscale_in_train") def forward(self, x, attn_mask=None): residual = x @@ -64,13 +64,13 @@ class MultiHeadAttention(paddle.nn.Layer): # compute qkv qkv = self.qkv_proj(x) - qkv = paddle.reshape(qkv, [0, 0, self.num_heads, 3 * self.head_dim]) + qkv = paddle.reshape(qkv, [0, 0, 3 * self.num_heads, self.head_dim]) qkv = paddle.transpose(qkv, [0, 2, 1, 3]) - q, k, v = paddle.split(qkv, num_or_sections=3, axis=-1) + q, k, v = paddle.split(qkv, num_or_sections=3, axis=1) # compute core attention + q = paddle.scale(q, scale=self.head_dim**-0.5) product = paddle.matmul(x=q, y=k, transpose_y=True) - product = paddle.scale(product, scale=self.head_dim**-0.5) if attn_mask is not None: product = product + attn_mask weights = F.softmax(product) @@ -104,21 +104,28 @@ class TestFusedAttentionPass(unittest.TestCase): self.pre_ln = True self.attn_dropout = True self.add_mask = True + self.x_data = None + self.mask_data = None - def test_pass(self): + def get_rst(self, use_pass=False): batch_size = 2 seq_len = 1024 hidden_size = 768 num_heads = 12 - x_data = np.random.rand(batch_size, seq_len, seq_len).astype('float32') - mask_data = np.random.rand( - batch_size, num_heads, seq_len, seq_len - ).astype('float32') + np.random.seed(1234) + if self.x_data is None: + self.x_data = np.random.rand(batch_size, seq_len, seq_len).astype( + 'float32' + ) + self.mask_data = np.random.rand( + batch_size, num_heads, seq_len, seq_len + ).astype('float32') main_prog = paddle.static.Program() main_prog.random_seed = 1234 startup_prog = paddle.static.Program() + startup_prog.random_seed = 1234 with paddle.static.program_guard(main_prog, startup_prog): data = paddle.static.data( @@ -150,29 +157,36 @@ class TestFusedAttentionPass(unittest.TestCase): sgd_optimizer = paddle.fluid.optimizer.SGD(learning_rate=0.001) sgd_optimizer.minimize(loss) - pass_manager = PassManager([new_pass("fused_attention")]) - pass_manager.apply([main_prog], [startup_prog]) - - ops = main_prog.global_block().ops - assert ops[2].type == 'fused_attention' - assert ops[3].type == 'reduce_mean' - assert ops[5].type == 'reduce_mean_grad' - assert ops[6].type == 'fused_attention_grad' - # two ops for linear, one op for reduce mean - # one fill constant - # one op for reduce mean grad, two ops for linear bwd - # the eighth op should be the optimizer - assert ops[9].type == 'sgd' + if use_pass: + pass_manager = PassManager([new_pass("fused_attention")]) + pass_manager.apply([main_prog], [startup_prog]) + + ops = main_prog.global_block().ops + assert ops[2].type == 'fused_attention' + assert ops[3].type == 'reduce_mean' + assert ops[5].type == 'reduce_mean_grad' + assert ops[6].type == 'fused_attention_grad' + # two ops for linear, one op for reduce mean + # one fill constant + # one op for reduce mean grad, two ops for linear bwd + # the eighth op should be the optimizer + assert ops[9].type == 'sgd' exe = paddle.static.Executor() exe.run(startup_prog) - rst = exe.run( - main_prog, - feed={'x': x_data, 'attn_mask': mask_data}, - fetch_list=[loss], - ) + for i in range(2): + rst = exe.run( + main_prog, + feed={'x': self.x_data, 'attn_mask': self.mask_data}, + fetch_list=[loss], + ) + return rst + + def test_pass(self): + fused_rst = self.get_rst(use_pass=True) + non_fused_rst = self.get_rst() + assert np.allclose(fused_rst, non_fused_rst) if __name__ == "__main__": - np.random.seed(0) unittest.main()