From 4d15b107f37e082538ed3e7768349683d59c577a Mon Sep 17 00:00:00 2001 From: ranqiu Date: Thu, 19 Oct 2017 10:53:03 +0800 Subject: [PATCH] Add multi-head attention --- .../paddle/trainer_config_helpers/networks.py | 140 +++++++++++++++++- 1 file changed, 136 insertions(+), 4 deletions(-) diff --git a/python/paddle/trainer_config_helpers/networks.py b/python/paddle/trainer_config_helpers/networks.py index 120c9d11a..c291a4ea1 100644 --- a/python/paddle/trainer_config_helpers/networks.py +++ b/python/paddle/trainer_config_helpers/networks.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import math from activations import LinearActivation, ReluActivation, SoftmaxActivation, \ IdentityActivation, TanhActivation, SequenceSoftmaxActivation @@ -26,9 +26,9 @@ __all__ = [ 'sequence_conv_pool', 'simple_lstm', "simple_img_conv_pool", "img_conv_bn_pool", 'lstmemory_group', 'lstmemory_unit', 'small_vgg', 'img_conv_group', 'vgg_16_network', 'gru_unit', 'gru_group', 'simple_gru', - 'simple_attention', 'dot_product_attention', 'simple_gru2', - 'bidirectional_gru', 'text_conv_pool', 'bidirectional_lstm', 'inputs', - 'outputs' + 'simple_attention', 'dot_product_attention', 'multi_head_attention', + 'simple_gru2', 'bidirectional_gru', 'text_conv_pool', 'bidirectional_lstm', + 'inputs', 'outputs' ] ###################################################### @@ -1480,6 +1480,138 @@ def dot_product_attention(encoded_sequence, input=scaled, pooling_type=SumPooling(), name="%s_pooling" % name) +@wrap_name_default() +def multi_head_attention(query, + key, + value, + key_proj_size, + value_proj_size, + head_num, + attention_type, + softmax_param_attr=None, + name=None): + """ + Calculate and return a context vector with dot-product attention mechanism. + The dimension of the context vector equals to value_proj_size * head_num. + + Please refer to **Attention Is All You Need** for more details. The link is + as follows: + https://arxiv.org/abs/1706.03762. + + The example usage is: + + .. code-block:: python + + context = multi_head_attention(query=decoder_state, + key=enc_seq, + value=enc_seq, + key_proj_size=64, + value_pro_size=64, + head_num=8, + attention_type='dot-product attention') + + :param name: A prefix attached to the name of each layer that defined inside + the multi_head_attention. + :type name: basestring + :param softmax_param_attr: The parameter attribute of sequence softmax + that is used to produce attention weight. + :type softmax_param_attr: ParameterAttribute + :param query: query is used to calculate attention weights over values at current step. + :type query: LayerOutput + :param key: key is used to calculate the attention weight of the corresponding value. + :type key: LayerOutput + :param value: value is the sequence to be attended. + :type value: LayerOutput + :param key_proj_size: The dimension of the linear projection performed on key and query. + :type key_proj_size: int + :param value_proj_size: The dimension of the linear projection performed on value. + :type value_proj_size: int + :param head_num: The number of attention heads. + :type head_num: int + :param attention_type: The type of the attention mechanism used in each attention + heads. Now, we only support scaled dot-product attention and ### + additive attention. + :type attention_type: basestring + :return: The context vector. + :rtype: LayerOutput + """ + assert attention_type in ['dot-product attention', 'additive attention'] + + with mixed_layer( + size=key_proj_size * head_num, + name='%s_query_proj' % name) as query_proj: + query_proj += full_matrix_projection(query) + query_proj = expand_layer(input=query_proj, expand_as=key) + + with mixed_layer( + size=key_proj_size * head_num, + name='%s_key_proj' % name) as key_proj: + key_proj += full_matrix_projection(key) + + with mixed_layer( + size=value_proj_size * head_num, + name='%s_value_proj' % name) as value_proj: + value_proj += full_matrix_projection(value) + + head_list = [] + for i in range(head_num): + with mixed_layer(size=key_proj_size) as sub_query_proj: + sub_query_proj += identity_projection( + query_proj, offset=key_proj_size * i) + + with mixed_layer(size=key_proj_size) as sub_key_proj: + sub_key_proj += identity_projection( + key_proj, offset=key_proj_size * i) + + with mixed_layer(size=value_proj_size) as sub_value_proj: + sub_value_proj += identity_projection( + value_proj, offset=value_proj_size * i) + + if attention_type == 'dot-product attention': + m = linear_comb_layer( + weights=sub_query_proj, + vectors=sub_key_proj, + name='%s_dot-product_%d' % (name, i)) + m = slope_intercept_layer( + input=m, + slope=math.sqrt(1.0 / key_proj_size), + name='%s_dot-product_scaling_%d' % (name, i)) + else: + with mixed_layer( + size=key_proj_size, + act=TanhActivation(), + name='%s_combine_%d' % (name, i)) as m: + m += identity_projection(sub_query_proj) + m += identity_projection(sub_key_proj) + + attention_weight = fc_layer( + input=m, + size=1, + act=SequenceSoftmaxActivation(), + param_attr=softmax_param_attr, + name="%s_softmax_%d" % (name, i), + bias_attr=False) + + scaled = scaling_layer( + weight=attention_weight, + input=sub_value_proj, + name='%s_scaling_%d' % (name, i)) + head = pooling_layer( + input=scaled, + pooling_type=SumPooling(), + name="%s_pooling_%d" % (name, i)) + + head_list.append(head) + + multi_head = concat_layer(head_list) + + with mixed_layer( + size=value_proj_size * head_num, name='%s_proj' % name) as attended: + attended += full_matrix_projection(multi_head) + + return attended + + def inputs(layers, *args): """ Declare the inputs of network. The order of input should be as same as -- GitLab