diff --git a/python_module/megengine/data/_queue.py b/python_module/megengine/data/_queue.py index a9e328c65c56e4f4ba736b510176677b6c735c32..8e359ae06c0bb2c7306c067c63a100d5528e9bad 100644 --- a/python_module/megengine/data/_queue.py +++ b/python_module/megengine/data/_queue.py @@ -26,7 +26,7 @@ def _clear_plasma_store(): # `_PlasmaStoreManager.__del__` will not be called automaticly in subprocess, # so this function should be called explicitly global MGE_PLASMA_STORE_MANAGER - if MGE_PLASMA_STORE_MANAGER is not None: + if MGE_PLASMA_STORE_MANAGER is not None and MGE_PLASMA_STORE_MANAGER.refcount == 0: del MGE_PLASMA_STORE_MANAGER MGE_PLASMA_STORE_MANAGER = None @@ -50,6 +50,7 @@ class _PlasmaStoreManager: stderr=None if debug_flag else subprocess.DEVNULL, ) self.__initialized = True + self.refcount = 1 def __del__(self): if self.__initialized and self.plasma_store.returncode is None: @@ -83,6 +84,8 @@ class PlasmaShmQueue: "Exception happened in starting plasma_store: {}\n" "Tips: {}".format(str(e), err_info) ) + else: + MGE_PLASMA_STORE_MANAGER.refcount += 1 self.socket_name = MGE_PLASMA_STORE_MANAGER.socket_name @@ -133,6 +136,8 @@ class PlasmaShmQueue: def close(self): self.queue.close() self.disconnect_client() + global MGE_PLASMA_STORE_MANAGER + MGE_PLASMA_STORE_MANAGER.refcount -= 1 _clear_plasma_store() def cancel_join_thread(self): diff --git a/python_module/test/unit/data/test_dataloader.py b/python_module/test/unit/data/test_dataloader.py index 7cf687c11365ccdcd6d90add440648482de8c93a..6bb0f3e32f592ea50b4514b6a3c616a4b0f8c117 100644 --- a/python_module/test/unit/data/test_dataloader.py +++ b/python_module/test/unit/data/test_dataloader.py @@ -132,3 +132,52 @@ def test_dataloader_parallel_worker_exception(): with pytest.raises(RuntimeError, match=r"worker.*died"): data_iter = iter(dataloader) batch_data = next(data_iter) + + +def _multi_instances_parallel_dataloader_worker(): + dataset = init_dataset() + + for divide_flag in [True, False]: + train_dataloader = DataLoader( + dataset, + sampler=RandomSampler(dataset, batch_size=4, drop_last=False), + num_workers=2, + divide=divide_flag, + ) + val_dataloader = DataLoader( + dataset, + sampler=RandomSampler(dataset, batch_size=10, drop_last=False), + num_workers=2, + divide=divide_flag, + ) + for idx, (data, label) in enumerate(train_dataloader): + assert data.shape == (4, 1, 32, 32) + assert label.shape == (4,) + if idx % 5 == 0: + for val_data, val_label in val_dataloader: + assert val_data.shape == (10, 1, 32, 32) + assert val_label.shape == (10,) + + +def test_dataloader_parallel_multi_instances(): + # set max shared memory to 100M + os.environ["MGE_PLASMA_MEMORY"] = "100000000" + + _multi_instances_parallel_dataloader_worker() + + +def test_dataloader_parallel_multi_instances_multiprocessing(): + # set max shared memory to 100M + os.environ["MGE_PLASMA_MEMORY"] = "100000000" + + import multiprocessing as mp + + # mp.set_start_method("spawn") + processes = [] + for i in range(4): + p = mp.Process(target=_multi_instances_parallel_dataloader_worker) + p.start() + processes.append(p) + + for p in processes: + p.join()