提交 78320194 编写于 作者: R ranqiu

refine dot-product attention according to the comments

上级 4545a058
...@@ -1400,13 +1400,13 @@ def simple_attention(encoded_sequence, ...@@ -1400,13 +1400,13 @@ def simple_attention(encoded_sequence,
@wrap_name_default() @wrap_name_default()
def dot_product_attention(encoded_sequence, def dot_product_attention(encoded_sequence,
attending_sequence, attended_sequence,
transformed_state, transformed_state,
softmax_param_attr=None, softmax_param_attr=None,
name=None): name=None):
""" """
Calculate and return a context vector with dot-product attention mechanism. Calculate and return a context vector with dot-product attention mechanism.
Size of the context vector equals to size of the attending_sequence. The dimension of the context vector equals to that of the attended_sequence.
.. math:: .. math::
...@@ -1419,35 +1419,38 @@ def dot_product_attention(encoded_sequence, ...@@ -1419,35 +1419,38 @@ def dot_product_attention(encoded_sequence,
c_{i} & = \\sum_{j=1}^{T_{x}}a_{i,j}z_{j} c_{i} & = \\sum_{j=1}^{T_{x}}a_{i,j}z_{j}
where :math:`h_{j}` is the jth element of encoded_sequence, where :math:`h_{j}` is the jth element of encoded_sequence,
:math:`z_{j}` is the jth element of attending_sequence, :math:`z_{j}` is the jth element of attended_sequence,
:math:`s_{i-1}` is transformed_state :math:`s_{i-1}` is transformed_state.
The example usage is: The example usage is:
.. code-block:: python .. code-block:: python
context = dot_product_attention(encoded_sequence=enc_seq, context = dot_product_attention(encoded_sequence=enc_seq,
attending_sequence=att_seq, attended_sequence=att_seq,
transformed_state=state,) transformed_state=state,)
:param name: name of the dot-product attention model. :param name: A prefix attached to the name of each layer that defined inside
the dot_product_attention.
:type name: basestring :type name: basestring
:param softmax_param_attr: parameter attribute of sequence softmax :param softmax_param_attr: The parameter attribute of sequence softmax
that is used to produce attention weight. that is used to produce attention weight.
:type softmax_param_attr: ParameterAttribute :type softmax_param_attr: ParameterAttribute
:param encoded_sequence: output of the encoder :param encoded_sequence: The output hidden vectors of the encoder.
:type encoded_sequence: LayerOutput :type encoded_sequence: LayerOutput
:param attending_sequence: attention weight is computed by a feed forward neural :param attended_sequence: The attention weight is computed by a feed forward neural
network which has two inputs : decoder's transformed network which has two inputs : decoder's transformed hidden
hidden state of previous time step and encoder's output. state of previous time step and encoder's output.
attending_sequence is the sequence to be attended. attended_sequence is the sequence to be attended.
:type attending_sequence: LayerOutput :type attended_sequence: LayerOutput
:param transformed_state: transformed hidden state of decoder in previous time step, :param transformed_state: The transformed hidden state of decoder in previous time step.
its size should equal to encoded_sequence's. Here we do the Since the dot-product operation will be performed on it and the
transformation outside dot_product_attention for flexibility encoded_sequence, their dimensions must be equal. For flexibility,
consideration. we suppose transformations of the decoder's hidden state have been
done outside dot_product_attention and no more will be performed
inside. Then users can use either the original or transformed one.
:type transformed_state: LayerOutput :type transformed_state: LayerOutput
:return: a context vector :return: The context vector.
:rtype: LayerOutput :rtype: LayerOutput
""" """
assert transformed_state.size == encoded_sequence.size assert transformed_state.size == encoded_sequence.size
...@@ -1470,7 +1473,7 @@ def dot_product_attention(encoded_sequence, ...@@ -1470,7 +1473,7 @@ def dot_product_attention(encoded_sequence,
scaled = scaling_layer( scaled = scaling_layer(
weight=attention_weight, weight=attention_weight,
input=attending_sequence, input=attended_sequence,
name='%s_scaling' % name) name='%s_scaling' % name)
return pooling_layer( return pooling_layer(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册