未验证 提交 669c7d51 编写于 作者: W Weilong Wu 提交者: GitHub

[Eager, Performance Optimization] Optimize clone interface (#46190)

* [Eager] polish clone interface

* rm clone in python, add clone in eager_method.cc
上级 b28bff06
...@@ -445,6 +445,24 @@ static PyObject* tensor_method_copy_(TensorObject* self, ...@@ -445,6 +445,24 @@ static PyObject* tensor_method_copy_(TensorObject* self,
EAGER_CATCH_AND_THROW_RETURN_NULL EAGER_CATCH_AND_THROW_RETURN_NULL
} }
static PyObject* tensor_method_clone(TensorObject* self,
PyObject* args,
PyObject* kwargs) {
EAGER_TRY
PADDLE_ENFORCE_EQ(
self->tensor.initialized(),
true,
paddle::platform::errors::InvalidArgument(
"We can only support initialized tensor in clone, however we got "
"uninitialized tensor %s, please check your code.",
self->tensor.name()));
auto out = assign_ad_func(self->tensor);
return ToPyObject(out);
EAGER_CATCH_AND_THROW_RETURN_NULL
}
static PyObject* tensor_retain_grads(TensorObject* self, static PyObject* tensor_retain_grads(TensorObject* self,
PyObject* args, PyObject* args,
PyObject* kwargs) { PyObject* kwargs) {
...@@ -1854,6 +1872,10 @@ PyMethodDef variable_methods[] = { ...@@ -1854,6 +1872,10 @@ PyMethodDef variable_methods[] = {
(PyCFunction)(void (*)(void))tensor_method_copy_, (PyCFunction)(void (*)(void))tensor_method_copy_,
METH_VARARGS | METH_KEYWORDS, METH_VARARGS | METH_KEYWORDS,
NULL}, NULL},
{"clone",
(PyCFunction)(void (*)(void))tensor_method_clone,
METH_VARARGS | METH_KEYWORDS,
NULL},
{"reconstruct_from_", {"reconstruct_from_",
(PyCFunction)(void (*)(void))tensor_method_reconstruct_from_, (PyCFunction)(void (*)(void))tensor_method_reconstruct_from_,
METH_VARARGS | METH_KEYWORDS, METH_VARARGS | METH_KEYWORDS,
......
...@@ -815,17 +815,6 @@ def monkey_patch_varbase(): ...@@ -815,17 +815,6 @@ def monkey_patch_varbase():
raise TypeError( raise TypeError(
"_set_grad_ivar is only supported for Parameter Tensor") "_set_grad_ivar is only supported for Parameter Tensor")
@framework.dygraph_only
def clone(self):
if in_dygraph_mode():
return _C_ops.assign(self)
if _in_legacy_dygraph():
output = core.VarBase()
else:
output = core.eager.Tensor()
return _legacy_C_ops.assign(self, output)
@framework.dygraph_only @framework.dygraph_only
def value(self): def value(self):
return self return self
...@@ -1009,7 +998,6 @@ def monkey_patch_varbase(): ...@@ -1009,7 +998,6 @@ def monkey_patch_varbase():
if framework._in_eager_mode_: if framework._in_eager_mode_:
setattr(core.eager.Tensor, "_grad_ivar", _grad_ivar) setattr(core.eager.Tensor, "_grad_ivar", _grad_ivar)
setattr(core.eager.Tensor, "_set_grad_ivar", _set_grad_ivar) setattr(core.eager.Tensor, "_set_grad_ivar", _set_grad_ivar)
setattr(core.eager.Tensor, "clone", clone)
setattr(core.eager.Tensor, "value", value) setattr(core.eager.Tensor, "value", value)
setattr(core.eager.Tensor, "cpu", cpu) setattr(core.eager.Tensor, "cpu", cpu)
setattr(core.eager.Tensor, "cuda", cuda) setattr(core.eager.Tensor, "cuda", cuda)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册