From 7461b3597770e3b7fdd39a130e36d049c4e34f05 Mon Sep 17 00:00:00 2001 From: ranqiu Date: Sun, 12 Nov 2017 20:26:51 +0800 Subject: [PATCH] Refine multi-head attention --- python/paddle/trainer_config_helpers/networks.py | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/python/paddle/trainer_config_helpers/networks.py b/python/paddle/trainer_config_helpers/networks.py index 7afca8d7782..e23da2068cc 100644 --- a/python/paddle/trainer_config_helpers/networks.py +++ b/python/paddle/trainer_config_helpers/networks.py @@ -1557,15 +1557,15 @@ def multi_head_attention(query, for i in range(head_num): with mixed_layer(size=key_proj_size) as sub_query_proj: sub_query_proj += identity_projection( - query_proj, offset=key_proj_size * i) + query_proj, offset=key_proj_size * i, size=key_proj_size) with mixed_layer(size=key_proj_size) as sub_key_proj: sub_key_proj += identity_projection( - key_proj, offset=key_proj_size * i) + key_proj, offset=key_proj_size * i, size=key_proj_size) with mixed_layer(size=value_proj_size) as sub_value_proj: sub_value_proj += identity_projection( - value_proj, offset=value_proj_size * i) + value_proj, offset=value_proj_size * i, size=value_proj_size) if attention_type == 'dot-product attention': m = linear_comb_layer( @@ -1603,11 +1603,7 @@ def multi_head_attention(query, head_list.append(head) - multi_head = concat_layer(head_list) - - with mixed_layer( - size=value_proj_size * head_num, name='%s_proj' % name) as attended: - attended += full_matrix_projection(multi_head) + attended = concat_layer(head_list) return attended -- GitLab