From be77aeea7265df7141b2a18069f670e8cdbe117b Mon Sep 17 00:00:00 2001 From: Kaipeng Deng Date: Wed, 4 May 2022 17:04:14 +0800 Subject: [PATCH] fix Tensor share memory in eager mode. test=develop (#42445) --- python/paddle/fluid/dataloader/worker.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/python/paddle/fluid/dataloader/worker.py b/python/paddle/fluid/dataloader/worker.py index 304f31c2b16..6dc3813fa6d 100644 --- a/python/paddle/fluid/dataloader/worker.py +++ b/python/paddle/fluid/dataloader/worker.py @@ -22,7 +22,7 @@ from collections import namedtuple from .. import core from .fetcher import _IterableDatasetFetcher, _MapDatasetFetcher from ..multiprocess_utils import _cleanup_mmap, CleanupFuncRegistrar, MP_STATUS_CHECK_INTERVAL -from ..framework import _non_static_mode +from ..framework import _non_static_mode, _in_eager_without_dygraph_check from .flat import _flatten_batch # NOTE: queue has a different name in python2 and python3 @@ -339,10 +339,16 @@ def _worker_loop(dataset, dataset_kind, indices_queue, out_queue, done_event, out_queue.put((idx, batch, None)) batch, structure = _flatten_batch(batch) if use_shared_memory: + # NOTE: In eager mode, Tensor._share_memory has no + # effect, fall back to _array_to_share_memory_tensor + def tensor_share_memory(tensor): + if _in_eager_without_dygraph_check(): + return core._array_to_share_memory_tensor(tensor) + return tensor._share_memory() tensor_list = [ core._array_to_share_memory_tensor(b) - if isinstance(b, np.ndarray) else b._share_memory() - for b in batch + if isinstance(b, np.ndarray) \ + else tensor_share_memory(b) for b in batch ] out_queue.put((idx, tensor_list, structure)) core._remove_tensor_list_mmap_fds(tensor_list) -- GitLab