未验证 提交 0a9c1f59 编写于 作者: W wanghuancoder 提交者: GitHub

[multiprocessing] Eager tensor support pickle (#48179)

* eager tensot support pickle
上级 3f265815
......@@ -21,7 +21,6 @@ import logging
import itertools
import threading
import numpy as np
import multiprocessing
from collections import namedtuple
from paddle.fluid.framework import (
_set_expected_place,
......@@ -422,6 +421,8 @@ class _DataLoaderIterMultiProcess(_DataLoaderIterBase):
self._shutdown = False
def _init_workers(self):
import paddle.incubate.multiprocessing as multiprocessing
# multiprocess worker and indice queue list initial as empty
self._workers = []
self._worker_status = []
......
......@@ -373,21 +373,19 @@ def _worker_loop(
out_queue.put((idx, batch, None))
batch, structure = _flatten_batch(batch)
if use_shared_memory:
# NOTE: In eager mode, Tensor._share_memory has no
# effect, fall back to _array_to_share_memory_tensor
def tensor_share_memory(tensor):
if _in_eager_without_dygraph_check():
return core._array_to_share_memory_tensor(tensor)
return tensor._share_memory()
def numpy2lodtensor(arr):
lodtensor = core.Tensor()
lodtensor.set(arr, core.CPUPlace())
return lodtensor
tensor_list = [
core._array_to_share_memory_tensor(b)
numpy2lodtensor(b)
if isinstance(b, np.ndarray)
else tensor_share_memory(b)
else b.value().get_tensor()
for b in batch
]
out_queue.put((idx, tensor_list, structure))
core._remove_tensor_list_mmap_fds(tensor_list)
else:
out_queue.put((idx, batch, structure))
except KeyboardInterrupt:
......
......@@ -34,6 +34,7 @@ from . import autograd # noqa: F401
from . import autotune # noqa: F401
from . import nn # noqa: F401
from . import asp # noqa: F401
from . import multiprocessing # noqa: F401
from ..fluid.layers.loss import identity_loss
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册