未验证 提交 db870872 编写于 作者: X xiemoyuan 提交者: GitHub

Optimize the encoder of Transformer. (#30439)

* Add cache for Transformer encoder.

* Bug fixed.

* add unittests for transformer encoder.
上级 cb66c53c
......@@ -318,6 +318,61 @@ class TestTransformer(unittest.TestCase):
np.testing.assert_allclose(
encoder_output.numpy(), src, rtol=1e-5, atol=1e-6)
def test_transformer_encoder_layer_attr_1(self):
with fluid.dygraph.guard(fluid.CPUPlace()):
paddle.framework.seed(2020)
paddle.framework.random._manual_program_seed(2020)
ffn_fc1_act = "relu"
# 1.generate basic params
batch_size, d_model, n_head, dim_feedforward, dropout, attn_dropout, act_dropout, sequence_length = generate_basic_params(
mode="encoder_layer")
# 2.generate input for encoder
src = np.random.rand(batch_size, sequence_length,
d_model).astype("float32")
src_mask = np.zeros((batch_size, n_head, sequence_length,
sequence_length)).astype("float32")
src_mask[0][0][0][0] = -np.inf
for cache in [True, False]:
# paddle
encoder_layer = TransformerEncoderLayer(
d_model, n_head, dim_feedforward, dropout, ffn_fc1_act,
attn_dropout, act_dropout)
cache_objs = None
if cache:
cache_objs = encoder_layer.gen_cache(paddle.to_tensor(src))
encoder_output = encoder_layer(
paddle.to_tensor(src),
paddle.to_tensor(src_mask), cache_objs)
encoder_output = encoder_output[0].numpy(
) if cache else encoder_output.numpy()
# 4.numpy:
residual = src
# paddle self attention
self_attn = MultiHeadAttention(
d_model, n_head, dropout=attn_dropout)
attn_output = self_attn(
paddle.to_tensor(src),
paddle.to_tensor(src),
paddle.to_tensor(src),
paddle.to_tensor(src_mask), cache_objs)
attn_output = attn_output[0].numpy(
) if cache else attn_output.numpy()
src = attn_output + residual
src_norm = layer_norm(src, d_model, encoder_layer.norm1)
residual = src_norm
ffn_output = ffn(src_norm, encoder_layer, ffn_fc1_act)
src = residual + ffn_output
src = layer_norm(src, d_model, encoder_layer.norm2)
np.testing.assert_allclose(
encoder_output, src, rtol=1e-5, atol=1e-6)
def test_transformer_decoder_layer(self):
with fluid.dygraph.guard(fluid.CPUPlace()):
paddle.framework.seed(2020)
......@@ -418,6 +473,32 @@ class TestTransformer(unittest.TestCase):
enc_output = encoder(
paddle.to_tensor(src), paddle.to_tensor(src_mask))
def test_encoder_attr_1(self):
batch_size, d_model, n_head, dim_feedforward, dropout, attn_dropout, act_dropout, sequence_length = generate_basic_params(
mode="encoder_layer")
src = np.random.rand(batch_size, sequence_length,
d_model).astype("float32")
src_mask = np.zeros((batch_size, n_head, sequence_length,
sequence_length)).astype("float32")
src_mask[0][0][0][0] = -np.inf
with fluid.dygraph.guard(fluid.CPUPlace()):
for cache in [True, False]:
# paddle
encoder_layer = TransformerEncoderLayer(
d_model, n_head, dim_feedforward, dropout)
num_layers = 6
encoder = TransformerEncoder(encoder_layer, num_layers)
cache_objs = None
if cache:
cache_objs = encoder.gen_cache(paddle.to_tensor(src))
# src, src_mask
enc_output = encoder(
paddle.to_tensor(src),
paddle.to_tensor(src_mask), cache_objs)
def test_decoder(self):
batch_size, d_model, n_head, dim_feedforward, dropout, _, _, source_length, target_length = generate_basic_params(
mode="decoder_layer")
......
......@@ -311,7 +311,7 @@ class MultiHeadAttention(Layer):
# incremental_state with initial value, mainly for usage like UniLM
return self.Cache(key, value)
def forward(self, query, key, value, attn_mask=None, cache=None):
def forward(self, query, key=None, value=None, attn_mask=None, cache=None):
r"""
Applies multi-head attention to map queries and a set of key-value pairs
to outputs.
......@@ -498,7 +498,7 @@ class TransformerEncoderLayer(Layer):
self.dropout2 = Dropout(dropout, mode="upscale_in_train")
self.activation = getattr(F, activation)
def forward(self, src, src_mask=None):
def forward(self, src, src_mask=None, cache=None):
r"""
Applies a Transformer encoder layer on the input.
......@@ -514,16 +514,30 @@ class TransformerEncoderLayer(Layer):
have 0 values. The data type should be float32 or float64. It can
be None when nothing wanted or needed to be prevented attention to.
Default None
cache (Tensor, optional): It is an instance of `MultiHeadAttention.Cache`.
See `TransformerEncoderLayer.gen_cache` for more details. It is
only used for inference and should be None for training. Default
None.
Returns:
Tensor: The output of Transformer encoder layer. It is a tensor that \
has the same shape and data type as `enc_input`.
Tensor|tuple: It is a tensor that has the same shape and data type \
as `enc_input`, representing the output of Transformer encoder \
layer. Or a tuple if `cache` is not None, except for encoder \
layer output, the tuple includes the new cache which is same \
as input `cache` argument but `incremental_cache` has an \
incremental length. See `MultiHeadAttention.gen_cache` and \
`MultiHeadAttention.forward` for more details.
"""
residual = src
if self.normalize_before:
src = self.norm1(src)
# TODO(guosheng): Add cache for encoder for the usage like UniLM
src = self.self_attn(src, src, src, src_mask)
if cache is None:
src = self.self_attn(src, src, src, src_mask)
else:
src, incremental_cache = self.self_attn(src, src, src, src_mask,
cache)
src = residual + self.dropout1(src)
if not self.normalize_before:
src = self.norm1(src)
......@@ -535,7 +549,28 @@ class TransformerEncoderLayer(Layer):
src = residual + self.dropout2(src)
if not self.normalize_before:
src = self.norm2(src)
return src
return src if cache is None else (src, incremental_cache)
def gen_cache(self, src):
r"""
Generates cache for `forward` usage. The generated cache is an
instance of `MultiHeadAttention.Cache`.
Parameters:
src (Tensor): The input of Transformer encoder. It is a tensor
with shape `[batch_size, source_length, d_model]`. The data
type should be float32 or float64.
Returns:
incremental_cache: It is an instance of `MultiHeadAttention.Cache` \
produced by `self_attn.gen_cache`, it reserves two tensors
shaped `[batch_size, nhead, 0, d_model // nhead]`. See \
`MultiHeadAttention.gen_cache` and `MultiHeadAttention.forward` \
for more details.
"""
incremental_cache = self.self_attn.gen_cache(
src, type=self.self_attn.Cache)
return incremental_cache
class TransformerEncoder(Layer):
......@@ -574,7 +609,7 @@ class TransformerEncoder(Layer):
self.num_layers = num_layers
self.norm = norm
def forward(self, src, src_mask=None):
def forward(self, src, src_mask=None, cache=None):
r"""
Applies a stack of N Transformer encoder layers on inputs. If `norm` is
provided, also applies layer normalization on the output of last encoder
......@@ -592,20 +627,55 @@ class TransformerEncoder(Layer):
have 0 values. The data type should be float32 or float64. It can
be None when nothing wanted or needed to be prevented attention to.
Default None
cache (list, optional): It is a list, and each element in the list
is `incremental_cache` produced by `TransformerEncoderLayer.gen_cache`.
See `TransformerEncoder.gen_cache` for more details. It is only
used for inference and should be None for training. Default None.
Returns:
Tensor: The output of Transformer encoder. It is a tensor that \
has the same shape and data type as `src`.
Tensor|tuple: It is a tensor that has the same shape and data type \
as `src`, representing the output of Transformer encoder. \
Or a tuple if `cache` is not None, except for encoder output, \
the tuple includes the new cache which is same as input `cache` \
argument but `incremental_cache` in it has an incremental length. \
See `MultiHeadAttention.gen_cache` and `MultiHeadAttention.forward` \
for more details.
"""
output = src
for mod in self.layers:
output = mod(output, src_mask=src_mask)
new_caches = []
for i, mod in enumerate(self.layers):
if cache is None:
output = mod(output, src_mask=src_mask)
else:
output, new_cache = mod(output,
src_mask=src_mask,
cache=cache[i])
new_caches.append(new_cache)
if self.norm is not None:
output = self.norm(output)
return output
return output if cache is None else (output, new_caches)
def gen_cache(self, src):
r"""
Generates cache for `forward` usage. The generated cache is a list, and
each element in it is `incremental_cache` produced by
`TransformerEncoderLayer.gen_cache`. See `TransformerEncoderLayer.gen_cache`
for more details.
Parameters:
src (Tensor): The input of Transformer encoder. It is a tensor
with shape `[batch_size, source_length, d_model]`. The data type
should be float32 or float64.
Returns:
list: It is a list, and each element in the list is `incremental_cache`
produced by `TransformerEncoderLayer.gen_cache`. See
`TransformerEncoderLayer.gen_cache` for more details.
"""
cache = [layer.gen_cache(src) for layer in self.layers]
return cache
class TransformerDecoderLayer(Layer):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册