diff --git a/paddle/fluid/pybind/eager_math_op_patch.cc b/paddle/fluid/pybind/eager_math_op_patch.cc index 537f9469c1ccf59fc9385f09e66c86c4130f9066..3e13c71cf22d3d40f945ec5ecffc0247e47157d0 100644 --- a/paddle/fluid/pybind/eager_math_op_patch.cc +++ b/paddle/fluid/pybind/eager_math_op_patch.cc @@ -763,6 +763,176 @@ static PyObject* tensor__rdiv__method(TensorObject* self, EAGER_CATCH_AND_THROW_RETURN_NULL } +static PyObject* tensor__gt__method(TensorObject* self, + PyObject* args, + PyObject* kwargs) { + paddle::platform::RecordEvent pythonc_record_event( + "__gt__ pybind_patch_func", + paddle::platform::TracerEventType::UserDefined, + 1); + + EAGER_TRY + VLOG(1) << "Running Eager tensor__gt__method"; + + // Set Device ID + auto place = egr::Controller::Instance().GetExpectedPlace(); + SetDevice(place); + + paddle::experimental::Tensor ret; + paddle::experimental::Tensor self_tensor = self->tensor; + PyObject* other_obj = PyTuple_GET_ITEM(args, 0); + + // 1. scalar exists cases + // there is no scalar function for __gt__ now + float other_float = 0.0; + bool has_other_float = false; + if (PyFloat_Check(other_obj) || PyCheckInteger(other_obj) || + IsNumpyType(other_obj)) { + if (PyFloat_Check(other_obj)) { + other_float = CastPyArg2AttrFloat(other_obj, 0); + has_other_float = true; + if (_supported_int_dtype_.find(self_tensor.dtype()) != + _supported_int_dtype_.end()) { + eager_gil_scoped_release guard; + self_tensor = cast_ad_func(self_tensor, DataType::FLOAT32); + } + } else if (PyCheckInteger(other_obj) || IsNumpyType(other_obj)) { + other_float = static_cast(CastPyArg2AttrInt(other_obj, 0)); + has_other_float = true; + } + } + + // 2. create or get tensor for other_obj + paddle::experimental::Tensor other_tensor; + if (has_other_float) { + eager_gil_scoped_release guard; + other_tensor = full_ad_func(self_tensor.shape(), + phi::Scalar(other_float), + self_tensor.dtype(), + place); + } else if (!PyCheckTensor(other_obj)) { + paddle::experimental::Scalar value = + CastPyArg2Scalar(other_obj, "__gt__", 0); + if (PyComplex_Check(other_obj)) { + eager_gil_scoped_release guard; + other_tensor = full_ad_func({1}, value, DataType::COMPLEX64, place); + } else { + eager_gil_scoped_release guard; + other_tensor = + full_ad_func(self_tensor.shape(), value, self_tensor.dtype(), place); + } + } else { + other_tensor = CastPyArg2Tensor(other_obj, 0); + } + + // 3. promote types or unify right var type to left var + phi::DataType lhs_dtype = self_tensor.dtype(); + phi::DataType rhs_dtype = other_tensor.dtype(); + if (lhs_dtype != rhs_dtype) { + LOG(WARNING) << "The dtype of left and right Tensor are not the same, left " + "dtype is " + << lhs_dtype << ", but right dtype is " << rhs_dtype + << ", the right dtype will convert to " << lhs_dtype; + eager_gil_scoped_release guard; + other_tensor = cast_ad_func(other_tensor, lhs_dtype); + } + + // 4. calculation + VLOG(6) << "Calling greater_than_ad_func in tensor__gt__method"; + { + eager_gil_scoped_release guard; + ret = greater_than_ad_func(self_tensor, other_tensor, -1); + } + + return ToPyObject(ret); + EAGER_CATCH_AND_THROW_RETURN_NULL +} + +static PyObject* tensor__ge__method(TensorObject* self, + PyObject* args, + PyObject* kwargs) { + paddle::platform::RecordEvent pythonc_record_event( + "__ge__ pybind_patch_func", + paddle::platform::TracerEventType::UserDefined, + 1); + + EAGER_TRY + VLOG(1) << "Running Eager tensor__ge__method"; + + // Set Device ID + auto place = egr::Controller::Instance().GetExpectedPlace(); + SetDevice(place); + + paddle::experimental::Tensor ret; + paddle::experimental::Tensor self_tensor = self->tensor; + PyObject* other_obj = PyTuple_GET_ITEM(args, 0); + + // 1. scalar exists cases + // there is no scalar function for __ge__ now + float other_float = 0.0; + bool has_other_float = false; + if (PyFloat_Check(other_obj) || PyCheckInteger(other_obj) || + IsNumpyType(other_obj)) { + if (PyFloat_Check(other_obj)) { + other_float = CastPyArg2AttrFloat(other_obj, 0); + has_other_float = true; + if (_supported_int_dtype_.find(self_tensor.dtype()) != + _supported_int_dtype_.end()) { + eager_gil_scoped_release guard; + self_tensor = cast_ad_func(self_tensor, DataType::FLOAT32); + } + } else if (PyCheckInteger(other_obj) || IsNumpyType(other_obj)) { + other_float = static_cast(CastPyArg2AttrInt(other_obj, 0)); + has_other_float = true; + } + } + + // 2. create or get tensor for other_obj + paddle::experimental::Tensor other_tensor; + if (has_other_float) { + eager_gil_scoped_release guard; + other_tensor = full_ad_func(self_tensor.shape(), + phi::Scalar(other_float), + self_tensor.dtype(), + place); + } else if (!PyCheckTensor(other_obj)) { + paddle::experimental::Scalar value = + CastPyArg2Scalar(other_obj, "__ge__", 0); + if (PyComplex_Check(other_obj)) { + eager_gil_scoped_release guard; + other_tensor = full_ad_func({1}, value, DataType::COMPLEX64, place); + } else { + eager_gil_scoped_release guard; + other_tensor = + full_ad_func(self_tensor.shape(), value, self_tensor.dtype(), place); + } + } else { + other_tensor = CastPyArg2Tensor(other_obj, 0); + } + + // 3. promote types or unify right var type to left var + phi::DataType lhs_dtype = self_tensor.dtype(); + phi::DataType rhs_dtype = other_tensor.dtype(); + if (lhs_dtype != rhs_dtype) { + LOG(WARNING) << "The dtype of left and right Tensor are not the same, left " + "dtype is " + << lhs_dtype << ", but right dtype is " << rhs_dtype + << ", the right dtype will convert to " << lhs_dtype; + eager_gil_scoped_release guard; + other_tensor = cast_ad_func(other_tensor, lhs_dtype); + } + + // 4. calculation + VLOG(6) << "Calling greater_equal_ad_func in tensor__ge__method"; + { + eager_gil_scoped_release guard; + ret = greater_equal_ad_func(self_tensor, other_tensor, -1); + } + + return ToPyObject(ret); + EAGER_CATCH_AND_THROW_RETURN_NULL +} + PyMethodDef math_op_patch_methods[] = { {"__add__", (PyCFunction)(void (*)(void))tensor__add__method, @@ -804,6 +974,14 @@ PyMethodDef math_op_patch_methods[] = { (PyCFunction)(void (*)(void))tensor__rdiv__method, METH_VARARGS | METH_KEYWORDS, NULL}, + {"__gt__", + (PyCFunction)(void (*)(void))tensor__gt__method, + METH_VARARGS | METH_KEYWORDS, + NULL}, + {"__ge__", + (PyCFunction)(void (*)(void))tensor__ge__method, + METH_VARARGS | METH_KEYWORDS, + NULL}, {NULL, NULL, 0, NULL}}; } // namespace pybind diff --git a/python/paddle/fluid/dygraph/math_op_patch.py b/python/paddle/fluid/dygraph/math_op_patch.py index 8e0f3c6436f9e5b61304cab34a14d828e5996c24..c30bb97b9745e3881b7b10777324bb73cc6ebce9 100644 --- a/python/paddle/fluid/dygraph/math_op_patch.py +++ b/python/paddle/fluid/dygraph/math_op_patch.py @@ -402,10 +402,6 @@ def monkey_patch_math_varbase(): ('__ne__', _binary_creator_('__ne__', 'not_equal', False, None, True)), ('__lt__', _binary_creator_('__lt__', 'less_than', False, None, True)), ('__le__', _binary_creator_('__le__', 'less_equal', False, None, True)), - ('__gt__', _binary_creator_('__gt__', 'greater_than', False, None, - True)), - ('__ge__', _binary_creator_('__ge__', 'greater_equal', False, None, - True)), ('__array_ufunc__', None) ] @@ -420,6 +416,8 @@ def monkey_patch_math_varbase(): '__truediv__', '__rdiv__', '__rtruediv__', + '__gt__', + '__ge__', ] global _already_patch_varbase