未验证 提交 08811d9b 编写于 作者: W Weilong Wu 提交者: GitHub

Update sequence_mask related code (#41393)

上级 e5e0b726
...@@ -1382,19 +1382,18 @@ def sequence_mask(x, maxlen=None, dtype='int64', name=None): ...@@ -1382,19 +1382,18 @@ def sequence_mask(x, maxlen=None, dtype='int64', name=None):
""" """
if _non_static_mode(): if in_dygraph_mode():
if not isinstance(dtype, core.VarDesc.VarType): if not isinstance(dtype, core.VarDesc.VarType):
dtype = convert_np_dtype_to_dtype_(dtype) dtype = convert_np_dtype_to_dtype_(dtype)
if in_dygraph_mode(): if maxlen is not None:
if maxlen is not None: if isinstance(maxlen, core.eager.Tensor):
if isinstance(maxlen, core.eager.Tensor): attrs = ('out_dtype', dtype)
attrs = ('out_dtype', dtype) out = _C_ops.sequence_mask(x, maxlen, *attrs)
out = _C_ops.sequence_mask(x, maxlen, *attrs) else:
else: attrs = ('out_dtype', dtype, 'maxlen', maxlen)
attrs = ('out_dtype', dtype, 'maxlen', maxlen) out = _C_ops.sequence_mask(x, None, *attrs)
out = _C_ops.sequence_mask(x, None, *attrs) out.stop_gradient = True
out.stop_gradient = True return out
return out
helper = LayerHelper('sequence_mask', **locals()) helper = LayerHelper('sequence_mask', **locals())
out = helper.create_variable_for_type_inference(dtype=dtype) out = helper.create_variable_for_type_inference(dtype=dtype)
......
...@@ -716,7 +716,6 @@ class TestBeamSearch(ModuleApiTest): ...@@ -716,7 +716,6 @@ class TestBeamSearch(ModuleApiTest):
def func_check_output(self): def func_check_output(self):
self.setUp() self.setUp()
self.make_inputs() self.make_inputs()
self.make_inputs()
self.check_output() self.check_output()
def test_check_output(self): def test_check_output(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册