未验证 提交 5131b11f 编写于 作者: A Aurelius84 提交者: GitHub

[Eager]Fix EagerTensor _copy_to memory overlap problem (#42668)

上级 754820fe
...@@ -361,12 +361,33 @@ static PyObject* tensor_method__is_dense_tensor_hold_allocation( ...@@ -361,12 +361,33 @@ static PyObject* tensor_method__is_dense_tensor_hold_allocation(
EAGER_CATCH_AND_THROW_RETURN_NULL EAGER_CATCH_AND_THROW_RETURN_NULL
} }
static void IncreaseTensorReferenceCountUntilCopyComplete(
const paddle::experimental::Tensor& tensor, const platform::Place& place) {
auto place_ = platform::is_gpu_place(place) ? place : tensor.place();
auto tracer = egr::Controller::Instance().GetCurrentTracer();
auto gc = tracer->MutableGarbageCollectorIfNotExists(place_);
// Note(dev): This is an empty callback, the only way is to "reference"
// inner memory Holder, so it will not be destructed until the kernels
// launched at current stream of given place is finished, such as
// CUDAPinned Mem -> CUDA by cudamemcpyAsync.
auto callback = [tensor, place_]() {
VLOG(3) << "Run callback of Tensor:" << tensor.name() << " at place "
<< place_;
};
gc->DirectClearCallback(callback);
}
static PyObject* tensor_method__copy_to(TensorObject* self, PyObject* args, static PyObject* tensor_method__copy_to(TensorObject* self, PyObject* args,
PyObject* kwargs) { PyObject* kwargs) {
EAGER_TRY EAGER_TRY
auto place = CastPyArg2Place(PyTuple_GET_ITEM(args, 0), 0); auto place = CastPyArg2Place(PyTuple_GET_ITEM(args, 0), 0);
bool blocking = CastPyArg2AttrBoolean(PyTuple_GET_ITEM(args, 1), 1); bool blocking = CastPyArg2AttrBoolean(PyTuple_GET_ITEM(args, 1), 1);
auto cp_tensor = self->tensor.copy_to(place, blocking); auto cp_tensor = self->tensor.copy_to(place, blocking);
if (!blocking) {
IncreaseTensorReferenceCountUntilCopyComplete(self->tensor, place);
}
egr::EagerUtils::autograd_meta(&cp_tensor)->SetStopGradient(true); egr::EagerUtils::autograd_meta(&cp_tensor)->SetStopGradient(true);
egr::EagerUtils::autograd_meta(&cp_tensor) egr::EagerUtils::autograd_meta(&cp_tensor)
->SetPersistable( ->SetPersistable(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册