未验证 提交 3777d4a9 编写于 作者: Z zxcd 提交者: GitHub

arange api add deal with np.int type (#54134)

* add deal with np.int type

* add assertEqual for dtype.

* fix bug

* fix by comment

* fix np.integer.
上级 af39c163
......@@ -182,31 +182,43 @@ class TestArangeImperative(unittest.TestCase):
x5 = paddle.arange(start_float, end_float, step_float)
x5_expected_data = np.arange(0.5, 1.5, 0.5).astype(np.float32)
self.assertEqual((x5.numpy() == x5_expected_data).all(), True)
self.assertEqual(x5.numpy().dtype, np.float32)
# [start, end] is float , [step] is int
x6 = paddle.arange(start_float, end_float, 1)
x6_expected_data = np.arange(0.5, 1.5, 1).astype(np.float32)
self.assertEqual((x6.numpy() == x6_expected_data).all(), True)
self.assertEqual(x6.numpy().dtype, np.float32)
# [start] is float , [end] is int
x7 = paddle.arange(start_float, 1)
x7_expected_data = np.arange(0.5, 1).astype(np.float32)
self.assertEqual((x7.numpy() == x7_expected_data).all(), True)
self.assertEqual(x7.numpy().dtype, np.float32)
# [start] is float
x8 = paddle.arange(start_float)
x8_expected_data = np.arange(0.5).astype(np.float32)
self.assertEqual((x8.numpy() == x8_expected_data).all(), True)
self.assertEqual(x8.numpy().dtype, np.float32)
# [start] is int
x9 = paddle.arange(1)
x9_expected_data = np.arange(1).astype(np.int64)
self.assertEqual((x9.numpy() == x9_expected_data).all(), True)
self.assertEqual(x9.numpy().dtype, np.int64)
# [start] is float
x10 = paddle.arange(1.0)
x10_expected_data = np.arange(1).astype(np.float32)
self.assertEqual((x10.numpy() == x10_expected_data).all(), True)
self.assertEqual(x10.numpy().dtype, np.float32)
# [start] is np.int
x11 = paddle.arange(np.int64(10))
x11_expected_data = np.arange(10).astype(np.int64)
self.assertEqual((x11.numpy() == x11_expected_data).all(), True)
self.assertEqual(x11.numpy().dtype, np.int64)
paddle.enable_static()
......
......@@ -1304,7 +1304,9 @@ def arange(start=0, end=None, step=1, dtype=None, name=None):
if isinstance(val, Variable) and not val.is_integer():
dtype = paddle.get_default_dtype()
break
elif not isinstance(val, int) and not isinstance(val, Variable):
elif not isinstance(val, (int, np.integer)) and not isinstance(
val, Variable
):
dtype = paddle.get_default_dtype()
break
else:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册