未验证 提交 2e8425b6 编写于 作者: J Jiaqi Liu 提交者: GitHub

Fix beam search bug (#29824)

* fix beam search bug

* add dygraph unittest

* update dynamic_decode argument doc

* add warning info for state which has no lengths attribute
上级 f43e1d8c
......@@ -16,6 +16,7 @@ from __future__ import print_function
import sys
from functools import partial, reduce
import warnings
import paddle
from paddle.utils import deprecated
......@@ -1382,10 +1383,17 @@ def _dynamic_decode_imperative(decoder,
sequence_lengths,
tensor.cast(
control_flow.logical_not(finished), sequence_lengths.dtype))
if impute_finished: # rectify the states for the finished.
next_states = map_structure(
lambda x, y: _maybe_copy(x, y, finished), states, next_states)
lambda x, y: _maybe_copy(x, y, finished), states,
next_states)
else:
warnings.warn(
"`next_states` has no `lengths` attribute, the returned `sequence_lengths` would be all zeros."
) if not hasattr(next_states, "lengths") else None
next_sequence_lengths = getattr(next_states, "lengths",
sequence_lengths)
outputs = map_structure(
lambda x: ArrayWrapper(x),
step_outputs) if step_idx == 0 else map_structure(
......@@ -1505,12 +1513,17 @@ def _dynamic_decode_declarative(decoder,
tensor.cast(
control_flow.logical_not(global_finished),
sequence_lengths.dtype))
if impute_finished: # rectify the states for the finished.
next_states = map_structure(
lambda x, y: _maybe_copy(x, y, global_finished),
states,
next_states, )
else:
warnings.warn(
"`next_states` has no `lengths` attribute, the returned `sequence_lengths` would be all zeros."
) if not hasattr(next_states, "lengths") else None
next_sequence_lengths = getattr(next_states, "lengths",
sequence_lengths)
# create tensor array in global block after dtype[s] of outputs can be got
outputs_arrays = map_structure(
......@@ -1595,13 +1608,13 @@ def dynamic_decode(decoder,
attr:`False`, the data layout would be batch major with shape
`[batch_size, seq_len, ...]`. If attr:`True`, the data layout would
be time major with shape `[seq_len, batch_size, ...]`. Default: `False`.
impute_finished(bool, optional): If `True`, then states get copied through
for batch entries which are marked as finished, which differs with the
unfinished using the new states returned by :code:`decoder.step()` and
ensures that the final states have the correct values. Otherwise, states
wouldn't be copied through when finished. If the returned `final_states`
is needed, it should be set as True, which causes some slowdown.
Default `False`.
impute_finished(bool, optional): If `True` and `decoder.tracks_own_finished`
is False, then states get copied through for batch entries which are
marked as finished, which differs with the unfinished using the new states
returned by :code:`decoder.step()` and ensures that the final states have
the correct values. Otherwise, states wouldn't be copied through when
finished. If the returned `final_states` is needed, it should be set as
True, which causes some slowdown. Default `False`.
is_test(bool, optional): A flag indicating whether to use test mode. In
test mode, it is more memory saving. Default `False`.
return_length(bool, optional): A flag indicating whether to return an
......
......@@ -178,16 +178,14 @@ class Seq2SeqModel(object):
beam_size=4):
self.start_token, self.end_token = start_token, end_token
self.max_decoding_length, self.beam_size = max_decoding_length, beam_size
self.src_embeder = lambda x: fluid.embedding(
input=x,
size=[src_vocab_size, hidden_size],
dtype="float32",
param_attr=fluid.ParamAttr(name="source_embedding"))
self.trg_embeder = lambda x: fluid.embedding(
input=x,
size=[trg_vocab_size, hidden_size],
dtype="float32",
param_attr=fluid.ParamAttr(name="target_embedding"))
self.src_embeder = paddle.nn.Embedding(
src_vocab_size,
hidden_size,
weight_attr=fluid.ParamAttr(name="source_embedding"))
self.trg_embeder = paddle.nn.Embedding(
trg_vocab_size,
hidden_size,
weight_attr=fluid.ParamAttr(name="target_embedding"))
self.encoder = Encoder(num_layers, hidden_size, dropout_prob)
self.decoder = Decoder(num_layers, hidden_size, dropout_prob,
decoding_strategy, max_decoding_length)
......@@ -195,7 +193,7 @@ class Seq2SeqModel(object):
x,
size=trg_vocab_size,
num_flatten_dims=len(x.shape) - 1,
param_attr=fluid.ParamAttr(name="output_w"),
param_attr=fluid.ParamAttr(),
bias_attr=False)
def __call__(self, src, src_length, trg=None, trg_length=None):
......@@ -556,6 +554,14 @@ class TestDynamicDecode(unittest.TestCase):
},
fetch_list=[output])[0]
def test_dynamic_basic_decoder(self):
paddle.disable_static()
src = paddle.to_tensor(np.random.randint(8, size=(8, 4)))
src_length = paddle.to_tensor(np.random.randint(8, size=(8)))
model = Seq2SeqModel(**self.model_hparams)
probs, samples, sample_length = model(src, src_length)
paddle.enable_static()
class ModuleApiTest(unittest.TestCase):
@classmethod
......@@ -672,8 +678,8 @@ class TestBeamSearch(ModuleApiTest):
hidden_size,
bos_id=0,
eos_id=1,
beam_size=2,
max_step_num=2):
beam_size=4,
max_step_num=20):
embedder = paddle.fluid.dygraph.Embedding(
size=[vocab_size, embed_dim], dtype="float64")
output_layer = nn.Linear(hidden_size, vocab_size)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册