未验证 提交 7fb20b46 编写于 作者: W Weilong Wu 提交者: GitHub

[Eager, Performance optimization] support multiply( * operator) to sink to Cpp layer (#46326)

* [Eager] math op sink to Cpp level

* fix ci errors

* draft version


* support + and - operator under cpp directly

* add static test

* polish code

* promote types or unify right type to left

* recover static test case

* polish code and fix some ci errors

* support complex and polish code

* fix conflicts

* fix windows ci errors

* fix windows-inference-ci errors

* polish and fix tests

* fix test case

* polish code

* [Eager, Performance optimization] support multiply( * operator) to sink to Cpp layer

* rm useless glog

* polish code

* polish code and fix code-format

* polish code

* fix CI

* polish code
上级 808bf2b4
......@@ -103,7 +103,7 @@ void SetDevice(paddle::platform::Place place) {
if (paddle::platform::is_gpu_place(place)) {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
phi::backends::gpu::SetDeviceId(place.device);
VLOG(1) << "CurrentDeviceId: " << phi::backends::gpu::GetCurrentDeviceId()
VLOG(6) << "CurrentDeviceId: " << phi::backends::gpu::GetCurrentDeviceId()
<< " from " << static_cast<int>(place.device);
#else
PADDLE_THROW(paddle::platform::errors::PreconditionNotMet(
......@@ -114,7 +114,7 @@ void SetDevice(paddle::platform::Place place) {
if (paddle::platform::is_custom_place(place)) {
#if defined(PADDLE_WITH_CUSTOM_DEVICE)
phi::DeviceManager::SetDevice(place);
VLOG(1) << "CurrentDeviceId: "
VLOG(6) << "CurrentDeviceId: "
<< phi::DeviceManager::GetDevice(place.GetDeviceType()) << " from "
<< static_cast<int>(place.device);
#else
......@@ -139,6 +139,8 @@ paddle::experimental::Tensor CallScalarFuction(
} else if (op_type == "rsub") {
ret = scale_ad_func(self_tensor, phi::Scalar(-1.0), other, true);
} else if (op_type == "mul") {
ret = scale_ad_func(self_tensor, phi::Scalar(other), 0.0, true);
}
return ret;
......@@ -431,6 +433,107 @@ static PyObject* tensor__rsub__method(TensorObject* self,
EAGER_CATCH_AND_THROW_RETURN_NULL
}
static PyObject* tensor__mul__method(TensorObject* self,
PyObject* args,
PyObject* kwargs) {
paddle::platform::RecordEvent pythonc_record_event(
"__mul__ pybind_patch_func",
paddle::platform::TracerEventType::UserDefined,
1);
EAGER_TRY
VLOG(6) << "Running Eager tensor__mul__method";
// Set Device ID
auto place = egr::Controller::Instance().GetExpectedPlace();
SetDevice(place);
paddle::experimental::Tensor ret;
paddle::experimental::Tensor self_tensor = self->tensor;
PyObject* other_obj = PyTuple_GET_ITEM(args, 0);
// 1. scalar exists cases
if ((PyFloat_Check(other_obj) || PyLong_Check(other_obj)) &&
!PyBool_Check(other_obj)) {
float other = 0.0;
if (PyFloat_Check(other_obj)) {
other = CastPyArg2AttrFloat(other_obj, 0);
if (_supported_int_dtype_.find(self_tensor.dtype()) !=
_supported_int_dtype_.end()) {
eager_gil_scoped_release guard;
self_tensor = cast_ad_func(self_tensor, DataType::FLOAT32);
}
} else if (PyCheckInteger(other_obj) || IsNumpyType(other_obj)) {
other = static_cast<float>(CastPyArg2AttrInt(other_obj, 0));
}
{
eager_gil_scoped_release guard;
ret = CallScalarFuction(self_tensor, other, "mul");
}
return ToPyObject(ret);
}
// 2. create or get tensor for other_obj
paddle::experimental::Tensor other_tensor;
if (!PyCheckTensor(other_obj)) {
paddle::experimental::Scalar value =
CastPyArg2Scalar(other_obj, "__mul__", 0);
if (PyComplex_Check(other_obj)) {
eager_gil_scoped_release guard;
other_tensor = full_ad_func({1}, value, DataType::COMPLEX64, place);
} else {
eager_gil_scoped_release guard;
other_tensor =
full_ad_func(self_tensor.shape(), value, self_tensor.dtype(), place);
}
} else {
other_tensor = CastPyArg2Tensor(other_obj, 0);
}
// 3. promote types or unify right var type to left var
phi::DataType lhs_dtype = self_tensor.dtype();
phi::DataType rhs_dtype = other_tensor.dtype();
if (lhs_dtype != rhs_dtype) {
// note: only op_type in _supported_promote_complex_types_ should promote
// dtype
if (_complex_dtypes.find(lhs_dtype) != _complex_dtypes.end() ||
_complex_dtypes.find(rhs_dtype) != _complex_dtypes.end()) {
phi::DataType promote_dtype =
framework::TransToPhiDataType(framework::PromoteTypesIfComplexExists(
framework::TransToProtoVarType(lhs_dtype),
framework::TransToProtoVarType(rhs_dtype)));
if (lhs_dtype != promote_dtype) {
// cast
eager_gil_scoped_release guard;
self_tensor = cast_ad_func(self_tensor, promote_dtype);
}
if (rhs_dtype != promote_dtype) {
eager_gil_scoped_release guard;
other_tensor = cast_ad_func(other_tensor, promote_dtype);
}
} else {
LOG(WARNING)
<< "The dtype of left and right Tensor are not the same, left "
"dtype is "
<< lhs_dtype << ", but right dtype is " << rhs_dtype
<< ", the right dtype will convert to " << lhs_dtype;
eager_gil_scoped_release guard;
other_tensor = cast_ad_func(other_tensor, lhs_dtype);
}
}
// 4. calculation
VLOG(6) << "Calling multiply_ad_func in tensor__mul__method";
{
eager_gil_scoped_release guard;
ret = multiply_ad_func(self_tensor, other_tensor);
}
return ToPyObject(ret);
EAGER_CATCH_AND_THROW_RETURN_NULL
}
PyMethodDef math_op_patch_methods[] = {
{"__add__",
(PyCFunction)(void (*)(void))tensor__add__method,
......@@ -448,6 +551,14 @@ PyMethodDef math_op_patch_methods[] = {
(PyCFunction)(void (*)(void))tensor__rsub__method,
METH_VARARGS | METH_KEYWORDS,
NULL},
{"__mul__",
(PyCFunction)(void (*)(void))tensor__mul__method,
METH_VARARGS | METH_KEYWORDS,
NULL},
{"__rmul__",
(PyCFunction)(void (*)(void))tensor__mul__method,
METH_VARARGS | METH_KEYWORDS,
NULL},
{NULL, NULL, 0, NULL}};
} // namespace pybind
......
......@@ -1300,6 +1300,9 @@ paddle::experimental::Scalar CastPyArg2Scalar(PyObject* obj,
return paddle::experimental::Scalar(value);
} else if (type_name.find("numpy") != std::string::npos) {
return CastNumpy2Scalar(obj, op_type, arg_pos);
} else if (PyComplex_Check(obj)) {
auto value = CastPyArg2Complex(obj, op_type, arg_pos);
return paddle::experimental::Scalar(value);
} else if (PyObject_CheckLongOrToLong(&obj)) {
int value = CastPyArg2Int(obj, op_type, arg_pos);
return paddle::experimental::Scalar(value);
......
......@@ -31,6 +31,7 @@
#include "paddle/fluid/imperative/type_defs.h"
#include "paddle/fluid/operators/ops_extra_info.h"
#include "paddle/fluid/pybind/imperative.h"
#include "paddle/phi/common/complex.h"
namespace py = pybind11;
namespace paddle {
......@@ -214,6 +215,25 @@ double CastPyArg2Double(PyObject* obj,
return 0.0;
}
phi::dtype::complex<float> CastPyArg2Complex(PyObject* obj,
const std::string& op_type,
ssize_t arg_pos) {
if (PyComplex_Check(obj)) {
double real = PyComplex_RealAsDouble(obj);
double imag = PyComplex_ImagAsDouble(obj);
return phi::dtype::complex<float>(real, imag);
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"%s(): argument (position %d) must be "
"complex, but got %s",
op_type,
arg_pos + 1,
((PyTypeObject*)obj->ob_type)->tp_name)); // NOLINT
}
return phi::dtype::complex<float>(0, 0);
}
void CastPyArg2AttrDouble(PyObject* obj,
paddle::framework::AttributeMap& attrs, // NOLINT
const std::string& key,
......
......@@ -61,6 +61,9 @@ float CastPyArg2Float(PyObject* obj,
double CastPyArg2Double(PyObject* obj,
const std::string& op_type,
ssize_t arg_pos);
phi::dtype::complex<float> CastPyArg2Complex(PyObject* obj,
const std::string& op_type,
ssize_t arg_pos);
std::string CastPyArg2String(PyObject* obj,
const std::string& op_type,
ssize_t arg_pos);
......
......@@ -389,10 +389,6 @@ def monkey_patch_math_varbase():
('ndim', _ndim_),
('size', _size_),
('T', _T_),
('__mul__',
_binary_creator_('__mul__', 'multiply', False, _scalar_mul_, True)),
('__rmul__',
_binary_creator_('__rmul__', 'multiply', False, _scalar_mul_, True)),
('__div__',
_binary_creator_('__div__', 'divide', False, _scalar_div_, True)),
('__truediv__',
......@@ -427,6 +423,8 @@ def monkey_patch_math_varbase():
"__radd__",
'__sub__',
'__rsub__',
'__mul__',
'__rmul__',
]
global _already_patch_varbase
......
......@@ -62,8 +62,7 @@ class TestTensorTypePromotion(unittest.TestCase):
def test_operator(self):
with _test_eager_guard():
self.setUp()
# add and sub has been sunk to cpp level, there is no warnings to catch by this test.
self.mul_operator()
# add / sub / mul has been sunk to cpp level, there is no warnings to catch by this test.
self.div_operator()
self.setUp()
self.add_operator()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册