未验证 提交 08811d9b 编写于 作者: W Weilong Wu 提交者: GitHub

Update sequence_mask related code (#41393)

上级 e5e0b726
......@@ -1382,19 +1382,18 @@ def sequence_mask(x, maxlen=None, dtype='int64', name=None):
"""
if _non_static_mode():
if in_dygraph_mode():
if not isinstance(dtype, core.VarDesc.VarType):
dtype = convert_np_dtype_to_dtype_(dtype)
if in_dygraph_mode():
if maxlen is not None:
if isinstance(maxlen, core.eager.Tensor):
attrs = ('out_dtype', dtype)
out = _C_ops.sequence_mask(x, maxlen, *attrs)
else:
attrs = ('out_dtype', dtype, 'maxlen', maxlen)
out = _C_ops.sequence_mask(x, None, *attrs)
out.stop_gradient = True
return out
if maxlen is not None:
if isinstance(maxlen, core.eager.Tensor):
attrs = ('out_dtype', dtype)
out = _C_ops.sequence_mask(x, maxlen, *attrs)
else:
attrs = ('out_dtype', dtype, 'maxlen', maxlen)
out = _C_ops.sequence_mask(x, None, *attrs)
out.stop_gradient = True
return out
helper = LayerHelper('sequence_mask', **locals())
out = helper.create_variable_for_type_inference(dtype=dtype)
......
......@@ -716,7 +716,6 @@ class TestBeamSearch(ModuleApiTest):
def func_check_output(self):
self.setUp()
self.make_inputs()
self.make_inputs()
self.check_output()
def test_check_output(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册