From 78320194aa25c0ced76a3a758413dedfd24e9550 Mon Sep 17 00:00:00 2001 From: ranqiu Date: Tue, 17 Oct 2017 17:15:06 +0800 Subject: [PATCH] refine dot-product attention according to the comments --- .../paddle/trainer_config_helpers/networks.py | 41 ++++++++++--------- 1 file changed, 22 insertions(+), 19 deletions(-) diff --git a/python/paddle/trainer_config_helpers/networks.py b/python/paddle/trainer_config_helpers/networks.py index 0ecbacb7bbc..120c9d11a5e 100644 --- a/python/paddle/trainer_config_helpers/networks.py +++ b/python/paddle/trainer_config_helpers/networks.py @@ -1400,13 +1400,13 @@ def simple_attention(encoded_sequence, @wrap_name_default() def dot_product_attention(encoded_sequence, - attending_sequence, + attended_sequence, transformed_state, softmax_param_attr=None, name=None): """ Calculate and return a context vector with dot-product attention mechanism. - Size of the context vector equals to size of the attending_sequence. + The dimension of the context vector equals to that of the attended_sequence. .. math:: @@ -1419,35 +1419,38 @@ def dot_product_attention(encoded_sequence, c_{i} & = \\sum_{j=1}^{T_{x}}a_{i,j}z_{j} where :math:`h_{j}` is the jth element of encoded_sequence, - :math:`z_{j}` is the jth element of attending_sequence, - :math:`s_{i-1}` is transformed_state + :math:`z_{j}` is the jth element of attended_sequence, + :math:`s_{i-1}` is transformed_state. The example usage is: .. code-block:: python context = dot_product_attention(encoded_sequence=enc_seq, - attending_sequence=att_seq, + attended_sequence=att_seq, transformed_state=state,) - :param name: name of the dot-product attention model. + :param name: A prefix attached to the name of each layer that defined inside + the dot_product_attention. :type name: basestring - :param softmax_param_attr: parameter attribute of sequence softmax + :param softmax_param_attr: The parameter attribute of sequence softmax that is used to produce attention weight. :type softmax_param_attr: ParameterAttribute - :param encoded_sequence: output of the encoder + :param encoded_sequence: The output hidden vectors of the encoder. :type encoded_sequence: LayerOutput - :param attending_sequence: attention weight is computed by a feed forward neural - network which has two inputs : decoder's transformed - hidden state of previous time step and encoder's output. - attending_sequence is the sequence to be attended. - :type attending_sequence: LayerOutput - :param transformed_state: transformed hidden state of decoder in previous time step, - its size should equal to encoded_sequence's. Here we do the - transformation outside dot_product_attention for flexibility - consideration. + :param attended_sequence: The attention weight is computed by a feed forward neural + network which has two inputs : decoder's transformed hidden + state of previous time step and encoder's output. + attended_sequence is the sequence to be attended. + :type attended_sequence: LayerOutput + :param transformed_state: The transformed hidden state of decoder in previous time step. + Since the dot-product operation will be performed on it and the + encoded_sequence, their dimensions must be equal. For flexibility, + we suppose transformations of the decoder's hidden state have been + done outside dot_product_attention and no more will be performed + inside. Then users can use either the original or transformed one. :type transformed_state: LayerOutput - :return: a context vector + :return: The context vector. :rtype: LayerOutput """ assert transformed_state.size == encoded_sequence.size @@ -1470,7 +1473,7 @@ def dot_product_attention(encoded_sequence, scaled = scaling_layer( weight=attention_weight, - input=attending_sequence, + input=attended_sequence, name='%s_scaling' % name) return pooling_layer( -- GitLab