diff --git a/paddle/fluid/pybind/eager_method.cc b/paddle/fluid/pybind/eager_method.cc index b54f4e1416c35a9cac5c5f856ade511a02c9b1f5..7831530bff03bf834b127aece0af8f75c8ea93c7 100644 --- a/paddle/fluid/pybind/eager_method.cc +++ b/paddle/fluid/pybind/eager_method.cc @@ -53,6 +53,7 @@ typedef SSIZE_T ssize_t; #include "paddle/fluid/memory/allocation/mmap_allocator.h" #include "paddle/fluid/pybind/tensor_py.h" #include "paddle/phi/core/ddim.h" +#include "paddle/phi/kernels/funcs/math_function.h" namespace paddle { namespace pybind { @@ -518,7 +519,10 @@ static PyObject* tensor_clear_gradient(TensorObject* self, PyObject* args, } else if (grad->is_dense_tensor()) { if (grad->initialized()) { if (set_to_zero) { - grad->set_impl(paddle::experimental::zeros_like(*grad).impl()); + auto* grad_t = static_cast(grad->impl().get()); + auto* dev_ctx = + platform::DeviceContextPool::Instance().Get(grad_t->place()); + phi::funcs::set_constant(*dev_ctx, grad_t, 0.0); if (is_leaf) { std::static_pointer_cast( egr::EagerUtils::grad_node(self->tensor)) @@ -555,13 +559,26 @@ static PyObject* tensor__zero_grads(TensorObject* self, PyObject* args, "Please check if you have manually cleared" "the grad inside autograd_meta")); if (grad->initialized()) { - grad->set_impl(paddle::experimental::zeros_like(*(grad)).impl()); + if (grad->is_dense_tensor()) { + auto* t = static_cast(grad->impl().get()); + auto* dev_ctx = platform::DeviceContextPool::Instance().Get(t->place()); + phi::funcs::set_constant(*dev_ctx, t, 0.0); + } else { + grad->set_impl(paddle::experimental::zeros_like(*(grad)).impl()); + } } } else { auto meta = egr::EagerUtils::unsafe_autograd_meta(self->tensor); if (meta->MutableGrad()->initialized()) { - meta->MutableGrad()->set_impl( - paddle::experimental::zeros_like(*(meta->MutableGrad())).impl()); + if (meta->MutableGrad()->is_dense_tensor()) { + auto* t = + static_cast(meta->MutableGrad()->impl().get()); + auto* dev_ctx = platform::DeviceContextPool::Instance().Get(t->place()); + phi::funcs::set_constant(*dev_ctx, t, 0.0); + } else { + meta->MutableGrad()->set_impl( + paddle::experimental::zeros_like(*(meta->MutableGrad())).impl()); + } } }