未验证 提交 d994d212 编写于 作者: R RedContritio 提交者: GitHub

Fix 空指针 (Null pointer) of case 17 paddle.flip (#50028)

* check tensor numel in PyObject_CheckLongOrToLong

* add unittest
上级 e79867b6
......@@ -30,6 +30,7 @@
#include "paddle/fluid/imperative/tracer.h"
#include "paddle/fluid/imperative/type_defs.h"
#include "paddle/fluid/operators/ops_extra_info.h"
#include "paddle/fluid/pybind/eager.h"
#include "paddle/fluid/pybind/imperative.h"
#include "paddle/phi/common/complex.h"
......@@ -70,7 +71,8 @@ bool PyObject_CheckLongOrToLong(PyObject** obj) {
if ((PyLong_Check(*obj) && !PyBool_Check(*obj)) ||
PyObject_IsInstance(*obj, (PyObject*)g_vartype_pytype) || // NOLINT
PyObject_IsInstance(*obj, (PyObject*)g_varbase_pytype) || // NOLINT
PyObject_IsInstance(*obj, (PyObject*)p_tensor_type)) { // NOLINT
(PyObject_IsInstance(*obj, (PyObject*)p_tensor_type) && // NOLINT
(((TensorObject*)(*obj))->tensor.numel() == 1))) { // NOLINT
return true;
}
......
......@@ -196,6 +196,23 @@ class TestFlipTripleGradCheck(unittest.TestCase):
self.func(p)
class TestFlipError(unittest.TestCase):
def test_axis(self):
paddle.enable_static()
def test_axis_rank():
input = fluid.data(name='input', dtype='float32', shape=[2, 3])
output = paddle.flip(input, axis=[[0]])
self.assertRaises(TypeError, test_axis_rank)
def test_axis_rank2():
input = fluid.data(name='input', dtype='float32', shape=[2, 3])
output = paddle.flip(input, axis=[[0, 0], [1, 1]])
self.assertRaises(TypeError, test_axis_rank2)
if __name__ == "__main__":
paddle.enable_static()
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册