From 3da8066a05c2eadc172e0669cbb6bcdbc7f8d057 Mon Sep 17 00:00:00 2001 From: Weilong Wu Date: Thu, 21 Apr 2022 10:32:25 +0800 Subject: [PATCH] [Eager] Support numpy.narray as input for eager expand (#42043) --- paddle/fluid/pybind/eager_utils.cc | 3 ++- .../tests/unittests/test_expand_v2_op.py | 25 +++++++++++++++++++ 2 files changed, 27 insertions(+), 1 deletion(-) diff --git a/paddle/fluid/pybind/eager_utils.cc b/paddle/fluid/pybind/eager_utils.cc index ec391a7fa64..9719963d51d 100644 --- a/paddle/fluid/pybind/eager_utils.cc +++ b/paddle/fluid/pybind/eager_utils.cc @@ -1101,7 +1101,8 @@ paddle::experimental::IntArray CastPyArg2IntArray(PyObject* obj, // obj could be: int, float, bool, paddle.Tensor PyTypeObject* type = obj->ob_type; auto type_name = std::string(type->tp_name); - if (type_name == "list" || type_name == "tuple") { + if (type_name == "list" || type_name == "tuple" || + type_name == "numpy.ndarray") { std::vector value = CastPyArg2Ints(obj, op_type, arg_pos); return paddle::experimental::IntArray(value); diff --git a/python/paddle/fluid/tests/unittests/test_expand_v2_op.py b/python/paddle/fluid/tests/unittests/test_expand_v2_op.py index 592a635ddcc..4932ea8a1b5 100644 --- a/python/paddle/fluid/tests/unittests/test_expand_v2_op.py +++ b/python/paddle/fluid/tests/unittests/test_expand_v2_op.py @@ -20,6 +20,7 @@ from op_test import OpTest import paddle.fluid as fluid from paddle.fluid import compiler, Program, program_guard import paddle +from paddle.fluid.framework import _test_eager_guard # Situation 1: shape is a list(without tensor) @@ -243,6 +244,30 @@ class TestExpandInferShape(unittest.TestCase): self.assertListEqual(list(out.shape), [-1, -1, -1]) +# Test python Dygraph API +class TestExpandV2DygraphAPI(unittest.TestCase): + def test_expand_times_is_tensor(self): + with paddle.fluid.dygraph.guard(): + with _test_eager_guard(): + paddle.seed(1) + a = paddle.rand([2, 5]) + egr_expand_1 = paddle.expand(a, shape=[2, 5]) + np_array = np.array([2, 5]) + egr_expand_2 = paddle.expand(a, shape=np_array) + + paddle.seed(1) + a = paddle.rand([2, 5]) + expand_1 = paddle.expand(a, shape=[2, 5]) + np_array = np.array([2, 5]) + expand_2 = paddle.expand(a, shape=np_array) + + self.assertTrue( + np.array_equal(egr_expand_1.numpy(), egr_expand_2.numpy())) + self.assertTrue(np.array_equal(expand_1.numpy(), expand_2.numpy())) + self.assertTrue( + np.array_equal(expand_1.numpy(), egr_expand_1.numpy())) + + if __name__ == "__main__": paddle.enable_static() unittest.main() -- GitLab