提交 4d15b107 编写于 作者: R ranqiu

Add multi-head attention

上级 7ad15259
...@@ -11,7 +11,7 @@ ...@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import math
from activations import LinearActivation, ReluActivation, SoftmaxActivation, \ from activations import LinearActivation, ReluActivation, SoftmaxActivation, \
IdentityActivation, TanhActivation, SequenceSoftmaxActivation IdentityActivation, TanhActivation, SequenceSoftmaxActivation
...@@ -26,9 +26,9 @@ __all__ = [ ...@@ -26,9 +26,9 @@ __all__ = [
'sequence_conv_pool', 'simple_lstm', "simple_img_conv_pool", 'sequence_conv_pool', 'simple_lstm', "simple_img_conv_pool",
"img_conv_bn_pool", 'lstmemory_group', 'lstmemory_unit', 'small_vgg', "img_conv_bn_pool", 'lstmemory_group', 'lstmemory_unit', 'small_vgg',
'img_conv_group', 'vgg_16_network', 'gru_unit', 'gru_group', 'simple_gru', 'img_conv_group', 'vgg_16_network', 'gru_unit', 'gru_group', 'simple_gru',
'simple_attention', 'dot_product_attention', 'simple_gru2', 'simple_attention', 'dot_product_attention', 'multi_head_attention',
'bidirectional_gru', 'text_conv_pool', 'bidirectional_lstm', 'inputs', 'simple_gru2', 'bidirectional_gru', 'text_conv_pool', 'bidirectional_lstm',
'outputs' 'inputs', 'outputs'
] ]
###################################################### ######################################################
...@@ -1480,6 +1480,138 @@ def dot_product_attention(encoded_sequence, ...@@ -1480,6 +1480,138 @@ def dot_product_attention(encoded_sequence,
input=scaled, pooling_type=SumPooling(), name="%s_pooling" % name) input=scaled, pooling_type=SumPooling(), name="%s_pooling" % name)
@wrap_name_default()
def multi_head_attention(query,
key,
value,
key_proj_size,
value_proj_size,
head_num,
attention_type,
softmax_param_attr=None,
name=None):
"""
Calculate and return a context vector with dot-product attention mechanism.
The dimension of the context vector equals to value_proj_size * head_num.
Please refer to **Attention Is All You Need** for more details. The link is
as follows:
https://arxiv.org/abs/1706.03762.
The example usage is:
.. code-block:: python
context = multi_head_attention(query=decoder_state,
key=enc_seq,
value=enc_seq,
key_proj_size=64,
value_pro_size=64,
head_num=8,
attention_type='dot-product attention')
:param name: A prefix attached to the name of each layer that defined inside
the multi_head_attention.
:type name: basestring
:param softmax_param_attr: The parameter attribute of sequence softmax
that is used to produce attention weight.
:type softmax_param_attr: ParameterAttribute
:param query: query is used to calculate attention weights over values at current step.
:type query: LayerOutput
:param key: key is used to calculate the attention weight of the corresponding value.
:type key: LayerOutput
:param value: value is the sequence to be attended.
:type value: LayerOutput
:param key_proj_size: The dimension of the linear projection performed on key and query.
:type key_proj_size: int
:param value_proj_size: The dimension of the linear projection performed on value.
:type value_proj_size: int
:param head_num: The number of attention heads.
:type head_num: int
:param attention_type: The type of the attention mechanism used in each attention
heads. Now, we only support scaled dot-product attention and ###
additive attention.
:type attention_type: basestring
:return: The context vector.
:rtype: LayerOutput
"""
assert attention_type in ['dot-product attention', 'additive attention']
with mixed_layer(
size=key_proj_size * head_num,
name='%s_query_proj' % name) as query_proj:
query_proj += full_matrix_projection(query)
query_proj = expand_layer(input=query_proj, expand_as=key)
with mixed_layer(
size=key_proj_size * head_num,
name='%s_key_proj' % name) as key_proj:
key_proj += full_matrix_projection(key)
with mixed_layer(
size=value_proj_size * head_num,
name='%s_value_proj' % name) as value_proj:
value_proj += full_matrix_projection(value)
head_list = []
for i in range(head_num):
with mixed_layer(size=key_proj_size) as sub_query_proj:
sub_query_proj += identity_projection(
query_proj, offset=key_proj_size * i)
with mixed_layer(size=key_proj_size) as sub_key_proj:
sub_key_proj += identity_projection(
key_proj, offset=key_proj_size * i)
with mixed_layer(size=value_proj_size) as sub_value_proj:
sub_value_proj += identity_projection(
value_proj, offset=value_proj_size * i)
if attention_type == 'dot-product attention':
m = linear_comb_layer(
weights=sub_query_proj,
vectors=sub_key_proj,
name='%s_dot-product_%d' % (name, i))
m = slope_intercept_layer(
input=m,
slope=math.sqrt(1.0 / key_proj_size),
name='%s_dot-product_scaling_%d' % (name, i))
else:
with mixed_layer(
size=key_proj_size,
act=TanhActivation(),
name='%s_combine_%d' % (name, i)) as m:
m += identity_projection(sub_query_proj)
m += identity_projection(sub_key_proj)
attention_weight = fc_layer(
input=m,
size=1,
act=SequenceSoftmaxActivation(),
param_attr=softmax_param_attr,
name="%s_softmax_%d" % (name, i),
bias_attr=False)
scaled = scaling_layer(
weight=attention_weight,
input=sub_value_proj,
name='%s_scaling_%d' % (name, i))
head = pooling_layer(
input=scaled,
pooling_type=SumPooling(),
name="%s_pooling_%d" % (name, i))
head_list.append(head)
multi_head = concat_layer(head_list)
with mixed_layer(
size=value_proj_size * head_num, name='%s_proj' % name) as attended:
attended += full_matrix_projection(multi_head)
return attended
def inputs(layers, *args): def inputs(layers, *args):
""" """
Declare the inputs of network. The order of input should be as same as Declare the inputs of network. The order of input should be as same as
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册