未验证 提交 2e8425b6 编写于 作者: J Jiaqi Liu 提交者: GitHub

Fix beam search bug (#29824)

* fix beam search bug

* add dygraph unittest

* update dynamic_decode argument doc

* add warning info for state which has no lengths attribute
上级 f43e1d8c
...@@ -16,6 +16,7 @@ from __future__ import print_function ...@@ -16,6 +16,7 @@ from __future__ import print_function
import sys import sys
from functools import partial, reduce from functools import partial, reduce
import warnings
import paddle import paddle
from paddle.utils import deprecated from paddle.utils import deprecated
...@@ -1382,10 +1383,17 @@ def _dynamic_decode_imperative(decoder, ...@@ -1382,10 +1383,17 @@ def _dynamic_decode_imperative(decoder,
sequence_lengths, sequence_lengths,
tensor.cast( tensor.cast(
control_flow.logical_not(finished), sequence_lengths.dtype)) control_flow.logical_not(finished), sequence_lengths.dtype))
if impute_finished: # rectify the states for the finished. if impute_finished: # rectify the states for the finished.
next_states = map_structure( next_states = map_structure(
lambda x, y: _maybe_copy(x, y, finished), states, next_states) lambda x, y: _maybe_copy(x, y, finished), states,
next_states)
else:
warnings.warn(
"`next_states` has no `lengths` attribute, the returned `sequence_lengths` would be all zeros."
) if not hasattr(next_states, "lengths") else None
next_sequence_lengths = getattr(next_states, "lengths",
sequence_lengths)
outputs = map_structure( outputs = map_structure(
lambda x: ArrayWrapper(x), lambda x: ArrayWrapper(x),
step_outputs) if step_idx == 0 else map_structure( step_outputs) if step_idx == 0 else map_structure(
...@@ -1505,12 +1513,17 @@ def _dynamic_decode_declarative(decoder, ...@@ -1505,12 +1513,17 @@ def _dynamic_decode_declarative(decoder,
tensor.cast( tensor.cast(
control_flow.logical_not(global_finished), control_flow.logical_not(global_finished),
sequence_lengths.dtype)) sequence_lengths.dtype))
if impute_finished: # rectify the states for the finished. if impute_finished: # rectify the states for the finished.
next_states = map_structure( next_states = map_structure(
lambda x, y: _maybe_copy(x, y, global_finished), lambda x, y: _maybe_copy(x, y, global_finished),
states, states,
next_states, ) next_states, )
else:
warnings.warn(
"`next_states` has no `lengths` attribute, the returned `sequence_lengths` would be all zeros."
) if not hasattr(next_states, "lengths") else None
next_sequence_lengths = getattr(next_states, "lengths",
sequence_lengths)
# create tensor array in global block after dtype[s] of outputs can be got # create tensor array in global block after dtype[s] of outputs can be got
outputs_arrays = map_structure( outputs_arrays = map_structure(
...@@ -1595,13 +1608,13 @@ def dynamic_decode(decoder, ...@@ -1595,13 +1608,13 @@ def dynamic_decode(decoder,
attr:`False`, the data layout would be batch major with shape attr:`False`, the data layout would be batch major with shape
`[batch_size, seq_len, ...]`. If attr:`True`, the data layout would `[batch_size, seq_len, ...]`. If attr:`True`, the data layout would
be time major with shape `[seq_len, batch_size, ...]`. Default: `False`. be time major with shape `[seq_len, batch_size, ...]`. Default: `False`.
impute_finished(bool, optional): If `True`, then states get copied through impute_finished(bool, optional): If `True` and `decoder.tracks_own_finished`
for batch entries which are marked as finished, which differs with the is False, then states get copied through for batch entries which are
unfinished using the new states returned by :code:`decoder.step()` and marked as finished, which differs with the unfinished using the new states
ensures that the final states have the correct values. Otherwise, states returned by :code:`decoder.step()` and ensures that the final states have
wouldn't be copied through when finished. If the returned `final_states` the correct values. Otherwise, states wouldn't be copied through when
is needed, it should be set as True, which causes some slowdown. finished. If the returned `final_states` is needed, it should be set as
Default `False`. True, which causes some slowdown. Default `False`.
is_test(bool, optional): A flag indicating whether to use test mode. In is_test(bool, optional): A flag indicating whether to use test mode. In
test mode, it is more memory saving. Default `False`. test mode, it is more memory saving. Default `False`.
return_length(bool, optional): A flag indicating whether to return an return_length(bool, optional): A flag indicating whether to return an
......
...@@ -178,16 +178,14 @@ class Seq2SeqModel(object): ...@@ -178,16 +178,14 @@ class Seq2SeqModel(object):
beam_size=4): beam_size=4):
self.start_token, self.end_token = start_token, end_token self.start_token, self.end_token = start_token, end_token
self.max_decoding_length, self.beam_size = max_decoding_length, beam_size self.max_decoding_length, self.beam_size = max_decoding_length, beam_size
self.src_embeder = lambda x: fluid.embedding( self.src_embeder = paddle.nn.Embedding(
input=x, src_vocab_size,
size=[src_vocab_size, hidden_size], hidden_size,
dtype="float32", weight_attr=fluid.ParamAttr(name="source_embedding"))
param_attr=fluid.ParamAttr(name="source_embedding")) self.trg_embeder = paddle.nn.Embedding(
self.trg_embeder = lambda x: fluid.embedding( trg_vocab_size,
input=x, hidden_size,
size=[trg_vocab_size, hidden_size], weight_attr=fluid.ParamAttr(name="target_embedding"))
dtype="float32",
param_attr=fluid.ParamAttr(name="target_embedding"))
self.encoder = Encoder(num_layers, hidden_size, dropout_prob) self.encoder = Encoder(num_layers, hidden_size, dropout_prob)
self.decoder = Decoder(num_layers, hidden_size, dropout_prob, self.decoder = Decoder(num_layers, hidden_size, dropout_prob,
decoding_strategy, max_decoding_length) decoding_strategy, max_decoding_length)
...@@ -195,7 +193,7 @@ class Seq2SeqModel(object): ...@@ -195,7 +193,7 @@ class Seq2SeqModel(object):
x, x,
size=trg_vocab_size, size=trg_vocab_size,
num_flatten_dims=len(x.shape) - 1, num_flatten_dims=len(x.shape) - 1,
param_attr=fluid.ParamAttr(name="output_w"), param_attr=fluid.ParamAttr(),
bias_attr=False) bias_attr=False)
def __call__(self, src, src_length, trg=None, trg_length=None): def __call__(self, src, src_length, trg=None, trg_length=None):
...@@ -556,6 +554,14 @@ class TestDynamicDecode(unittest.TestCase): ...@@ -556,6 +554,14 @@ class TestDynamicDecode(unittest.TestCase):
}, },
fetch_list=[output])[0] fetch_list=[output])[0]
def test_dynamic_basic_decoder(self):
paddle.disable_static()
src = paddle.to_tensor(np.random.randint(8, size=(8, 4)))
src_length = paddle.to_tensor(np.random.randint(8, size=(8)))
model = Seq2SeqModel(**self.model_hparams)
probs, samples, sample_length = model(src, src_length)
paddle.enable_static()
class ModuleApiTest(unittest.TestCase): class ModuleApiTest(unittest.TestCase):
@classmethod @classmethod
...@@ -672,8 +678,8 @@ class TestBeamSearch(ModuleApiTest): ...@@ -672,8 +678,8 @@ class TestBeamSearch(ModuleApiTest):
hidden_size, hidden_size,
bos_id=0, bos_id=0,
eos_id=1, eos_id=1,
beam_size=2, beam_size=4,
max_step_num=2): max_step_num=20):
embedder = paddle.fluid.dygraph.Embedding( embedder = paddle.fluid.dygraph.Embedding(
size=[vocab_size, embed_dim], dtype="float64") size=[vocab_size, embed_dim], dtype="float64")
output_layer = nn.Linear(hidden_size, vocab_size) output_layer = nn.Linear(hidden_size, vocab_size)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册