未验证 提交 be77aeea 编写于 作者: K Kaipeng Deng 提交者: GitHub

fix Tensor share memory in eager mode. test=develop (#42445)

上级 d6442df6
......@@ -22,7 +22,7 @@ from collections import namedtuple
from .. import core
from .fetcher import _IterableDatasetFetcher, _MapDatasetFetcher
from ..multiprocess_utils import _cleanup_mmap, CleanupFuncRegistrar, MP_STATUS_CHECK_INTERVAL
from ..framework import _non_static_mode
from ..framework import _non_static_mode, _in_eager_without_dygraph_check
from .flat import _flatten_batch
# NOTE: queue has a different name in python2 and python3
......@@ -339,10 +339,16 @@ def _worker_loop(dataset, dataset_kind, indices_queue, out_queue, done_event,
out_queue.put((idx, batch, None))
batch, structure = _flatten_batch(batch)
if use_shared_memory:
# NOTE: In eager mode, Tensor._share_memory has no
# effect, fall back to _array_to_share_memory_tensor
def tensor_share_memory(tensor):
if _in_eager_without_dygraph_check():
return core._array_to_share_memory_tensor(tensor)
return tensor._share_memory()
tensor_list = [
core._array_to_share_memory_tensor(b)
if isinstance(b, np.ndarray) else b._share_memory()
for b in batch
if isinstance(b, np.ndarray) \
else tensor_share_memory(b) for b in batch
]
out_queue.put((idx, tensor_list, structure))
core._remove_tensor_list_mmap_fds(tensor_list)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册