未验证 提交 7d7444cc 编写于 作者: W Weilong Wu 提交者: GitHub

[Eager, Performance optimization] support mod / matmul ( % and @ operator) to...

[Eager, Performance optimization] support mod / matmul ( % and @ operator) to sink to Cpp layer (#46565)

* [Eager, Performance optimization] support mod ( % operator) to sink to Cpp layer

* fix mod logic

* support matmul math operator

* rm LOG(warning), use VLOG(6)

* fix conflicts mistake
上级 58a478f8
...@@ -224,8 +224,7 @@ static PyObject* tensor__add__method(TensorObject* self, ...@@ -224,8 +224,7 @@ static PyObject* tensor__add__method(TensorObject* self,
other_tensor = cast_ad_func(other_tensor, promote_dtype); other_tensor = cast_ad_func(other_tensor, promote_dtype);
} }
} else { } else {
LOG(WARNING) VLOG(6) << "The dtype of left and right Tensor are not the same, left "
<< "The dtype of left and right Tensor are not the same, left "
"dtype is " "dtype is "
<< lhs_dtype << ", but right dtype is " << rhs_dtype << lhs_dtype << ", but right dtype is " << rhs_dtype
<< ", the right dtype will convert to " << lhs_dtype; << ", the right dtype will convert to " << lhs_dtype;
...@@ -320,8 +319,7 @@ static PyObject* tensor__sub__method(TensorObject* self, ...@@ -320,8 +319,7 @@ static PyObject* tensor__sub__method(TensorObject* self,
other_tensor = cast_ad_func(other_tensor, promote_dtype); other_tensor = cast_ad_func(other_tensor, promote_dtype);
} }
} else { } else {
LOG(WARNING) VLOG(6) << "The dtype of left and right Tensor are not the same, left "
<< "The dtype of left and right Tensor are not the same, left "
"dtype is " "dtype is "
<< lhs_dtype << ", but right dtype is " << rhs_dtype << lhs_dtype << ", but right dtype is " << rhs_dtype
<< ", the right dtype will convert to " << lhs_dtype; << ", the right dtype will convert to " << lhs_dtype;
...@@ -414,8 +412,7 @@ static PyObject* tensor__rsub__method(TensorObject* self, ...@@ -414,8 +412,7 @@ static PyObject* tensor__rsub__method(TensorObject* self,
other_tensor = cast_ad_func(other_tensor, promote_dtype); other_tensor = cast_ad_func(other_tensor, promote_dtype);
} }
} else { } else {
LOG(WARNING) VLOG(6) << "The dtype of left and right Tensor are not the same, left "
<< "The dtype of left and right Tensor are not the same, left "
"dtype is " "dtype is "
<< lhs_dtype << ", but right dtype is " << rhs_dtype << lhs_dtype << ", but right dtype is " << rhs_dtype
<< ", the right dtype will convert to " << lhs_dtype; << ", the right dtype will convert to " << lhs_dtype;
...@@ -515,8 +512,7 @@ static PyObject* tensor__mul__method(TensorObject* self, ...@@ -515,8 +512,7 @@ static PyObject* tensor__mul__method(TensorObject* self,
other_tensor = cast_ad_func(other_tensor, promote_dtype); other_tensor = cast_ad_func(other_tensor, promote_dtype);
} }
} else { } else {
LOG(WARNING) VLOG(6) << "The dtype of left and right Tensor are not the same, left "
<< "The dtype of left and right Tensor are not the same, left "
"dtype is " "dtype is "
<< lhs_dtype << ", but right dtype is " << rhs_dtype << lhs_dtype << ", but right dtype is " << rhs_dtype
<< ", the right dtype will convert to " << lhs_dtype; << ", the right dtype will convert to " << lhs_dtype;
...@@ -617,8 +613,7 @@ static PyObject* tensor__div__method(TensorObject* self, ...@@ -617,8 +613,7 @@ static PyObject* tensor__div__method(TensorObject* self,
other_tensor = cast_ad_func(other_tensor, promote_dtype); other_tensor = cast_ad_func(other_tensor, promote_dtype);
} }
} else { } else {
LOG(WARNING) VLOG(6) << "The dtype of left and right Tensor are not the same, left "
<< "The dtype of left and right Tensor are not the same, left "
"dtype is " "dtype is "
<< lhs_dtype << ", but right dtype is " << rhs_dtype << lhs_dtype << ", but right dtype is " << rhs_dtype
<< ", the right dtype will convert to " << lhs_dtype; << ", the right dtype will convert to " << lhs_dtype;
...@@ -733,8 +728,7 @@ static PyObject* tensor__rdiv__method(TensorObject* self, ...@@ -733,8 +728,7 @@ static PyObject* tensor__rdiv__method(TensorObject* self,
other_tensor = cast_ad_func(other_tensor, promote_dtype); other_tensor = cast_ad_func(other_tensor, promote_dtype);
} }
} else { } else {
LOG(WARNING) VLOG(6) << "The dtype of left and right Tensor are not the same, left "
<< "The dtype of left and right Tensor are not the same, left "
"dtype is " "dtype is "
<< lhs_dtype << ", but right dtype is " << rhs_dtype << lhs_dtype << ", but right dtype is " << rhs_dtype
<< ", the right dtype will convert to " << lhs_dtype; << ", the right dtype will convert to " << lhs_dtype;
...@@ -829,7 +823,7 @@ static PyObject* tensor__gt__method(TensorObject* self, ...@@ -829,7 +823,7 @@ static PyObject* tensor__gt__method(TensorObject* self,
phi::DataType lhs_dtype = self_tensor.dtype(); phi::DataType lhs_dtype = self_tensor.dtype();
phi::DataType rhs_dtype = other_tensor.dtype(); phi::DataType rhs_dtype = other_tensor.dtype();
if (lhs_dtype != rhs_dtype) { if (lhs_dtype != rhs_dtype) {
LOG(WARNING) << "The dtype of left and right Tensor are not the same, left " VLOG(6) << "The dtype of left and right Tensor are not the same, left "
"dtype is " "dtype is "
<< lhs_dtype << ", but right dtype is " << rhs_dtype << lhs_dtype << ", but right dtype is " << rhs_dtype
<< ", the right dtype will convert to " << lhs_dtype; << ", the right dtype will convert to " << lhs_dtype;
...@@ -914,7 +908,7 @@ static PyObject* tensor__ge__method(TensorObject* self, ...@@ -914,7 +908,7 @@ static PyObject* tensor__ge__method(TensorObject* self,
phi::DataType lhs_dtype = self_tensor.dtype(); phi::DataType lhs_dtype = self_tensor.dtype();
phi::DataType rhs_dtype = other_tensor.dtype(); phi::DataType rhs_dtype = other_tensor.dtype();
if (lhs_dtype != rhs_dtype) { if (lhs_dtype != rhs_dtype) {
LOG(WARNING) << "The dtype of left and right Tensor are not the same, left " VLOG(6) << "The dtype of left and right Tensor are not the same, left "
"dtype is " "dtype is "
<< lhs_dtype << ", but right dtype is " << rhs_dtype << lhs_dtype << ", but right dtype is " << rhs_dtype
<< ", the right dtype will convert to " << lhs_dtype; << ", the right dtype will convert to " << lhs_dtype;
...@@ -933,6 +927,192 @@ static PyObject* tensor__ge__method(TensorObject* self, ...@@ -933,6 +927,192 @@ static PyObject* tensor__ge__method(TensorObject* self,
EAGER_CATCH_AND_THROW_RETURN_NULL EAGER_CATCH_AND_THROW_RETURN_NULL
} }
static PyObject* tensor__mod__method(TensorObject* self,
PyObject* args,
PyObject* kwargs) {
paddle::platform::RecordEvent pythonc_record_event(
"__mod__ pybind_patch_func",
paddle::platform::TracerEventType::UserDefined,
1);
EAGER_TRY
VLOG(6) << "Running Eager tensor__mod__method";
// Set Device ID
auto place = egr::Controller::Instance().GetExpectedPlace();
SetDevice(place);
paddle::experimental::Tensor ret;
paddle::experimental::Tensor self_tensor = self->tensor;
PyObject* other_obj = PyTuple_GET_ITEM(args, 0);
// 1. scalar exists cases
// there is no scalar_mod function for __mod__ now
float other_float = 0.0;
bool has_other_float = false;
if (PyFloat_Check(other_obj) || PyCheckInteger(other_obj) ||
IsNumpyType(other_obj)) {
if (PyFloat_Check(other_obj)) {
other_float = CastPyArg2AttrFloat(other_obj, 0);
has_other_float = true;
if (_supported_int_dtype_.find(self_tensor.dtype()) !=
_supported_int_dtype_.end()) {
eager_gil_scoped_release guard;
self_tensor = cast_ad_func(self_tensor, DataType::FLOAT32);
}
} else if (PyCheckInteger(other_obj) || IsNumpyType(other_obj)) {
other_float = static_cast<float>(CastPyArg2AttrInt(other_obj, 0));
has_other_float = true;
}
}
// 2. create or get tensor for other_obj
paddle::experimental::Tensor other_tensor;
if (has_other_float) {
eager_gil_scoped_release guard;
other_tensor = full_ad_func(self_tensor.shape(),
phi::Scalar(other_float),
self_tensor.dtype(),
place);
} else if (!PyCheckTensor(other_obj)) {
paddle::experimental::Scalar value =
CastPyArg2Scalar(other_obj, "__mod__", 0);
if (PyComplex_Check(other_obj)) {
eager_gil_scoped_release guard;
other_tensor = full_ad_func({1}, value, DataType::COMPLEX64, place);
} else {
eager_gil_scoped_release guard;
other_tensor =
full_ad_func(self_tensor.shape(), value, self_tensor.dtype(), place);
}
} else {
other_tensor = CastPyArg2Tensor(other_obj, 0);
}
// 3. promote types or unify right var type to left var
phi::DataType lhs_dtype = self_tensor.dtype();
phi::DataType rhs_dtype = other_tensor.dtype();
if (lhs_dtype != rhs_dtype) {
VLOG(6) << "The dtype of left and right Tensor are not the same, left "
"dtype is "
<< lhs_dtype << ", but right dtype is " << rhs_dtype
<< ", the right dtype will convert to " << lhs_dtype;
eager_gil_scoped_release guard;
other_tensor = cast_ad_func(other_tensor, lhs_dtype);
}
// 4. calculation
VLOG(6) << "Calling remainder_ad_func in tensor__mod__method";
{
eager_gil_scoped_release guard;
ret = remainder_ad_func(self_tensor, other_tensor);
}
return ToPyObject(ret);
EAGER_CATCH_AND_THROW_RETURN_NULL
}
static PyObject* tensor__matmul__method(TensorObject* self,
PyObject* args,
PyObject* kwargs) {
paddle::platform::RecordEvent pythonc_record_event(
"__matmul__ pybind_patch_func",
paddle::platform::TracerEventType::UserDefined,
1);
EAGER_TRY
VLOG(6) << "Running Eager tensor__matmul__method";
// Set Device ID
auto place = egr::Controller::Instance().GetExpectedPlace();
SetDevice(place);
paddle::experimental::Tensor ret;
paddle::experimental::Tensor self_tensor = self->tensor;
PyObject* other_obj = PyTuple_GET_ITEM(args, 0);
// 1. scalar exists cases
// there is no scalar_matmul function for __matmul__ now
float other_float = 0.0;
bool has_other_float = false;
if (PyFloat_Check(other_obj) || PyCheckInteger(other_obj) ||
IsNumpyType(other_obj)) {
if (PyFloat_Check(other_obj)) {
other_float = CastPyArg2AttrFloat(other_obj, 0);
has_other_float = true;
if (_supported_int_dtype_.find(self_tensor.dtype()) !=
_supported_int_dtype_.end()) {
eager_gil_scoped_release guard;
self_tensor = cast_ad_func(self_tensor, DataType::FLOAT32);
}
} else if (PyCheckInteger(other_obj) || IsNumpyType(other_obj)) {
other_float = static_cast<float>(CastPyArg2AttrInt(other_obj, 0));
has_other_float = true;
}
}
// 2. create or get tensor for other_obj
paddle::experimental::Tensor other_tensor;
if (has_other_float) {
eager_gil_scoped_release guard;
other_tensor =
full_ad_func({1}, phi::Scalar(other_float), self_tensor.dtype(), place);
} else if (!PyCheckTensor(other_obj)) {
paddle::experimental::Scalar value =
CastPyArg2Scalar(other_obj, "__matmul__", 0);
if (PyComplex_Check(other_obj)) {
eager_gil_scoped_release guard;
other_tensor = full_ad_func({1}, value, DataType::COMPLEX64, place);
} else {
eager_gil_scoped_release guard;
other_tensor = full_ad_func({1}, value, self_tensor.dtype(), place);
}
} else {
other_tensor = CastPyArg2Tensor(other_obj, 0);
}
// 3. promote types or unify right var type to left var
phi::DataType lhs_dtype = self_tensor.dtype();
phi::DataType rhs_dtype = other_tensor.dtype();
if (lhs_dtype != rhs_dtype) {
// note: only op_type in _supported_promote_complex_types_ should promote
// dtype
if (_complex_dtypes.find(lhs_dtype) != _complex_dtypes.end() ||
_complex_dtypes.find(rhs_dtype) != _complex_dtypes.end()) {
phi::DataType promote_dtype =
framework::TransToPhiDataType(framework::PromoteTypesIfComplexExists(
framework::TransToProtoVarType(lhs_dtype),
framework::TransToProtoVarType(rhs_dtype)));
if (lhs_dtype != promote_dtype) {
// cast
eager_gil_scoped_release guard;
self_tensor = cast_ad_func(self_tensor, promote_dtype);
}
if (rhs_dtype != promote_dtype) {
eager_gil_scoped_release guard;
other_tensor = cast_ad_func(other_tensor, promote_dtype);
}
} else {
VLOG(6) << "The dtype of left and right Tensor are not the same, left "
"dtype is "
<< lhs_dtype << ", but right dtype is " << rhs_dtype
<< ", the right dtype will convert to " << lhs_dtype;
eager_gil_scoped_release guard;
other_tensor = cast_ad_func(other_tensor, lhs_dtype);
}
}
// 4. calculation
VLOG(6) << "Calling matmul_ad_func in tensor__matmul__method";
{
eager_gil_scoped_release guard;
ret = matmul_ad_func(self_tensor, other_tensor, false, false);
}
return ToPyObject(ret);
EAGER_CATCH_AND_THROW_RETURN_NULL
}
static PyObject* tensor__lt__method(TensorObject* self, static PyObject* tensor__lt__method(TensorObject* self,
PyObject* args, PyObject* args,
PyObject* kwargs) { PyObject* kwargs) {
...@@ -999,7 +1179,7 @@ static PyObject* tensor__lt__method(TensorObject* self, ...@@ -999,7 +1179,7 @@ static PyObject* tensor__lt__method(TensorObject* self,
phi::DataType lhs_dtype = self_tensor.dtype(); phi::DataType lhs_dtype = self_tensor.dtype();
phi::DataType rhs_dtype = other_tensor.dtype(); phi::DataType rhs_dtype = other_tensor.dtype();
if (lhs_dtype != rhs_dtype) { if (lhs_dtype != rhs_dtype) {
LOG(WARNING) << "The dtype of left and right Tensor are not the same, left " VLOG(6) << "The dtype of left and right Tensor are not the same, left "
"dtype is " "dtype is "
<< lhs_dtype << ", but right dtype is " << rhs_dtype << lhs_dtype << ", but right dtype is " << rhs_dtype
<< ", the right dtype will convert to " << lhs_dtype; << ", the right dtype will convert to " << lhs_dtype;
...@@ -1084,7 +1264,7 @@ static PyObject* tensor__le__method(TensorObject* self, ...@@ -1084,7 +1264,7 @@ static PyObject* tensor__le__method(TensorObject* self,
phi::DataType lhs_dtype = self_tensor.dtype(); phi::DataType lhs_dtype = self_tensor.dtype();
phi::DataType rhs_dtype = other_tensor.dtype(); phi::DataType rhs_dtype = other_tensor.dtype();
if (lhs_dtype != rhs_dtype) { if (lhs_dtype != rhs_dtype) {
LOG(WARNING) << "The dtype of left and right Tensor are not the same, left " VLOG(6) << "The dtype of left and right Tensor are not the same, left "
"dtype is " "dtype is "
<< lhs_dtype << ", but right dtype is " << rhs_dtype << lhs_dtype << ", but right dtype is " << rhs_dtype
<< ", the right dtype will convert to " << lhs_dtype; << ", the right dtype will convert to " << lhs_dtype;
...@@ -1144,6 +1324,14 @@ PyMethodDef math_op_patch_methods[] = { ...@@ -1144,6 +1324,14 @@ PyMethodDef math_op_patch_methods[] = {
(PyCFunction)(void (*)(void))tensor__rdiv__method, (PyCFunction)(void (*)(void))tensor__rdiv__method,
METH_VARARGS | METH_KEYWORDS, METH_VARARGS | METH_KEYWORDS,
NULL}, NULL},
{"__mod__",
(PyCFunction)(void (*)(void))tensor__mod__method,
METH_VARARGS | METH_KEYWORDS,
NULL},
{"__matmul__",
(PyCFunction)(void (*)(void))tensor__matmul__method,
METH_VARARGS | METH_KEYWORDS,
NULL},
{"__gt__", {"__gt__",
(PyCFunction)(void (*)(void))tensor__gt__method, (PyCFunction)(void (*)(void))tensor__gt__method,
METH_VARARGS | METH_KEYWORDS, METH_VARARGS | METH_KEYWORDS,
......
...@@ -393,10 +393,6 @@ def monkey_patch_math_varbase(): ...@@ -393,10 +393,6 @@ def monkey_patch_math_varbase():
None)), None)),
('__floordiv__', ('__floordiv__',
_binary_creator_('__floordiv__', 'floor_divide', False, None, True)), _binary_creator_('__floordiv__', 'floor_divide', False, None, True)),
('__mod__', _binary_creator_('__mod__', 'remainder', False, None,
True)),
('__matmul__',
_binary_creator_('__matmul__', "matmul", False, None, True)),
# for logical compare # for logical compare
('__eq__', _binary_creator_('__eq__', 'equal', False, None, True)), ('__eq__', _binary_creator_('__eq__', 'equal', False, None, True)),
('__ne__', _binary_creator_('__ne__', 'not_equal', False, None, True)), ('__ne__', _binary_creator_('__ne__', 'not_equal', False, None, True)),
...@@ -414,6 +410,8 @@ def monkey_patch_math_varbase(): ...@@ -414,6 +410,8 @@ def monkey_patch_math_varbase():
'__truediv__', '__truediv__',
'__rdiv__', '__rdiv__',
'__rtruediv__', '__rtruediv__',
'__mod__',
'__matmul__',
'__gt__', '__gt__',
'__ge__', '__ge__',
'__lt__', '__lt__',
......
...@@ -732,6 +732,22 @@ class TestMathOpPatchesVarBase(unittest.TestCase): ...@@ -732,6 +732,22 @@ class TestMathOpPatchesVarBase(unittest.TestCase):
self.func_test_complex_scalar() self.func_test_complex_scalar()
self.func_test_complex_scalar() self.func_test_complex_scalar()
def func_test_matmul(self):
x_np = np.random.uniform(-1, 1, [2, 3]).astype(self.dtype)
y_np = np.random.uniform(-1, 1, [3, 2]).astype(self.dtype)
except_out = x_np @ y_np
with fluid.dygraph.guard():
x = paddle.to_tensor(x_np)
y = paddle.to_tensor(y_np)
out = x @ y
np.testing.assert_allclose(out.numpy(), except_out, atol=1e-03)
def test_matmul(self):
with _test_eager_guard():
self.func_test_matmul()
self.func_test_matmul()
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册