未验证 提交 f5c08c3f 编写于 作者: Z Zhang Ting 提交者: GitHub

set int64 for Output(Length) of sequence_pad, test=develop (#24161)

上级 9a93f6aa
...@@ -18,6 +18,7 @@ from .layer_function_generator import templatedoc ...@@ -18,6 +18,7 @@ from .layer_function_generator import templatedoc
from ..framework import Variable, in_dygraph_mode from ..framework import Variable, in_dygraph_mode
from ..layer_helper import LayerHelper from ..layer_helper import LayerHelper
from ..data_feeder import check_variable_and_dtype, check_type, check_dtype from ..data_feeder import check_variable_and_dtype, check_type, check_dtype
from ..core import VarDesc
__all__ = [ __all__ = [
'sequence_conv', 'sequence_conv',
...@@ -941,7 +942,7 @@ def sequence_pad(x, pad_value, maxlen=None, name=None): ...@@ -941,7 +942,7 @@ def sequence_pad(x, pad_value, maxlen=None, name=None):
'fluid.layers.sequence_pad') 'fluid.layers.sequence_pad')
dtype = helper.input_dtype() dtype = helper.input_dtype()
out = helper.create_variable_for_type_inference(dtype) out = helper.create_variable_for_type_inference(dtype)
length = helper.create_variable_for_type_inference(dtype) length = helper.create_variable_for_type_inference(VarDesc.VarType.INT64)
pad_value.stop_gradient = True pad_value.stop_gradient = True
length.stop_gradient = True length.stop_gradient = True
......
...@@ -19,6 +19,7 @@ sys.path.append("../") ...@@ -19,6 +19,7 @@ sys.path.append("../")
from op_test import OpTest from op_test import OpTest
import paddle.fluid as fluid import paddle.fluid as fluid
import paddle.fluid.core as core
class TestSequencePadOp(OpTest): class TestSequencePadOp(OpTest):
...@@ -173,6 +174,13 @@ class TestSequencePadOpError(unittest.TestCase): ...@@ -173,6 +174,13 @@ class TestSequencePadOpError(unittest.TestCase):
self.assertRaises(TypeError, test_dtype) self.assertRaises(TypeError, test_dtype)
def test_length_dtype(self):
x = fluid.data(name='x', shape=[10, 5], dtype='float32', lod_level=1)
pad_value = fluid.layers.assign(input=np.array([0.0], dtype=np.float32))
out, length = fluid.layers.sequence_pad(x=x, pad_value=pad_value)
# check if the dtype of length is int64 in compile time
self.assertEqual(length.dtype, core.VarDesc.VarType.INT64)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册