提交 afcda610 编写于 作者: M Megvii Engine Team

fix(mge/data/dataloader): fix typo, import and refine the logic of plasma_store

GitOrigin-RevId: 7d169a52945e74802e08ddaaf4b578945cc8424d
上级 65432d3b
......@@ -16,6 +16,19 @@ import pyarrow.plasma as plasma
MGE_PLASMA_MEMORY = int(os.environ.get("MGE_PLASMA_MEMORY", 4000000000)) # 4GB
# Each process only need to start one plasma store, so we set it as a global variable.
# TODO: how to share between different processes?
MGE_PLASMA_STORE_MANAGER = None
def _clear_plasma_store():
# `_PlasmaStoreManager.__del__` will not ne called automaticly in subprocess,
# so this function should be called explicitly
global MGE_PLASMA_STORE_MANAGER
if MGE_PLASMA_STORE_MANAGER is not None:
del MGE_PLASMA_STORE_MANAGER
MGE_PLASMA_STORE_MANAGER = None
class _PlasmaStoreManager:
def __init__(self):
......@@ -34,11 +47,6 @@ class _PlasmaStoreManager:
self.plasma_store.kill()
# Each process only need to start one plasma store, so we set it as a global variable.
# TODO: how to share between different processes?
MGE_PLASMA_STORE_MANAGER = _PlasmaStoreManager()
class PlasmaShmQueue:
def __init__(self, maxsize: int = 0):
r"""Use pyarrow in-memory plasma store to implement shared memory queue.
......@@ -51,6 +59,17 @@ class PlasmaShmQueue:
:param maxsize: maximum size of the queue, `None` means no limit. (default: ``None``)
"""
# Lazy start the plasma store manager
global MGE_PLASMA_STORE_MANAGER
if MGE_PLASMA_STORE_MANAGER is None:
try:
MGE_PLASMA_STORE_MANAGER = _PlasmaStoreManager()
except FileNotFoundError as e:
raise FileNotFoundError(
"command 'plasma_store' not found in your $PATH!"
"Please make sure pyarrow installed and add into $PATH."
)
self.socket_name = MGE_PLASMA_STORE_MANAGER.socket_name
# TODO: how to catch the exception happened in `plasma.connect`?
......@@ -100,6 +119,7 @@ class PlasmaShmQueue:
def close(self):
self.queue.close()
self.disconnect_client()
_clear_plasma_store()
def cancel_join_thread(self):
self.queue.cancel_join_thread()
......@@ -17,12 +17,13 @@ import numpy as np
import megengine as mge
from ..logger import get_logger
from .collator import Collator
from .dataset import Dataset
from .sampler import Sampler, SequentialSampler
from .transform import PseudoTransform, Transform
logger = mge.get_logger(__name__)
logger = get_logger(__name__)
MP_QUEUE_GET_TIMEOUT = 5
......@@ -167,7 +168,7 @@ class _SerialDataLoaderIter(_BaseDataLoaderIter):
class _ParallelDataLoaderIter(_BaseDataLoaderIter):
__initialzed = False
__initialized = False
def __init__(self, loader):
super(_ParallelDataLoaderIter, self).__init__(loader)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册