diff --git a/python/paddle/fluid/layers/sequence_lod.py b/python/paddle/fluid/layers/sequence_lod.py index ccdf28f09a5869d0b04e7259eea67d4ed22eae3c..6c33371ef6e3fc52bfd5a21205c3d75509a2284a 100644 --- a/python/paddle/fluid/layers/sequence_lod.py +++ b/python/paddle/fluid/layers/sequence_lod.py @@ -18,6 +18,7 @@ from .layer_function_generator import templatedoc from ..framework import Variable, in_dygraph_mode from ..layer_helper import LayerHelper from ..data_feeder import check_variable_and_dtype, check_type, check_dtype +from ..core import VarDesc __all__ = [ 'sequence_conv', @@ -941,7 +942,7 @@ def sequence_pad(x, pad_value, maxlen=None, name=None): 'fluid.layers.sequence_pad') dtype = helper.input_dtype() out = helper.create_variable_for_type_inference(dtype) - length = helper.create_variable_for_type_inference(dtype) + length = helper.create_variable_for_type_inference(VarDesc.VarType.INT64) pad_value.stop_gradient = True length.stop_gradient = True diff --git a/python/paddle/fluid/tests/unittests/sequence/test_sequence_pad_op.py b/python/paddle/fluid/tests/unittests/sequence/test_sequence_pad_op.py index b9d53452adead5a796ee646acf5dce725ca0a745..7d2ba834de1633f0558e527a39cea297f5c4b778 100644 --- a/python/paddle/fluid/tests/unittests/sequence/test_sequence_pad_op.py +++ b/python/paddle/fluid/tests/unittests/sequence/test_sequence_pad_op.py @@ -19,6 +19,7 @@ sys.path.append("../") from op_test import OpTest import paddle.fluid as fluid +import paddle.fluid.core as core class TestSequencePadOp(OpTest): @@ -173,6 +174,13 @@ class TestSequencePadOpError(unittest.TestCase): self.assertRaises(TypeError, test_dtype) + def test_length_dtype(self): + x = fluid.data(name='x', shape=[10, 5], dtype='float32', lod_level=1) + pad_value = fluid.layers.assign(input=np.array([0.0], dtype=np.float32)) + out, length = fluid.layers.sequence_pad(x=x, pad_value=pad_value) + # check if the dtype of length is int64 in compile time + self.assertEqual(length.dtype, core.VarDesc.VarType.INT64) + if __name__ == '__main__': unittest.main()