未验证 提交 7fb20b46 编写于 作者: W Weilong Wu 提交者: GitHub

[Eager, Performance optimization] support multiply( * operator) to sink to Cpp layer (#46326)

* [Eager] math op sink to Cpp level

* fix ci errors

* draft version


* support + and - operator under cpp directly

* add static test

* polish code

* promote types or unify right type to left

* recover static test case

* polish code and fix some ci errors

* support complex and polish code

* fix conflicts

* fix windows ci errors

* fix windows-inference-ci errors

* polish and fix tests

* fix test case

* polish code

* [Eager, Performance optimization] support multiply( * operator) to sink to Cpp layer

* rm useless glog

* polish code

* polish code and fix code-format

* polish code

* fix CI

* polish code
上级 808bf2b4
...@@ -103,7 +103,7 @@ void SetDevice(paddle::platform::Place place) { ...@@ -103,7 +103,7 @@ void SetDevice(paddle::platform::Place place) {
if (paddle::platform::is_gpu_place(place)) { if (paddle::platform::is_gpu_place(place)) {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
phi::backends::gpu::SetDeviceId(place.device); phi::backends::gpu::SetDeviceId(place.device);
VLOG(1) << "CurrentDeviceId: " << phi::backends::gpu::GetCurrentDeviceId() VLOG(6) << "CurrentDeviceId: " << phi::backends::gpu::GetCurrentDeviceId()
<< " from " << static_cast<int>(place.device); << " from " << static_cast<int>(place.device);
#else #else
PADDLE_THROW(paddle::platform::errors::PreconditionNotMet( PADDLE_THROW(paddle::platform::errors::PreconditionNotMet(
...@@ -114,7 +114,7 @@ void SetDevice(paddle::platform::Place place) { ...@@ -114,7 +114,7 @@ void SetDevice(paddle::platform::Place place) {
if (paddle::platform::is_custom_place(place)) { if (paddle::platform::is_custom_place(place)) {
#if defined(PADDLE_WITH_CUSTOM_DEVICE) #if defined(PADDLE_WITH_CUSTOM_DEVICE)
phi::DeviceManager::SetDevice(place); phi::DeviceManager::SetDevice(place);
VLOG(1) << "CurrentDeviceId: " VLOG(6) << "CurrentDeviceId: "
<< phi::DeviceManager::GetDevice(place.GetDeviceType()) << " from " << phi::DeviceManager::GetDevice(place.GetDeviceType()) << " from "
<< static_cast<int>(place.device); << static_cast<int>(place.device);
#else #else
...@@ -139,6 +139,8 @@ paddle::experimental::Tensor CallScalarFuction( ...@@ -139,6 +139,8 @@ paddle::experimental::Tensor CallScalarFuction(
} else if (op_type == "rsub") { } else if (op_type == "rsub") {
ret = scale_ad_func(self_tensor, phi::Scalar(-1.0), other, true); ret = scale_ad_func(self_tensor, phi::Scalar(-1.0), other, true);
} else if (op_type == "mul") {
ret = scale_ad_func(self_tensor, phi::Scalar(other), 0.0, true);
} }
return ret; return ret;
...@@ -431,6 +433,107 @@ static PyObject* tensor__rsub__method(TensorObject* self, ...@@ -431,6 +433,107 @@ static PyObject* tensor__rsub__method(TensorObject* self,
EAGER_CATCH_AND_THROW_RETURN_NULL EAGER_CATCH_AND_THROW_RETURN_NULL
} }
static PyObject* tensor__mul__method(TensorObject* self,
PyObject* args,
PyObject* kwargs) {
paddle::platform::RecordEvent pythonc_record_event(
"__mul__ pybind_patch_func",
paddle::platform::TracerEventType::UserDefined,
1);
EAGER_TRY
VLOG(6) << "Running Eager tensor__mul__method";
// Set Device ID
auto place = egr::Controller::Instance().GetExpectedPlace();
SetDevice(place);
paddle::experimental::Tensor ret;
paddle::experimental::Tensor self_tensor = self->tensor;
PyObject* other_obj = PyTuple_GET_ITEM(args, 0);
// 1. scalar exists cases
if ((PyFloat_Check(other_obj) || PyLong_Check(other_obj)) &&
!PyBool_Check(other_obj)) {
float other = 0.0;
if (PyFloat_Check(other_obj)) {
other = CastPyArg2AttrFloat(other_obj, 0);
if (_supported_int_dtype_.find(self_tensor.dtype()) !=
_supported_int_dtype_.end()) {
eager_gil_scoped_release guard;
self_tensor = cast_ad_func(self_tensor, DataType::FLOAT32);
}
} else if (PyCheckInteger(other_obj) || IsNumpyType(other_obj)) {
other = static_cast<float>(CastPyArg2AttrInt(other_obj, 0));
}
{
eager_gil_scoped_release guard;
ret = CallScalarFuction(self_tensor, other, "mul");
}
return ToPyObject(ret);
}
// 2. create or get tensor for other_obj
paddle::experimental::Tensor other_tensor;
if (!PyCheckTensor(other_obj)) {
paddle::experimental::Scalar value =
CastPyArg2Scalar(other_obj, "__mul__", 0);
if (PyComplex_Check(other_obj)) {
eager_gil_scoped_release guard;
other_tensor = full_ad_func({1}, value, DataType::COMPLEX64, place);
} else {
eager_gil_scoped_release guard;
other_tensor =
full_ad_func(self_tensor.shape(), value, self_tensor.dtype(), place);
}
} else {
other_tensor = CastPyArg2Tensor(other_obj, 0);
}
// 3. promote types or unify right var type to left var
phi::DataType lhs_dtype = self_tensor.dtype();
phi::DataType rhs_dtype = other_tensor.dtype();
if (lhs_dtype != rhs_dtype) {
// note: only op_type in _supported_promote_complex_types_ should promote
// dtype
if (_complex_dtypes.find(lhs_dtype) != _complex_dtypes.end() ||
_complex_dtypes.find(rhs_dtype) != _complex_dtypes.end()) {
phi::DataType promote_dtype =
framework::TransToPhiDataType(framework::PromoteTypesIfComplexExists(
framework::TransToProtoVarType(lhs_dtype),
framework::TransToProtoVarType(rhs_dtype)));
if (lhs_dtype != promote_dtype) {
// cast
eager_gil_scoped_release guard;
self_tensor = cast_ad_func(self_tensor, promote_dtype);
}
if (rhs_dtype != promote_dtype) {
eager_gil_scoped_release guard;
other_tensor = cast_ad_func(other_tensor, promote_dtype);
}
} else {
LOG(WARNING)
<< "The dtype of left and right Tensor are not the same, left "
"dtype is "
<< lhs_dtype << ", but right dtype is " << rhs_dtype
<< ", the right dtype will convert to " << lhs_dtype;
eager_gil_scoped_release guard;
other_tensor = cast_ad_func(other_tensor, lhs_dtype);
}
}
// 4. calculation
VLOG(6) << "Calling multiply_ad_func in tensor__mul__method";
{
eager_gil_scoped_release guard;
ret = multiply_ad_func(self_tensor, other_tensor);
}
return ToPyObject(ret);
EAGER_CATCH_AND_THROW_RETURN_NULL
}
PyMethodDef math_op_patch_methods[] = { PyMethodDef math_op_patch_methods[] = {
{"__add__", {"__add__",
(PyCFunction)(void (*)(void))tensor__add__method, (PyCFunction)(void (*)(void))tensor__add__method,
...@@ -448,6 +551,14 @@ PyMethodDef math_op_patch_methods[] = { ...@@ -448,6 +551,14 @@ PyMethodDef math_op_patch_methods[] = {
(PyCFunction)(void (*)(void))tensor__rsub__method, (PyCFunction)(void (*)(void))tensor__rsub__method,
METH_VARARGS | METH_KEYWORDS, METH_VARARGS | METH_KEYWORDS,
NULL}, NULL},
{"__mul__",
(PyCFunction)(void (*)(void))tensor__mul__method,
METH_VARARGS | METH_KEYWORDS,
NULL},
{"__rmul__",
(PyCFunction)(void (*)(void))tensor__mul__method,
METH_VARARGS | METH_KEYWORDS,
NULL},
{NULL, NULL, 0, NULL}}; {NULL, NULL, 0, NULL}};
} // namespace pybind } // namespace pybind
......
...@@ -1300,6 +1300,9 @@ paddle::experimental::Scalar CastPyArg2Scalar(PyObject* obj, ...@@ -1300,6 +1300,9 @@ paddle::experimental::Scalar CastPyArg2Scalar(PyObject* obj,
return paddle::experimental::Scalar(value); return paddle::experimental::Scalar(value);
} else if (type_name.find("numpy") != std::string::npos) { } else if (type_name.find("numpy") != std::string::npos) {
return CastNumpy2Scalar(obj, op_type, arg_pos); return CastNumpy2Scalar(obj, op_type, arg_pos);
} else if (PyComplex_Check(obj)) {
auto value = CastPyArg2Complex(obj, op_type, arg_pos);
return paddle::experimental::Scalar(value);
} else if (PyObject_CheckLongOrToLong(&obj)) { } else if (PyObject_CheckLongOrToLong(&obj)) {
int value = CastPyArg2Int(obj, op_type, arg_pos); int value = CastPyArg2Int(obj, op_type, arg_pos);
return paddle::experimental::Scalar(value); return paddle::experimental::Scalar(value);
......
...@@ -31,6 +31,7 @@ ...@@ -31,6 +31,7 @@
#include "paddle/fluid/imperative/type_defs.h" #include "paddle/fluid/imperative/type_defs.h"
#include "paddle/fluid/operators/ops_extra_info.h" #include "paddle/fluid/operators/ops_extra_info.h"
#include "paddle/fluid/pybind/imperative.h" #include "paddle/fluid/pybind/imperative.h"
#include "paddle/phi/common/complex.h"
namespace py = pybind11; namespace py = pybind11;
namespace paddle { namespace paddle {
...@@ -214,6 +215,25 @@ double CastPyArg2Double(PyObject* obj, ...@@ -214,6 +215,25 @@ double CastPyArg2Double(PyObject* obj,
return 0.0; return 0.0;
} }
phi::dtype::complex<float> CastPyArg2Complex(PyObject* obj,
const std::string& op_type,
ssize_t arg_pos) {
if (PyComplex_Check(obj)) {
double real = PyComplex_RealAsDouble(obj);
double imag = PyComplex_ImagAsDouble(obj);
return phi::dtype::complex<float>(real, imag);
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"%s(): argument (position %d) must be "
"complex, but got %s",
op_type,
arg_pos + 1,
((PyTypeObject*)obj->ob_type)->tp_name)); // NOLINT
}
return phi::dtype::complex<float>(0, 0);
}
void CastPyArg2AttrDouble(PyObject* obj, void CastPyArg2AttrDouble(PyObject* obj,
paddle::framework::AttributeMap& attrs, // NOLINT paddle::framework::AttributeMap& attrs, // NOLINT
const std::string& key, const std::string& key,
......
...@@ -61,6 +61,9 @@ float CastPyArg2Float(PyObject* obj, ...@@ -61,6 +61,9 @@ float CastPyArg2Float(PyObject* obj,
double CastPyArg2Double(PyObject* obj, double CastPyArg2Double(PyObject* obj,
const std::string& op_type, const std::string& op_type,
ssize_t arg_pos); ssize_t arg_pos);
phi::dtype::complex<float> CastPyArg2Complex(PyObject* obj,
const std::string& op_type,
ssize_t arg_pos);
std::string CastPyArg2String(PyObject* obj, std::string CastPyArg2String(PyObject* obj,
const std::string& op_type, const std::string& op_type,
ssize_t arg_pos); ssize_t arg_pos);
......
...@@ -389,10 +389,6 @@ def monkey_patch_math_varbase(): ...@@ -389,10 +389,6 @@ def monkey_patch_math_varbase():
('ndim', _ndim_), ('ndim', _ndim_),
('size', _size_), ('size', _size_),
('T', _T_), ('T', _T_),
('__mul__',
_binary_creator_('__mul__', 'multiply', False, _scalar_mul_, True)),
('__rmul__',
_binary_creator_('__rmul__', 'multiply', False, _scalar_mul_, True)),
('__div__', ('__div__',
_binary_creator_('__div__', 'divide', False, _scalar_div_, True)), _binary_creator_('__div__', 'divide', False, _scalar_div_, True)),
('__truediv__', ('__truediv__',
...@@ -427,6 +423,8 @@ def monkey_patch_math_varbase(): ...@@ -427,6 +423,8 @@ def monkey_patch_math_varbase():
"__radd__", "__radd__",
'__sub__', '__sub__',
'__rsub__', '__rsub__',
'__mul__',
'__rmul__',
] ]
global _already_patch_varbase global _already_patch_varbase
......
...@@ -62,8 +62,7 @@ class TestTensorTypePromotion(unittest.TestCase): ...@@ -62,8 +62,7 @@ class TestTensorTypePromotion(unittest.TestCase):
def test_operator(self): def test_operator(self):
with _test_eager_guard(): with _test_eager_guard():
self.setUp() self.setUp()
# add and sub has been sunk to cpp level, there is no warnings to catch by this test. # add / sub / mul has been sunk to cpp level, there is no warnings to catch by this test.
self.mul_operator()
self.div_operator() self.div_operator()
self.setUp() self.setUp()
self.add_operator() self.add_operator()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册