From d994d212260373426a26577fd233f6e9b1d76306 Mon Sep 17 00:00:00 2001 From: RedContritio Date: Mon, 6 Feb 2023 12:05:17 +0800 Subject: [PATCH] =?UTF-8?q?Fix=20=E7=A9=BA=E6=8C=87=E9=92=88=20(Null=20poi?= =?UTF-8?q?nter)=20of=20case=2017=20paddle.flip=20(#50028)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * check tensor numel in PyObject_CheckLongOrToLong * add unittest --- paddle/fluid/pybind/op_function_common.cc | 4 +++- .../paddle/fluid/tests/unittests/test_flip.py | 17 +++++++++++++++++ 2 files changed, 20 insertions(+), 1 deletion(-) diff --git a/paddle/fluid/pybind/op_function_common.cc b/paddle/fluid/pybind/op_function_common.cc index edab97c8b5e..9f97556e200 100644 --- a/paddle/fluid/pybind/op_function_common.cc +++ b/paddle/fluid/pybind/op_function_common.cc @@ -30,6 +30,7 @@ #include "paddle/fluid/imperative/tracer.h" #include "paddle/fluid/imperative/type_defs.h" #include "paddle/fluid/operators/ops_extra_info.h" +#include "paddle/fluid/pybind/eager.h" #include "paddle/fluid/pybind/imperative.h" #include "paddle/phi/common/complex.h" @@ -70,7 +71,8 @@ bool PyObject_CheckLongOrToLong(PyObject** obj) { if ((PyLong_Check(*obj) && !PyBool_Check(*obj)) || PyObject_IsInstance(*obj, (PyObject*)g_vartype_pytype) || // NOLINT PyObject_IsInstance(*obj, (PyObject*)g_varbase_pytype) || // NOLINT - PyObject_IsInstance(*obj, (PyObject*)p_tensor_type)) { // NOLINT + (PyObject_IsInstance(*obj, (PyObject*)p_tensor_type) && // NOLINT + (((TensorObject*)(*obj))->tensor.numel() == 1))) { // NOLINT return true; } diff --git a/python/paddle/fluid/tests/unittests/test_flip.py b/python/paddle/fluid/tests/unittests/test_flip.py index 4f095493f00..855108f52db 100644 --- a/python/paddle/fluid/tests/unittests/test_flip.py +++ b/python/paddle/fluid/tests/unittests/test_flip.py @@ -196,6 +196,23 @@ class TestFlipTripleGradCheck(unittest.TestCase): self.func(p) +class TestFlipError(unittest.TestCase): + def test_axis(self): + paddle.enable_static() + + def test_axis_rank(): + input = fluid.data(name='input', dtype='float32', shape=[2, 3]) + output = paddle.flip(input, axis=[[0]]) + + self.assertRaises(TypeError, test_axis_rank) + + def test_axis_rank2(): + input = fluid.data(name='input', dtype='float32', shape=[2, 3]) + output = paddle.flip(input, axis=[[0, 0], [1, 1]]) + + self.assertRaises(TypeError, test_axis_rank2) + + if __name__ == "__main__": paddle.enable_static() unittest.main() -- GitLab