diff --git a/paddle/fluid/pybind/eager_math_op_patch.cc b/paddle/fluid/pybind/eager_math_op_patch.cc index 3e13c71cf22d3d40f945ec5ecffc0247e47157d0..7e4f7b0cfac7e896b8abf4a864b4c92649bccdc3 100644 --- a/paddle/fluid/pybind/eager_math_op_patch.cc +++ b/paddle/fluid/pybind/eager_math_op_patch.cc @@ -933,6 +933,176 @@ static PyObject* tensor__ge__method(TensorObject* self, EAGER_CATCH_AND_THROW_RETURN_NULL } +static PyObject* tensor__lt__method(TensorObject* self, + PyObject* args, + PyObject* kwargs) { + paddle::platform::RecordEvent pythonc_record_event( + "__lt__ pybind_patch_func", + paddle::platform::TracerEventType::UserDefined, + 1); + + EAGER_TRY + VLOG(1) << "Running Eager tensor__lt__method"; + + // Set Device ID + auto place = egr::Controller::Instance().GetExpectedPlace(); + SetDevice(place); + + paddle::experimental::Tensor ret; + paddle::experimental::Tensor self_tensor = self->tensor; + PyObject* other_obj = PyTuple_GET_ITEM(args, 0); + + // 1. scalar exists cases + // there is no scalar function for __lt__ now + float other_float = 0.0; + bool has_other_float = false; + if (PyFloat_Check(other_obj) || PyCheckInteger(other_obj) || + IsNumpyType(other_obj)) { + if (PyFloat_Check(other_obj)) { + other_float = CastPyArg2AttrFloat(other_obj, 0); + has_other_float = true; + if (_supported_int_dtype_.find(self_tensor.dtype()) != + _supported_int_dtype_.end()) { + eager_gil_scoped_release guard; + self_tensor = cast_ad_func(self_tensor, DataType::FLOAT32); + } + } else if (PyCheckInteger(other_obj) || IsNumpyType(other_obj)) { + other_float = static_cast(CastPyArg2AttrInt(other_obj, 0)); + has_other_float = true; + } + } + + // 2. create or get tensor for other_obj + paddle::experimental::Tensor other_tensor; + if (has_other_float) { + eager_gil_scoped_release guard; + other_tensor = full_ad_func(self_tensor.shape(), + phi::Scalar(other_float), + self_tensor.dtype(), + place); + } else if (!PyCheckTensor(other_obj)) { + paddle::experimental::Scalar value = + CastPyArg2Scalar(other_obj, "__lt__", 0); + if (PyComplex_Check(other_obj)) { + eager_gil_scoped_release guard; + other_tensor = full_ad_func({1}, value, DataType::COMPLEX64, place); + } else { + eager_gil_scoped_release guard; + other_tensor = + full_ad_func(self_tensor.shape(), value, self_tensor.dtype(), place); + } + } else { + other_tensor = CastPyArg2Tensor(other_obj, 0); + } + + // 3. promote types or unify right var type to left var + phi::DataType lhs_dtype = self_tensor.dtype(); + phi::DataType rhs_dtype = other_tensor.dtype(); + if (lhs_dtype != rhs_dtype) { + LOG(WARNING) << "The dtype of left and right Tensor are not the same, left " + "dtype is " + << lhs_dtype << ", but right dtype is " << rhs_dtype + << ", the right dtype will convert to " << lhs_dtype; + eager_gil_scoped_release guard; + other_tensor = cast_ad_func(other_tensor, lhs_dtype); + } + + // 4. calculation + VLOG(6) << "Calling less_than_ad_func in tensor__lt__method"; + { + eager_gil_scoped_release guard; + ret = less_than_ad_func(self_tensor, other_tensor, -1); + } + + return ToPyObject(ret); + EAGER_CATCH_AND_THROW_RETURN_NULL +} + +static PyObject* tensor__le__method(TensorObject* self, + PyObject* args, + PyObject* kwargs) { + paddle::platform::RecordEvent pythonc_record_event( + "__le__ pybind_patch_func", + paddle::platform::TracerEventType::UserDefined, + 1); + + EAGER_TRY + VLOG(1) << "Running Eager tensor__le__method"; + + // Set Device ID + auto place = egr::Controller::Instance().GetExpectedPlace(); + SetDevice(place); + + paddle::experimental::Tensor ret; + paddle::experimental::Tensor self_tensor = self->tensor; + PyObject* other_obj = PyTuple_GET_ITEM(args, 0); + + // 1. scalar exists cases + // there is no scalar function for __le__ now + float other_float = 0.0; + bool has_other_float = false; + if (PyFloat_Check(other_obj) || PyCheckInteger(other_obj) || + IsNumpyType(other_obj)) { + if (PyFloat_Check(other_obj)) { + other_float = CastPyArg2AttrFloat(other_obj, 0); + has_other_float = true; + if (_supported_int_dtype_.find(self_tensor.dtype()) != + _supported_int_dtype_.end()) { + eager_gil_scoped_release guard; + self_tensor = cast_ad_func(self_tensor, DataType::FLOAT32); + } + } else if (PyCheckInteger(other_obj) || IsNumpyType(other_obj)) { + other_float = static_cast(CastPyArg2AttrInt(other_obj, 0)); + has_other_float = true; + } + } + + // 2. create or get tensor for other_obj + paddle::experimental::Tensor other_tensor; + if (has_other_float) { + eager_gil_scoped_release guard; + other_tensor = full_ad_func(self_tensor.shape(), + phi::Scalar(other_float), + self_tensor.dtype(), + place); + } else if (!PyCheckTensor(other_obj)) { + paddle::experimental::Scalar value = + CastPyArg2Scalar(other_obj, "__le__", 0); + if (PyComplex_Check(other_obj)) { + eager_gil_scoped_release guard; + other_tensor = full_ad_func({1}, value, DataType::COMPLEX64, place); + } else { + eager_gil_scoped_release guard; + other_tensor = + full_ad_func(self_tensor.shape(), value, self_tensor.dtype(), place); + } + } else { + other_tensor = CastPyArg2Tensor(other_obj, 0); + } + + // 3. promote types or unify right var type to left var + phi::DataType lhs_dtype = self_tensor.dtype(); + phi::DataType rhs_dtype = other_tensor.dtype(); + if (lhs_dtype != rhs_dtype) { + LOG(WARNING) << "The dtype of left and right Tensor are not the same, left " + "dtype is " + << lhs_dtype << ", but right dtype is " << rhs_dtype + << ", the right dtype will convert to " << lhs_dtype; + eager_gil_scoped_release guard; + other_tensor = cast_ad_func(other_tensor, lhs_dtype); + } + + // 4. calculation + VLOG(6) << "Calling less_equal_ad_func in tensor__le__method"; + { + eager_gil_scoped_release guard; + ret = less_equal_ad_func(self_tensor, other_tensor, -1); + } + + return ToPyObject(ret); + EAGER_CATCH_AND_THROW_RETURN_NULL +} + PyMethodDef math_op_patch_methods[] = { {"__add__", (PyCFunction)(void (*)(void))tensor__add__method, @@ -982,6 +1152,14 @@ PyMethodDef math_op_patch_methods[] = { (PyCFunction)(void (*)(void))tensor__ge__method, METH_VARARGS | METH_KEYWORDS, NULL}, + {"__lt__", + (PyCFunction)(void (*)(void))tensor__lt__method, + METH_VARARGS | METH_KEYWORDS, + NULL}, + {"__le__", + (PyCFunction)(void (*)(void))tensor__le__method, + METH_VARARGS | METH_KEYWORDS, + NULL}, {NULL, NULL, 0, NULL}}; } // namespace pybind diff --git a/python/paddle/fluid/dygraph/math_op_patch.py b/python/paddle/fluid/dygraph/math_op_patch.py index c30bb97b9745e3881b7b10777324bb73cc6ebce9..eca775df5307b667f2643f5bfea696a7a01e17c5 100644 --- a/python/paddle/fluid/dygraph/math_op_patch.py +++ b/python/paddle/fluid/dygraph/math_op_patch.py @@ -400,8 +400,6 @@ def monkey_patch_math_varbase(): # for logical compare ('__eq__', _binary_creator_('__eq__', 'equal', False, None, True)), ('__ne__', _binary_creator_('__ne__', 'not_equal', False, None, True)), - ('__lt__', _binary_creator_('__lt__', 'less_than', False, None, True)), - ('__le__', _binary_creator_('__le__', 'less_equal', False, None, True)), ('__array_ufunc__', None) ] @@ -418,6 +416,8 @@ def monkey_patch_math_varbase(): '__rtruediv__', '__gt__', '__ge__', + '__lt__', + '__le__', ] global _already_patch_varbase