提交 7461b359 编写于 作者: R ranqiu

Refine multi-head attention

上级 947c5285
...@@ -1557,15 +1557,15 @@ def multi_head_attention(query, ...@@ -1557,15 +1557,15 @@ def multi_head_attention(query,
for i in range(head_num): for i in range(head_num):
with mixed_layer(size=key_proj_size) as sub_query_proj: with mixed_layer(size=key_proj_size) as sub_query_proj:
sub_query_proj += identity_projection( sub_query_proj += identity_projection(
query_proj, offset=key_proj_size * i) query_proj, offset=key_proj_size * i, size=key_proj_size)
with mixed_layer(size=key_proj_size) as sub_key_proj: with mixed_layer(size=key_proj_size) as sub_key_proj:
sub_key_proj += identity_projection( sub_key_proj += identity_projection(
key_proj, offset=key_proj_size * i) key_proj, offset=key_proj_size * i, size=key_proj_size)
with mixed_layer(size=value_proj_size) as sub_value_proj: with mixed_layer(size=value_proj_size) as sub_value_proj:
sub_value_proj += identity_projection( sub_value_proj += identity_projection(
value_proj, offset=value_proj_size * i) value_proj, offset=value_proj_size * i, size=value_proj_size)
if attention_type == 'dot-product attention': if attention_type == 'dot-product attention':
m = linear_comb_layer( m = linear_comb_layer(
...@@ -1603,11 +1603,7 @@ def multi_head_attention(query, ...@@ -1603,11 +1603,7 @@ def multi_head_attention(query,
head_list.append(head) head_list.append(head)
multi_head = concat_layer(head_list) attended = concat_layer(head_list)
with mixed_layer(
size=value_proj_size * head_num, name='%s_proj' % name) as attended:
attended += full_matrix_projection(multi_head)
return attended return attended
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册