未验证 提交 1e0ea6a4 编写于 作者: W wanghuancoder 提交者: GitHub

[Eager] clear_gradient use set_constant but not zeros_like (#43171)

* clear_gradient use set_constant but not zeros_like
上级 3fcfcd51
...@@ -53,6 +53,7 @@ typedef SSIZE_T ssize_t; ...@@ -53,6 +53,7 @@ typedef SSIZE_T ssize_t;
#include "paddle/fluid/memory/allocation/mmap_allocator.h" #include "paddle/fluid/memory/allocation/mmap_allocator.h"
#include "paddle/fluid/pybind/tensor_py.h" #include "paddle/fluid/pybind/tensor_py.h"
#include "paddle/phi/core/ddim.h" #include "paddle/phi/core/ddim.h"
#include "paddle/phi/kernels/funcs/math_function.h"
namespace paddle { namespace paddle {
namespace pybind { namespace pybind {
...@@ -518,7 +519,10 @@ static PyObject* tensor_clear_gradient(TensorObject* self, PyObject* args, ...@@ -518,7 +519,10 @@ static PyObject* tensor_clear_gradient(TensorObject* self, PyObject* args,
} else if (grad->is_dense_tensor()) { } else if (grad->is_dense_tensor()) {
if (grad->initialized()) { if (grad->initialized()) {
if (set_to_zero) { if (set_to_zero) {
grad->set_impl(paddle::experimental::zeros_like(*grad).impl()); auto* grad_t = static_cast<phi::DenseTensor*>(grad->impl().get());
auto* dev_ctx =
platform::DeviceContextPool::Instance().Get(grad_t->place());
phi::funcs::set_constant(*dev_ctx, grad_t, 0.0);
if (is_leaf) { if (is_leaf) {
std::static_pointer_cast<egr::GradNodeAccumulation>( std::static_pointer_cast<egr::GradNodeAccumulation>(
egr::EagerUtils::grad_node(self->tensor)) egr::EagerUtils::grad_node(self->tensor))
...@@ -555,15 +559,28 @@ static PyObject* tensor__zero_grads(TensorObject* self, PyObject* args, ...@@ -555,15 +559,28 @@ static PyObject* tensor__zero_grads(TensorObject* self, PyObject* args,
"Please check if you have manually cleared" "Please check if you have manually cleared"
"the grad inside autograd_meta")); "the grad inside autograd_meta"));
if (grad->initialized()) { if (grad->initialized()) {
if (grad->is_dense_tensor()) {
auto* t = static_cast<phi::DenseTensor*>(grad->impl().get());
auto* dev_ctx = platform::DeviceContextPool::Instance().Get(t->place());
phi::funcs::set_constant(*dev_ctx, t, 0.0);
} else {
grad->set_impl(paddle::experimental::zeros_like(*(grad)).impl()); grad->set_impl(paddle::experimental::zeros_like(*(grad)).impl());
} }
}
} else { } else {
auto meta = egr::EagerUtils::unsafe_autograd_meta(self->tensor); auto meta = egr::EagerUtils::unsafe_autograd_meta(self->tensor);
if (meta->MutableGrad()->initialized()) { if (meta->MutableGrad()->initialized()) {
if (meta->MutableGrad()->is_dense_tensor()) {
auto* t =
static_cast<phi::DenseTensor*>(meta->MutableGrad()->impl().get());
auto* dev_ctx = platform::DeviceContextPool::Instance().Get(t->place());
phi::funcs::set_constant(*dev_ctx, t, 0.0);
} else {
meta->MutableGrad()->set_impl( meta->MutableGrad()->set_impl(
paddle::experimental::zeros_like(*(meta->MutableGrad())).impl()); paddle::experimental::zeros_like(*(meta->MutableGrad())).impl());
} }
} }
}
RETURN_PY_NONE RETURN_PY_NONE
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册