diff --git a/python_module/megengine/data/_queue.py b/python_module/megengine/data/_queue.py index f90f090dd59b3114284ba9e50bad3db95a4fde23..94898e99ccc0784ce8f93f457dfb7472f3b90e72 100644 --- a/python_module/megengine/data/_queue.py +++ b/python_module/megengine/data/_queue.py @@ -16,6 +16,19 @@ import pyarrow.plasma as plasma MGE_PLASMA_MEMORY = int(os.environ.get("MGE_PLASMA_MEMORY", 4000000000)) # 4GB +# Each process only need to start one plasma store, so we set it as a global variable. +# TODO: how to share between different processes? +MGE_PLASMA_STORE_MANAGER = None + + +def _clear_plasma_store(): + # `_PlasmaStoreManager.__del__` will not ne called automaticly in subprocess, + # so this function should be called explicitly + global MGE_PLASMA_STORE_MANAGER + if MGE_PLASMA_STORE_MANAGER is not None: + del MGE_PLASMA_STORE_MANAGER + MGE_PLASMA_STORE_MANAGER = None + class _PlasmaStoreManager: def __init__(self): @@ -34,11 +47,6 @@ class _PlasmaStoreManager: self.plasma_store.kill() -# Each process only need to start one plasma store, so we set it as a global variable. -# TODO: how to share between different processes? -MGE_PLASMA_STORE_MANAGER = _PlasmaStoreManager() - - class PlasmaShmQueue: def __init__(self, maxsize: int = 0): r"""Use pyarrow in-memory plasma store to implement shared memory queue. @@ -51,6 +59,17 @@ class PlasmaShmQueue: :param maxsize: maximum size of the queue, `None` means no limit. (default: ``None``) """ + # Lazy start the plasma store manager + global MGE_PLASMA_STORE_MANAGER + if MGE_PLASMA_STORE_MANAGER is None: + try: + MGE_PLASMA_STORE_MANAGER = _PlasmaStoreManager() + except FileNotFoundError as e: + raise FileNotFoundError( + "command 'plasma_store' not found in your $PATH!" + "Please make sure pyarrow installed and add into $PATH." + ) + self.socket_name = MGE_PLASMA_STORE_MANAGER.socket_name # TODO: how to catch the exception happened in `plasma.connect`? @@ -100,6 +119,7 @@ class PlasmaShmQueue: def close(self): self.queue.close() self.disconnect_client() + _clear_plasma_store() def cancel_join_thread(self): self.queue.cancel_join_thread() diff --git a/python_module/megengine/data/dataloader.py b/python_module/megengine/data/dataloader.py index d6e03c8bb3e0f56c3f7f01b1fedc9a596df23005..388a8f0e657a9bc03364b2ecc36f439b53d220bc 100644 --- a/python_module/megengine/data/dataloader.py +++ b/python_module/megengine/data/dataloader.py @@ -17,12 +17,13 @@ import numpy as np import megengine as mge +from ..logger import get_logger from .collator import Collator from .dataset import Dataset from .sampler import Sampler, SequentialSampler from .transform import PseudoTransform, Transform -logger = mge.get_logger(__name__) +logger = get_logger(__name__) MP_QUEUE_GET_TIMEOUT = 5 @@ -167,7 +168,7 @@ class _SerialDataLoaderIter(_BaseDataLoaderIter): class _ParallelDataLoaderIter(_BaseDataLoaderIter): - __initialzed = False + __initialized = False def __init__(self, loader): super(_ParallelDataLoaderIter, self).__init__(loader)