From 643079037a4cc54f6f2d8e4038ba2e9891ae79c6 Mon Sep 17 00:00:00 2001 From: Weilong Wu Date: Mon, 17 Oct 2022 17:10:37 +0800 Subject: [PATCH] support __floordiv__ (#47060) --- paddle/fluid/pybind/eager_math_op_patch.cc | 94 ++++++++++++++++++++ python/paddle/fluid/dygraph/math_op_patch.py | 3 +- 2 files changed, 95 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/pybind/eager_math_op_patch.cc b/paddle/fluid/pybind/eager_math_op_patch.cc index a5cd16dd927..4cd76f4a104 100644 --- a/paddle/fluid/pybind/eager_math_op_patch.cc +++ b/paddle/fluid/pybind/eager_math_op_patch.cc @@ -1295,6 +1295,96 @@ static PyObject* tensor__le__method(TensorObject* self, EAGER_CATCH_AND_THROW_RETURN_NULL } +static PyObject* tensor__floordiv__method(TensorObject* self, + PyObject* args, + PyObject* kwargs) { + paddle::platform::RecordEvent pythonc_record_event( + "floordiv pybind_patch_func", + paddle::platform::TracerEventType::UserDefined, + 1); + EAGER_TRY + VLOG(6) << "Running Eager tensor__floordiv__method"; + + // Set Device ID + auto place = egr::Controller::Instance().GetExpectedPlace(); + SetDevice(place); + + paddle::experimental::Tensor ret; + paddle::experimental::Tensor self_tensor = self->tensor; + + PyObject* other_obj = PyTuple_GET_ITEM(args, 0); + + // 1. scalar exists cases or not + // there is no scalar case for floordiv, but alse need to cast self_tensor + // in need. + double other_double = 0.0; + bool has_other_double = false; + if (PyFloat_Check(other_obj) || PyCheckInteger(other_obj) || + IsNumpyType(other_obj)) { + if (PyFloat_Check(other_obj)) { + other_double = CastPyArg2Double(other_obj, "__floordiv__", 0); + has_other_double = true; + if (_supported_int_dtype_.find(self_tensor.dtype()) != + _supported_int_dtype_.end()) { + eager_gil_scoped_release guard; + self_tensor = cast_ad_func(self_tensor, DataType::FLOAT32); + } + } else if (PyCheckInteger(other_obj) || IsNumpyType(other_obj)) { + other_double = CastPyArg2Double(other_obj, "__floordiv__", 0); + has_other_double = true; + } + } + + // 2. create or get tensor for other_obj + paddle::experimental::Tensor other_tensor; + if (has_other_double) { + eager_gil_scoped_release guard; + other_tensor = full_ad_func(self_tensor.shape(), + phi::Scalar(other_double), + self_tensor.dtype(), + self_tensor.place()); + } else if (!PyCheckTensor(other_obj)) { + paddle::experimental::Scalar value = + CastPyArg2Scalar(other_obj, "__floordiv__", 0); + if (PyComplex_Check(other_obj)) { + eager_gil_scoped_release guard; + other_tensor = + full_ad_func({1}, value, DataType::COMPLEX64, self_tensor.place()); + } else { + eager_gil_scoped_release guard; + other_tensor = + full_ad_func({1}, value, self_tensor.dtype(), self_tensor.place()); + } + } else { + other_tensor = CastPyArg2Tensor(other_obj, 0); + } + + // 3. promote types or unify right var type to left var + phi::DataType lhs_dtype = self_tensor.dtype(); + phi::DataType rhs_dtype = other_tensor.dtype(); + if (lhs_dtype != rhs_dtype) { + // note: only op_type in _supported_promote_complex_types_ should promote + // dtype, floordiv is not in _supported_promote_complex_types_, will not do + // promote dtype + VLOG(6) << "The dtype of left and right Tensor are not the same, left " + "dtype is " + << lhs_dtype << ", but right dtype is " << rhs_dtype + << ", the right dtype will convert to " << lhs_dtype; + eager_gil_scoped_release guard; + other_tensor = cast_ad_func(other_tensor, lhs_dtype); + } + + // 4. calculation + VLOG(6) << "Calling floor_divide_ad_func in tensor__floordiv__method"; + { + eager_gil_scoped_release guard; + ret = floor_divide_ad_func(self_tensor, other_tensor); + } + + return ToPyObject(ret); + EAGER_CATCH_AND_THROW_RETURN_NULL +} + PyMethodDef math_op_patch_methods[] = { {"__add__", (PyCFunction)(void (*)(void))tensor__add__method, @@ -1336,6 +1426,10 @@ PyMethodDef math_op_patch_methods[] = { (PyCFunction)(void (*)(void))tensor__rdiv__method, METH_VARARGS | METH_KEYWORDS, NULL}, + {"__floordiv__", + (PyCFunction)(void (*)(void))tensor__floordiv__method, + METH_VARARGS | METH_KEYWORDS, + NULL}, {"__mod__", (PyCFunction)(void (*)(void))tensor__mod__method, METH_VARARGS | METH_KEYWORDS, diff --git a/python/paddle/fluid/dygraph/math_op_patch.py b/python/paddle/fluid/dygraph/math_op_patch.py index 72e770e8e50..f754fb93c45 100644 --- a/python/paddle/fluid/dygraph/math_op_patch.py +++ b/python/paddle/fluid/dygraph/math_op_patch.py @@ -392,8 +392,6 @@ def monkey_patch_math_varbase(): True)), ('__rpow__', _binary_creator_('__rpow__', 'elementwise_pow', True, None)), - ('__floordiv__', - _binary_creator_('__floordiv__', 'floor_divide', False, None, True)), # for logical compare ('__eq__', _binary_creator_('__eq__', 'equal', False, None, True)), ('__ne__', _binary_creator_('__ne__', 'not_equal', False, None, True)), @@ -417,6 +415,7 @@ def monkey_patch_math_varbase(): '__ge__', '__lt__', '__le__', + '__floordiv__', ] global _already_patch_varbase -- GitLab