diff --git a/paddle/fluid/pybind/eager_utils.cc b/paddle/fluid/pybind/eager_utils.cc index bd7194198e565673b7f91a2af783eb17982bbd83..c0167a19e2b8c78d5d88a2d6578126e02a151de7 100644 --- a/paddle/fluid/pybind/eager_utils.cc +++ b/paddle/fluid/pybind/eager_utils.cc @@ -1100,7 +1100,8 @@ paddle::experimental::IntArray CastPyArg2IntArray(PyObject* obj, // obj could be: int, float, bool, paddle.Tensor PyTypeObject* type = obj->ob_type; auto type_name = std::string(type->tp_name); - if (type_name == "list" || type_name == "tuple") { + if (type_name == "list" || type_name == "tuple" || + type_name == "numpy.ndarray") { std::vector value = CastPyArg2Ints(obj, op_type, arg_pos); return paddle::experimental::IntArray(value); diff --git a/python/paddle/fluid/tests/unittests/test_expand_v2_op.py b/python/paddle/fluid/tests/unittests/test_expand_v2_op.py index 592a635ddcccc587cba766e00525fc9c8f3c6639..4932ea8a1b5c96a61b613e54cfe4c457aa93076f 100644 --- a/python/paddle/fluid/tests/unittests/test_expand_v2_op.py +++ b/python/paddle/fluid/tests/unittests/test_expand_v2_op.py @@ -20,6 +20,7 @@ from op_test import OpTest import paddle.fluid as fluid from paddle.fluid import compiler, Program, program_guard import paddle +from paddle.fluid.framework import _test_eager_guard # Situation 1: shape is a list(without tensor) @@ -243,6 +244,30 @@ class TestExpandInferShape(unittest.TestCase): self.assertListEqual(list(out.shape), [-1, -1, -1]) +# Test python Dygraph API +class TestExpandV2DygraphAPI(unittest.TestCase): + def test_expand_times_is_tensor(self): + with paddle.fluid.dygraph.guard(): + with _test_eager_guard(): + paddle.seed(1) + a = paddle.rand([2, 5]) + egr_expand_1 = paddle.expand(a, shape=[2, 5]) + np_array = np.array([2, 5]) + egr_expand_2 = paddle.expand(a, shape=np_array) + + paddle.seed(1) + a = paddle.rand([2, 5]) + expand_1 = paddle.expand(a, shape=[2, 5]) + np_array = np.array([2, 5]) + expand_2 = paddle.expand(a, shape=np_array) + + self.assertTrue( + np.array_equal(egr_expand_1.numpy(), egr_expand_2.numpy())) + self.assertTrue(np.array_equal(expand_1.numpy(), expand_2.numpy())) + self.assertTrue( + np.array_equal(expand_1.numpy(), egr_expand_1.numpy())) + + if __name__ == "__main__": paddle.enable_static() unittest.main()