未验证 提交 7d238139 编写于 作者: W Weilong Wu 提交者: GitHub

[Eager, Performance optimization] support less_than & less_equal( < & <=...

[Eager, Performance optimization] support less_than & less_equal( < & <= operator) to sink to Cpp layer (#46542)
上级 2aec65be
......@@ -933,6 +933,176 @@ static PyObject* tensor__ge__method(TensorObject* self,
EAGER_CATCH_AND_THROW_RETURN_NULL
}
static PyObject* tensor__lt__method(TensorObject* self,
PyObject* args,
PyObject* kwargs) {
paddle::platform::RecordEvent pythonc_record_event(
"__lt__ pybind_patch_func",
paddle::platform::TracerEventType::UserDefined,
1);
EAGER_TRY
VLOG(1) << "Running Eager tensor__lt__method";
// Set Device ID
auto place = egr::Controller::Instance().GetExpectedPlace();
SetDevice(place);
paddle::experimental::Tensor ret;
paddle::experimental::Tensor self_tensor = self->tensor;
PyObject* other_obj = PyTuple_GET_ITEM(args, 0);
// 1. scalar exists cases
// there is no scalar function for __lt__ now
float other_float = 0.0;
bool has_other_float = false;
if (PyFloat_Check(other_obj) || PyCheckInteger(other_obj) ||
IsNumpyType(other_obj)) {
if (PyFloat_Check(other_obj)) {
other_float = CastPyArg2AttrFloat(other_obj, 0);
has_other_float = true;
if (_supported_int_dtype_.find(self_tensor.dtype()) !=
_supported_int_dtype_.end()) {
eager_gil_scoped_release guard;
self_tensor = cast_ad_func(self_tensor, DataType::FLOAT32);
}
} else if (PyCheckInteger(other_obj) || IsNumpyType(other_obj)) {
other_float = static_cast<float>(CastPyArg2AttrInt(other_obj, 0));
has_other_float = true;
}
}
// 2. create or get tensor for other_obj
paddle::experimental::Tensor other_tensor;
if (has_other_float) {
eager_gil_scoped_release guard;
other_tensor = full_ad_func(self_tensor.shape(),
phi::Scalar(other_float),
self_tensor.dtype(),
place);
} else if (!PyCheckTensor(other_obj)) {
paddle::experimental::Scalar value =
CastPyArg2Scalar(other_obj, "__lt__", 0);
if (PyComplex_Check(other_obj)) {
eager_gil_scoped_release guard;
other_tensor = full_ad_func({1}, value, DataType::COMPLEX64, place);
} else {
eager_gil_scoped_release guard;
other_tensor =
full_ad_func(self_tensor.shape(), value, self_tensor.dtype(), place);
}
} else {
other_tensor = CastPyArg2Tensor(other_obj, 0);
}
// 3. promote types or unify right var type to left var
phi::DataType lhs_dtype = self_tensor.dtype();
phi::DataType rhs_dtype = other_tensor.dtype();
if (lhs_dtype != rhs_dtype) {
LOG(WARNING) << "The dtype of left and right Tensor are not the same, left "
"dtype is "
<< lhs_dtype << ", but right dtype is " << rhs_dtype
<< ", the right dtype will convert to " << lhs_dtype;
eager_gil_scoped_release guard;
other_tensor = cast_ad_func(other_tensor, lhs_dtype);
}
// 4. calculation
VLOG(6) << "Calling less_than_ad_func in tensor__lt__method";
{
eager_gil_scoped_release guard;
ret = less_than_ad_func(self_tensor, other_tensor, -1);
}
return ToPyObject(ret);
EAGER_CATCH_AND_THROW_RETURN_NULL
}
static PyObject* tensor__le__method(TensorObject* self,
PyObject* args,
PyObject* kwargs) {
paddle::platform::RecordEvent pythonc_record_event(
"__le__ pybind_patch_func",
paddle::platform::TracerEventType::UserDefined,
1);
EAGER_TRY
VLOG(1) << "Running Eager tensor__le__method";
// Set Device ID
auto place = egr::Controller::Instance().GetExpectedPlace();
SetDevice(place);
paddle::experimental::Tensor ret;
paddle::experimental::Tensor self_tensor = self->tensor;
PyObject* other_obj = PyTuple_GET_ITEM(args, 0);
// 1. scalar exists cases
// there is no scalar function for __le__ now
float other_float = 0.0;
bool has_other_float = false;
if (PyFloat_Check(other_obj) || PyCheckInteger(other_obj) ||
IsNumpyType(other_obj)) {
if (PyFloat_Check(other_obj)) {
other_float = CastPyArg2AttrFloat(other_obj, 0);
has_other_float = true;
if (_supported_int_dtype_.find(self_tensor.dtype()) !=
_supported_int_dtype_.end()) {
eager_gil_scoped_release guard;
self_tensor = cast_ad_func(self_tensor, DataType::FLOAT32);
}
} else if (PyCheckInteger(other_obj) || IsNumpyType(other_obj)) {
other_float = static_cast<float>(CastPyArg2AttrInt(other_obj, 0));
has_other_float = true;
}
}
// 2. create or get tensor for other_obj
paddle::experimental::Tensor other_tensor;
if (has_other_float) {
eager_gil_scoped_release guard;
other_tensor = full_ad_func(self_tensor.shape(),
phi::Scalar(other_float),
self_tensor.dtype(),
place);
} else if (!PyCheckTensor(other_obj)) {
paddle::experimental::Scalar value =
CastPyArg2Scalar(other_obj, "__le__", 0);
if (PyComplex_Check(other_obj)) {
eager_gil_scoped_release guard;
other_tensor = full_ad_func({1}, value, DataType::COMPLEX64, place);
} else {
eager_gil_scoped_release guard;
other_tensor =
full_ad_func(self_tensor.shape(), value, self_tensor.dtype(), place);
}
} else {
other_tensor = CastPyArg2Tensor(other_obj, 0);
}
// 3. promote types or unify right var type to left var
phi::DataType lhs_dtype = self_tensor.dtype();
phi::DataType rhs_dtype = other_tensor.dtype();
if (lhs_dtype != rhs_dtype) {
LOG(WARNING) << "The dtype of left and right Tensor are not the same, left "
"dtype is "
<< lhs_dtype << ", but right dtype is " << rhs_dtype
<< ", the right dtype will convert to " << lhs_dtype;
eager_gil_scoped_release guard;
other_tensor = cast_ad_func(other_tensor, lhs_dtype);
}
// 4. calculation
VLOG(6) << "Calling less_equal_ad_func in tensor__le__method";
{
eager_gil_scoped_release guard;
ret = less_equal_ad_func(self_tensor, other_tensor, -1);
}
return ToPyObject(ret);
EAGER_CATCH_AND_THROW_RETURN_NULL
}
PyMethodDef math_op_patch_methods[] = {
{"__add__",
(PyCFunction)(void (*)(void))tensor__add__method,
......@@ -982,6 +1152,14 @@ PyMethodDef math_op_patch_methods[] = {
(PyCFunction)(void (*)(void))tensor__ge__method,
METH_VARARGS | METH_KEYWORDS,
NULL},
{"__lt__",
(PyCFunction)(void (*)(void))tensor__lt__method,
METH_VARARGS | METH_KEYWORDS,
NULL},
{"__le__",
(PyCFunction)(void (*)(void))tensor__le__method,
METH_VARARGS | METH_KEYWORDS,
NULL},
{NULL, NULL, 0, NULL}};
} // namespace pybind
......
......@@ -400,8 +400,6 @@ def monkey_patch_math_varbase():
# for logical compare
('__eq__', _binary_creator_('__eq__', 'equal', False, None, True)),
('__ne__', _binary_creator_('__ne__', 'not_equal', False, None, True)),
('__lt__', _binary_creator_('__lt__', 'less_than', False, None, True)),
('__le__', _binary_creator_('__le__', 'less_equal', False, None, True)),
('__array_ufunc__', None)
]
......@@ -418,6 +416,8 @@ def monkey_patch_math_varbase():
'__rtruediv__',
'__gt__',
'__ge__',
'__lt__',
'__le__',
]
global _already_patch_varbase
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册