提交 c8a9094b 编写于 作者: M Megvii Engine Team 提交者: Xu Xinran

fix(mge/data/dataloader): add refcount in _PlasmaStoreManager

GitOrigin-RevId: 9a95cb1a5d82dcb20a43493bfc59dd0100ca85f4
上级 05682707
......@@ -26,7 +26,7 @@ def _clear_plasma_store():
# `_PlasmaStoreManager.__del__` will not be called automaticly in subprocess,
# so this function should be called explicitly
global MGE_PLASMA_STORE_MANAGER
if MGE_PLASMA_STORE_MANAGER is not None:
if MGE_PLASMA_STORE_MANAGER is not None and MGE_PLASMA_STORE_MANAGER.refcount == 0:
del MGE_PLASMA_STORE_MANAGER
MGE_PLASMA_STORE_MANAGER = None
......@@ -50,6 +50,7 @@ class _PlasmaStoreManager:
stderr=None if debug_flag else subprocess.DEVNULL,
)
self.__initialized = True
self.refcount = 1
def __del__(self):
if self.__initialized and self.plasma_store.returncode is None:
......@@ -83,6 +84,8 @@ class PlasmaShmQueue:
"Exception happened in starting plasma_store: {}\n"
"Tips: {}".format(str(e), err_info)
)
else:
MGE_PLASMA_STORE_MANAGER.refcount += 1
self.socket_name = MGE_PLASMA_STORE_MANAGER.socket_name
......@@ -133,6 +136,8 @@ class PlasmaShmQueue:
def close(self):
self.queue.close()
self.disconnect_client()
global MGE_PLASMA_STORE_MANAGER
MGE_PLASMA_STORE_MANAGER.refcount -= 1
_clear_plasma_store()
def cancel_join_thread(self):
......
......@@ -132,3 +132,52 @@ def test_dataloader_parallel_worker_exception():
with pytest.raises(RuntimeError, match=r"worker.*died"):
data_iter = iter(dataloader)
batch_data = next(data_iter)
def _multi_instances_parallel_dataloader_worker():
dataset = init_dataset()
for divide_flag in [True, False]:
train_dataloader = DataLoader(
dataset,
sampler=RandomSampler(dataset, batch_size=4, drop_last=False),
num_workers=2,
divide=divide_flag,
)
val_dataloader = DataLoader(
dataset,
sampler=RandomSampler(dataset, batch_size=10, drop_last=False),
num_workers=2,
divide=divide_flag,
)
for idx, (data, label) in enumerate(train_dataloader):
assert data.shape == (4, 1, 32, 32)
assert label.shape == (4,)
if idx % 5 == 0:
for val_data, val_label in val_dataloader:
assert val_data.shape == (10, 1, 32, 32)
assert val_label.shape == (10,)
def test_dataloader_parallel_multi_instances():
# set max shared memory to 100M
os.environ["MGE_PLASMA_MEMORY"] = "100000000"
_multi_instances_parallel_dataloader_worker()
def test_dataloader_parallel_multi_instances_multiprocessing():
# set max shared memory to 100M
os.environ["MGE_PLASMA_MEMORY"] = "100000000"
import multiprocessing as mp
# mp.set_start_method("spawn")
processes = []
for i in range(4):
p = mp.Process(target=_multi_instances_parallel_dataloader_worker)
p.start()
processes.append(p)
for p in processes:
p.join()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册