未验证 提交 a0d0bb63 编写于 作者: J Jiabin Yang 提交者: GitHub

[Eager] Support set_grad_ivar for eager (#43378)

* support set_grad_ivar for eager

* support set_grad_ivar for eager

* support set_grad_ivar for eager
上级 edf69ae0
...@@ -1635,6 +1635,26 @@ static PyObject* tensor__grad_value(TensorObject* self, PyObject* args, ...@@ -1635,6 +1635,26 @@ static PyObject* tensor__grad_value(TensorObject* self, PyObject* args,
EAGER_CATCH_AND_THROW_RETURN_NULL EAGER_CATCH_AND_THROW_RETURN_NULL
} }
static PyObject* tensor__unset_fake_empty(TensorObject* self, PyObject* args,
PyObject* kwargs) {
EAGER_TRY
paddle::experimental::Tensor* grad =
egr::EagerUtils::mutable_grad(self->tensor);
PADDLE_ENFORCE_EQ(grad != nullptr, true,
platform::errors::InvalidArgument(
"Detected NULL grad. Please check if you have manually "
"cleared the grad inside autograd_meta"));
bool is_leaf = egr::egr_utils_api::IsLeafTensor(self->tensor);
if (is_leaf) {
std::static_pointer_cast<egr::GradNodeAccumulation>(
egr::EagerUtils::grad_node(self->tensor))
->SetFakeEmpty(false);
}
RETURN_PY_NONE
EAGER_CATCH_AND_THROW_RETURN_NULL
}
#if defined(PADDLE_WITH_CUDA) #if defined(PADDLE_WITH_CUDA)
static PyObject* tensor_method__uva(TensorObject* self, PyObject* args, static PyObject* tensor_method__uva(TensorObject* self, PyObject* args,
PyObject* kwargs) { PyObject* kwargs) {
...@@ -1791,6 +1811,8 @@ PyMethodDef variable_methods[] = { ...@@ -1791,6 +1811,8 @@ PyMethodDef variable_methods[] = {
METH_VARARGS | METH_KEYWORDS, NULL}, METH_VARARGS | METH_KEYWORDS, NULL},
{"_grad_value", (PyCFunction)(void (*)(void))tensor__grad_value, {"_grad_value", (PyCFunction)(void (*)(void))tensor__grad_value,
METH_VARARGS | METH_KEYWORDS, NULL}, METH_VARARGS | METH_KEYWORDS, NULL},
{"_unset_fake_empty", (PyCFunction)(void (*)(void))tensor__unset_fake_empty,
METH_VARARGS | METH_KEYWORDS, NULL},
#if defined(PADDLE_WITH_CUDA) #if defined(PADDLE_WITH_CUDA)
{"_tensor_uva", (PyCFunction)(void (*)(void))tensor_method__uva, {"_tensor_uva", (PyCFunction)(void (*)(void))tensor_method__uva,
METH_VARARGS | METH_KEYWORDS, NULL}, METH_VARARGS | METH_KEYWORDS, NULL},
......
...@@ -804,6 +804,7 @@ def monkey_patch_varbase(): ...@@ -804,6 +804,7 @@ def monkey_patch_varbase():
def _set_grad_ivar(self, value): def _set_grad_ivar(self, value):
if isinstance(self, EagerParamBase): if isinstance(self, EagerParamBase):
self.grad = value self.grad = value
self._unset_fake_empty()
else: else:
raise TypeError( raise TypeError(
"_set_grad_ivar is only supported for Parameter Tensor") "_set_grad_ivar is only supported for Parameter Tensor")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册