From a9d30244f9642bf3d941244db5565136f2799e9f Mon Sep 17 00:00:00 2001 From: liujiaxiang Date: Wed, 12 Aug 2020 11:08:15 +0800 Subject: [PATCH] add big bird transformer based on PGL --- examples/xformer/README.md | 28 ++ .../sparse_scaled_dot_product_attention.py | 224 +++++++++++ .../xformer/transformer_encoder_sparse.py | 361 ++++++++++++++++++ 3 files changed, 613 insertions(+) create mode 100644 examples/xformer/README.md create mode 100644 examples/xformer/sparse_scaled_dot_product_attention.py create mode 100644 examples/xformer/transformer_encoder_sparse.py diff --git a/examples/xformer/README.md b/examples/xformer/README.md new file mode 100644 index 0000000..23aee9c --- /dev/null +++ b/examples/xformer/README.md @@ -0,0 +1,28 @@ +# X-Transformer + +Models based on Transformers are wildly successful for a wide variety of Natural Language Processing (NLP) tasks and consequently are a mainstay of modern NLP research. Transformer is constituted of a self-attention and a feed-forward module. The self-attention mechanism allows each token in the input sequence to attend independently to every other token in the sequence. From the view of graph representation, the generalized attention mechanism can be described by a directed graph whose vertex is the token. So, the attention module can be implemented by a graph library, especially recently the efficient attention implementation, e.g. [BigBird](https://arxiv.org/abs/2007.14062) \ [LongFormer](https://arxiv.org/abs/2004.05150) \ [Sparse Transformer](https://arxiv.org/abs/1904.10509). + +We have showcased the [BigBird](https://arxiv.org/abs/2007.14062) implementation and tested the performence as show below, and the [LongFormer](https://arxiv.org/abs/2004.05150) \ [Sparse Transformer](https://arxiv.org/abs/1904.10509) can be easily implemented by revised the correspoding code. + + + +## Dependencies + +- [paddlepaddle >= 1.7](https://github.com/PaddlePaddle/paddle) +- [pgl 1.1](https://github.com/PaddlePaddle/PGL) + + +## Performance + +We have evaluate the implemented method on a summarization dataset CNN/DM. The experiment was conducted on two P40 GPU cards. + +| CNN/DM | BatchSize | R1 | R2 | R3 | speed(steps/s) | +| ------------------ | --------- | ----------------- | ----------------- | ----------------- | ------ | +| LEAD | - | 40.42 | 17.62 | 36.67 | - | +| Oracle | - | 52.59 | 31.24 | 48.87 | - | +| non-sparse, L=512 | 32 | 42.175 | 19.392 | 38.613 | 0.6359 | +| L=2048 | 10 | 41.334 | 18.369 | 37.752 | 0.8246 | +| L=1024 | 20 | 41.453 | 18.529 | 37.872 | 0.6432 | +| L=768 | 26 | 41.611 | 18.735 | 38.051 | 0.6517 | +| L=512 | 40 | 41.742 | 18.733 | 38.127 | 0.6213 | + diff --git a/examples/xformer/sparse_scaled_dot_product_attention.py b/examples/xformer/sparse_scaled_dot_product_attention.py new file mode 100644 index 0000000..c15501d --- /dev/null +++ b/examples/xformer/sparse_scaled_dot_product_attention.py @@ -0,0 +1,224 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +######################################################################## +# # +# Copyright (c) 2019 Baidu.com, Inc. All Rights Reserved # +# # +######################################################################## +import paddle.fluid as fluid +import paddle.fluid.layers as L +import paddle.fluid.layers as layers +from pgl.utils import paddle_helper +import pgl + +def masked_select(input, mask): + """masked_select + + Slice the value from given Mask + + Args: + input: Input tensor to be selected + + mask: A bool tensor for sliced. + + Return: + Part of inputs where mask is True. + """ + index = L.where(mask) + return L.gather(input, index, overwrite=False) + + + +class BigBirdWrapper(pgl.graph_wrapper.BaseGraphWrapper): + """Implement of Edge Drop """ + def __init__(self, input_mask): + super(BigBirdWrapper, self).__init__() + max_seqlen = L.shape(input_mask)[1] + input_mask = L.reshape(input_mask, [-1]) + num_nodes = L.shape(input_mask)[0] + src, dst = build_edges(num_nodes, input_mask, max_seqlen) + self._edges_src = src + self._edges_dst = dst + self._edges_src.stop_gradient=True + self._edges_dst.stop_gradient=True + self._num_nodes = num_nodes + self._num_edges = L.shape(self._edges_src)[0] + self._node_ids = L.range(0, self._num_nodes, step=1, dtype="int32") + self._edge_uniq_dst, _, uniq_count = L.unique_with_counts(self._edges_dst, dtype="int32") + self._edge_uniq_dst.stop_gradient=True + last = L.reduce_sum(uniq_count, keep_dim=True) + uniq_count = L.cumsum(uniq_count, exclusive=True) + self._edge_uniq_dst_count = L.concat([uniq_count, last]) + self._edge_uniq_dst_count.stop_gradient=True + +def select_edges(src, dst, input_mask, num_nodes, max_seqlen): + src = fluid.layers.elementwise_max(src, num_nodes * 0) + dst = fluid.layers.elementwise_max(dst, num_nodes * 0) + src = fluid.layers.elementwise_min(src, num_nodes - 1) + dst = fluid.layers.elementwise_min(dst, num_nodes - 1) + + conditions = [] + conditions.append(L.gather(input_mask, src) > 0.5) + conditions.append(L.gather(input_mask, dst) > 0.5) + block_src = src / max_seqlen + block_dst = dst / max_seqlen + conditions.append(block_src == block_dst) + mask = None + for cond in conditions: + if mask is None: + mask = cond + else: + mask = L.logical_and(mask, cond) + + dst = masked_select(dst, mask) + src = masked_select(src, mask) + return src, dst + +def uniq_edges(src, dst, num_nodes): + sorted_dst = L.cast(dst, dtype="int64") + sorted_src = L.cast(src, dtype="int64") + num_nodes = L.cast(num_nodes, dtype="int64") + edge_hash = sorted_dst * num_nodes + sorted_src + edge_hash, _ = L.argsort(edge_hash) + edge_hash, _ = L.unique(edge_hash, dtype="int64") + sorted_src = L.elementwise_mod(edge_hash, num_nodes) + sorted_dst = L.elementwise_div(edge_hash, num_nodes) + sorted_src = L.cast(sorted_src, dtype="int32") + sorted_dst = L.cast(sorted_dst, dtype="int32") + return sorted_src, sorted_dst + + +#def build_edges(num_nodes, input_mask, max_seqlen): +# edges = L.range(start=0, end=num_nodes, step=1, dtype="int32") +# all_edges = [] +# # Window +# filter_func = lambda x, y: select_edges(x, y, input_mask, num_nodes, max_seqlen) +# +# all_edges.append(filter_func(edges - 1, edges)) # win-1 +# all_edges.append(filter_func(edges + 1, edges)) # win-2 +# all_edges.append(filter_func(edges, edges)) #self-loop +# +# # Global Assume [CLS] is the first token. +# cls_position = edges / max_seqlen * max_seqlen +# all_edges.append(filter_func(cls_position, edges)) +# all_edges.append(filter_func(edges, cls_position)) +# +# # Random +# for i in range(2): +# rand_edge = L.floor(L.uniform_random(min=0, max=1, shape=[num_nodes]) * L.cast(max_seqlen, dtype="float32")) +# rand_edge = L.cast(rand_edge, dtype="int32") + cls_position +# all_edges.append(filter_func(rand_edge, edges)) +# +# if len(all_edges) > 1: +# src = L.concat([ s for s, d in all_edges], 0) +# dst = L.concat([ d for s, d in all_edges], 0) +# else: +# src = all_edges[0][0] +# dst = all_edges[0][1] +# +# # sort edges +# sorted_src, sorted_dst = uniq_edges(src, dst, num_nodes) +# return sorted_src, sorted_dst + +def build_edges(num_nodes, input_mask, max_seqlen): + edges = L.range(start=0, end=num_nodes, step=1, dtype="int32") + all_edges = [] + # Window + filter_func = lambda x, y: select_edges(x, y, input_mask, num_nodes, max_seqlen) + + all_edges.append(filter_func(edges - 1, edges)) # win-1 + #all_edges.append(filter_func(edges - 2, edges)) # win-1 + #all_edges.append(filter_func(edges - 3, edges)) # win-1 + all_edges.append(filter_func(edges + 1, edges)) # win-2 + #all_edges.append(filter_func(edges + 2, edges)) # win-2 + #all_edges.append(filter_func(edges + 3, edges)) # win-2 + all_edges.append(filter_func(edges, edges)) #self-loop + + # Global Assume [CLS] is the first token. + + # vertical cls-window attention + cls_position = edges / max_seqlen * max_seqlen + #all_edges.append(filter_func(cls_position + 1, edges)) + all_edges.append(filter_func(cls_position, edges)) + + # vertical sliding attention + #all_edges.append(filter_func(cls_position + 6, edges)) + #all_edges.append(filter_func(cls_position + max_seqlen - 6, edges)) + + # horizontal cls attention + all_edges.append(filter_func(edges, cls_position)) + #all_edges.append(filter_func(edges, cls_position)) + + # horizontal sliding attention + #all_edges.append(filter_func(edges, cls_position + 6) + #all_edges.append(filter_func(edges, cls_position + max_seq_len - 6) + + + # Random + #for i in range(2): + for i in range(2): + rand_edge = L.floor(L.uniform_random(min=0, max=1, shape=[num_nodes]) * L.cast(max_seqlen, dtype="float32")) + rand_edge = L.cast(rand_edge, dtype="int32") + cls_position + all_edges.append(filter_func(rand_edge, edges)) + + if len(all_edges) > 1: + src = L.concat([ s for s, d in all_edges], 0) + dst = L.concat([ d for s, d in all_edges], 0) + else: + src = all_edges[0][0] + dst = all_edges[0][1] + + # sort edges + sorted_src, sorted_dst = uniq_edges(src, dst, num_nodes) + return sorted_src, sorted_dst + + + + +def sparse_scaled_dot_product_attention(q, k, v, input_mask, dropout_rate, n_head, d_key, d_value): + def send_q_k_spmm(src_feat, dst_feat, edge_feat): + # q [ num_edges, n_head * dim] + # k [ num_edges, n_head * dim] + # v [ num_edges, n_head * dim] + _q = dst_feat["q"] + _k = src_feat["k"] + _v = src_feat["v"] + _q = L.reshape(_q, [-1, n_head, _q.shape[-1] // n_head]) + _k = L.reshape(_k, [-1, n_head, _k.shape[-1] // n_head]) + score = L.reduce_sum(_q * _k, -1) # [num_edge, n_head] + return { "score": score, "value": _v} + + def recv_score_v_spmm(msg): + score = msg["score"] + score = paddle_helper.sequence_softmax(score) + score = layers.dropout( + score, + dropout_prob=dropout_rate, + dropout_implementation="upscale_in_train", + is_test=False) + + score = L.reshape(score, [-1, n_head, 1]) + _v = msg["value"] + _new_v = L.reshape(_v, [-1, n_head, _v.shape[-1] // n_head]) + + _new_v = _new_v * score + + _new_v = L.reshape(_new_v, [-1, _v.shape[-1]]) + _new_v = L.lod_reset(_new_v, _v) + return L.sequence_pool(_new_v, "sum") + + graph_wrapper = BigBirdWrapper(input_mask) + old_v = v + + q = L.reshape(q, [-1, d_key * n_head]) + k = L.reshape(k, [-1, d_key * n_head]) + v = L.reshape(v, [-1, d_value * n_head]) + + q = L.scale(q, scale=d_key ** -0.5) + msg = graph_wrapper.send(send_q_k_spmm, nfeat_list=[("k", k), ("v", v), ("q", q)]) + out = graph_wrapper.recv(msg, recv_score_v_spmm) + out = L.reshape(out, [-1, L.shape(old_v)[1], d_value * n_head]) + return out, out + + + diff --git a/examples/xformer/transformer_encoder_sparse.py b/examples/xformer/transformer_encoder_sparse.py new file mode 100644 index 0000000..575a9b4 --- /dev/null +++ b/examples/xformer/transformer_encoder_sparse.py @@ -0,0 +1,361 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from functools import partial + +import paddle.fluid as fluid +import paddle.fluid.layers as L +import paddle.fluid.layers as layers + + +from .sparse_scaled_dot_product_attention import sparse_scaled_dot_product_attention + +to_3d = lambda a: a # will change later +to_2d = lambda a: a + + +def multi_head_attention(queries, + keys, + values, + attn_bias, + d_key, + d_value, + d_model, + input_mask, + n_head=1, + dropout_rate=0., + cache=None, + param_initializer=None, + name='multi_head_att'): + """ + Multi-Head Attention. Note that attn_bias is added to the logit before + computing softmax activiation to mask certain selected positions so that + they will not considered in attention weights. + """ + keys = queries if keys is None else keys + values = keys if values is None else values + + def __compute_qkv(queries, keys, values, n_head, d_key, d_value): + """ + Add linear projection to queries, keys, and values. + """ + q = layers.fc(input=queries, + size=d_key * n_head, + num_flatten_dims=len(queries.shape) - 1, + param_attr=fluid.ParamAttr( + name=name + '_query_fc.w_0', + initializer=param_initializer), + bias_attr=name + '_query_fc.b_0') + k = layers.fc(input=keys, + size=d_key * n_head, + num_flatten_dims=len(keys.shape) - 1, + param_attr=fluid.ParamAttr( + name=name + '_key_fc.w_0', + initializer=param_initializer), + bias_attr=name + '_key_fc.b_0') + v = layers.fc(input=values, + size=d_value * n_head, + num_flatten_dims=len(values.shape) - 1, + param_attr=fluid.ParamAttr( + name=name + '_value_fc.w_0', + initializer=param_initializer), + bias_attr=name + '_value_fc.b_0') + return q, k, v + + def __split_heads(x, n_head): + """ + Reshape the last dimension of inpunt tensor x so that it becomes two + dimensions and then transpose. Specifically, input a tensor with shape + [bs, max_sequence_length, n_head * hidden_dim] then output a tensor + with shape [bs, n_head, max_sequence_length, hidden_dim]. + """ + hidden_size = x.shape[-1] + # The value 0 in shape attr means copying the corresponding dimension + # size of the input as the output dimension size. + reshaped = layers.reshape( + x=x, shape=[0, 0, n_head, hidden_size // n_head], inplace=True) + + # permuate the dimensions into: + # [batch_size, n_head, max_sequence_len, hidden_size_per_head] + return layers.transpose(x=reshaped, perm=[0, 2, 1, 3]) + + def __combine_heads(x): + """ + Transpose and then reshape the last two dimensions of inpunt tensor x + so that it becomes one dimension, which is reverse to __split_heads. + """ + if len(x.shape) == 3: return x + if len(x.shape) != 4: + raise ValueError("Input(x) should be a 4-D Tensor.") + trans_x = layers.transpose(x, perm=[0, 2, 1, 3]) + # The value 0 in shape attr means copying the corresponding dimension + # size of the input as the output dimension size. + #trans_x.desc.set_shape((-1, 1, n_head, d_value)) + return layers.reshape(x=trans_x, shape=[0, 0, d_model], inplace=True) + + q, k, v = __compute_qkv(queries, keys, values, n_head, d_key, d_value) + q = to_3d(q) + k = to_3d(k) + v = to_3d(v) + + if cache is not None: # use cache and concat time steps + # Since the inplace reshape in __split_heads changes the shape of k and + # v, which is the cache input for next time step, reshape the cache + # input from the previous time step first. + k = cache["k"] = layers.concat( + [layers.reshape( + cache["k"], shape=[0, 0, d_model]), k], axis=1) + v = cache["v"] = layers.concat( + [layers.reshape( + cache["v"], shape=[0, 0, d_model]), v], axis=1) + + out, _ = sparse_scaled_dot_product_attention(q, k, v, + input_mask, dropout_rate, n_head, d_key, d_value) + + out = to_2d(out) + + # Project back to the model size. + proj_out = layers.fc(input=out, + size=d_model, + num_flatten_dims=len(out.shape) - 1, + param_attr=fluid.ParamAttr( + name=name + '_output_fc.w_0', + initializer=param_initializer), + bias_attr=name + '_output_fc.b_0') + return proj_out, _ + + +def positionwise_feed_forward(x, + d_inner_hid, + d_hid, + dropout_rate, + hidden_act, + param_initializer=None, + name='ffn'): + """ + Position-wise Feed-Forward Networks. + This module consists of two linear transformations with a ReLU activation + in between, which is applied to each position separately and identically. + """ + hidden = layers.fc(input=x, + size=d_inner_hid, + num_flatten_dims=len(x.shape) - 1, + act=hidden_act, + param_attr=fluid.ParamAttr( + name=name + '_fc_0.w_0', + initializer=param_initializer), + bias_attr=name + '_fc_0.b_0') + if dropout_rate: + hidden = layers.dropout( + hidden, + dropout_prob=dropout_rate, + dropout_implementation="upscale_in_train", + is_test=False) + out = layers.fc(input=hidden, + size=d_hid, + num_flatten_dims=len(hidden.shape) - 1, + param_attr=fluid.ParamAttr( + name=name + '_fc_1.w_0', + initializer=param_initializer), + bias_attr=name + '_fc_1.b_0') + return out + + +def pre_post_process_layer(prev_out, + out, + process_cmd, + dropout_rate=0., + name=''): + """ + Add residual connection, layer normalization and droput to the out tensor + optionally according to the value of process_cmd. + This will be used before or after multi-head attention and position-wise + feed-forward networks. + """ + for cmd in process_cmd: + if cmd == "a": # add residual connection + out = out + prev_out if prev_out else out + elif cmd == "n": # add layer normalization + out_dtype = out.dtype + if out_dtype == fluid.core.VarDesc.VarType.FP16: + out = layers.cast(x=out, dtype="float32") + out = layers.layer_norm( + out, + begin_norm_axis=len(out.shape) - 1, + param_attr=fluid.ParamAttr( + name=name + '_layer_norm_scale', + initializer=fluid.initializer.Constant(1.)), + bias_attr=fluid.ParamAttr( + name=name + '_layer_norm_bias', + initializer=fluid.initializer.Constant(0.))) + if out_dtype == fluid.core.VarDesc.VarType.FP16: + out = layers.cast(x=out, dtype="float16") + elif cmd == "d": # add dropout + if dropout_rate: + out = layers.dropout( + out, + dropout_prob=dropout_rate, + dropout_implementation="upscale_in_train", + is_test=False) + return out + + +pre_process_layer = partial(pre_post_process_layer, None) +post_process_layer = pre_post_process_layer + + +def encoder_layer(enc_input, + input_mask, + attn_bias, + n_head, + d_key, + d_value, + d_model, + d_inner_hid, + prepostprocess_dropout, + attention_dropout, + relu_dropout, + hidden_act, + preprocess_cmd="n", + postprocess_cmd="da", + param_initializer=None, + name=''): + """The encoder layers that can be stacked to form a deep encoder. + This module consits of a multi-head (self) attention followed by + position-wise feed-forward networks and both the two components companied + with the post_process_layer to add residual connection, layer normalization + and droput. + """ + attn_output, ctx_multiheads_attn = multi_head_attention( + pre_process_layer( + enc_input, + preprocess_cmd, + prepostprocess_dropout, + name=name + '_pre_att'), + None, + None, + attn_bias, + d_key, + d_value, + d_model, + input_mask, + n_head, + attention_dropout, + param_initializer=param_initializer, + name=name + '_multi_head_att') + attn_output = post_process_layer( + enc_input, + attn_output, + postprocess_cmd, + prepostprocess_dropout, + name=name + '_post_att') + + ffd_output = positionwise_feed_forward( + pre_process_layer( + attn_output, + preprocess_cmd, + prepostprocess_dropout, + name=name + '_pre_ffn'), + d_inner_hid, + d_model, + relu_dropout, + hidden_act, + param_initializer=param_initializer, + name=name + '_ffn') + + ret = post_process_layer( + attn_output, + ffd_output, + postprocess_cmd, + prepostprocess_dropout, + name=name + '_post_ffn') + + return ret, ctx_multiheads_attn, ffd_output + + +def build_pad_idx(input_mask): + pad_idx = L.where(L.cast(L.squeeze(input_mask, [2]), 'bool')) + return pad_idx + + +def build_attn_bias(input_mask, n_head, dtype): + attn_bias = L.matmul( + input_mask, input_mask, transpose_y=True) # [batch, seq, seq] + attn_bias = (1. - attn_bias) * -10000. + attn_bias = L.stack([attn_bias] * n_head, 1) + if attn_bias.dtype != dtype: + attn_bias = L.cast(attn_bias, dtype) + return attn_bias + + +def encoder(enc_input, + input_mask, + n_layer, + n_head, + d_key, + d_value, + d_model, + d_inner_hid, + prepostprocess_dropout, + attention_dropout, + relu_dropout, + hidden_act, + preprocess_cmd="n", + postprocess_cmd="da", + param_initializer=None, + name=''): + """ + The encoder is composed of a stack of identical layers returned by calling + encoder_layer. + """ + + d_shape = L.shape(input_mask) + pad_idx = build_pad_idx(input_mask) + attn_bias = build_attn_bias(input_mask, n_head, enc_input.dtype) + + enc_input = to_2d(enc_input) + all_hidden = [] + all_attn = [] + all_ffn = [] + for i in range(n_layer): + enc_output, ctx_multiheads_attn, ffn_output = encoder_layer( + enc_input, + input_mask, + attn_bias, + n_head, + d_key, + d_value, + d_model, + d_inner_hid, + prepostprocess_dropout, + attention_dropout, + relu_dropout, + hidden_act, + preprocess_cmd, + postprocess_cmd, + param_initializer=param_initializer, + name=name + '_layer_' + str(i)) + all_hidden.append(enc_output) + all_attn.append(ctx_multiheads_attn) + all_ffn.append(ffn_output) + enc_input = enc_output + enc_output = pre_process_layer( + enc_output, + preprocess_cmd, + prepostprocess_dropout, + name="post_encoder") + enc_output = to_3d(enc_output) + return enc_output, all_hidden, all_attn, all_ffn + + -- GitLab