From 7fb20b46dc7d67b36b24787dbd22f902f4869d82 Mon Sep 17 00:00:00 2001 From: Weilong Wu Date: Mon, 26 Sep 2022 19:44:55 +0800 Subject: [PATCH] [Eager, Performance optimization] support multiply( * operator) to sink to Cpp layer (#46326) * [Eager] math op sink to Cpp level * fix ci errors * draft version * support + and - operator under cpp directly * add static test * polish code * promote types or unify right type to left * recover static test case * polish code and fix some ci errors * support complex and polish code * fix conflicts * fix windows ci errors * fix windows-inference-ci errors * polish and fix tests * fix test case * polish code * [Eager, Performance optimization] support multiply( * operator) to sink to Cpp layer * rm useless glog * polish code * polish code and fix code-format * polish code * fix CI * polish code --- paddle/fluid/pybind/eager_math_op_patch.cc | 115 +++++++++++++++++- paddle/fluid/pybind/eager_utils.cc | 3 + paddle/fluid/pybind/op_function_common.cc | 20 +++ paddle/fluid/pybind/op_function_common.h | 3 + python/paddle/fluid/dygraph/math_op_patch.py | 6 +- .../unittests/test_tensor_type_promotion.py | 3 +- 6 files changed, 142 insertions(+), 8 deletions(-) diff --git a/paddle/fluid/pybind/eager_math_op_patch.cc b/paddle/fluid/pybind/eager_math_op_patch.cc index 9dc54092e8..5a45c14f78 100644 --- a/paddle/fluid/pybind/eager_math_op_patch.cc +++ b/paddle/fluid/pybind/eager_math_op_patch.cc @@ -103,7 +103,7 @@ void SetDevice(paddle::platform::Place place) { if (paddle::platform::is_gpu_place(place)) { #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) phi::backends::gpu::SetDeviceId(place.device); - VLOG(1) << "CurrentDeviceId: " << phi::backends::gpu::GetCurrentDeviceId() + VLOG(6) << "CurrentDeviceId: " << phi::backends::gpu::GetCurrentDeviceId() << " from " << static_cast(place.device); #else PADDLE_THROW(paddle::platform::errors::PreconditionNotMet( @@ -114,7 +114,7 @@ void SetDevice(paddle::platform::Place place) { if (paddle::platform::is_custom_place(place)) { #if defined(PADDLE_WITH_CUSTOM_DEVICE) phi::DeviceManager::SetDevice(place); - VLOG(1) << "CurrentDeviceId: " + VLOG(6) << "CurrentDeviceId: " << phi::DeviceManager::GetDevice(place.GetDeviceType()) << " from " << static_cast(place.device); #else @@ -139,6 +139,8 @@ paddle::experimental::Tensor CallScalarFuction( } else if (op_type == "rsub") { ret = scale_ad_func(self_tensor, phi::Scalar(-1.0), other, true); + } else if (op_type == "mul") { + ret = scale_ad_func(self_tensor, phi::Scalar(other), 0.0, true); } return ret; @@ -431,6 +433,107 @@ static PyObject* tensor__rsub__method(TensorObject* self, EAGER_CATCH_AND_THROW_RETURN_NULL } +static PyObject* tensor__mul__method(TensorObject* self, + PyObject* args, + PyObject* kwargs) { + paddle::platform::RecordEvent pythonc_record_event( + "__mul__ pybind_patch_func", + paddle::platform::TracerEventType::UserDefined, + 1); + + EAGER_TRY + VLOG(6) << "Running Eager tensor__mul__method"; + + // Set Device ID + auto place = egr::Controller::Instance().GetExpectedPlace(); + SetDevice(place); + + paddle::experimental::Tensor ret; + paddle::experimental::Tensor self_tensor = self->tensor; + + PyObject* other_obj = PyTuple_GET_ITEM(args, 0); + + // 1. scalar exists cases + if ((PyFloat_Check(other_obj) || PyLong_Check(other_obj)) && + !PyBool_Check(other_obj)) { + float other = 0.0; + if (PyFloat_Check(other_obj)) { + other = CastPyArg2AttrFloat(other_obj, 0); + if (_supported_int_dtype_.find(self_tensor.dtype()) != + _supported_int_dtype_.end()) { + eager_gil_scoped_release guard; + self_tensor = cast_ad_func(self_tensor, DataType::FLOAT32); + } + } else if (PyCheckInteger(other_obj) || IsNumpyType(other_obj)) { + other = static_cast(CastPyArg2AttrInt(other_obj, 0)); + } + { + eager_gil_scoped_release guard; + ret = CallScalarFuction(self_tensor, other, "mul"); + } + return ToPyObject(ret); + } + + // 2. create or get tensor for other_obj + paddle::experimental::Tensor other_tensor; + if (!PyCheckTensor(other_obj)) { + paddle::experimental::Scalar value = + CastPyArg2Scalar(other_obj, "__mul__", 0); + if (PyComplex_Check(other_obj)) { + eager_gil_scoped_release guard; + other_tensor = full_ad_func({1}, value, DataType::COMPLEX64, place); + } else { + eager_gil_scoped_release guard; + other_tensor = + full_ad_func(self_tensor.shape(), value, self_tensor.dtype(), place); + } + } else { + other_tensor = CastPyArg2Tensor(other_obj, 0); + } + + // 3. promote types or unify right var type to left var + phi::DataType lhs_dtype = self_tensor.dtype(); + phi::DataType rhs_dtype = other_tensor.dtype(); + if (lhs_dtype != rhs_dtype) { + // note: only op_type in _supported_promote_complex_types_ should promote + // dtype + if (_complex_dtypes.find(lhs_dtype) != _complex_dtypes.end() || + _complex_dtypes.find(rhs_dtype) != _complex_dtypes.end()) { + phi::DataType promote_dtype = + framework::TransToPhiDataType(framework::PromoteTypesIfComplexExists( + framework::TransToProtoVarType(lhs_dtype), + framework::TransToProtoVarType(rhs_dtype))); + if (lhs_dtype != promote_dtype) { + // cast + eager_gil_scoped_release guard; + self_tensor = cast_ad_func(self_tensor, promote_dtype); + } + if (rhs_dtype != promote_dtype) { + eager_gil_scoped_release guard; + other_tensor = cast_ad_func(other_tensor, promote_dtype); + } + } else { + LOG(WARNING) + << "The dtype of left and right Tensor are not the same, left " + "dtype is " + << lhs_dtype << ", but right dtype is " << rhs_dtype + << ", the right dtype will convert to " << lhs_dtype; + eager_gil_scoped_release guard; + other_tensor = cast_ad_func(other_tensor, lhs_dtype); + } + } + + // 4. calculation + VLOG(6) << "Calling multiply_ad_func in tensor__mul__method"; + { + eager_gil_scoped_release guard; + ret = multiply_ad_func(self_tensor, other_tensor); + } + + return ToPyObject(ret); + EAGER_CATCH_AND_THROW_RETURN_NULL +} + PyMethodDef math_op_patch_methods[] = { {"__add__", (PyCFunction)(void (*)(void))tensor__add__method, @@ -448,6 +551,14 @@ PyMethodDef math_op_patch_methods[] = { (PyCFunction)(void (*)(void))tensor__rsub__method, METH_VARARGS | METH_KEYWORDS, NULL}, + {"__mul__", + (PyCFunction)(void (*)(void))tensor__mul__method, + METH_VARARGS | METH_KEYWORDS, + NULL}, + {"__rmul__", + (PyCFunction)(void (*)(void))tensor__mul__method, + METH_VARARGS | METH_KEYWORDS, + NULL}, {NULL, NULL, 0, NULL}}; } // namespace pybind diff --git a/paddle/fluid/pybind/eager_utils.cc b/paddle/fluid/pybind/eager_utils.cc index df09dd7ec0..6582bffcb8 100644 --- a/paddle/fluid/pybind/eager_utils.cc +++ b/paddle/fluid/pybind/eager_utils.cc @@ -1300,6 +1300,9 @@ paddle::experimental::Scalar CastPyArg2Scalar(PyObject* obj, return paddle::experimental::Scalar(value); } else if (type_name.find("numpy") != std::string::npos) { return CastNumpy2Scalar(obj, op_type, arg_pos); + } else if (PyComplex_Check(obj)) { + auto value = CastPyArg2Complex(obj, op_type, arg_pos); + return paddle::experimental::Scalar(value); } else if (PyObject_CheckLongOrToLong(&obj)) { int value = CastPyArg2Int(obj, op_type, arg_pos); return paddle::experimental::Scalar(value); diff --git a/paddle/fluid/pybind/op_function_common.cc b/paddle/fluid/pybind/op_function_common.cc index 5680a84ca4..6a6b8841d3 100644 --- a/paddle/fluid/pybind/op_function_common.cc +++ b/paddle/fluid/pybind/op_function_common.cc @@ -31,6 +31,7 @@ #include "paddle/fluid/imperative/type_defs.h" #include "paddle/fluid/operators/ops_extra_info.h" #include "paddle/fluid/pybind/imperative.h" +#include "paddle/phi/common/complex.h" namespace py = pybind11; namespace paddle { @@ -214,6 +215,25 @@ double CastPyArg2Double(PyObject* obj, return 0.0; } +phi::dtype::complex CastPyArg2Complex(PyObject* obj, + const std::string& op_type, + ssize_t arg_pos) { + if (PyComplex_Check(obj)) { + double real = PyComplex_RealAsDouble(obj); + double imag = PyComplex_ImagAsDouble(obj); + return phi::dtype::complex(real, imag); + } else { + PADDLE_THROW(platform::errors::InvalidArgument( + "%s(): argument (position %d) must be " + "complex, but got %s", + op_type, + arg_pos + 1, + ((PyTypeObject*)obj->ob_type)->tp_name)); // NOLINT + } + + return phi::dtype::complex(0, 0); +} + void CastPyArg2AttrDouble(PyObject* obj, paddle::framework::AttributeMap& attrs, // NOLINT const std::string& key, diff --git a/paddle/fluid/pybind/op_function_common.h b/paddle/fluid/pybind/op_function_common.h index efa16494e7..686694631c 100644 --- a/paddle/fluid/pybind/op_function_common.h +++ b/paddle/fluid/pybind/op_function_common.h @@ -61,6 +61,9 @@ float CastPyArg2Float(PyObject* obj, double CastPyArg2Double(PyObject* obj, const std::string& op_type, ssize_t arg_pos); +phi::dtype::complex CastPyArg2Complex(PyObject* obj, + const std::string& op_type, + ssize_t arg_pos); std::string CastPyArg2String(PyObject* obj, const std::string& op_type, ssize_t arg_pos); diff --git a/python/paddle/fluid/dygraph/math_op_patch.py b/python/paddle/fluid/dygraph/math_op_patch.py index 5284c7763a..89974ed6a5 100644 --- a/python/paddle/fluid/dygraph/math_op_patch.py +++ b/python/paddle/fluid/dygraph/math_op_patch.py @@ -389,10 +389,6 @@ def monkey_patch_math_varbase(): ('ndim', _ndim_), ('size', _size_), ('T', _T_), - ('__mul__', - _binary_creator_('__mul__', 'multiply', False, _scalar_mul_, True)), - ('__rmul__', - _binary_creator_('__rmul__', 'multiply', False, _scalar_mul_, True)), ('__div__', _binary_creator_('__div__', 'divide', False, _scalar_div_, True)), ('__truediv__', @@ -427,6 +423,8 @@ def monkey_patch_math_varbase(): "__radd__", '__sub__', '__rsub__', + '__mul__', + '__rmul__', ] global _already_patch_varbase diff --git a/python/paddle/fluid/tests/unittests/test_tensor_type_promotion.py b/python/paddle/fluid/tests/unittests/test_tensor_type_promotion.py index dc3485b932..2da5530d52 100644 --- a/python/paddle/fluid/tests/unittests/test_tensor_type_promotion.py +++ b/python/paddle/fluid/tests/unittests/test_tensor_type_promotion.py @@ -62,8 +62,7 @@ class TestTensorTypePromotion(unittest.TestCase): def test_operator(self): with _test_eager_guard(): self.setUp() - # add and sub has been sunk to cpp level, there is no warnings to catch by this test. - self.mul_operator() + # add / sub / mul has been sunk to cpp level, there is no warnings to catch by this test. self.div_operator() self.setUp() self.add_operator() -- GitLab