未验证 提交 93cee48e 编写于 作者: W wanghuancoder 提交者: GitHub

refine _grad_ivar (#49787)

上级 1508cae7
...@@ -1893,6 +1893,21 @@ static PyObject* tensor_data_ptr(TensorObject* self, ...@@ -1893,6 +1893,21 @@ static PyObject* tensor_data_ptr(TensorObject* self,
EAGER_CATCH_AND_THROW_RETURN_NULL EAGER_CATCH_AND_THROW_RETURN_NULL
} }
static PyObject* tensor__grad_ivar(TensorObject* self,
PyObject* args,
PyObject* kwargs) {
EAGER_TRY
VLOG(6) << "Get grad for tensor: " << self->tensor.name();
auto meta = egr::EagerUtils::nullable_autograd_meta(self->tensor);
VLOG(6) << meta << " initialized: " << meta->Grad().initialized();
if (meta && meta->Grad().initialized()) {
return ToPyObject(meta->Grad());
} else {
RETURN_PY_NONE
}
EAGER_CATCH_AND_THROW_RETURN_NULL
}
#if defined(PADDLE_WITH_CUDA) #if defined(PADDLE_WITH_CUDA)
static PyObject* tensor_method__uva(TensorObject* self, static PyObject* tensor_method__uva(TensorObject* self,
PyObject* args, PyObject* args,
...@@ -2152,6 +2167,10 @@ PyMethodDef variable_methods[] = { ...@@ -2152,6 +2167,10 @@ PyMethodDef variable_methods[] = {
(PyCFunction)(void (*)(void))tensor_data_ptr, (PyCFunction)(void (*)(void))tensor_data_ptr,
METH_VARARGS | METH_KEYWORDS, METH_VARARGS | METH_KEYWORDS,
NULL}, NULL},
{"_grad_ivar",
(PyCFunction)(void (*)(void))tensor__grad_ivar,
METH_VARARGS | METH_KEYWORDS,
NULL},
#if defined(PADDLE_WITH_CUDA) #if defined(PADDLE_WITH_CUDA)
{"_tensor_uva", {"_tensor_uva",
(PyCFunction)(void (*)(void))tensor_method__uva, (PyCFunction)(void (*)(void))tensor_method__uva,
......
...@@ -856,13 +856,6 @@ def monkey_patch_varbase(): ...@@ -856,13 +856,6 @@ def monkey_patch_varbase():
# Call c++ func __setitem_varbase__ to speedup. # Call c++ func __setitem_varbase__ to speedup.
return self.__setitem_varbase__(item, value) return self.__setitem_varbase__(item, value)
@framework.dygraph_only
def _grad_ivar(self):
if self.grad is not None:
if self.grad._is_initialized():
return self.grad
return None
@framework.dygraph_only @framework.dygraph_only
def _set_grad_ivar(self, value): def _set_grad_ivar(self, value):
if isinstance(self, EagerParamBase): if isinstance(self, EagerParamBase):
...@@ -1060,7 +1053,6 @@ def monkey_patch_varbase(): ...@@ -1060,7 +1053,6 @@ def monkey_patch_varbase():
setattr(core.VarBase, method_name, method) setattr(core.VarBase, method_name, method)
if framework._in_eager_mode_: if framework._in_eager_mode_:
setattr(core.eager.Tensor, "_grad_ivar", _grad_ivar)
setattr(core.eager.Tensor, "_set_grad_ivar", _set_grad_ivar) setattr(core.eager.Tensor, "_set_grad_ivar", _set_grad_ivar)
setattr(core.eager.Tensor, "value", value) setattr(core.eager.Tensor, "value", value)
setattr(core.eager.Tensor, "cpu", cpu) setattr(core.eager.Tensor, "cpu", cpu)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册