diff --git a/paddle/fluid/pybind/eager_math_op_patch.cc b/paddle/fluid/pybind/eager_math_op_patch.cc index 5a45c14f78d07d11a977b516495f40b47c33cbed..537f9469c1ccf59fc9385f09e66c86c4130f9066 100644 --- a/paddle/fluid/pybind/eager_math_op_patch.cc +++ b/paddle/fluid/pybind/eager_math_op_patch.cc @@ -141,6 +141,8 @@ paddle::experimental::Tensor CallScalarFuction( ret = scale_ad_func(self_tensor, phi::Scalar(-1.0), other, true); } else if (op_type == "mul") { ret = scale_ad_func(self_tensor, phi::Scalar(other), 0.0, true); + } else if (op_type == "div") { + ret = scale_ad_func(self_tensor, phi::Scalar(1.0 / other), 0.0, true); } return ret; @@ -454,8 +456,8 @@ static PyObject* tensor__mul__method(TensorObject* self, PyObject* other_obj = PyTuple_GET_ITEM(args, 0); // 1. scalar exists cases - if ((PyFloat_Check(other_obj) || PyLong_Check(other_obj)) && - !PyBool_Check(other_obj)) { + if (PyFloat_Check(other_obj) || PyCheckInteger(other_obj) || + IsNumpyType(other_obj)) { float other = 0.0; if (PyFloat_Check(other_obj)) { other = CastPyArg2AttrFloat(other_obj, 0); @@ -534,6 +536,233 @@ static PyObject* tensor__mul__method(TensorObject* self, EAGER_CATCH_AND_THROW_RETURN_NULL } +static PyObject* tensor__div__method(TensorObject* self, + PyObject* args, + PyObject* kwargs) { + paddle::platform::RecordEvent pythonc_record_event( + "__div__ pybind_patch_func", + paddle::platform::TracerEventType::UserDefined, + 1); + + EAGER_TRY + + VLOG(6) << "Running Eager tensor__div__method"; + + // Set Device ID + auto place = egr::Controller::Instance().GetExpectedPlace(); + SetDevice(place); + + paddle::experimental::Tensor ret; + paddle::experimental::Tensor self_tensor = self->tensor; + + PyObject* other_obj = PyTuple_GET_ITEM(args, 0); + + // 1. scalar exists cases + if (PyFloat_Check(other_obj) || PyCheckInteger(other_obj) || + IsNumpyType(other_obj)) { + float other = 0.0; + if (PyFloat_Check(other_obj)) { + other = CastPyArg2AttrFloat(other_obj, 0); + } else if (PyCheckInteger(other_obj) || IsNumpyType(other_obj)) { + other = static_cast(CastPyArg2AttrInt(other_obj, 0)); + } + if (_supported_int_dtype_.find(self_tensor.dtype()) != + _supported_int_dtype_.end()) { + eager_gil_scoped_release guard; + self_tensor = cast_ad_func(self_tensor, DataType::FLOAT32); + } + { + eager_gil_scoped_release guard; + ret = CallScalarFuction(self_tensor, other, "div"); + } + return ToPyObject(ret); + } + + // 2. create or get tensor for other_obj + paddle::experimental::Tensor other_tensor; + if (!PyCheckTensor(other_obj)) { + paddle::experimental::Scalar value = + CastPyArg2Scalar(other_obj, "__div__", 0); + if (PyComplex_Check(other_obj)) { + eager_gil_scoped_release guard; + other_tensor = full_ad_func({1}, value, DataType::COMPLEX64, place); + } else { + eager_gil_scoped_release guard; + other_tensor = + full_ad_func(self_tensor.shape(), value, self_tensor.dtype(), place); + } + } else { + other_tensor = CastPyArg2Tensor(other_obj, 0); + } + + // 3. promote types or unify right var type to left var + phi::DataType lhs_dtype = self_tensor.dtype(); + phi::DataType rhs_dtype = other_tensor.dtype(); + if (lhs_dtype != rhs_dtype) { + // note: only op_type in _supported_promote_complex_types_ should promote + // dtype + if (_complex_dtypes.find(lhs_dtype) != _complex_dtypes.end() || + _complex_dtypes.find(rhs_dtype) != _complex_dtypes.end()) { + phi::DataType promote_dtype = + framework::TransToPhiDataType(framework::PromoteTypesIfComplexExists( + framework::TransToProtoVarType(lhs_dtype), + framework::TransToProtoVarType(rhs_dtype))); + if (lhs_dtype != promote_dtype) { + // cast + eager_gil_scoped_release guard; + self_tensor = cast_ad_func(self_tensor, promote_dtype); + } + if (rhs_dtype != promote_dtype) { + eager_gil_scoped_release guard; + other_tensor = cast_ad_func(other_tensor, promote_dtype); + } + } else { + LOG(WARNING) + << "The dtype of left and right Tensor are not the same, left " + "dtype is " + << lhs_dtype << ", but right dtype is " << rhs_dtype + << ", the right dtype will convert to " << lhs_dtype; + eager_gil_scoped_release guard; + other_tensor = cast_ad_func(other_tensor, lhs_dtype); + } + } + if (_supported_int_dtype_.find(self_tensor.dtype()) != + _supported_int_dtype_.end()) { + eager_gil_scoped_release guard; + self_tensor = cast_ad_func(self_tensor, DataType::FLOAT32); + } + if (_supported_int_dtype_.find(other_tensor.dtype()) != + _supported_int_dtype_.end()) { + eager_gil_scoped_release guard; + other_tensor = cast_ad_func(other_tensor, DataType::FLOAT32); + } + + // 4. calculation + VLOG(6) << "Calling divide_ad_func in tensor__div__method"; + { + eager_gil_scoped_release guard; + ret = divide_ad_func(self_tensor, other_tensor); + } + + return ToPyObject(ret); + EAGER_CATCH_AND_THROW_RETURN_NULL +} + +static PyObject* tensor__rdiv__method(TensorObject* self, + PyObject* args, + PyObject* kwargs) { + paddle::platform::RecordEvent pythonc_record_event( + "__rdiv__ pybind_patch_func", + paddle::platform::TracerEventType::UserDefined, + 1); + EAGER_TRY + + VLOG(6) << "Running Eager tensor__rdiv__method"; + + // Set Device ID + auto place = egr::Controller::Instance().GetExpectedPlace(); + SetDevice(place); + + paddle::experimental::Tensor ret; + paddle::experimental::Tensor self_tensor = self->tensor; + + PyObject* other_obj = PyTuple_GET_ITEM(args, 0); + + // 1. scalar exists cases + // there is no scalar_div function for __rdiv__ and __rtruediv__ + float other_float = 0.0; + bool has_other_float = false; + if (PyFloat_Check(other_obj) || PyCheckInteger(other_obj) || + IsNumpyType(other_obj)) { + if (PyFloat_Check(other_obj)) { + other_float = CastPyArg2AttrFloat(other_obj, 0); + has_other_float = true; + } else if (PyCheckInteger(other_obj) || IsNumpyType(other_obj)) { + other_float = static_cast(CastPyArg2AttrInt(other_obj, 0)); + has_other_float = true; + } + if (_supported_int_dtype_.find(self_tensor.dtype()) != + _supported_int_dtype_.end()) { + eager_gil_scoped_release guard; + self_tensor = cast_ad_func(self_tensor, DataType::FLOAT32); + } + } + + // 2. create or get tensor for other_obj + paddle::experimental::Tensor other_tensor; + if (has_other_float) { + eager_gil_scoped_release guard; + other_tensor = full_ad_func(self_tensor.shape(), + phi::Scalar(other_float), + self_tensor.dtype(), + place); + } else if (!PyCheckTensor(other_obj)) { + paddle::experimental::Scalar value = + CastPyArg2Scalar(other_obj, "__rdiv__", 0); + if (PyComplex_Check(other_obj)) { + eager_gil_scoped_release guard; + other_tensor = full_ad_func({1}, value, DataType::COMPLEX64, place); + } else { + eager_gil_scoped_release guard; + other_tensor = + full_ad_func(self_tensor.shape(), value, self_tensor.dtype(), place); + } + } else { + other_tensor = CastPyArg2Tensor(other_obj, 0); + } + + // 3. promote types or unify right var type to left var + phi::DataType lhs_dtype = self_tensor.dtype(); + phi::DataType rhs_dtype = other_tensor.dtype(); + if (lhs_dtype != rhs_dtype) { + // note: only op_type in _supported_promote_complex_types_ should promote + // dtype + if (_complex_dtypes.find(lhs_dtype) != _complex_dtypes.end() || + _complex_dtypes.find(rhs_dtype) != _complex_dtypes.end()) { + phi::DataType promote_dtype = + framework::TransToPhiDataType(framework::PromoteTypesIfComplexExists( + framework::TransToProtoVarType(lhs_dtype), + framework::TransToProtoVarType(rhs_dtype))); + if (lhs_dtype != promote_dtype) { + // cast + eager_gil_scoped_release guard; + self_tensor = cast_ad_func(self_tensor, promote_dtype); + } + if (rhs_dtype != promote_dtype) { + eager_gil_scoped_release guard; + other_tensor = cast_ad_func(other_tensor, promote_dtype); + } + } else { + LOG(WARNING) + << "The dtype of left and right Tensor are not the same, left " + "dtype is " + << lhs_dtype << ", but right dtype is " << rhs_dtype + << ", the right dtype will convert to " << lhs_dtype; + eager_gil_scoped_release guard; + other_tensor = cast_ad_func(other_tensor, lhs_dtype); + } + } + if (_supported_int_dtype_.find(self_tensor.dtype()) != + _supported_int_dtype_.end()) { + eager_gil_scoped_release guard; + self_tensor = cast_ad_func(self_tensor, DataType::FLOAT32); + } + if (_supported_int_dtype_.find(other_tensor.dtype()) != + _supported_int_dtype_.end()) { + eager_gil_scoped_release guard; + other_tensor = cast_ad_func(other_tensor, DataType::FLOAT32); + } + + // 4. calculation + VLOG(6) << "Calling divide_ad_func in tensor__rdiv__method"; + { + eager_gil_scoped_release guard; + ret = divide_ad_func(other_tensor, self_tensor); + } + return ToPyObject(ret); + EAGER_CATCH_AND_THROW_RETURN_NULL +} + PyMethodDef math_op_patch_methods[] = { {"__add__", (PyCFunction)(void (*)(void))tensor__add__method, @@ -559,6 +788,22 @@ PyMethodDef math_op_patch_methods[] = { (PyCFunction)(void (*)(void))tensor__mul__method, METH_VARARGS | METH_KEYWORDS, NULL}, + {"__div__", + (PyCFunction)(void (*)(void))tensor__div__method, + METH_VARARGS | METH_KEYWORDS, + NULL}, + {"__truediv__", + (PyCFunction)(void (*)(void))tensor__div__method, + METH_VARARGS | METH_KEYWORDS, + NULL}, + {"__rdiv__", + (PyCFunction)(void (*)(void))tensor__rdiv__method, + METH_VARARGS | METH_KEYWORDS, + NULL}, + {"__rtruediv__", + (PyCFunction)(void (*)(void))tensor__rdiv__method, + METH_VARARGS | METH_KEYWORDS, + NULL}, {NULL, NULL, 0, NULL}}; } // namespace pybind diff --git a/python/paddle/fluid/dygraph/math_op_patch.py b/python/paddle/fluid/dygraph/math_op_patch.py index 6d445d36f10d9f3b511c3026a98d41377d26e912..8e0f3c6436f9e5b61304cab34a14d828e5996c24 100644 --- a/python/paddle/fluid/dygraph/math_op_patch.py +++ b/python/paddle/fluid/dygraph/math_op_patch.py @@ -387,13 +387,6 @@ def monkey_patch_math_varbase(): ('ndim', _ndim_), ('size', _size_), ('T', _T_), - ('__div__', - _binary_creator_('__div__', 'divide', False, _scalar_div_, True)), - ('__truediv__', - _binary_creator_('__truediv__', 'divide', False, _scalar_div_, True)), - ('__rdiv__', _binary_creator_('__rdiv__', 'divide', True, None, True)), - ('__rtruediv__', - _binary_creator_('rtruediv__', 'divide', True, None, True)), ('__pow__', _binary_creator_('__pow__', 'pow', False, _C_ops.pow, True)), ('__rpow__', _binary_creator_('__rpow__', 'elementwise_pow', True, @@ -423,6 +416,10 @@ def monkey_patch_math_varbase(): '__rsub__', '__mul__', '__rmul__', + '__div__', + '__truediv__', + '__rdiv__', + '__rtruediv__', ] global _already_patch_varbase diff --git a/python/paddle/fluid/tests/unittests/test_tensor_type_promotion.py b/python/paddle/fluid/tests/unittests/test_tensor_type_promotion.py index 2c373bb0e57ef7cccafcdcd6f9b2110f073b6f98..9137ba7c6d9a3b544cc6f542b61881992dbbfa37 100644 --- a/python/paddle/fluid/tests/unittests/test_tensor_type_promotion.py +++ b/python/paddle/fluid/tests/unittests/test_tensor_type_promotion.py @@ -59,9 +59,8 @@ class TestTensorTypePromotion(unittest.TestCase): def test_operator(self): with _test_eager_guard(): - self.setUp() - # add / sub / mul has been sunk to cpp level, there is no warnings to catch by this test. - self.div_operator() + pass + # add / sub / mul / div has been sunk to cpp level, there is no warnings to catch by this test. self.setUp() self.add_operator() self.sub_operator()