diff --git a/python/paddle/trainer_config_helpers/networks.py b/python/paddle/trainer_config_helpers/networks.py index 7afca8d77823e64337dfca59a3a6dad9a5baec90..e23da2068ccb731ddfa748374db7f55173016a2b 100644 --- a/python/paddle/trainer_config_helpers/networks.py +++ b/python/paddle/trainer_config_helpers/networks.py @@ -1557,15 +1557,15 @@ def multi_head_attention(query, for i in range(head_num): with mixed_layer(size=key_proj_size) as sub_query_proj: sub_query_proj += identity_projection( - query_proj, offset=key_proj_size * i) + query_proj, offset=key_proj_size * i, size=key_proj_size) with mixed_layer(size=key_proj_size) as sub_key_proj: sub_key_proj += identity_projection( - key_proj, offset=key_proj_size * i) + key_proj, offset=key_proj_size * i, size=key_proj_size) with mixed_layer(size=value_proj_size) as sub_value_proj: sub_value_proj += identity_projection( - value_proj, offset=value_proj_size * i) + value_proj, offset=value_proj_size * i, size=value_proj_size) if attention_type == 'dot-product attention': m = linear_comb_layer( @@ -1603,11 +1603,7 @@ def multi_head_attention(query, head_list.append(head) - multi_head = concat_layer(head_list) - - with mixed_layer( - size=value_proj_size * head_num, name='%s_proj' % name) as attended: - attended += full_matrix_projection(multi_head) + attended = concat_layer(head_list) return attended