未验证 提交 f20b361c 编写于 作者: W Weilong Wu 提交者: GitHub

[Eager, Performance optimization] support divide( / operator) to sink to Cpp layer (#46329)

* [Eager] math op sink to Cpp level

* fix ci errors

* draft version

* support + and - operator under cpp directly

* add static test

* polish code

* promote types or unify right type to left

* recover static test case

* polish code and fix some ci errors

* support complex and polish code

* fix conflicts

* fix windows ci errors

* fix windows-inference-ci errors

* polish and fix tests

* fix test case

* polish code

* [Eager, Performance optimization] support multiply( * operator) to sink to Cpp layer

* rm useless glog

* [Eager, Performance optimization] support divide( / and // operator) to sink to Cpp layer

* polish code

* polish code and fix code-format

* polish code

* fix CI

* polish code

* update test

* support div operator under cpp

* fix scalar as input

* Polish div logic, fix ci test

* fix errors
上级 35bff2a5
...@@ -141,6 +141,8 @@ paddle::experimental::Tensor CallScalarFuction( ...@@ -141,6 +141,8 @@ paddle::experimental::Tensor CallScalarFuction(
ret = scale_ad_func(self_tensor, phi::Scalar(-1.0), other, true); ret = scale_ad_func(self_tensor, phi::Scalar(-1.0), other, true);
} else if (op_type == "mul") { } else if (op_type == "mul") {
ret = scale_ad_func(self_tensor, phi::Scalar(other), 0.0, true); ret = scale_ad_func(self_tensor, phi::Scalar(other), 0.0, true);
} else if (op_type == "div") {
ret = scale_ad_func(self_tensor, phi::Scalar(1.0 / other), 0.0, true);
} }
return ret; return ret;
...@@ -454,8 +456,8 @@ static PyObject* tensor__mul__method(TensorObject* self, ...@@ -454,8 +456,8 @@ static PyObject* tensor__mul__method(TensorObject* self,
PyObject* other_obj = PyTuple_GET_ITEM(args, 0); PyObject* other_obj = PyTuple_GET_ITEM(args, 0);
// 1. scalar exists cases // 1. scalar exists cases
if ((PyFloat_Check(other_obj) || PyLong_Check(other_obj)) && if (PyFloat_Check(other_obj) || PyCheckInteger(other_obj) ||
!PyBool_Check(other_obj)) { IsNumpyType(other_obj)) {
float other = 0.0; float other = 0.0;
if (PyFloat_Check(other_obj)) { if (PyFloat_Check(other_obj)) {
other = CastPyArg2AttrFloat(other_obj, 0); other = CastPyArg2AttrFloat(other_obj, 0);
...@@ -534,6 +536,233 @@ static PyObject* tensor__mul__method(TensorObject* self, ...@@ -534,6 +536,233 @@ static PyObject* tensor__mul__method(TensorObject* self,
EAGER_CATCH_AND_THROW_RETURN_NULL EAGER_CATCH_AND_THROW_RETURN_NULL
} }
static PyObject* tensor__div__method(TensorObject* self,
PyObject* args,
PyObject* kwargs) {
paddle::platform::RecordEvent pythonc_record_event(
"__div__ pybind_patch_func",
paddle::platform::TracerEventType::UserDefined,
1);
EAGER_TRY
VLOG(6) << "Running Eager tensor__div__method";
// Set Device ID
auto place = egr::Controller::Instance().GetExpectedPlace();
SetDevice(place);
paddle::experimental::Tensor ret;
paddle::experimental::Tensor self_tensor = self->tensor;
PyObject* other_obj = PyTuple_GET_ITEM(args, 0);
// 1. scalar exists cases
if (PyFloat_Check(other_obj) || PyCheckInteger(other_obj) ||
IsNumpyType(other_obj)) {
float other = 0.0;
if (PyFloat_Check(other_obj)) {
other = CastPyArg2AttrFloat(other_obj, 0);
} else if (PyCheckInteger(other_obj) || IsNumpyType(other_obj)) {
other = static_cast<float>(CastPyArg2AttrInt(other_obj, 0));
}
if (_supported_int_dtype_.find(self_tensor.dtype()) !=
_supported_int_dtype_.end()) {
eager_gil_scoped_release guard;
self_tensor = cast_ad_func(self_tensor, DataType::FLOAT32);
}
{
eager_gil_scoped_release guard;
ret = CallScalarFuction(self_tensor, other, "div");
}
return ToPyObject(ret);
}
// 2. create or get tensor for other_obj
paddle::experimental::Tensor other_tensor;
if (!PyCheckTensor(other_obj)) {
paddle::experimental::Scalar value =
CastPyArg2Scalar(other_obj, "__div__", 0);
if (PyComplex_Check(other_obj)) {
eager_gil_scoped_release guard;
other_tensor = full_ad_func({1}, value, DataType::COMPLEX64, place);
} else {
eager_gil_scoped_release guard;
other_tensor =
full_ad_func(self_tensor.shape(), value, self_tensor.dtype(), place);
}
} else {
other_tensor = CastPyArg2Tensor(other_obj, 0);
}
// 3. promote types or unify right var type to left var
phi::DataType lhs_dtype = self_tensor.dtype();
phi::DataType rhs_dtype = other_tensor.dtype();
if (lhs_dtype != rhs_dtype) {
// note: only op_type in _supported_promote_complex_types_ should promote
// dtype
if (_complex_dtypes.find(lhs_dtype) != _complex_dtypes.end() ||
_complex_dtypes.find(rhs_dtype) != _complex_dtypes.end()) {
phi::DataType promote_dtype =
framework::TransToPhiDataType(framework::PromoteTypesIfComplexExists(
framework::TransToProtoVarType(lhs_dtype),
framework::TransToProtoVarType(rhs_dtype)));
if (lhs_dtype != promote_dtype) {
// cast
eager_gil_scoped_release guard;
self_tensor = cast_ad_func(self_tensor, promote_dtype);
}
if (rhs_dtype != promote_dtype) {
eager_gil_scoped_release guard;
other_tensor = cast_ad_func(other_tensor, promote_dtype);
}
} else {
LOG(WARNING)
<< "The dtype of left and right Tensor are not the same, left "
"dtype is "
<< lhs_dtype << ", but right dtype is " << rhs_dtype
<< ", the right dtype will convert to " << lhs_dtype;
eager_gil_scoped_release guard;
other_tensor = cast_ad_func(other_tensor, lhs_dtype);
}
}
if (_supported_int_dtype_.find(self_tensor.dtype()) !=
_supported_int_dtype_.end()) {
eager_gil_scoped_release guard;
self_tensor = cast_ad_func(self_tensor, DataType::FLOAT32);
}
if (_supported_int_dtype_.find(other_tensor.dtype()) !=
_supported_int_dtype_.end()) {
eager_gil_scoped_release guard;
other_tensor = cast_ad_func(other_tensor, DataType::FLOAT32);
}
// 4. calculation
VLOG(6) << "Calling divide_ad_func in tensor__div__method";
{
eager_gil_scoped_release guard;
ret = divide_ad_func(self_tensor, other_tensor);
}
return ToPyObject(ret);
EAGER_CATCH_AND_THROW_RETURN_NULL
}
static PyObject* tensor__rdiv__method(TensorObject* self,
PyObject* args,
PyObject* kwargs) {
paddle::platform::RecordEvent pythonc_record_event(
"__rdiv__ pybind_patch_func",
paddle::platform::TracerEventType::UserDefined,
1);
EAGER_TRY
VLOG(6) << "Running Eager tensor__rdiv__method";
// Set Device ID
auto place = egr::Controller::Instance().GetExpectedPlace();
SetDevice(place);
paddle::experimental::Tensor ret;
paddle::experimental::Tensor self_tensor = self->tensor;
PyObject* other_obj = PyTuple_GET_ITEM(args, 0);
// 1. scalar exists cases
// there is no scalar_div function for __rdiv__ and __rtruediv__
float other_float = 0.0;
bool has_other_float = false;
if (PyFloat_Check(other_obj) || PyCheckInteger(other_obj) ||
IsNumpyType(other_obj)) {
if (PyFloat_Check(other_obj)) {
other_float = CastPyArg2AttrFloat(other_obj, 0);
has_other_float = true;
} else if (PyCheckInteger(other_obj) || IsNumpyType(other_obj)) {
other_float = static_cast<float>(CastPyArg2AttrInt(other_obj, 0));
has_other_float = true;
}
if (_supported_int_dtype_.find(self_tensor.dtype()) !=
_supported_int_dtype_.end()) {
eager_gil_scoped_release guard;
self_tensor = cast_ad_func(self_tensor, DataType::FLOAT32);
}
}
// 2. create or get tensor for other_obj
paddle::experimental::Tensor other_tensor;
if (has_other_float) {
eager_gil_scoped_release guard;
other_tensor = full_ad_func(self_tensor.shape(),
phi::Scalar(other_float),
self_tensor.dtype(),
place);
} else if (!PyCheckTensor(other_obj)) {
paddle::experimental::Scalar value =
CastPyArg2Scalar(other_obj, "__rdiv__", 0);
if (PyComplex_Check(other_obj)) {
eager_gil_scoped_release guard;
other_tensor = full_ad_func({1}, value, DataType::COMPLEX64, place);
} else {
eager_gil_scoped_release guard;
other_tensor =
full_ad_func(self_tensor.shape(), value, self_tensor.dtype(), place);
}
} else {
other_tensor = CastPyArg2Tensor(other_obj, 0);
}
// 3. promote types or unify right var type to left var
phi::DataType lhs_dtype = self_tensor.dtype();
phi::DataType rhs_dtype = other_tensor.dtype();
if (lhs_dtype != rhs_dtype) {
// note: only op_type in _supported_promote_complex_types_ should promote
// dtype
if (_complex_dtypes.find(lhs_dtype) != _complex_dtypes.end() ||
_complex_dtypes.find(rhs_dtype) != _complex_dtypes.end()) {
phi::DataType promote_dtype =
framework::TransToPhiDataType(framework::PromoteTypesIfComplexExists(
framework::TransToProtoVarType(lhs_dtype),
framework::TransToProtoVarType(rhs_dtype)));
if (lhs_dtype != promote_dtype) {
// cast
eager_gil_scoped_release guard;
self_tensor = cast_ad_func(self_tensor, promote_dtype);
}
if (rhs_dtype != promote_dtype) {
eager_gil_scoped_release guard;
other_tensor = cast_ad_func(other_tensor, promote_dtype);
}
} else {
LOG(WARNING)
<< "The dtype of left and right Tensor are not the same, left "
"dtype is "
<< lhs_dtype << ", but right dtype is " << rhs_dtype
<< ", the right dtype will convert to " << lhs_dtype;
eager_gil_scoped_release guard;
other_tensor = cast_ad_func(other_tensor, lhs_dtype);
}
}
if (_supported_int_dtype_.find(self_tensor.dtype()) !=
_supported_int_dtype_.end()) {
eager_gil_scoped_release guard;
self_tensor = cast_ad_func(self_tensor, DataType::FLOAT32);
}
if (_supported_int_dtype_.find(other_tensor.dtype()) !=
_supported_int_dtype_.end()) {
eager_gil_scoped_release guard;
other_tensor = cast_ad_func(other_tensor, DataType::FLOAT32);
}
// 4. calculation
VLOG(6) << "Calling divide_ad_func in tensor__rdiv__method";
{
eager_gil_scoped_release guard;
ret = divide_ad_func(other_tensor, self_tensor);
}
return ToPyObject(ret);
EAGER_CATCH_AND_THROW_RETURN_NULL
}
PyMethodDef math_op_patch_methods[] = { PyMethodDef math_op_patch_methods[] = {
{"__add__", {"__add__",
(PyCFunction)(void (*)(void))tensor__add__method, (PyCFunction)(void (*)(void))tensor__add__method,
...@@ -559,6 +788,22 @@ PyMethodDef math_op_patch_methods[] = { ...@@ -559,6 +788,22 @@ PyMethodDef math_op_patch_methods[] = {
(PyCFunction)(void (*)(void))tensor__mul__method, (PyCFunction)(void (*)(void))tensor__mul__method,
METH_VARARGS | METH_KEYWORDS, METH_VARARGS | METH_KEYWORDS,
NULL}, NULL},
{"__div__",
(PyCFunction)(void (*)(void))tensor__div__method,
METH_VARARGS | METH_KEYWORDS,
NULL},
{"__truediv__",
(PyCFunction)(void (*)(void))tensor__div__method,
METH_VARARGS | METH_KEYWORDS,
NULL},
{"__rdiv__",
(PyCFunction)(void (*)(void))tensor__rdiv__method,
METH_VARARGS | METH_KEYWORDS,
NULL},
{"__rtruediv__",
(PyCFunction)(void (*)(void))tensor__rdiv__method,
METH_VARARGS | METH_KEYWORDS,
NULL},
{NULL, NULL, 0, NULL}}; {NULL, NULL, 0, NULL}};
} // namespace pybind } // namespace pybind
......
...@@ -387,13 +387,6 @@ def monkey_patch_math_varbase(): ...@@ -387,13 +387,6 @@ def monkey_patch_math_varbase():
('ndim', _ndim_), ('ndim', _ndim_),
('size', _size_), ('size', _size_),
('T', _T_), ('T', _T_),
('__div__',
_binary_creator_('__div__', 'divide', False, _scalar_div_, True)),
('__truediv__',
_binary_creator_('__truediv__', 'divide', False, _scalar_div_, True)),
('__rdiv__', _binary_creator_('__rdiv__', 'divide', True, None, True)),
('__rtruediv__',
_binary_creator_('rtruediv__', 'divide', True, None, True)),
('__pow__', _binary_creator_('__pow__', 'pow', False, _C_ops.pow, ('__pow__', _binary_creator_('__pow__', 'pow', False, _C_ops.pow,
True)), True)),
('__rpow__', _binary_creator_('__rpow__', 'elementwise_pow', True, ('__rpow__', _binary_creator_('__rpow__', 'elementwise_pow', True,
...@@ -423,6 +416,10 @@ def monkey_patch_math_varbase(): ...@@ -423,6 +416,10 @@ def monkey_patch_math_varbase():
'__rsub__', '__rsub__',
'__mul__', '__mul__',
'__rmul__', '__rmul__',
'__div__',
'__truediv__',
'__rdiv__',
'__rtruediv__',
] ]
global _already_patch_varbase global _already_patch_varbase
......
...@@ -59,9 +59,8 @@ class TestTensorTypePromotion(unittest.TestCase): ...@@ -59,9 +59,8 @@ class TestTensorTypePromotion(unittest.TestCase):
def test_operator(self): def test_operator(self):
with _test_eager_guard(): with _test_eager_guard():
self.setUp() pass
# add / sub / mul has been sunk to cpp level, there is no warnings to catch by this test. # add / sub / mul / div has been sunk to cpp level, there is no warnings to catch by this test.
self.div_operator()
self.setUp() self.setUp()
self.add_operator() self.add_operator()
self.sub_operator() self.sub_operator()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册