未验证 提交 01baa0b6 编写于 作者: W Weilong Wu 提交者: GitHub

[Eager] polish the place setting code (#46840)

上级 67c9b0b3
...@@ -195,8 +195,8 @@ static PyObject* tensor__add__method(TensorObject* self, ...@@ -195,8 +195,8 @@ static PyObject* tensor__add__method(TensorObject* self,
CastPyArg2Scalar(other_obj, "__add__", 0); CastPyArg2Scalar(other_obj, "__add__", 0);
{ {
eager_gil_scoped_release guard; eager_gil_scoped_release guard;
other_tensor = other_tensor = full_ad_func(
full_ad_func(self_tensor.shape(), value, self_tensor.dtype(), place); self_tensor.shape(), value, self_tensor.dtype(), self_tensor.place());
} }
} else { } else {
other_tensor = CastPyArg2Tensor(other_obj, 0); other_tensor = CastPyArg2Tensor(other_obj, 0);
...@@ -292,8 +292,8 @@ static PyObject* tensor__sub__method(TensorObject* self, ...@@ -292,8 +292,8 @@ static PyObject* tensor__sub__method(TensorObject* self,
CastPyArg2Scalar(other_obj, "__sub__", 0); CastPyArg2Scalar(other_obj, "__sub__", 0);
{ {
eager_gil_scoped_release guard; eager_gil_scoped_release guard;
other_tensor = other_tensor = full_ad_func(
full_ad_func(self_tensor.shape(), value, self_tensor.dtype(), place); self_tensor.shape(), value, self_tensor.dtype(), self_tensor.place());
} }
} else { } else {
other_tensor = CastPyArg2Tensor(other_obj, 0); other_tensor = CastPyArg2Tensor(other_obj, 0);
...@@ -385,8 +385,8 @@ static PyObject* tensor__rsub__method(TensorObject* self, ...@@ -385,8 +385,8 @@ static PyObject* tensor__rsub__method(TensorObject* self,
CastPyArg2Scalar(other_obj, "__rsub__", 0); CastPyArg2Scalar(other_obj, "__rsub__", 0);
{ {
eager_gil_scoped_release guard; eager_gil_scoped_release guard;
other_tensor = other_tensor = full_ad_func(
full_ad_func(self_tensor.shape(), value, self_tensor.dtype(), place); self_tensor.shape(), value, self_tensor.dtype(), self_tensor.place());
} }
} else { } else {
other_tensor = CastPyArg2Tensor(other_obj, 0); other_tensor = CastPyArg2Tensor(other_obj, 0);
...@@ -480,11 +480,12 @@ static PyObject* tensor__mul__method(TensorObject* self, ...@@ -480,11 +480,12 @@ static PyObject* tensor__mul__method(TensorObject* self,
CastPyArg2Scalar(other_obj, "__mul__", 0); CastPyArg2Scalar(other_obj, "__mul__", 0);
if (PyComplex_Check(other_obj)) { if (PyComplex_Check(other_obj)) {
eager_gil_scoped_release guard; eager_gil_scoped_release guard;
other_tensor = full_ad_func({1}, value, DataType::COMPLEX64, place); other_tensor =
full_ad_func({1}, value, DataType::COMPLEX64, self_tensor.place());
} else { } else {
eager_gil_scoped_release guard; eager_gil_scoped_release guard;
other_tensor = other_tensor = full_ad_func(
full_ad_func(self_tensor.shape(), value, self_tensor.dtype(), place); self_tensor.shape(), value, self_tensor.dtype(), self_tensor.place());
} }
} else { } else {
other_tensor = CastPyArg2Tensor(other_obj, 0); other_tensor = CastPyArg2Tensor(other_obj, 0);
...@@ -581,11 +582,12 @@ static PyObject* tensor__div__method(TensorObject* self, ...@@ -581,11 +582,12 @@ static PyObject* tensor__div__method(TensorObject* self,
CastPyArg2Scalar(other_obj, "__div__", 0); CastPyArg2Scalar(other_obj, "__div__", 0);
if (PyComplex_Check(other_obj)) { if (PyComplex_Check(other_obj)) {
eager_gil_scoped_release guard; eager_gil_scoped_release guard;
other_tensor = full_ad_func({1}, value, DataType::COMPLEX64, place); other_tensor =
full_ad_func({1}, value, DataType::COMPLEX64, self_tensor.place());
} else { } else {
eager_gil_scoped_release guard; eager_gil_scoped_release guard;
other_tensor = other_tensor = full_ad_func(
full_ad_func(self_tensor.shape(), value, self_tensor.dtype(), place); self_tensor.shape(), value, self_tensor.dtype(), self_tensor.place());
} }
} else { } else {
other_tensor = CastPyArg2Tensor(other_obj, 0); other_tensor = CastPyArg2Tensor(other_obj, 0);
...@@ -696,11 +698,12 @@ static PyObject* tensor__rdiv__method(TensorObject* self, ...@@ -696,11 +698,12 @@ static PyObject* tensor__rdiv__method(TensorObject* self,
CastPyArg2Scalar(other_obj, "__rdiv__", 0); CastPyArg2Scalar(other_obj, "__rdiv__", 0);
if (PyComplex_Check(other_obj)) { if (PyComplex_Check(other_obj)) {
eager_gil_scoped_release guard; eager_gil_scoped_release guard;
other_tensor = full_ad_func({1}, value, DataType::COMPLEX64, place); other_tensor =
full_ad_func({1}, value, DataType::COMPLEX64, self_tensor.place());
} else { } else {
eager_gil_scoped_release guard; eager_gil_scoped_release guard;
other_tensor = other_tensor = full_ad_func(
full_ad_func(self_tensor.shape(), value, self_tensor.dtype(), place); self_tensor.shape(), value, self_tensor.dtype(), self_tensor.place());
} }
} else { } else {
other_tensor = CastPyArg2Tensor(other_obj, 0); other_tensor = CastPyArg2Tensor(other_obj, 0);
...@@ -809,11 +812,12 @@ static PyObject* tensor__gt__method(TensorObject* self, ...@@ -809,11 +812,12 @@ static PyObject* tensor__gt__method(TensorObject* self,
CastPyArg2Scalar(other_obj, "__gt__", 0); CastPyArg2Scalar(other_obj, "__gt__", 0);
if (PyComplex_Check(other_obj)) { if (PyComplex_Check(other_obj)) {
eager_gil_scoped_release guard; eager_gil_scoped_release guard;
other_tensor = full_ad_func({1}, value, DataType::COMPLEX64, place); other_tensor =
full_ad_func({1}, value, DataType::COMPLEX64, self_tensor.place());
} else { } else {
eager_gil_scoped_release guard; eager_gil_scoped_release guard;
other_tensor = other_tensor = full_ad_func(
full_ad_func(self_tensor.shape(), value, self_tensor.dtype(), place); self_tensor.shape(), value, self_tensor.dtype(), self_tensor.place());
} }
} else { } else {
other_tensor = CastPyArg2Tensor(other_obj, 0); other_tensor = CastPyArg2Tensor(other_obj, 0);
...@@ -894,11 +898,12 @@ static PyObject* tensor__ge__method(TensorObject* self, ...@@ -894,11 +898,12 @@ static PyObject* tensor__ge__method(TensorObject* self,
CastPyArg2Scalar(other_obj, "__ge__", 0); CastPyArg2Scalar(other_obj, "__ge__", 0);
if (PyComplex_Check(other_obj)) { if (PyComplex_Check(other_obj)) {
eager_gil_scoped_release guard; eager_gil_scoped_release guard;
other_tensor = full_ad_func({1}, value, DataType::COMPLEX64, place); other_tensor =
full_ad_func({1}, value, DataType::COMPLEX64, self_tensor.place());
} else { } else {
eager_gil_scoped_release guard; eager_gil_scoped_release guard;
other_tensor = other_tensor = full_ad_func(
full_ad_func(self_tensor.shape(), value, self_tensor.dtype(), place); self_tensor.shape(), value, self_tensor.dtype(), self_tensor.place());
} }
} else { } else {
other_tensor = CastPyArg2Tensor(other_obj, 0); other_tensor = CastPyArg2Tensor(other_obj, 0);
...@@ -974,17 +979,18 @@ static PyObject* tensor__mod__method(TensorObject* self, ...@@ -974,17 +979,18 @@ static PyObject* tensor__mod__method(TensorObject* self,
other_tensor = full_ad_func(self_tensor.shape(), other_tensor = full_ad_func(self_tensor.shape(),
phi::Scalar(other_float), phi::Scalar(other_float),
self_tensor.dtype(), self_tensor.dtype(),
place); self_tensor.place());
} else if (!PyCheckTensor(other_obj)) { } else if (!PyCheckTensor(other_obj)) {
paddle::experimental::Scalar value = paddle::experimental::Scalar value =
CastPyArg2Scalar(other_obj, "__mod__", 0); CastPyArg2Scalar(other_obj, "__mod__", 0);
if (PyComplex_Check(other_obj)) { if (PyComplex_Check(other_obj)) {
eager_gil_scoped_release guard; eager_gil_scoped_release guard;
other_tensor = full_ad_func({1}, value, DataType::COMPLEX64, place); other_tensor =
full_ad_func({1}, value, DataType::COMPLEX64, self_tensor.place());
} else { } else {
eager_gil_scoped_release guard; eager_gil_scoped_release guard;
other_tensor = other_tensor = full_ad_func(
full_ad_func(self_tensor.shape(), value, self_tensor.dtype(), place); self_tensor.shape(), value, self_tensor.dtype(), self_tensor.place());
} }
} else { } else {
other_tensor = CastPyArg2Tensor(other_obj, 0); other_tensor = CastPyArg2Tensor(other_obj, 0);
...@@ -1056,17 +1062,21 @@ static PyObject* tensor__matmul__method(TensorObject* self, ...@@ -1056,17 +1062,21 @@ static PyObject* tensor__matmul__method(TensorObject* self,
paddle::experimental::Tensor other_tensor; paddle::experimental::Tensor other_tensor;
if (has_other_float) { if (has_other_float) {
eager_gil_scoped_release guard; eager_gil_scoped_release guard;
other_tensor = other_tensor = full_ad_func({1},
full_ad_func({1}, phi::Scalar(other_float), self_tensor.dtype(), place); phi::Scalar(other_float),
self_tensor.dtype(),
self_tensor.place());
} else if (!PyCheckTensor(other_obj)) { } else if (!PyCheckTensor(other_obj)) {
paddle::experimental::Scalar value = paddle::experimental::Scalar value =
CastPyArg2Scalar(other_obj, "__matmul__", 0); CastPyArg2Scalar(other_obj, "__matmul__", 0);
if (PyComplex_Check(other_obj)) { if (PyComplex_Check(other_obj)) {
eager_gil_scoped_release guard; eager_gil_scoped_release guard;
other_tensor = full_ad_func({1}, value, DataType::COMPLEX64, place); other_tensor =
full_ad_func({1}, value, DataType::COMPLEX64, self_tensor.place());
} else { } else {
eager_gil_scoped_release guard; eager_gil_scoped_release guard;
other_tensor = full_ad_func({1}, value, self_tensor.dtype(), place); other_tensor =
full_ad_func({1}, value, self_tensor.dtype(), self_tensor.place());
} }
} else { } else {
other_tensor = CastPyArg2Tensor(other_obj, 0); other_tensor = CastPyArg2Tensor(other_obj, 0);
...@@ -1159,17 +1169,18 @@ static PyObject* tensor__lt__method(TensorObject* self, ...@@ -1159,17 +1169,18 @@ static PyObject* tensor__lt__method(TensorObject* self,
other_tensor = full_ad_func(self_tensor.shape(), other_tensor = full_ad_func(self_tensor.shape(),
phi::Scalar(other_float), phi::Scalar(other_float),
self_tensor.dtype(), self_tensor.dtype(),
place); self_tensor.place());
} else if (!PyCheckTensor(other_obj)) { } else if (!PyCheckTensor(other_obj)) {
paddle::experimental::Scalar value = paddle::experimental::Scalar value =
CastPyArg2Scalar(other_obj, "__lt__", 0); CastPyArg2Scalar(other_obj, "__lt__", 0);
if (PyComplex_Check(other_obj)) { if (PyComplex_Check(other_obj)) {
eager_gil_scoped_release guard; eager_gil_scoped_release guard;
other_tensor = full_ad_func({1}, value, DataType::COMPLEX64, place); other_tensor =
full_ad_func({1}, value, DataType::COMPLEX64, self_tensor.place());
} else { } else {
eager_gil_scoped_release guard; eager_gil_scoped_release guard;
other_tensor = other_tensor = full_ad_func(
full_ad_func(self_tensor.shape(), value, self_tensor.dtype(), place); self_tensor.shape(), value, self_tensor.dtype(), self_tensor.place());
} }
} else { } else {
other_tensor = CastPyArg2Tensor(other_obj, 0); other_tensor = CastPyArg2Tensor(other_obj, 0);
...@@ -1244,17 +1255,18 @@ static PyObject* tensor__le__method(TensorObject* self, ...@@ -1244,17 +1255,18 @@ static PyObject* tensor__le__method(TensorObject* self,
other_tensor = full_ad_func(self_tensor.shape(), other_tensor = full_ad_func(self_tensor.shape(),
phi::Scalar(other_float), phi::Scalar(other_float),
self_tensor.dtype(), self_tensor.dtype(),
place); self_tensor.place());
} else if (!PyCheckTensor(other_obj)) { } else if (!PyCheckTensor(other_obj)) {
paddle::experimental::Scalar value = paddle::experimental::Scalar value =
CastPyArg2Scalar(other_obj, "__le__", 0); CastPyArg2Scalar(other_obj, "__le__", 0);
if (PyComplex_Check(other_obj)) { if (PyComplex_Check(other_obj)) {
eager_gil_scoped_release guard; eager_gil_scoped_release guard;
other_tensor = full_ad_func({1}, value, DataType::COMPLEX64, place); other_tensor =
full_ad_func({1}, value, DataType::COMPLEX64, self_tensor.place());
} else { } else {
eager_gil_scoped_release guard; eager_gil_scoped_release guard;
other_tensor = other_tensor = full_ad_func(
full_ad_func(self_tensor.shape(), value, self_tensor.dtype(), place); self_tensor.shape(), value, self_tensor.dtype(), self_tensor.place());
} }
} else { } else {
other_tensor = CastPyArg2Tensor(other_obj, 0); other_tensor = CastPyArg2Tensor(other_obj, 0);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册