From a000e9b8399fe79724858b04a3bc19267867d7ac Mon Sep 17 00:00:00 2001 From: HongyuJia Date: Fri, 13 Jan 2023 10:43:14 +0800 Subject: [PATCH] unify PyCheckTensor function (#49751) --- paddle/fluid/pybind/eager_math_op_patch.cc | 6 ------ paddle/fluid/pybind/eager_method.cc | 4 ---- paddle/fluid/pybind/eager_py_layer.cc | 24 +++++++++++----------- paddle/fluid/pybind/eager_utils.cc | 10 ++++----- paddle/fluid/pybind/eager_utils.h | 2 +- paddle/fluid/pybind/imperative.cc | 5 +---- paddle/fluid/pybind/slice_utils.h | 1 - 7 files changed, 19 insertions(+), 33 deletions(-) diff --git a/paddle/fluid/pybind/eager_math_op_patch.cc b/paddle/fluid/pybind/eager_math_op_patch.cc index 8387123ae1..ef8ff4d6c1 100644 --- a/paddle/fluid/pybind/eager_math_op_patch.cc +++ b/paddle/fluid/pybind/eager_math_op_patch.cc @@ -52,12 +52,6 @@ typedef SSIZE_T ssize_t; namespace paddle { namespace pybind { -extern PyTypeObject* p_tensor_type; - -bool PyCheckTensor(PyObject* obj) { - return PyObject_IsInstance(obj, reinterpret_cast(p_tensor_type)); -} - static bool PyCheckInteger(PyObject* obj) { #if PY_VERSION_HEX < 0x03000000 return (PyLong_Check(obj) || PyInt_Check(obj)) && !PyBool_Check(obj); diff --git a/paddle/fluid/pybind/eager_method.cc b/paddle/fluid/pybind/eager_method.cc index 443dbd1557..ac36f142d4 100644 --- a/paddle/fluid/pybind/eager_method.cc +++ b/paddle/fluid/pybind/eager_method.cc @@ -88,10 +88,6 @@ Py_ssize_t GetSliceIndexFromPyObject(PyObject* obj) { } } -bool PyCheckTensor(PyObject* obj) { - return PyObject_IsInstance(obj, reinterpret_cast(p_tensor_type)); -} - static PyObject* tensor_method_numpy(TensorObject* self, PyObject* args, PyObject* kwargs) { diff --git a/paddle/fluid/pybind/eager_py_layer.cc b/paddle/fluid/pybind/eager_py_layer.cc index 8befe6318b..de3cf80cef 100644 --- a/paddle/fluid/pybind/eager_py_layer.cc +++ b/paddle/fluid/pybind/eager_py_layer.cc @@ -49,12 +49,12 @@ std::set GetTensorsFromPyObject(PyObject* obj) { if (obj == nullptr) { return result; } - if (IsEagerTensor(obj)) { + if (PyCheckTensor(obj)) { result.insert(&reinterpret_cast(obj)->tensor); // NOLINT } else if (PyList_Check(obj)) { Py_ssize_t len = PyList_Size(obj); for (Py_ssize_t i = 0; i < len; i++) { - if (IsEagerTensor(PyList_GetItem(obj, i))) { + if (PyCheckTensor(PyList_GetItem(obj, i))) { result.insert( &reinterpret_cast(PyList_GetItem(obj, i)) // NOLINT ->tensor); @@ -63,7 +63,7 @@ std::set GetTensorsFromPyObject(PyObject* obj) { } else if (PyTuple_Check(obj)) { Py_ssize_t len = PyTuple_Size(obj); for (Py_ssize_t i = 0; i < len; i++) { - if (IsEagerTensor(PyTuple_GetItem(obj, i))) { + if (PyCheckTensor(PyTuple_GetItem(obj, i))) { result.insert( &reinterpret_cast(PyTuple_GetItem(obj, i)) // NOLINT ->tensor); @@ -177,7 +177,7 @@ PyObject* pylayer_method_apply(PyObject* cls, } else { obj = PyTuple_GET_ITEM(args, i); } - if (IsEagerTensor(obj)) { + if (PyCheckTensor(obj)) { input_tensorbases.insert( reinterpret_cast(obj)->tensor.impl().get()); auto autograd_meta = egr::EagerUtils::nullable_autograd_meta( @@ -196,7 +196,7 @@ PyObject* pylayer_method_apply(PyObject* cls, Py_ssize_t len = PyList_Size(obj); for (Py_ssize_t j = 0; j < len; j++) { PyObject* o = PyList_GetItem(obj, j); - if (IsEagerTensor(o)) { + if (PyCheckTensor(o)) { input_tensorbases.insert( reinterpret_cast(o)->tensor.impl().get()); tensors.push_back(&(reinterpret_cast(o)->tensor)); @@ -219,7 +219,7 @@ PyObject* pylayer_method_apply(PyObject* cls, Py_ssize_t len = PyTuple_Size(obj); for (Py_ssize_t j = 0; j < len; j++) { PyObject* o = PyTuple_GetItem(obj, j); - if (IsEagerTensor(o)) { + if (PyCheckTensor(o)) { input_tensorbases.insert( reinterpret_cast(o)->tensor.impl().get()); tensors.push_back(&(reinterpret_cast(o)->tensor)); @@ -292,7 +292,7 @@ PyObject* pylayer_method_apply(PyObject* cls, ctx->forward_output_tensor_is_duplicable.reserve(outputs_size); for (Py_ssize_t i = 0; i < outputs_size; i++) { PyObject* obj = PyTuple_GET_ITEM(outputs_tuple, i); - if (IsEagerTensor(obj)) { + if (PyCheckTensor(obj)) { outputs_tensor.push_back( {&(reinterpret_cast(obj)->tensor)}); outputs_autograd_meta.push_back({egr::EagerUtils::autograd_meta( @@ -316,7 +316,7 @@ PyObject* pylayer_method_apply(PyObject* cls, Py_ssize_t len = PyList_Size(obj); for (Py_ssize_t j = 0; j < len; j++) { PyObject* o = PyList_GetItem(obj, j); - if (IsEagerTensor(o)) { + if (PyCheckTensor(o)) { tensors.push_back(&(reinterpret_cast(o)->tensor)); if (input_tensorbases.count( reinterpret_cast(o)->tensor.impl().get())) { @@ -344,7 +344,7 @@ PyObject* pylayer_method_apply(PyObject* cls, Py_ssize_t len = PyTuple_Size(obj); for (Py_ssize_t j = 0; j < len; j++) { PyObject* o = PyTuple_GetItem(obj, j); - if (IsEagerTensor(o)) { + if (PyCheckTensor(o)) { tensors.push_back(&(reinterpret_cast(o)->tensor)); if (input_tensorbases.count( reinterpret_cast(o)->tensor.impl().get())) { @@ -538,7 +538,7 @@ void call_pack_hook(PyLayerObject* self, PyObject* value) { for (Py_ssize_t i = 0; i < saved_value_size; i++) { PyObject* obj = PyTuple_GET_ITEM(saved_value, i); - if (IsEagerTensor(obj)) { + if (PyCheckTensor(obj)) { PyTuple_SET_ITEM(packed_value, i, reinterpret_cast( @@ -548,7 +548,7 @@ void call_pack_hook(PyLayerObject* self, PyObject* value) { auto tmp_list = PyList_New(len); for (Py_ssize_t j = 0; j < len; j++) { PyObject* o = PyList_GetItem(obj, j); - if (IsEagerTensor(o)) { + if (PyCheckTensor(o)) { PyTuple_SET_ITEM(tmp_list, j, reinterpret_cast( @@ -565,7 +565,7 @@ void call_pack_hook(PyLayerObject* self, PyObject* value) { auto tmp_tuple = PyTuple_New(len); for (Py_ssize_t j = 0; j < len; j++) { PyObject* o = PyTuple_GetItem(obj, j); - if (IsEagerTensor(o)) { + if (PyCheckTensor(o)) { PyTuple_SET_ITEM(tmp_tuple, j, reinterpret_cast( diff --git a/paddle/fluid/pybind/eager_utils.cc b/paddle/fluid/pybind/eager_utils.cc index 1eca3dd1d9..371ba65a46 100644 --- a/paddle/fluid/pybind/eager_utils.cc +++ b/paddle/fluid/pybind/eager_utils.cc @@ -207,7 +207,7 @@ std::string CastPyArg2AttrString(PyObject* obj, ssize_t arg_pos) { } } -bool IsEagerTensor(PyObject* obj) { +bool PyCheckTensor(PyObject* obj) { return PyObject_IsInstance(obj, reinterpret_cast(p_tensor_type)); } @@ -1307,7 +1307,7 @@ std::vector GetTensorListFromPyObject( } paddle::experimental::Tensor& GetTensorFromPyObject(PyObject* obj) { - if (!IsEagerTensor(obj)) { + if (!PyCheckTensor(obj)) { PADDLE_THROW(platform::errors::InvalidArgument( "argument must be " "Tensor, but got %s", @@ -1384,7 +1384,7 @@ paddle::experimental::Scalar CastPyArg2Scalar(PyObject* obj, } else if (PyFloat_Check(obj)) { double value = CastPyArg2Double(obj, op_type, arg_pos); return paddle::experimental::Scalar(value); - } else if (IsEagerTensor(obj)) { + } else if (PyCheckTensor(obj)) { paddle::experimental::Tensor& value = GetTensorFromPyObject( op_type, "" /*arg_name*/, obj, arg_pos, false /*dispensable*/); return paddle::experimental::Scalar(value); @@ -1715,7 +1715,7 @@ paddle::experimental::Tensor UnPackHook::operator()( Py_XDECREF(args); egr::Controller::Instance().SetHasGrad(grad_tmp); - PADDLE_ENFORCE_EQ(paddle::pybind::IsEagerTensor(ret), + PADDLE_ENFORCE_EQ(paddle::pybind::PyCheckTensor(ret), true, paddle::platform::errors::InvalidArgument( "paddle.autograd.saved_tensors_hooks only one pair " @@ -1740,7 +1740,7 @@ void* UnPackHook::operator()(void* packed_value, void* other) { Py_XDECREF(args); egr::Controller::Instance().SetHasGrad(grad_tmp); - PADDLE_ENFORCE_EQ(paddle::pybind::IsEagerTensor(ret), + PADDLE_ENFORCE_EQ(paddle::pybind::PyCheckTensor(ret), true, paddle::platform::errors::InvalidArgument( "paddle.autograd.saved_tensors_hooks only one pair " diff --git a/paddle/fluid/pybind/eager_utils.h b/paddle/fluid/pybind/eager_utils.h index 02a8e10dac..063e14903c 100644 --- a/paddle/fluid/pybind/eager_utils.h +++ b/paddle/fluid/pybind/eager_utils.h @@ -46,7 +46,7 @@ namespace py = ::pybind11; int TensorDtype2NumpyDtype(phi::DataType dtype); -bool IsEagerTensor(PyObject* obj); +bool PyCheckTensor(PyObject* obj); bool PyObject_CheckLongOrConvertToLong(PyObject** obj); bool PyObject_CheckFloatOrConvertToFloat(PyObject** obj); diff --git a/paddle/fluid/pybind/imperative.cc b/paddle/fluid/pybind/imperative.cc index 67db8027eb..9262fec62b 100644 --- a/paddle/fluid/pybind/imperative.cc +++ b/paddle/fluid/pybind/imperative.cc @@ -350,9 +350,6 @@ Py_ssize_t GetSliceIndexFromPyObject(PyObject *obj) { } } -bool PyCheckTensor(PyObject *obj) { - return py::isinstance(obj); -} using PyNameVarBaseMap = std::unordered_map; // NOTE(zjl): py::handle is a very light wrapper of PyObject *. @@ -872,7 +869,7 @@ void BindImperative(py::module *m_ptr) { self->Name())); } - if (PyCheckTensor(value_obj.ptr())) { + if (py::isinstance(value_obj.ptr())) { auto value_tensor = value_obj.cast>(); ins.insert({"ValueTensor", {value_tensor}}); diff --git a/paddle/fluid/pybind/slice_utils.h b/paddle/fluid/pybind/slice_utils.h index 5f74a534ff..1e0a4bd67b 100644 --- a/paddle/fluid/pybind/slice_utils.h +++ b/paddle/fluid/pybind/slice_utils.h @@ -30,7 +30,6 @@ namespace py = pybind11; namespace paddle { namespace pybind { -static bool PyCheckTensor(PyObject* obj); static Py_ssize_t GetSliceIndexFromPyObject(PyObject* obj); // Slice related methods static bool PyCheckInteger(PyObject* obj) { -- GitLab