未验证 提交 a0d0bb63 编写于 作者: J Jiabin Yang 提交者: GitHub

[Eager] Support set_grad_ivar for eager (#43378)

* support set_grad_ivar for eager

* support set_grad_ivar for eager

* support set_grad_ivar for eager
上级 edf69ae0
......@@ -1635,6 +1635,26 @@ static PyObject* tensor__grad_value(TensorObject* self, PyObject* args,
EAGER_CATCH_AND_THROW_RETURN_NULL
}
static PyObject* tensor__unset_fake_empty(TensorObject* self, PyObject* args,
PyObject* kwargs) {
EAGER_TRY
paddle::experimental::Tensor* grad =
egr::EagerUtils::mutable_grad(self->tensor);
PADDLE_ENFORCE_EQ(grad != nullptr, true,
platform::errors::InvalidArgument(
"Detected NULL grad. Please check if you have manually "
"cleared the grad inside autograd_meta"));
bool is_leaf = egr::egr_utils_api::IsLeafTensor(self->tensor);
if (is_leaf) {
std::static_pointer_cast<egr::GradNodeAccumulation>(
egr::EagerUtils::grad_node(self->tensor))
->SetFakeEmpty(false);
}
RETURN_PY_NONE
EAGER_CATCH_AND_THROW_RETURN_NULL
}
#if defined(PADDLE_WITH_CUDA)
static PyObject* tensor_method__uva(TensorObject* self, PyObject* args,
PyObject* kwargs) {
......@@ -1791,6 +1811,8 @@ PyMethodDef variable_methods[] = {
METH_VARARGS | METH_KEYWORDS, NULL},
{"_grad_value", (PyCFunction)(void (*)(void))tensor__grad_value,
METH_VARARGS | METH_KEYWORDS, NULL},
{"_unset_fake_empty", (PyCFunction)(void (*)(void))tensor__unset_fake_empty,
METH_VARARGS | METH_KEYWORDS, NULL},
#if defined(PADDLE_WITH_CUDA)
{"_tensor_uva", (PyCFunction)(void (*)(void))tensor_method__uva,
METH_VARARGS | METH_KEYWORDS, NULL},
......
......@@ -804,6 +804,7 @@ def monkey_patch_varbase():
def _set_grad_ivar(self, value):
if isinstance(self, EagerParamBase):
self.grad = value
self._unset_fake_empty()
else:
raise TypeError(
"_set_grad_ivar is only supported for Parameter Tensor")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册