提交 a9d30244 编写于 作者: L liujiaxiang

add big bird transformer based on PGL

上级 c0f98318
# X-Transformer
Models based on Transformers are wildly successful for a wide variety of Natural Language Processing (NLP) tasks and consequently are a mainstay of modern NLP research. Transformer is constituted of a self-attention and a feed-forward module. The self-attention mechanism allows each token in the input sequence to attend independently to every other token in the sequence. From the view of graph representation, the generalized attention mechanism can be described by a directed graph whose vertex is the token. So, the attention module can be implemented by a graph library, especially recently the efficient attention implementation, e.g. [BigBird](https://arxiv.org/abs/2007.14062) \ [LongFormer](https://arxiv.org/abs/2004.05150) \ [Sparse Transformer](https://arxiv.org/abs/1904.10509).
We have showcased the [BigBird](https://arxiv.org/abs/2007.14062) implementation and tested the performence as show below, and the [LongFormer](https://arxiv.org/abs/2004.05150) \ [Sparse Transformer](https://arxiv.org/abs/1904.10509) can be easily implemented by revised the correspoding code.
## Dependencies
- [paddlepaddle >= 1.7](https://github.com/PaddlePaddle/paddle)
- [pgl 1.1](https://github.com/PaddlePaddle/PGL)
## Performance
We have evaluate the implemented method on a summarization dataset CNN/DM. The experiment was conducted on two P40 GPU cards.
| CNN/DM | BatchSize | R1 | R2 | R3 | speed(steps/s) |
| ------------------ | --------- | ----------------- | ----------------- | ----------------- | ------ |
| LEAD | - | 40.42 | 17.62 | 36.67 | - |
| Oracle | - | 52.59 | 31.24 | 48.87 | - |
| non-sparse, L=512 | 32 | 42.175 | 19.392 | 38.613 | 0.6359 |
| L=2048 | 10 | 41.334 | 18.369 | 37.752 | 0.8246 |
| L=1024 | 20 | 41.453 | 18.529 | 37.872 | 0.6432 |
| L=768 | 26 | 41.611 | 18.735 | 38.051 | 0.6517 |
| L=512 | 40 | 41.742 | 18.733 | 38.127 | 0.6213 |
#!/usr/bin/env python
# -*- coding: utf-8 -*-
########################################################################
# #
# Copyright (c) 2019 Baidu.com, Inc. All Rights Reserved #
# #
########################################################################
import paddle.fluid as fluid
import paddle.fluid.layers as L
import paddle.fluid.layers as layers
from pgl.utils import paddle_helper
import pgl
def masked_select(input, mask):
"""masked_select
Slice the value from given Mask
Args:
input: Input tensor to be selected
mask: A bool tensor for sliced.
Return:
Part of inputs where mask is True.
"""
index = L.where(mask)
return L.gather(input, index, overwrite=False)
class BigBirdWrapper(pgl.graph_wrapper.BaseGraphWrapper):
"""Implement of Edge Drop """
def __init__(self, input_mask):
super(BigBirdWrapper, self).__init__()
max_seqlen = L.shape(input_mask)[1]
input_mask = L.reshape(input_mask, [-1])
num_nodes = L.shape(input_mask)[0]
src, dst = build_edges(num_nodes, input_mask, max_seqlen)
self._edges_src = src
self._edges_dst = dst
self._edges_src.stop_gradient=True
self._edges_dst.stop_gradient=True
self._num_nodes = num_nodes
self._num_edges = L.shape(self._edges_src)[0]
self._node_ids = L.range(0, self._num_nodes, step=1, dtype="int32")
self._edge_uniq_dst, _, uniq_count = L.unique_with_counts(self._edges_dst, dtype="int32")
self._edge_uniq_dst.stop_gradient=True
last = L.reduce_sum(uniq_count, keep_dim=True)
uniq_count = L.cumsum(uniq_count, exclusive=True)
self._edge_uniq_dst_count = L.concat([uniq_count, last])
self._edge_uniq_dst_count.stop_gradient=True
def select_edges(src, dst, input_mask, num_nodes, max_seqlen):
src = fluid.layers.elementwise_max(src, num_nodes * 0)
dst = fluid.layers.elementwise_max(dst, num_nodes * 0)
src = fluid.layers.elementwise_min(src, num_nodes - 1)
dst = fluid.layers.elementwise_min(dst, num_nodes - 1)
conditions = []
conditions.append(L.gather(input_mask, src) > 0.5)
conditions.append(L.gather(input_mask, dst) > 0.5)
block_src = src / max_seqlen
block_dst = dst / max_seqlen
conditions.append(block_src == block_dst)
mask = None
for cond in conditions:
if mask is None:
mask = cond
else:
mask = L.logical_and(mask, cond)
dst = masked_select(dst, mask)
src = masked_select(src, mask)
return src, dst
def uniq_edges(src, dst, num_nodes):
sorted_dst = L.cast(dst, dtype="int64")
sorted_src = L.cast(src, dtype="int64")
num_nodes = L.cast(num_nodes, dtype="int64")
edge_hash = sorted_dst * num_nodes + sorted_src
edge_hash, _ = L.argsort(edge_hash)
edge_hash, _ = L.unique(edge_hash, dtype="int64")
sorted_src = L.elementwise_mod(edge_hash, num_nodes)
sorted_dst = L.elementwise_div(edge_hash, num_nodes)
sorted_src = L.cast(sorted_src, dtype="int32")
sorted_dst = L.cast(sorted_dst, dtype="int32")
return sorted_src, sorted_dst
#def build_edges(num_nodes, input_mask, max_seqlen):
# edges = L.range(start=0, end=num_nodes, step=1, dtype="int32")
# all_edges = []
# # Window
# filter_func = lambda x, y: select_edges(x, y, input_mask, num_nodes, max_seqlen)
#
# all_edges.append(filter_func(edges - 1, edges)) # win-1
# all_edges.append(filter_func(edges + 1, edges)) # win-2
# all_edges.append(filter_func(edges, edges)) #self-loop
#
# # Global Assume [CLS] is the first token.
# cls_position = edges / max_seqlen * max_seqlen
# all_edges.append(filter_func(cls_position, edges))
# all_edges.append(filter_func(edges, cls_position))
#
# # Random
# for i in range(2):
# rand_edge = L.floor(L.uniform_random(min=0, max=1, shape=[num_nodes]) * L.cast(max_seqlen, dtype="float32"))
# rand_edge = L.cast(rand_edge, dtype="int32") + cls_position
# all_edges.append(filter_func(rand_edge, edges))
#
# if len(all_edges) > 1:
# src = L.concat([ s for s, d in all_edges], 0)
# dst = L.concat([ d for s, d in all_edges], 0)
# else:
# src = all_edges[0][0]
# dst = all_edges[0][1]
#
# # sort edges
# sorted_src, sorted_dst = uniq_edges(src, dst, num_nodes)
# return sorted_src, sorted_dst
def build_edges(num_nodes, input_mask, max_seqlen):
edges = L.range(start=0, end=num_nodes, step=1, dtype="int32")
all_edges = []
# Window
filter_func = lambda x, y: select_edges(x, y, input_mask, num_nodes, max_seqlen)
all_edges.append(filter_func(edges - 1, edges)) # win-1
#all_edges.append(filter_func(edges - 2, edges)) # win-1
#all_edges.append(filter_func(edges - 3, edges)) # win-1
all_edges.append(filter_func(edges + 1, edges)) # win-2
#all_edges.append(filter_func(edges + 2, edges)) # win-2
#all_edges.append(filter_func(edges + 3, edges)) # win-2
all_edges.append(filter_func(edges, edges)) #self-loop
# Global Assume [CLS] is the first token.
# vertical cls-window attention
cls_position = edges / max_seqlen * max_seqlen
#all_edges.append(filter_func(cls_position + 1, edges))
all_edges.append(filter_func(cls_position, edges))
# vertical sliding attention
#all_edges.append(filter_func(cls_position + 6, edges))
#all_edges.append(filter_func(cls_position + max_seqlen - 6, edges))
# horizontal cls attention
all_edges.append(filter_func(edges, cls_position))
#all_edges.append(filter_func(edges, cls_position))
# horizontal sliding attention
#all_edges.append(filter_func(edges, cls_position + 6)
#all_edges.append(filter_func(edges, cls_position + max_seq_len - 6)
# Random
#for i in range(2):
for i in range(2):
rand_edge = L.floor(L.uniform_random(min=0, max=1, shape=[num_nodes]) * L.cast(max_seqlen, dtype="float32"))
rand_edge = L.cast(rand_edge, dtype="int32") + cls_position
all_edges.append(filter_func(rand_edge, edges))
if len(all_edges) > 1:
src = L.concat([ s for s, d in all_edges], 0)
dst = L.concat([ d for s, d in all_edges], 0)
else:
src = all_edges[0][0]
dst = all_edges[0][1]
# sort edges
sorted_src, sorted_dst = uniq_edges(src, dst, num_nodes)
return sorted_src, sorted_dst
def sparse_scaled_dot_product_attention(q, k, v, input_mask, dropout_rate, n_head, d_key, d_value):
def send_q_k_spmm(src_feat, dst_feat, edge_feat):
# q [ num_edges, n_head * dim]
# k [ num_edges, n_head * dim]
# v [ num_edges, n_head * dim]
_q = dst_feat["q"]
_k = src_feat["k"]
_v = src_feat["v"]
_q = L.reshape(_q, [-1, n_head, _q.shape[-1] // n_head])
_k = L.reshape(_k, [-1, n_head, _k.shape[-1] // n_head])
score = L.reduce_sum(_q * _k, -1) # [num_edge, n_head]
return { "score": score, "value": _v}
def recv_score_v_spmm(msg):
score = msg["score"]
score = paddle_helper.sequence_softmax(score)
score = layers.dropout(
score,
dropout_prob=dropout_rate,
dropout_implementation="upscale_in_train",
is_test=False)
score = L.reshape(score, [-1, n_head, 1])
_v = msg["value"]
_new_v = L.reshape(_v, [-1, n_head, _v.shape[-1] // n_head])
_new_v = _new_v * score
_new_v = L.reshape(_new_v, [-1, _v.shape[-1]])
_new_v = L.lod_reset(_new_v, _v)
return L.sequence_pool(_new_v, "sum")
graph_wrapper = BigBirdWrapper(input_mask)
old_v = v
q = L.reshape(q, [-1, d_key * n_head])
k = L.reshape(k, [-1, d_key * n_head])
v = L.reshape(v, [-1, d_value * n_head])
q = L.scale(q, scale=d_key ** -0.5)
msg = graph_wrapper.send(send_q_k_spmm, nfeat_list=[("k", k), ("v", v), ("q", q)])
out = graph_wrapper.recv(msg, recv_score_v_spmm)
out = L.reshape(out, [-1, L.shape(old_v)[1], d_value * n_head])
return out, out
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial
import paddle.fluid as fluid
import paddle.fluid.layers as L
import paddle.fluid.layers as layers
from .sparse_scaled_dot_product_attention import sparse_scaled_dot_product_attention
to_3d = lambda a: a # will change later
to_2d = lambda a: a
def multi_head_attention(queries,
keys,
values,
attn_bias,
d_key,
d_value,
d_model,
input_mask,
n_head=1,
dropout_rate=0.,
cache=None,
param_initializer=None,
name='multi_head_att'):
"""
Multi-Head Attention. Note that attn_bias is added to the logit before
computing softmax activiation to mask certain selected positions so that
they will not considered in attention weights.
"""
keys = queries if keys is None else keys
values = keys if values is None else values
def __compute_qkv(queries, keys, values, n_head, d_key, d_value):
"""
Add linear projection to queries, keys, and values.
"""
q = layers.fc(input=queries,
size=d_key * n_head,
num_flatten_dims=len(queries.shape) - 1,
param_attr=fluid.ParamAttr(
name=name + '_query_fc.w_0',
initializer=param_initializer),
bias_attr=name + '_query_fc.b_0')
k = layers.fc(input=keys,
size=d_key * n_head,
num_flatten_dims=len(keys.shape) - 1,
param_attr=fluid.ParamAttr(
name=name + '_key_fc.w_0',
initializer=param_initializer),
bias_attr=name + '_key_fc.b_0')
v = layers.fc(input=values,
size=d_value * n_head,
num_flatten_dims=len(values.shape) - 1,
param_attr=fluid.ParamAttr(
name=name + '_value_fc.w_0',
initializer=param_initializer),
bias_attr=name + '_value_fc.b_0')
return q, k, v
def __split_heads(x, n_head):
"""
Reshape the last dimension of inpunt tensor x so that it becomes two
dimensions and then transpose. Specifically, input a tensor with shape
[bs, max_sequence_length, n_head * hidden_dim] then output a tensor
with shape [bs, n_head, max_sequence_length, hidden_dim].
"""
hidden_size = x.shape[-1]
# The value 0 in shape attr means copying the corresponding dimension
# size of the input as the output dimension size.
reshaped = layers.reshape(
x=x, shape=[0, 0, n_head, hidden_size // n_head], inplace=True)
# permuate the dimensions into:
# [batch_size, n_head, max_sequence_len, hidden_size_per_head]
return layers.transpose(x=reshaped, perm=[0, 2, 1, 3])
def __combine_heads(x):
"""
Transpose and then reshape the last two dimensions of inpunt tensor x
so that it becomes one dimension, which is reverse to __split_heads.
"""
if len(x.shape) == 3: return x
if len(x.shape) != 4:
raise ValueError("Input(x) should be a 4-D Tensor.")
trans_x = layers.transpose(x, perm=[0, 2, 1, 3])
# The value 0 in shape attr means copying the corresponding dimension
# size of the input as the output dimension size.
#trans_x.desc.set_shape((-1, 1, n_head, d_value))
return layers.reshape(x=trans_x, shape=[0, 0, d_model], inplace=True)
q, k, v = __compute_qkv(queries, keys, values, n_head, d_key, d_value)
q = to_3d(q)
k = to_3d(k)
v = to_3d(v)
if cache is not None: # use cache and concat time steps
# Since the inplace reshape in __split_heads changes the shape of k and
# v, which is the cache input for next time step, reshape the cache
# input from the previous time step first.
k = cache["k"] = layers.concat(
[layers.reshape(
cache["k"], shape=[0, 0, d_model]), k], axis=1)
v = cache["v"] = layers.concat(
[layers.reshape(
cache["v"], shape=[0, 0, d_model]), v], axis=1)
out, _ = sparse_scaled_dot_product_attention(q, k, v,
input_mask, dropout_rate, n_head, d_key, d_value)
out = to_2d(out)
# Project back to the model size.
proj_out = layers.fc(input=out,
size=d_model,
num_flatten_dims=len(out.shape) - 1,
param_attr=fluid.ParamAttr(
name=name + '_output_fc.w_0',
initializer=param_initializer),
bias_attr=name + '_output_fc.b_0')
return proj_out, _
def positionwise_feed_forward(x,
d_inner_hid,
d_hid,
dropout_rate,
hidden_act,
param_initializer=None,
name='ffn'):
"""
Position-wise Feed-Forward Networks.
This module consists of two linear transformations with a ReLU activation
in between, which is applied to each position separately and identically.
"""
hidden = layers.fc(input=x,
size=d_inner_hid,
num_flatten_dims=len(x.shape) - 1,
act=hidden_act,
param_attr=fluid.ParamAttr(
name=name + '_fc_0.w_0',
initializer=param_initializer),
bias_attr=name + '_fc_0.b_0')
if dropout_rate:
hidden = layers.dropout(
hidden,
dropout_prob=dropout_rate,
dropout_implementation="upscale_in_train",
is_test=False)
out = layers.fc(input=hidden,
size=d_hid,
num_flatten_dims=len(hidden.shape) - 1,
param_attr=fluid.ParamAttr(
name=name + '_fc_1.w_0',
initializer=param_initializer),
bias_attr=name + '_fc_1.b_0')
return out
def pre_post_process_layer(prev_out,
out,
process_cmd,
dropout_rate=0.,
name=''):
"""
Add residual connection, layer normalization and droput to the out tensor
optionally according to the value of process_cmd.
This will be used before or after multi-head attention and position-wise
feed-forward networks.
"""
for cmd in process_cmd:
if cmd == "a": # add residual connection
out = out + prev_out if prev_out else out
elif cmd == "n": # add layer normalization
out_dtype = out.dtype
if out_dtype == fluid.core.VarDesc.VarType.FP16:
out = layers.cast(x=out, dtype="float32")
out = layers.layer_norm(
out,
begin_norm_axis=len(out.shape) - 1,
param_attr=fluid.ParamAttr(
name=name + '_layer_norm_scale',
initializer=fluid.initializer.Constant(1.)),
bias_attr=fluid.ParamAttr(
name=name + '_layer_norm_bias',
initializer=fluid.initializer.Constant(0.)))
if out_dtype == fluid.core.VarDesc.VarType.FP16:
out = layers.cast(x=out, dtype="float16")
elif cmd == "d": # add dropout
if dropout_rate:
out = layers.dropout(
out,
dropout_prob=dropout_rate,
dropout_implementation="upscale_in_train",
is_test=False)
return out
pre_process_layer = partial(pre_post_process_layer, None)
post_process_layer = pre_post_process_layer
def encoder_layer(enc_input,
input_mask,
attn_bias,
n_head,
d_key,
d_value,
d_model,
d_inner_hid,
prepostprocess_dropout,
attention_dropout,
relu_dropout,
hidden_act,
preprocess_cmd="n",
postprocess_cmd="da",
param_initializer=None,
name=''):
"""The encoder layers that can be stacked to form a deep encoder.
This module consits of a multi-head (self) attention followed by
position-wise feed-forward networks and both the two components companied
with the post_process_layer to add residual connection, layer normalization
and droput.
"""
attn_output, ctx_multiheads_attn = multi_head_attention(
pre_process_layer(
enc_input,
preprocess_cmd,
prepostprocess_dropout,
name=name + '_pre_att'),
None,
None,
attn_bias,
d_key,
d_value,
d_model,
input_mask,
n_head,
attention_dropout,
param_initializer=param_initializer,
name=name + '_multi_head_att')
attn_output = post_process_layer(
enc_input,
attn_output,
postprocess_cmd,
prepostprocess_dropout,
name=name + '_post_att')
ffd_output = positionwise_feed_forward(
pre_process_layer(
attn_output,
preprocess_cmd,
prepostprocess_dropout,
name=name + '_pre_ffn'),
d_inner_hid,
d_model,
relu_dropout,
hidden_act,
param_initializer=param_initializer,
name=name + '_ffn')
ret = post_process_layer(
attn_output,
ffd_output,
postprocess_cmd,
prepostprocess_dropout,
name=name + '_post_ffn')
return ret, ctx_multiheads_attn, ffd_output
def build_pad_idx(input_mask):
pad_idx = L.where(L.cast(L.squeeze(input_mask, [2]), 'bool'))
return pad_idx
def build_attn_bias(input_mask, n_head, dtype):
attn_bias = L.matmul(
input_mask, input_mask, transpose_y=True) # [batch, seq, seq]
attn_bias = (1. - attn_bias) * -10000.
attn_bias = L.stack([attn_bias] * n_head, 1)
if attn_bias.dtype != dtype:
attn_bias = L.cast(attn_bias, dtype)
return attn_bias
def encoder(enc_input,
input_mask,
n_layer,
n_head,
d_key,
d_value,
d_model,
d_inner_hid,
prepostprocess_dropout,
attention_dropout,
relu_dropout,
hidden_act,
preprocess_cmd="n",
postprocess_cmd="da",
param_initializer=None,
name=''):
"""
The encoder is composed of a stack of identical layers returned by calling
encoder_layer.
"""
d_shape = L.shape(input_mask)
pad_idx = build_pad_idx(input_mask)
attn_bias = build_attn_bias(input_mask, n_head, enc_input.dtype)
enc_input = to_2d(enc_input)
all_hidden = []
all_attn = []
all_ffn = []
for i in range(n_layer):
enc_output, ctx_multiheads_attn, ffn_output = encoder_layer(
enc_input,
input_mask,
attn_bias,
n_head,
d_key,
d_value,
d_model,
d_inner_hid,
prepostprocess_dropout,
attention_dropout,
relu_dropout,
hidden_act,
preprocess_cmd,
postprocess_cmd,
param_initializer=param_initializer,
name=name + '_layer_' + str(i))
all_hidden.append(enc_output)
all_attn.append(ctx_multiheads_attn)
all_ffn.append(ffn_output)
enc_input = enc_output
enc_output = pre_process_layer(
enc_output,
preprocess_cmd,
prepostprocess_dropout,
name="post_encoder")
enc_output = to_3d(enc_output)
return enc_output, all_hidden, all_attn, all_ffn
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册