提交 f2240293 编写于 作者: R ranqiu

Refine multi_head_attention

上级 d29901b8
...@@ -1586,9 +1586,9 @@ def multi_head_attention(query, ...@@ -1586,9 +1586,9 @@ def multi_head_attention(query,
value_proj, offset=value_proj_size * i, size=value_proj_size) value_proj, offset=value_proj_size * i, size=value_proj_size)
if attention_type == 'dot-product attention': if attention_type == 'dot-product attention':
m = linear_comb_layer( m = dot_prod_layer(
weights=sub_query_proj, input1=sub_query_proj,
vectors=sub_key_proj, input2=sub_key_proj,
name='%s_dot-product_%d' % (name, i)) name='%s_dot-product_%d' % (name, i))
m = slope_intercept_layer( m = slope_intercept_layer(
input=m, input=m,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册