未验证 提交 01baa0b6 编写于 作者: W Weilong Wu 提交者: GitHub

[Eager] polish the place setting code (#46840)

上级 67c9b0b3
......@@ -195,8 +195,8 @@ static PyObject* tensor__add__method(TensorObject* self,
CastPyArg2Scalar(other_obj, "__add__", 0);
{
eager_gil_scoped_release guard;
other_tensor =
full_ad_func(self_tensor.shape(), value, self_tensor.dtype(), place);
other_tensor = full_ad_func(
self_tensor.shape(), value, self_tensor.dtype(), self_tensor.place());
}
} else {
other_tensor = CastPyArg2Tensor(other_obj, 0);
......@@ -292,8 +292,8 @@ static PyObject* tensor__sub__method(TensorObject* self,
CastPyArg2Scalar(other_obj, "__sub__", 0);
{
eager_gil_scoped_release guard;
other_tensor =
full_ad_func(self_tensor.shape(), value, self_tensor.dtype(), place);
other_tensor = full_ad_func(
self_tensor.shape(), value, self_tensor.dtype(), self_tensor.place());
}
} else {
other_tensor = CastPyArg2Tensor(other_obj, 0);
......@@ -385,8 +385,8 @@ static PyObject* tensor__rsub__method(TensorObject* self,
CastPyArg2Scalar(other_obj, "__rsub__", 0);
{
eager_gil_scoped_release guard;
other_tensor =
full_ad_func(self_tensor.shape(), value, self_tensor.dtype(), place);
other_tensor = full_ad_func(
self_tensor.shape(), value, self_tensor.dtype(), self_tensor.place());
}
} else {
other_tensor = CastPyArg2Tensor(other_obj, 0);
......@@ -480,11 +480,12 @@ static PyObject* tensor__mul__method(TensorObject* self,
CastPyArg2Scalar(other_obj, "__mul__", 0);
if (PyComplex_Check(other_obj)) {
eager_gil_scoped_release guard;
other_tensor = full_ad_func({1}, value, DataType::COMPLEX64, place);
other_tensor =
full_ad_func({1}, value, DataType::COMPLEX64, self_tensor.place());
} else {
eager_gil_scoped_release guard;
other_tensor =
full_ad_func(self_tensor.shape(), value, self_tensor.dtype(), place);
other_tensor = full_ad_func(
self_tensor.shape(), value, self_tensor.dtype(), self_tensor.place());
}
} else {
other_tensor = CastPyArg2Tensor(other_obj, 0);
......@@ -581,11 +582,12 @@ static PyObject* tensor__div__method(TensorObject* self,
CastPyArg2Scalar(other_obj, "__div__", 0);
if (PyComplex_Check(other_obj)) {
eager_gil_scoped_release guard;
other_tensor = full_ad_func({1}, value, DataType::COMPLEX64, place);
other_tensor =
full_ad_func({1}, value, DataType::COMPLEX64, self_tensor.place());
} else {
eager_gil_scoped_release guard;
other_tensor =
full_ad_func(self_tensor.shape(), value, self_tensor.dtype(), place);
other_tensor = full_ad_func(
self_tensor.shape(), value, self_tensor.dtype(), self_tensor.place());
}
} else {
other_tensor = CastPyArg2Tensor(other_obj, 0);
......@@ -696,11 +698,12 @@ static PyObject* tensor__rdiv__method(TensorObject* self,
CastPyArg2Scalar(other_obj, "__rdiv__", 0);
if (PyComplex_Check(other_obj)) {
eager_gil_scoped_release guard;
other_tensor = full_ad_func({1}, value, DataType::COMPLEX64, place);
other_tensor =
full_ad_func({1}, value, DataType::COMPLEX64, self_tensor.place());
} else {
eager_gil_scoped_release guard;
other_tensor =
full_ad_func(self_tensor.shape(), value, self_tensor.dtype(), place);
other_tensor = full_ad_func(
self_tensor.shape(), value, self_tensor.dtype(), self_tensor.place());
}
} else {
other_tensor = CastPyArg2Tensor(other_obj, 0);
......@@ -809,11 +812,12 @@ static PyObject* tensor__gt__method(TensorObject* self,
CastPyArg2Scalar(other_obj, "__gt__", 0);
if (PyComplex_Check(other_obj)) {
eager_gil_scoped_release guard;
other_tensor = full_ad_func({1}, value, DataType::COMPLEX64, place);
other_tensor =
full_ad_func({1}, value, DataType::COMPLEX64, self_tensor.place());
} else {
eager_gil_scoped_release guard;
other_tensor =
full_ad_func(self_tensor.shape(), value, self_tensor.dtype(), place);
other_tensor = full_ad_func(
self_tensor.shape(), value, self_tensor.dtype(), self_tensor.place());
}
} else {
other_tensor = CastPyArg2Tensor(other_obj, 0);
......@@ -894,11 +898,12 @@ static PyObject* tensor__ge__method(TensorObject* self,
CastPyArg2Scalar(other_obj, "__ge__", 0);
if (PyComplex_Check(other_obj)) {
eager_gil_scoped_release guard;
other_tensor = full_ad_func({1}, value, DataType::COMPLEX64, place);
other_tensor =
full_ad_func({1}, value, DataType::COMPLEX64, self_tensor.place());
} else {
eager_gil_scoped_release guard;
other_tensor =
full_ad_func(self_tensor.shape(), value, self_tensor.dtype(), place);
other_tensor = full_ad_func(
self_tensor.shape(), value, self_tensor.dtype(), self_tensor.place());
}
} else {
other_tensor = CastPyArg2Tensor(other_obj, 0);
......@@ -974,17 +979,18 @@ static PyObject* tensor__mod__method(TensorObject* self,
other_tensor = full_ad_func(self_tensor.shape(),
phi::Scalar(other_float),
self_tensor.dtype(),
place);
self_tensor.place());
} else if (!PyCheckTensor(other_obj)) {
paddle::experimental::Scalar value =
CastPyArg2Scalar(other_obj, "__mod__", 0);
if (PyComplex_Check(other_obj)) {
eager_gil_scoped_release guard;
other_tensor = full_ad_func({1}, value, DataType::COMPLEX64, place);
other_tensor =
full_ad_func({1}, value, DataType::COMPLEX64, self_tensor.place());
} else {
eager_gil_scoped_release guard;
other_tensor =
full_ad_func(self_tensor.shape(), value, self_tensor.dtype(), place);
other_tensor = full_ad_func(
self_tensor.shape(), value, self_tensor.dtype(), self_tensor.place());
}
} else {
other_tensor = CastPyArg2Tensor(other_obj, 0);
......@@ -1056,17 +1062,21 @@ static PyObject* tensor__matmul__method(TensorObject* self,
paddle::experimental::Tensor other_tensor;
if (has_other_float) {
eager_gil_scoped_release guard;
other_tensor =
full_ad_func({1}, phi::Scalar(other_float), self_tensor.dtype(), place);
other_tensor = full_ad_func({1},
phi::Scalar(other_float),
self_tensor.dtype(),
self_tensor.place());
} else if (!PyCheckTensor(other_obj)) {
paddle::experimental::Scalar value =
CastPyArg2Scalar(other_obj, "__matmul__", 0);
if (PyComplex_Check(other_obj)) {
eager_gil_scoped_release guard;
other_tensor = full_ad_func({1}, value, DataType::COMPLEX64, place);
other_tensor =
full_ad_func({1}, value, DataType::COMPLEX64, self_tensor.place());
} else {
eager_gil_scoped_release guard;
other_tensor = full_ad_func({1}, value, self_tensor.dtype(), place);
other_tensor =
full_ad_func({1}, value, self_tensor.dtype(), self_tensor.place());
}
} else {
other_tensor = CastPyArg2Tensor(other_obj, 0);
......@@ -1159,17 +1169,18 @@ static PyObject* tensor__lt__method(TensorObject* self,
other_tensor = full_ad_func(self_tensor.shape(),
phi::Scalar(other_float),
self_tensor.dtype(),
place);
self_tensor.place());
} else if (!PyCheckTensor(other_obj)) {
paddle::experimental::Scalar value =
CastPyArg2Scalar(other_obj, "__lt__", 0);
if (PyComplex_Check(other_obj)) {
eager_gil_scoped_release guard;
other_tensor = full_ad_func({1}, value, DataType::COMPLEX64, place);
other_tensor =
full_ad_func({1}, value, DataType::COMPLEX64, self_tensor.place());
} else {
eager_gil_scoped_release guard;
other_tensor =
full_ad_func(self_tensor.shape(), value, self_tensor.dtype(), place);
other_tensor = full_ad_func(
self_tensor.shape(), value, self_tensor.dtype(), self_tensor.place());
}
} else {
other_tensor = CastPyArg2Tensor(other_obj, 0);
......@@ -1244,17 +1255,18 @@ static PyObject* tensor__le__method(TensorObject* self,
other_tensor = full_ad_func(self_tensor.shape(),
phi::Scalar(other_float),
self_tensor.dtype(),
place);
self_tensor.place());
} else if (!PyCheckTensor(other_obj)) {
paddle::experimental::Scalar value =
CastPyArg2Scalar(other_obj, "__le__", 0);
if (PyComplex_Check(other_obj)) {
eager_gil_scoped_release guard;
other_tensor = full_ad_func({1}, value, DataType::COMPLEX64, place);
other_tensor =
full_ad_func({1}, value, DataType::COMPLEX64, self_tensor.place());
} else {
eager_gil_scoped_release guard;
other_tensor =
full_ad_func(self_tensor.shape(), value, self_tensor.dtype(), place);
other_tensor = full_ad_func(
self_tensor.shape(), value, self_tensor.dtype(), self_tensor.place());
}
} else {
other_tensor = CastPyArg2Tensor(other_obj, 0);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册