From 08811d9b873948d2d5b1bf2f9b9811fc7a2d6e60 Mon Sep 17 00:00:00 2001 From: Weilong Wu Date: Mon, 4 Apr 2022 16:38:10 +0800 Subject: [PATCH] Update sequence_mask related code (#41393) --- python/paddle/fluid/layers/sequence_lod.py | 21 +++++++++---------- .../tests/unittests/test_rnn_decode_api.py | 1 - 2 files changed, 10 insertions(+), 12 deletions(-) diff --git a/python/paddle/fluid/layers/sequence_lod.py b/python/paddle/fluid/layers/sequence_lod.py index 1758123f0e6..80dc990af45 100644 --- a/python/paddle/fluid/layers/sequence_lod.py +++ b/python/paddle/fluid/layers/sequence_lod.py @@ -1382,19 +1382,18 @@ def sequence_mask(x, maxlen=None, dtype='int64', name=None): """ - if _non_static_mode(): + if in_dygraph_mode(): if not isinstance(dtype, core.VarDesc.VarType): dtype = convert_np_dtype_to_dtype_(dtype) - if in_dygraph_mode(): - if maxlen is not None: - if isinstance(maxlen, core.eager.Tensor): - attrs = ('out_dtype', dtype) - out = _C_ops.sequence_mask(x, maxlen, *attrs) - else: - attrs = ('out_dtype', dtype, 'maxlen', maxlen) - out = _C_ops.sequence_mask(x, None, *attrs) - out.stop_gradient = True - return out + if maxlen is not None: + if isinstance(maxlen, core.eager.Tensor): + attrs = ('out_dtype', dtype) + out = _C_ops.sequence_mask(x, maxlen, *attrs) + else: + attrs = ('out_dtype', dtype, 'maxlen', maxlen) + out = _C_ops.sequence_mask(x, None, *attrs) + out.stop_gradient = True + return out helper = LayerHelper('sequence_mask', **locals()) out = helper.create_variable_for_type_inference(dtype=dtype) diff --git a/python/paddle/fluid/tests/unittests/test_rnn_decode_api.py b/python/paddle/fluid/tests/unittests/test_rnn_decode_api.py index bf848357e31..dacb7a5b599 100644 --- a/python/paddle/fluid/tests/unittests/test_rnn_decode_api.py +++ b/python/paddle/fluid/tests/unittests/test_rnn_decode_api.py @@ -716,7 +716,6 @@ class TestBeamSearch(ModuleApiTest): def func_check_output(self): self.setUp() self.make_inputs() - self.make_inputs() self.check_output() def test_check_output(self): -- GitLab