提交 cf752eba 编写于 作者: G guosheng

Add api docs for TransformerCell and TransformerBeamSearchDecoder.

上级 f1243462
...@@ -1856,7 +1856,7 @@ class GRU(Layer): ...@@ -1856,7 +1856,7 @@ class GRU(Layer):
Parameters: Parameters:
input_size (int): The input size for the first GRU cell. input_size (int): The input feature size for the first GRU cell.
hidden_size (int): The hidden size for every GRU cell. hidden_size (int): The hidden size for every GRU cell.
gate_activation (function, optional): The activation function for gates gate_activation (function, optional): The activation function for gates
of GRU, that is :math:`act_g` in the formula. Default: None, of GRU, that is :math:`act_g` in the formula. Default: None,
...@@ -1971,7 +1971,7 @@ class BidirectionalGRU(Layer): ...@@ -1971,7 +1971,7 @@ class BidirectionalGRU(Layer):
Parameters: Parameters:
input_size (int): The input size for the first GRU cell. input_size (int): The input feature size for the first GRU cell.
hidden_size (int): The hidden size for every GRU cell. hidden_size (int): The hidden size for every GRU cell.
gate_activation (function, optional): The activation function for gates gate_activation (function, optional): The activation function for gates
of GRU, that is :math:`act_g` in the formula. Default: None, of GRU, that is :math:`act_g` in the formula. Default: None,
...@@ -2346,8 +2346,59 @@ class DynamicDecode(Layer): ...@@ -2346,8 +2346,59 @@ class DynamicDecode(Layer):
class TransformerCell(Layer): class TransformerCell(Layer):
""" """
Let inputs=(trg_word, trg_pos), states=cache to make Transformer can be TransformerCell wraps a Transformer decoder producing logits from `inputs`
used as RNNCell composed by ids and position.
Parameters:
decoder(callable): A TransformerDecoder instance. Or a wrapper of it that
includes a embedding layer accepting ids and positions instead of embeddings
and includes a output layer transforming decoder output features to logits.
embedding_fn(function, optional): A callable that accepts ids and position
as arguments and return embeddings as input of `decoder`. It can be
None if `decoder` includes a embedding layer. Default None.
output_fn(callable, optional): A callable applid on `decoder` output to
transform decoder output features to get logits. Mostly it is a Linear
layer with vocabulary size. It can be None if `decoder` includes a
output layer. Default None.
Examples:
.. code-block:: python
import paddle
import paddle.fluid as fluid
from paddle.incubate.hapi.text import TransformerCell
from paddle.incubate.hapi.text import TransformerBeamSearchDecoder
embedder = Embedding(size=[1000, 128])
output_layer = Linear(128, 1000)
decoder = TransformerDecoder(2, 2, 64, 64, 128, 512)
transformer_cell = TransformerCell(decoder, embedder, output_layer)
dynamic_decoder = DynamicDecode(
TransformerBeamSearchDecoder(
transformer_cell,
bos_id=0,
eos_id=1,
beam_size=4,
var_dim_in_state=2),
max_step_num,
is_test=True)
enc_output = paddle.rand((2, 4, 64))
# cross attention bias: [batch_size, n_head, trg_len, src_len]
trg_src_attn_bias = paddle.rand((2, 2, 1, 4))
# inputs for beam search on Transformer
states = cell.get_initial_states(encoder_output)
enc_output = TransformerBeamSearchDecoder.tile_beam_merge_with_batch(
enc_output, beam_size=4)
trg_src_attn_bias = TransformerBeamSearchDecoder.tile_beam_merge_with_batch(
trg_src_attn_bias, self.beam_size)
static_caches = decoder.prepare_static_cache(enc_output)
outputs = dynamic_decoder(
inits=caches,
enc_output=enc_output,
trg_src_attn_bias=trg_src_attn_bias,
static_caches=static_caches)
""" """
def __init__(self, decoder, embedding_fn=None, output_fn=None): def __init__(self, decoder, embedding_fn=None, output_fn=None):
...@@ -2356,11 +2407,56 @@ class TransformerCell(Layer): ...@@ -2356,11 +2407,56 @@ class TransformerCell(Layer):
self.embedding_fn = embedding_fn self.embedding_fn = embedding_fn
self.output_fn = output_fn self.output_fn = output_fn
def forward(self, inputs, states, trg_src_attn_bias, enc_output, def forward(self,
static_caches): inputs,
states=None,
enc_output=None,
trg_slf_attn_bias=None,
trg_src_attn_bias=None,
static_caches=[]):
"""
Produces logits from `inputs` composed by ids and positions.
Parameters:
inputs(tuple): A tuple includes target ids and positions. The two
tensors both have int64 data type and with 2D shape
`[batch_size, sequence_length]` where `sequence_length` is 1
for inference.
states(list): It caches the multi-head attention intermediate results
of history decoding steps. It is a list of dict where the length
of list is decoder layer number, and each dict has `k` and `v` as
keys and values are cached results. Default None
enc_output(Variable): The output of Transformer encoder. It is a tensor
with shape `[batch_size, sequence_length, d_model]`. The data type
should be float32 or float64.
trg_slf_attn_bias(Variable, optional): A tensor used in decoder self
attention to mask out attention on unwanted target positions. It
is a tensor with shape `[batch_size, n_head, target_length, target_length]`,
where the unwanted positions have `-INF` values and the others
have 0 values. It can be None for inference. The data type should
be float32 or float64.
trg_src_attn_bias(Variable, optional): A tensor used in decoder encoder
cross attention to mask out unwanted attention on source (encoder output).
It is a tensor with shape `[batch_size, n_head, target_length, source_length]`,
where the unwanted positions have `-INF` values and the others
have 0 values. The data type should be float32 or float64.
static_caches(list): It stores the multi-head attention intermediate
results of encoder output. It is a list of dict where the length
of list is decoder layer number, and each dict has `static_k` and
`static_v` as keys and values are stored results. Default empty list
Returns:
tuple: A tuple( :code:`(outputs, new_states)` ), where `outputs` \
is a float32 or float64 3D tensor representing logits shaped \
`[batch_size, sequence_length, vocab_size]`. `new_states has \
the same structure and date type with `states` while the length \
is one larger since the intermediate results of current step are \
concatenated into it.
"""
trg_word, trg_pos = inputs trg_word, trg_pos = inputs
for cache, static_cache in zip(states, static_caches): if states and static_caches:
cache.update(static_cache) for cache, static_cache in zip(states, static_caches):
cache.update(static_cache)
if self.embedding_fn is not None: if self.embedding_fn is not None:
dec_input = self.embedding_fn(trg_word, trg_pos) dec_input = self.embedding_fn(trg_word, trg_pos)
outputs = self.decoder(dec_input, enc_output, None, outputs = self.decoder(dec_input, enc_output, None,
...@@ -2370,14 +2466,30 @@ class TransformerCell(Layer): ...@@ -2370,14 +2466,30 @@ class TransformerCell(Layer):
trg_src_attn_bias, states) trg_src_attn_bias, states)
if self.output_fn is not None: if self.output_fn is not None:
outputs = self.output_fn(outputs) outputs = self.output_fn(outputs)
if len(outputs.shape) == 3:
# squeeze to adapt to BeamSearchDecoder which use 2D logits new_states = [{
outputs = layers.squeeze(outputs, [1]) "k": cache["k"],
new_states = [{"k": cache["k"], "v": cache["v"]} for cache in states] "v": cache["v"]
} for cache in states] if states else states
return outputs, new_states return outputs, new_states
@property @property
def state_shape(self): def state_shape(self):
"""
States of TransformerCell cache the multi-head attention intermediate
results of history decoding steps, and have a increasing length as
decoding continued.
`state_shape` of TransformerCell is used to initialize states. It is a
list of dict where the length of list is decoder layer, and each dict
has `k` and `v` as keys and values are `[n_head, 0, d_key]`, `[n_head, 0, d_value]`
separately. (-1 for batch size would be automatically inserted into shape).
Returns:
list: It is a list of dict where the length of list is decoder layer \
number, and each dict has `k` and `v` as keys and values are cached \
results.
"""
return [{ return [{
"k": [self.decoder.n_head, 0, self.decoder.d_key], "k": [self.decoder.n_head, 0, self.decoder.d_key],
"v": [self.decoder.n_head, 0, self.decoder.d_value], "v": [self.decoder.n_head, 0, self.decoder.d_value],
...@@ -2385,6 +2497,60 @@ class TransformerCell(Layer): ...@@ -2385,6 +2497,60 @@ class TransformerCell(Layer):
class TransformerBeamSearchDecoder(layers.BeamSearchDecoder): class TransformerBeamSearchDecoder(layers.BeamSearchDecoder):
"""
Compared with a RNN step :code:`outputs, new_states = cell(inputs, states)`,
Transformer decoder's `inputs` uses 2D tensor shaped `[batch_size * beam_size, 1]`
and includes extra position data. And its `states` (caches) has increasing
length. These are not consistent with `BeamSearchDecoder`, thus subclass
`BeamSearchDecoder` to make beam search adapt to Transformer decoder.
Parameters:
cell(TransformerCell): An instance of `TransformerCell`.
start_token(int): The start token id.
end_token(int): The end token id.
beam_size(int): The beam width used in beam search.
var_dim_in_state(int): Indicate which dimension of states is variant.
Examples:
.. code-block:: python
import paddle
import paddle.fluid as fluid
from paddle.incubate.hapi.text import TransformerCell
from paddle.incubate.hapi.text import TransformerBeamSearchDecoder
embedder = Embedding(size=[1000, 128])
output_layer = Linear(128, 1000)
decoder = TransformerDecoder(2, 2, 64, 64, 128, 512)
transformer_cell = TransformerCell(decoder, embedder, output_layer)
dynamic_decoder = DynamicDecode(
TransformerBeamSearchDecoder(
transformer_cell,
bos_id=0,
eos_id=1,
beam_size=4,
var_dim_in_state=2),
max_step_num,
is_test=True)
enc_output = paddle.rand((2, 4, 64))
# cross attention bias: [batch_size, n_head, trg_len, src_len]
trg_src_attn_bias = paddle.rand((2, 2, 1, 4))
# inputs for beam search on Transformer
states = cell.get_initial_states(encoder_output)
enc_output = TransformerBeamSearchDecoder.tile_beam_merge_with_batch(
enc_output, beam_size=4)
trg_src_attn_bias = TransformerBeamSearchDecoder.tile_beam_merge_with_batch(
trg_src_attn_bias, self.beam_size)
static_caches = decoder.prepare_static_cache(enc_output)
outputs = dynamic_decoder(
inits=caches,
enc_output=enc_output,
trg_src_attn_bias=trg_src_attn_bias,
static_caches=static_caches)
"""
def __init__(self, cell, start_token, end_token, beam_size, def __init__(self, cell, start_token, end_token, beam_size,
var_dim_in_state): var_dim_in_state):
super(TransformerBeamSearchDecoder, super(TransformerBeamSearchDecoder,
...@@ -2393,6 +2559,18 @@ class TransformerBeamSearchDecoder(layers.BeamSearchDecoder): ...@@ -2393,6 +2559,18 @@ class TransformerBeamSearchDecoder(layers.BeamSearchDecoder):
self.var_dim_in_state = var_dim_in_state self.var_dim_in_state = var_dim_in_state
def _merge_batch_beams_with_var_dim(self, x): def _merge_batch_beams_with_var_dim(self, x):
"""
Reshape a tensor with shape `[batch_size, beam_size, ...]` to a new
tensor with shape `[batch_size * beam_size, ...]`.
Parameters:
x(Variable): A tensor with shape `[batch_size, beam_size, ...]`. The
data type should be float32, float64, int32, int64 or bool.
Returns:
Variable: A tensor with shape `[batch_size * beam_size, ...]`, whose \
data type is same as `x`.
"""
# init length of cache is 0, and it increases with decoding carrying on, # init length of cache is 0, and it increases with decoding carrying on,
# thus need to reshape elaborately # thus need to reshape elaborately
var_dim_in_state = self.var_dim_in_state + 1 # count in beam dim var_dim_in_state = self.var_dim_in_state + 1 # count in beam dim
...@@ -2410,6 +2588,18 @@ class TransformerBeamSearchDecoder(layers.BeamSearchDecoder): ...@@ -2410,6 +2588,18 @@ class TransformerBeamSearchDecoder(layers.BeamSearchDecoder):
return x return x
def _split_batch_beams_with_var_dim(self, x): def _split_batch_beams_with_var_dim(self, x):
"""
Reshape a tensor with shape `[batch_size * beam_size, ...]` to a new
tensor with shape `[batch_size, beam_size, ...]`.
Parameters:
x(Variable): A tensor with shape `[batch_size * beam_size, ...]`. The
data type should be float32, float64, int32, int64 or bool.
Returns:
Variable: A tensor with shape `[batch_size, beam_size, ...]`, whose \
data type is same as `x`.
"""
var_dim_size = layers.shape(x)[self.var_dim_in_state] var_dim_size = layers.shape(x)[self.var_dim_in_state]
x = layers.reshape( x = layers.reshape(
x, [-1, self.beam_size] + x, [-1, self.beam_size] +
...@@ -2419,6 +2609,38 @@ class TransformerBeamSearchDecoder(layers.BeamSearchDecoder): ...@@ -2419,6 +2609,38 @@ class TransformerBeamSearchDecoder(layers.BeamSearchDecoder):
return x return x
def step(self, time, inputs, states, **kwargs): def step(self, time, inputs, states, **kwargs):
"""
Perform a beam search decoding step, which uses `cell` to get probabilities,
and follows a beam search step to calculate scores and select candidate
token ids.
Note: compared with `BeamSearchDecoder.step`, it feed 2D id tensor shaped
`[batch_size * beam_size, 1]` rather than `[batch_size * beam_size]` combined
position data as inputs to `cell`.
Parameters:
time(Variable): An `int64` tensor with shape `[1]` provided by the caller,
representing the current time step number of decoding.
inputs(Variable): A tensor variable. It is same as `initial_inputs`
returned by `initialize()` for the first decoding step and
`next_inputs` returned by `step()` for the others. It is a int64
id tensor with shape `[batch_size * beam_size]`
states(Variable): A structure of tensor variables.
It is same as the `initial_states` returned by `initialize()` for
the first decoding step and `beam_search_state` returned by
`step()` for the others.
**kwargs: Additional keyword arguments, provided by the caller.
Returns:
tuple: A tuple( :code:`(beam_search_output, beam_search_state, next_inputs, finished)` ). \
`beam_search_state` and `next_inputs` have the same structure, \
shape and data type as the input arguments `states` and `inputs` separately. \
`beam_search_output` is a namedtuple(including scores, predicted_ids, \
parent_ids as fields) of tensor variables, where \
`scores, predicted_ids, parent_ids` all has a tensor value shaped \
`[batch_size, beam_size]` with data type `float32, int64, int64`. \
`finished` is a `bool` tensor with shape `[batch_size, beam_size]`.
"""
# compared to RNN, Transformer has 3D data at every decoding step # compared to RNN, Transformer has 3D data at every decoding step
inputs = layers.reshape(inputs, [-1, 1]) # token inputs = layers.reshape(inputs, [-1, 1]) # token
pos = layers.ones_like(inputs) * time # pos pos = layers.ones_like(inputs) * time # pos
...@@ -2427,6 +2649,11 @@ class TransformerBeamSearchDecoder(layers.BeamSearchDecoder): ...@@ -2427,6 +2649,11 @@ class TransformerBeamSearchDecoder(layers.BeamSearchDecoder):
cell_outputs, next_cell_states = self.cell((inputs, pos), cell_states, cell_outputs, next_cell_states = self.cell((inputs, pos), cell_states,
**kwargs) **kwargs)
# squeeze to adapt to BeamSearchDecoder which use 2D logits
cell_outputs = map_structure(
lambda x: layers.squeeze(x, [1]) if len(x.shape) == 3 else x,
cell_outputs)
cell_outputs = map_structure(self._split_batch_beams, cell_outputs) cell_outputs = map_structure(self._split_batch_beams, cell_outputs)
next_cell_states = map_structure(self._split_batch_beams_with_var_dim, next_cell_states = map_structure(self._split_batch_beams_with_var_dim,
next_cell_states) next_cell_states)
...@@ -2715,7 +2942,66 @@ class TransformerEncoderLayer(Layer): ...@@ -2715,7 +2942,66 @@ class TransformerEncoderLayer(Layer):
class TransformerEncoder(Layer): class TransformerEncoder(Layer):
""" """
encoder TransformerEncoder is a stack of N encoder layers.
Applies a stacked multi-layer gated recurrent unit (GRU) RNN to an input
sequence.
Parameters:
n_layer (int): The number of encoder layers to be stacked.
n_head (int): The number of heads in the multi-head attention(MHA).
d_key (int): The number of heads in the multi-head attention. Mostly .
d_value (int): The number of heads in the multiheadattention.
d_model (int): The expected feature size in the input and output.
d_inner_hid (int): The hidden layer size in the feedforward network(FFN).
prepostprocess_dropout (float, optional): The dropout probability used
in pre-process and post-precess of MHA and FFN sub-layer. Default 0.1
attention_dropout (float, optional): The dropout probability used
in MHA to drop some attention target. Default 0.1
relu_dropout (float, optional): The dropout probability used in FFN
in MHA to drop some attention target. Default 0.1
preprocess_cmd (str, optional): The process applied before each MHA and
FFN sub-layer, and it also would be applied. It should be a string
that includes `d`, `a`, `n` as , where `d` for dropout, `a` for add
residual connection, `n` for layer normalization.
network. Default `n`.
ffn_fc1_act (str, optional): The activation function in the feedforward
network. Default relu.
dropout(float|list|tuple, optional): The dropout probability after each
GRU. It also can be a list or tuple, including dropout probabilities
for the corresponding GRU. Default 0.0
is_reverse (bool, optional): Indicate whether to calculate in the reverse
order of input sequences. Default: `False`.
time_major (bool, optional): Indicate the data layout of Tensor included
in `input` and `output` tensors. If `False`, the data layout would
be batch major with shape `[batch_size, sequence_length, ...]`. If
`True`, the data layout would be time major with shape
`[sequence_length, batch_size, ...]`. Default: `False`.
param_attr (list|tuple|ParamAttr): A list, tuple or something can be
converted to a ParamAttr instance by `ParamAttr._to_attr`. If it is
a list or tuple, it's length must equal to `num_layers`. Otherwise,
construct a list by `StackedRNNCell.stack_param_attr(param_attr, num_layers)`.
Default None.
bias_attr (list|tuple|ParamAttr): A list, tuple or something can be
converted to a ParamAttr instance by `ParamAttr._to_attr`. If it is
a list or tuple, it's length must equal to `num_layers`. Otherwise,
construct a list by `StackedRNNCell.stack_param_attr(bias_attr, num_layers)`.
Default None.
dtype(string, optional): The data type used in this cell. It can be
float32 or float64. Default float32.
Examples:
.. code-block:: python
import paddle
import paddle.fluid as fluid
from paddle.incubate.hapi.text import TransformerEncoder
inputs = paddle.rand((2, 4, 32))
gru = TransformerEncoder(n_layers=2, input_size=32, hidden_size=64,)
outputs, _ = gru(inputs) # [2, 4, 32]
""" """
def __init__(self, def __init__(self,
...@@ -2725,9 +3011,9 @@ class TransformerEncoder(Layer): ...@@ -2725,9 +3011,9 @@ class TransformerEncoder(Layer):
d_value, d_value,
d_model, d_model,
d_inner_hid, d_inner_hid,
prepostprocess_dropout, prepostprocess_dropout=0.1,
attention_dropout, attention_dropout=0.1,
relu_dropout, relu_dropout=0.1,
preprocess_cmd="n", preprocess_cmd="n",
postprocess_cmd="da", postprocess_cmd="da",
ffn_fc1_act="relu"): ffn_fc1_act="relu"):
...@@ -2908,8 +3194,8 @@ class TransformerDecoder(Layer): ...@@ -2908,8 +3194,8 @@ class TransformerDecoder(Layer):
caches=None): caches=None):
for i, decoder_layer in enumerate(self.decoder_layers): for i, decoder_layer in enumerate(self.decoder_layers):
dec_output = decoder_layer(dec_input, enc_output, self_attn_bias, dec_output = decoder_layer(dec_input, enc_output, self_attn_bias,
cross_attn_bias, None cross_attn_bias, caches[i]
if caches is None else caches[i]) if caches else None)
dec_input = dec_output dec_input = dec_output
return self.processer(dec_output) return self.processer(dec_output)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册