From 01baa0b62deab168b14c94ca86fd042b70a22540 Mon Sep 17 00:00:00 2001 From: Weilong Wu Date: Wed, 12 Oct 2022 19:02:44 +0800 Subject: [PATCH] [Eager] polish the place setting code (#46840) --- paddle/fluid/pybind/eager_math_op_patch.cc | 86 ++++++++++++---------- 1 file changed, 49 insertions(+), 37 deletions(-) diff --git a/paddle/fluid/pybind/eager_math_op_patch.cc b/paddle/fluid/pybind/eager_math_op_patch.cc index f6ace5a9fef..dc85c17e6d6 100644 --- a/paddle/fluid/pybind/eager_math_op_patch.cc +++ b/paddle/fluid/pybind/eager_math_op_patch.cc @@ -195,8 +195,8 @@ static PyObject* tensor__add__method(TensorObject* self, CastPyArg2Scalar(other_obj, "__add__", 0); { eager_gil_scoped_release guard; - other_tensor = - full_ad_func(self_tensor.shape(), value, self_tensor.dtype(), place); + other_tensor = full_ad_func( + self_tensor.shape(), value, self_tensor.dtype(), self_tensor.place()); } } else { other_tensor = CastPyArg2Tensor(other_obj, 0); @@ -292,8 +292,8 @@ static PyObject* tensor__sub__method(TensorObject* self, CastPyArg2Scalar(other_obj, "__sub__", 0); { eager_gil_scoped_release guard; - other_tensor = - full_ad_func(self_tensor.shape(), value, self_tensor.dtype(), place); + other_tensor = full_ad_func( + self_tensor.shape(), value, self_tensor.dtype(), self_tensor.place()); } } else { other_tensor = CastPyArg2Tensor(other_obj, 0); @@ -385,8 +385,8 @@ static PyObject* tensor__rsub__method(TensorObject* self, CastPyArg2Scalar(other_obj, "__rsub__", 0); { eager_gil_scoped_release guard; - other_tensor = - full_ad_func(self_tensor.shape(), value, self_tensor.dtype(), place); + other_tensor = full_ad_func( + self_tensor.shape(), value, self_tensor.dtype(), self_tensor.place()); } } else { other_tensor = CastPyArg2Tensor(other_obj, 0); @@ -480,11 +480,12 @@ static PyObject* tensor__mul__method(TensorObject* self, CastPyArg2Scalar(other_obj, "__mul__", 0); if (PyComplex_Check(other_obj)) { eager_gil_scoped_release guard; - other_tensor = full_ad_func({1}, value, DataType::COMPLEX64, place); + other_tensor = + full_ad_func({1}, value, DataType::COMPLEX64, self_tensor.place()); } else { eager_gil_scoped_release guard; - other_tensor = - full_ad_func(self_tensor.shape(), value, self_tensor.dtype(), place); + other_tensor = full_ad_func( + self_tensor.shape(), value, self_tensor.dtype(), self_tensor.place()); } } else { other_tensor = CastPyArg2Tensor(other_obj, 0); @@ -581,11 +582,12 @@ static PyObject* tensor__div__method(TensorObject* self, CastPyArg2Scalar(other_obj, "__div__", 0); if (PyComplex_Check(other_obj)) { eager_gil_scoped_release guard; - other_tensor = full_ad_func({1}, value, DataType::COMPLEX64, place); + other_tensor = + full_ad_func({1}, value, DataType::COMPLEX64, self_tensor.place()); } else { eager_gil_scoped_release guard; - other_tensor = - full_ad_func(self_tensor.shape(), value, self_tensor.dtype(), place); + other_tensor = full_ad_func( + self_tensor.shape(), value, self_tensor.dtype(), self_tensor.place()); } } else { other_tensor = CastPyArg2Tensor(other_obj, 0); @@ -696,11 +698,12 @@ static PyObject* tensor__rdiv__method(TensorObject* self, CastPyArg2Scalar(other_obj, "__rdiv__", 0); if (PyComplex_Check(other_obj)) { eager_gil_scoped_release guard; - other_tensor = full_ad_func({1}, value, DataType::COMPLEX64, place); + other_tensor = + full_ad_func({1}, value, DataType::COMPLEX64, self_tensor.place()); } else { eager_gil_scoped_release guard; - other_tensor = - full_ad_func(self_tensor.shape(), value, self_tensor.dtype(), place); + other_tensor = full_ad_func( + self_tensor.shape(), value, self_tensor.dtype(), self_tensor.place()); } } else { other_tensor = CastPyArg2Tensor(other_obj, 0); @@ -809,11 +812,12 @@ static PyObject* tensor__gt__method(TensorObject* self, CastPyArg2Scalar(other_obj, "__gt__", 0); if (PyComplex_Check(other_obj)) { eager_gil_scoped_release guard; - other_tensor = full_ad_func({1}, value, DataType::COMPLEX64, place); + other_tensor = + full_ad_func({1}, value, DataType::COMPLEX64, self_tensor.place()); } else { eager_gil_scoped_release guard; - other_tensor = - full_ad_func(self_tensor.shape(), value, self_tensor.dtype(), place); + other_tensor = full_ad_func( + self_tensor.shape(), value, self_tensor.dtype(), self_tensor.place()); } } else { other_tensor = CastPyArg2Tensor(other_obj, 0); @@ -894,11 +898,12 @@ static PyObject* tensor__ge__method(TensorObject* self, CastPyArg2Scalar(other_obj, "__ge__", 0); if (PyComplex_Check(other_obj)) { eager_gil_scoped_release guard; - other_tensor = full_ad_func({1}, value, DataType::COMPLEX64, place); + other_tensor = + full_ad_func({1}, value, DataType::COMPLEX64, self_tensor.place()); } else { eager_gil_scoped_release guard; - other_tensor = - full_ad_func(self_tensor.shape(), value, self_tensor.dtype(), place); + other_tensor = full_ad_func( + self_tensor.shape(), value, self_tensor.dtype(), self_tensor.place()); } } else { other_tensor = CastPyArg2Tensor(other_obj, 0); @@ -974,17 +979,18 @@ static PyObject* tensor__mod__method(TensorObject* self, other_tensor = full_ad_func(self_tensor.shape(), phi::Scalar(other_float), self_tensor.dtype(), - place); + self_tensor.place()); } else if (!PyCheckTensor(other_obj)) { paddle::experimental::Scalar value = CastPyArg2Scalar(other_obj, "__mod__", 0); if (PyComplex_Check(other_obj)) { eager_gil_scoped_release guard; - other_tensor = full_ad_func({1}, value, DataType::COMPLEX64, place); + other_tensor = + full_ad_func({1}, value, DataType::COMPLEX64, self_tensor.place()); } else { eager_gil_scoped_release guard; - other_tensor = - full_ad_func(self_tensor.shape(), value, self_tensor.dtype(), place); + other_tensor = full_ad_func( + self_tensor.shape(), value, self_tensor.dtype(), self_tensor.place()); } } else { other_tensor = CastPyArg2Tensor(other_obj, 0); @@ -1056,17 +1062,21 @@ static PyObject* tensor__matmul__method(TensorObject* self, paddle::experimental::Tensor other_tensor; if (has_other_float) { eager_gil_scoped_release guard; - other_tensor = - full_ad_func({1}, phi::Scalar(other_float), self_tensor.dtype(), place); + other_tensor = full_ad_func({1}, + phi::Scalar(other_float), + self_tensor.dtype(), + self_tensor.place()); } else if (!PyCheckTensor(other_obj)) { paddle::experimental::Scalar value = CastPyArg2Scalar(other_obj, "__matmul__", 0); if (PyComplex_Check(other_obj)) { eager_gil_scoped_release guard; - other_tensor = full_ad_func({1}, value, DataType::COMPLEX64, place); + other_tensor = + full_ad_func({1}, value, DataType::COMPLEX64, self_tensor.place()); } else { eager_gil_scoped_release guard; - other_tensor = full_ad_func({1}, value, self_tensor.dtype(), place); + other_tensor = + full_ad_func({1}, value, self_tensor.dtype(), self_tensor.place()); } } else { other_tensor = CastPyArg2Tensor(other_obj, 0); @@ -1159,17 +1169,18 @@ static PyObject* tensor__lt__method(TensorObject* self, other_tensor = full_ad_func(self_tensor.shape(), phi::Scalar(other_float), self_tensor.dtype(), - place); + self_tensor.place()); } else if (!PyCheckTensor(other_obj)) { paddle::experimental::Scalar value = CastPyArg2Scalar(other_obj, "__lt__", 0); if (PyComplex_Check(other_obj)) { eager_gil_scoped_release guard; - other_tensor = full_ad_func({1}, value, DataType::COMPLEX64, place); + other_tensor = + full_ad_func({1}, value, DataType::COMPLEX64, self_tensor.place()); } else { eager_gil_scoped_release guard; - other_tensor = - full_ad_func(self_tensor.shape(), value, self_tensor.dtype(), place); + other_tensor = full_ad_func( + self_tensor.shape(), value, self_tensor.dtype(), self_tensor.place()); } } else { other_tensor = CastPyArg2Tensor(other_obj, 0); @@ -1244,17 +1255,18 @@ static PyObject* tensor__le__method(TensorObject* self, other_tensor = full_ad_func(self_tensor.shape(), phi::Scalar(other_float), self_tensor.dtype(), - place); + self_tensor.place()); } else if (!PyCheckTensor(other_obj)) { paddle::experimental::Scalar value = CastPyArg2Scalar(other_obj, "__le__", 0); if (PyComplex_Check(other_obj)) { eager_gil_scoped_release guard; - other_tensor = full_ad_func({1}, value, DataType::COMPLEX64, place); + other_tensor = + full_ad_func({1}, value, DataType::COMPLEX64, self_tensor.place()); } else { eager_gil_scoped_release guard; - other_tensor = - full_ad_func(self_tensor.shape(), value, self_tensor.dtype(), place); + other_tensor = full_ad_func( + self_tensor.shape(), value, self_tensor.dtype(), self_tensor.place()); } } else { other_tensor = CastPyArg2Tensor(other_obj, 0); -- GitLab