提交 afcda610 编写于 作者: M Megvii Engine Team

fix(mge/data/dataloader): fix typo, import and refine the logic of plasma_store

GitOrigin-RevId: 7d169a52945e74802e08ddaaf4b578945cc8424d
上级 65432d3b
...@@ -16,6 +16,19 @@ import pyarrow.plasma as plasma ...@@ -16,6 +16,19 @@ import pyarrow.plasma as plasma
MGE_PLASMA_MEMORY = int(os.environ.get("MGE_PLASMA_MEMORY", 4000000000)) # 4GB MGE_PLASMA_MEMORY = int(os.environ.get("MGE_PLASMA_MEMORY", 4000000000)) # 4GB
# Each process only need to start one plasma store, so we set it as a global variable.
# TODO: how to share between different processes?
MGE_PLASMA_STORE_MANAGER = None
def _clear_plasma_store():
# `_PlasmaStoreManager.__del__` will not ne called automaticly in subprocess,
# so this function should be called explicitly
global MGE_PLASMA_STORE_MANAGER
if MGE_PLASMA_STORE_MANAGER is not None:
del MGE_PLASMA_STORE_MANAGER
MGE_PLASMA_STORE_MANAGER = None
class _PlasmaStoreManager: class _PlasmaStoreManager:
def __init__(self): def __init__(self):
...@@ -34,11 +47,6 @@ class _PlasmaStoreManager: ...@@ -34,11 +47,6 @@ class _PlasmaStoreManager:
self.plasma_store.kill() self.plasma_store.kill()
# Each process only need to start one plasma store, so we set it as a global variable.
# TODO: how to share between different processes?
MGE_PLASMA_STORE_MANAGER = _PlasmaStoreManager()
class PlasmaShmQueue: class PlasmaShmQueue:
def __init__(self, maxsize: int = 0): def __init__(self, maxsize: int = 0):
r"""Use pyarrow in-memory plasma store to implement shared memory queue. r"""Use pyarrow in-memory plasma store to implement shared memory queue.
...@@ -51,6 +59,17 @@ class PlasmaShmQueue: ...@@ -51,6 +59,17 @@ class PlasmaShmQueue:
:param maxsize: maximum size of the queue, `None` means no limit. (default: ``None``) :param maxsize: maximum size of the queue, `None` means no limit. (default: ``None``)
""" """
# Lazy start the plasma store manager
global MGE_PLASMA_STORE_MANAGER
if MGE_PLASMA_STORE_MANAGER is None:
try:
MGE_PLASMA_STORE_MANAGER = _PlasmaStoreManager()
except FileNotFoundError as e:
raise FileNotFoundError(
"command 'plasma_store' not found in your $PATH!"
"Please make sure pyarrow installed and add into $PATH."
)
self.socket_name = MGE_PLASMA_STORE_MANAGER.socket_name self.socket_name = MGE_PLASMA_STORE_MANAGER.socket_name
# TODO: how to catch the exception happened in `plasma.connect`? # TODO: how to catch the exception happened in `plasma.connect`?
...@@ -100,6 +119,7 @@ class PlasmaShmQueue: ...@@ -100,6 +119,7 @@ class PlasmaShmQueue:
def close(self): def close(self):
self.queue.close() self.queue.close()
self.disconnect_client() self.disconnect_client()
_clear_plasma_store()
def cancel_join_thread(self): def cancel_join_thread(self):
self.queue.cancel_join_thread() self.queue.cancel_join_thread()
...@@ -17,12 +17,13 @@ import numpy as np ...@@ -17,12 +17,13 @@ import numpy as np
import megengine as mge import megengine as mge
from ..logger import get_logger
from .collator import Collator from .collator import Collator
from .dataset import Dataset from .dataset import Dataset
from .sampler import Sampler, SequentialSampler from .sampler import Sampler, SequentialSampler
from .transform import PseudoTransform, Transform from .transform import PseudoTransform, Transform
logger = mge.get_logger(__name__) logger = get_logger(__name__)
MP_QUEUE_GET_TIMEOUT = 5 MP_QUEUE_GET_TIMEOUT = 5
...@@ -167,7 +168,7 @@ class _SerialDataLoaderIter(_BaseDataLoaderIter): ...@@ -167,7 +168,7 @@ class _SerialDataLoaderIter(_BaseDataLoaderIter):
class _ParallelDataLoaderIter(_BaseDataLoaderIter): class _ParallelDataLoaderIter(_BaseDataLoaderIter):
__initialzed = False __initialized = False
def __init__(self, loader): def __init__(self, loader):
super(_ParallelDataLoaderIter, self).__init__(loader) super(_ParallelDataLoaderIter, self).__init__(loader)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册