未验证 提交 b4a1f43f 编写于 作者: W Weilong Wu 提交者: GitHub

[Eager] use CastPyArg2Double to parse python float obj (#47029)

上级 73196e5a
......@@ -129,7 +129,7 @@ void SetDevice(paddle::platform::Place place) {
// this function will update gradually.
paddle::experimental::Tensor CallScalarFuction(
const paddle::experimental::Tensor& self_tensor,
float other,
double other,
std::string op_type) {
paddle::experimental::Tensor ret;
if (op_type == "add" || op_type == "radd") {
......@@ -169,16 +169,16 @@ static PyObject* tensor__add__method(TensorObject* self,
// 1. scalar exists cases
if (PyFloat_Check(other_obj) || PyCheckInteger(other_obj) ||
IsNumpyType(other_obj)) {
float other = 0.0;
double other = 0.0;
if (PyFloat_Check(other_obj)) {
other = CastPyArg2AttrFloat(other_obj, 0);
other = CastPyArg2Double(other_obj, "__add__", 0);
if (_supported_int_dtype_.find(self_tensor.dtype()) !=
_supported_int_dtype_.end()) {
eager_gil_scoped_release guard;
self_tensor = cast_ad_func(self_tensor, DataType::FLOAT32);
}
} else if (PyCheckInteger(other_obj) || IsNumpyType(other_obj)) {
other = CastPyArg2AttrFloat(other_obj, 0);
other = CastPyArg2Double(other_obj, "__add__", 0);
}
{
......@@ -267,16 +267,16 @@ static PyObject* tensor__sub__method(TensorObject* self,
// 1. scalar exists cases
if (PyFloat_Check(other_obj) || PyCheckInteger(other_obj) ||
IsNumpyType(other_obj)) {
float other = 0.0;
double other = 0.0;
if (PyFloat_Check(other_obj)) {
other = CastPyArg2AttrFloat(other_obj, 0);
other = CastPyArg2Double(other_obj, "__sub__", 0);
if (_supported_int_dtype_.find(self_tensor.dtype()) !=
_supported_int_dtype_.end()) {
eager_gil_scoped_release guard;
self_tensor = cast_ad_func(self_tensor, DataType::FLOAT32);
}
} else if (PyCheckInteger(other_obj) || IsNumpyType(other_obj)) {
other = CastPyArg2AttrFloat(other_obj, 0);
other = CastPyArg2Double(other_obj, "__sub__", 0);
}
{
eager_gil_scoped_release guard;
......@@ -360,16 +360,16 @@ static PyObject* tensor__rsub__method(TensorObject* self,
// 1. scalar exists cases
if (PyFloat_Check(other_obj) || PyCheckInteger(other_obj) ||
IsNumpyType(other_obj)) {
float other = 0.0;
double other = 0.0;
if (PyFloat_Check(other_obj)) {
other = CastPyArg2AttrFloat(other_obj, 0);
other = CastPyArg2Double(other_obj, "__rsub__", 0);
if (_supported_int_dtype_.find(self_tensor.dtype()) !=
_supported_int_dtype_.end()) {
eager_gil_scoped_release guard;
self_tensor = cast_ad_func(self_tensor, DataType::FLOAT32);
}
} else if (PyCheckInteger(other_obj) || IsNumpyType(other_obj)) {
other = CastPyArg2AttrFloat(other_obj, 0);
other = CastPyArg2Double(other_obj, "__rsub__", 0);
}
{
eager_gil_scoped_release guard;
......@@ -455,16 +455,16 @@ static PyObject* tensor__mul__method(TensorObject* self,
// 1. scalar exists cases
if (PyFloat_Check(other_obj) || PyCheckInteger(other_obj) ||
IsNumpyType(other_obj)) {
float other = 0.0;
double other = 0.0;
if (PyFloat_Check(other_obj)) {
other = CastPyArg2AttrFloat(other_obj, 0);
other = CastPyArg2Double(other_obj, "__mul__", 0);
if (_supported_int_dtype_.find(self_tensor.dtype()) !=
_supported_int_dtype_.end()) {
eager_gil_scoped_release guard;
self_tensor = cast_ad_func(self_tensor, DataType::FLOAT32);
}
} else if (PyCheckInteger(other_obj) || IsNumpyType(other_obj)) {
other = CastPyArg2AttrFloat(other_obj, 0);
other = CastPyArg2Double(other_obj, "__mul__", 0);
}
{
eager_gil_scoped_release guard;
......@@ -557,11 +557,11 @@ static PyObject* tensor__div__method(TensorObject* self,
// 1. scalar exists cases
if (PyFloat_Check(other_obj) || PyCheckInteger(other_obj) ||
IsNumpyType(other_obj)) {
float other = 0.0;
double other = 0.0;
if (PyFloat_Check(other_obj)) {
other = CastPyArg2AttrFloat(other_obj, 0);
other = CastPyArg2Double(other_obj, "__div__", 0);
} else if (PyCheckInteger(other_obj) || IsNumpyType(other_obj)) {
other = CastPyArg2AttrFloat(other_obj, 0);
other = CastPyArg2Double(other_obj, "__div__", 0);
}
if (_supported_int_dtype_.find(self_tensor.dtype()) !=
_supported_int_dtype_.end()) {
......@@ -667,16 +667,16 @@ static PyObject* tensor__rdiv__method(TensorObject* self,
// 1. scalar exists cases
// there is no scalar_div function for __rdiv__ and __rtruediv__
float other_float = 0.0;
bool has_other_float = false;
double other_double = 0.0;
bool has_other_double = false;
if (PyFloat_Check(other_obj) || PyCheckInteger(other_obj) ||
IsNumpyType(other_obj)) {
if (PyFloat_Check(other_obj)) {
other_float = CastPyArg2AttrFloat(other_obj, 0);
has_other_float = true;
other_double = CastPyArg2Double(other_obj, "__rdiv__", 0);
has_other_double = true;
} else if (PyCheckInteger(other_obj) || IsNumpyType(other_obj)) {
other_float = CastPyArg2AttrFloat(other_obj, 0);
has_other_float = true;
other_double = CastPyArg2Double(other_obj, "__rdiv__", 0);
has_other_double = true;
}
if (_supported_int_dtype_.find(self_tensor.dtype()) !=
_supported_int_dtype_.end()) {
......@@ -687,10 +687,10 @@ static PyObject* tensor__rdiv__method(TensorObject* self,
// 2. create or get tensor for other_obj
paddle::experimental::Tensor other_tensor;
if (has_other_float) {
if (has_other_double) {
eager_gil_scoped_release guard;
other_tensor = full_ad_func(self_tensor.shape(),
phi::Scalar(other_float),
phi::Scalar(other_double),
self_tensor.dtype(),
place);
} else if (!PyCheckTensor(other_obj)) {
......@@ -781,30 +781,30 @@ static PyObject* tensor__gt__method(TensorObject* self,
// 1. scalar exists cases
// there is no scalar function for __gt__ now
float other_float = 0.0;
bool has_other_float = false;
double other_double = 0.0;
bool has_other_double = false;
if (PyFloat_Check(other_obj) || PyCheckInteger(other_obj) ||
IsNumpyType(other_obj)) {
if (PyFloat_Check(other_obj)) {
other_float = CastPyArg2AttrFloat(other_obj, 0);
has_other_float = true;
other_double = CastPyArg2Double(other_obj, "__gt__", 0);
has_other_double = true;
if (_supported_int_dtype_.find(self_tensor.dtype()) !=
_supported_int_dtype_.end()) {
eager_gil_scoped_release guard;
self_tensor = cast_ad_func(self_tensor, DataType::FLOAT32);
}
} else if (PyCheckInteger(other_obj) || IsNumpyType(other_obj)) {
other_float = CastPyArg2AttrFloat(other_obj, 0);
has_other_float = true;
other_double = CastPyArg2Double(other_obj, "__gt__", 0);
has_other_double = true;
}
}
// 2. create or get tensor for other_obj
paddle::experimental::Tensor other_tensor;
if (has_other_float) {
if (has_other_double) {
eager_gil_scoped_release guard;
other_tensor = full_ad_func(self_tensor.shape(),
phi::Scalar(other_float),
phi::Scalar(other_double),
self_tensor.dtype(),
place);
} else if (!PyCheckTensor(other_obj)) {
......@@ -867,30 +867,30 @@ static PyObject* tensor__ge__method(TensorObject* self,
// 1. scalar exists cases
// there is no scalar function for __ge__ now
float other_float = 0.0;
bool has_other_float = false;
double other_double = 0.0;
bool has_other_double = false;
if (PyFloat_Check(other_obj) || PyCheckInteger(other_obj) ||
IsNumpyType(other_obj)) {
if (PyFloat_Check(other_obj)) {
other_float = CastPyArg2AttrFloat(other_obj, 0);
has_other_float = true;
other_double = CastPyArg2Double(other_obj, "__ge__", 0);
has_other_double = true;
if (_supported_int_dtype_.find(self_tensor.dtype()) !=
_supported_int_dtype_.end()) {
eager_gil_scoped_release guard;
self_tensor = cast_ad_func(self_tensor, DataType::FLOAT32);
}
} else if (PyCheckInteger(other_obj) || IsNumpyType(other_obj)) {
other_float = CastPyArg2AttrFloat(other_obj, 0);
has_other_float = true;
other_double = CastPyArg2Double(other_obj, "__ge__", 0);
has_other_double = true;
}
}
// 2. create or get tensor for other_obj
paddle::experimental::Tensor other_tensor;
if (has_other_float) {
if (has_other_double) {
eager_gil_scoped_release guard;
other_tensor = full_ad_func(self_tensor.shape(),
phi::Scalar(other_float),
phi::Scalar(other_double),
self_tensor.dtype(),
place);
} else if (!PyCheckTensor(other_obj)) {
......@@ -954,30 +954,30 @@ static PyObject* tensor__mod__method(TensorObject* self,
// 1. scalar exists cases
// there is no scalar_mod function for __mod__ now
float other_float = 0.0;
bool has_other_float = false;
float other_double = 0.0;
bool has_other_double = false;
if (PyFloat_Check(other_obj) || PyCheckInteger(other_obj) ||
IsNumpyType(other_obj)) {
if (PyFloat_Check(other_obj)) {
other_float = CastPyArg2AttrFloat(other_obj, 0);
has_other_float = true;
other_double = CastPyArg2Double(other_obj, "__mod__", 0);
has_other_double = true;
if (_supported_int_dtype_.find(self_tensor.dtype()) !=
_supported_int_dtype_.end()) {
eager_gil_scoped_release guard;
self_tensor = cast_ad_func(self_tensor, DataType::FLOAT32);
}
} else if (PyCheckInteger(other_obj) || IsNumpyType(other_obj)) {
other_float = CastPyArg2AttrFloat(other_obj, 0);
has_other_float = true;
other_double = CastPyArg2Double(other_obj, "__mod__", 0);
has_other_double = true;
}
}
// 2. create or get tensor for other_obj
paddle::experimental::Tensor other_tensor;
if (has_other_float) {
if (has_other_double) {
eager_gil_scoped_release guard;
other_tensor = full_ad_func(self_tensor.shape(),
phi::Scalar(other_float),
phi::Scalar(other_double),
self_tensor.dtype(),
self_tensor.place());
} else if (!PyCheckTensor(other_obj)) {
......@@ -1040,30 +1040,30 @@ static PyObject* tensor__matmul__method(TensorObject* self,
// 1. scalar exists cases
// there is no scalar_matmul function for __matmul__ now
float other_float = 0.0;
bool has_other_float = false;
float other_double = 0.0;
bool has_other_double = false;
if (PyFloat_Check(other_obj) || PyCheckInteger(other_obj) ||
IsNumpyType(other_obj)) {
if (PyFloat_Check(other_obj)) {
other_float = CastPyArg2AttrFloat(other_obj, 0);
has_other_float = true;
other_double = CastPyArg2Double(other_obj, "__matmul__", 0);
has_other_double = true;
if (_supported_int_dtype_.find(self_tensor.dtype()) !=
_supported_int_dtype_.end()) {
eager_gil_scoped_release guard;
self_tensor = cast_ad_func(self_tensor, DataType::FLOAT32);
}
} else if (PyCheckInteger(other_obj) || IsNumpyType(other_obj)) {
other_float = CastPyArg2AttrFloat(other_obj, 0);
has_other_float = true;
other_double = CastPyArg2Double(other_obj, "__matmul__", 0);
has_other_double = true;
}
}
// 2. create or get tensor for other_obj
paddle::experimental::Tensor other_tensor;
if (has_other_float) {
if (has_other_double) {
eager_gil_scoped_release guard;
other_tensor = full_ad_func({1},
phi::Scalar(other_float),
phi::Scalar(other_double),
self_tensor.dtype(),
self_tensor.place());
} else if (!PyCheckTensor(other_obj)) {
......@@ -1144,30 +1144,30 @@ static PyObject* tensor__lt__method(TensorObject* self,
// 1. scalar exists cases
// there is no scalar function for __lt__ now
float other_float = 0.0;
bool has_other_float = false;
float other_double = 0.0;
bool has_other_double = false;
if (PyFloat_Check(other_obj) || PyCheckInteger(other_obj) ||
IsNumpyType(other_obj)) {
if (PyFloat_Check(other_obj)) {
other_float = CastPyArg2AttrFloat(other_obj, 0);
has_other_float = true;
other_double = CastPyArg2Double(other_obj, "__lt__", 0);
has_other_double = true;
if (_supported_int_dtype_.find(self_tensor.dtype()) !=
_supported_int_dtype_.end()) {
eager_gil_scoped_release guard;
self_tensor = cast_ad_func(self_tensor, DataType::FLOAT32);
}
} else if (PyCheckInteger(other_obj) || IsNumpyType(other_obj)) {
other_float = CastPyArg2AttrFloat(other_obj, 0);
has_other_float = true;
other_double = CastPyArg2Double(other_obj, "__lt__", 0);
has_other_double = true;
}
}
// 2. create or get tensor for other_obj
paddle::experimental::Tensor other_tensor;
if (has_other_float) {
if (has_other_double) {
eager_gil_scoped_release guard;
other_tensor = full_ad_func(self_tensor.shape(),
phi::Scalar(other_float),
phi::Scalar(other_double),
self_tensor.dtype(),
self_tensor.place());
} else if (!PyCheckTensor(other_obj)) {
......@@ -1230,30 +1230,30 @@ static PyObject* tensor__le__method(TensorObject* self,
// 1. scalar exists cases
// there is no scalar function for __le__ now
float other_float = 0.0;
bool has_other_float = false;
float other_double = 0.0;
bool has_other_double = false;
if (PyFloat_Check(other_obj) || PyCheckInteger(other_obj) ||
IsNumpyType(other_obj)) {
if (PyFloat_Check(other_obj)) {
other_float = CastPyArg2AttrFloat(other_obj, 0);
has_other_float = true;
other_double = CastPyArg2Double(other_obj, "__le__", 0);
has_other_double = true;
if (_supported_int_dtype_.find(self_tensor.dtype()) !=
_supported_int_dtype_.end()) {
eager_gil_scoped_release guard;
self_tensor = cast_ad_func(self_tensor, DataType::FLOAT32);
}
} else if (PyCheckInteger(other_obj) || IsNumpyType(other_obj)) {
other_float = CastPyArg2AttrFloat(other_obj, 0);
has_other_float = true;
other_double = CastPyArg2Double(other_obj, "__le__", 0);
has_other_double = true;
}
}
// 2. create or get tensor for other_obj
paddle::experimental::Tensor other_tensor;
if (has_other_float) {
if (has_other_double) {
eager_gil_scoped_release guard;
other_tensor = full_ad_func(self_tensor.shape(),
phi::Scalar(other_float),
phi::Scalar(other_double),
self_tensor.dtype(),
self_tensor.place());
} else if (!PyCheckTensor(other_obj)) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册