未验证 提交 3ac1ccf9 编写于 作者: Z zxcd 提交者: GitHub

fix processing logic of the arange function when dtype is empty. (#53800)

* fix processing logic of the arange function when dtype is empty.

* update commit version

* fix ValueError when end is None.

* add unitest for new case.

* fix tensor type.

* remove paddle.to_tensor(), add more test unit.

* remove useless line.

* fix enable_static

* add new test unit.

* fix by comment.
上级 c8375f86
......@@ -170,12 +170,46 @@ class TestArangeImperative(unittest.TestCase):
end = paddle.to_tensor(np.array([5], 'float32'))
step = paddle.to_tensor(np.array([1], 'float32'))
x4 = paddle.arange(start, end, step, 'int64')
paddle.enable_static()
expected_data = np.arange(0, 5, 1).astype(np.int64)
for i in [x1, x2, x3, x4]:
self.assertEqual((i.numpy() == expected_data).all(), True)
start_float = paddle.to_tensor(np.array([0.5], 'float32'))
end_float = paddle.to_tensor(np.array([1.5], 'float32'))
step_float = paddle.to_tensor(np.array([0.5], 'float32'))
# all [start, end, step] is float
x5 = paddle.arange(start_float, end_float, step_float)
x5_expected_data = np.arange(0.5, 1.5, 0.5).astype(np.float32)
self.assertEqual((x5.numpy() == x5_expected_data).all(), True)
# [start, end] is float , [step] is int
x6 = paddle.arange(start_float, end_float, 1)
x6_expected_data = np.arange(0.5, 1.5, 1).astype(np.float32)
self.assertEqual((x6.numpy() == x6_expected_data).all(), True)
# [start] is float , [end] is int
x7 = paddle.arange(start_float, 1)
x7_expected_data = np.arange(0.5, 1).astype(np.float32)
self.assertEqual((x7.numpy() == x7_expected_data).all(), True)
# [start] is float
x8 = paddle.arange(start_float)
x8_expected_data = np.arange(0.5).astype(np.float32)
self.assertEqual((x8.numpy() == x8_expected_data).all(), True)
# [start] is int
x9 = paddle.arange(1)
x9_expected_data = np.arange(1).astype(np.int64)
self.assertEqual((x9.numpy() == x9_expected_data).all(), True)
# [start] is float
x10 = paddle.arange(1.0)
x10_expected_data = np.arange(1).astype(np.float32)
self.assertEqual((x10.numpy() == x10_expected_data).all(), True)
paddle.enable_static()
if __name__ == "__main__":
unittest.main()
......@@ -1295,12 +1295,21 @@ def arange(start=0, end=None, step=1, dtype=None, name=None):
# [3, 4, 5, 6]
"""
if dtype is None:
dtype = 'int64'
if end is None:
end = start
start = 0
if dtype is None:
for val in [start, end, step]:
if isinstance(val, Variable) and not val.is_integer():
dtype = paddle.get_default_dtype()
break
elif not isinstance(val, int) and not isinstance(val, Variable):
dtype = paddle.get_default_dtype()
break
else:
dtype = 'int64'
out_shape = None
if not in_dynamic_mode() and (
not isinstance(start, Variable)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册