提交 497ef6c3 编写于 作者: M Megvii Engine Team 提交者: huangxinda

fix(mge/dist): fix gl oom error

GitOrigin-RevId: 4ba3d2cfd74ed1d63274f175e232eeead1ec7b6d
上级 43098fb8
......@@ -186,9 +186,9 @@ def _get_device_count_worker(queue, device_type):
queue.put(num)
def _check_device_initialized(device_type: str):
def _check_device_initialized(device_type: str, rank: int):
try:
test = Tensor(1, device=device_type)
test = Tensor(1, device=(device_type + str(rank)))
inited = False
del test
except:
......
......@@ -39,7 +39,7 @@ def _run_wrapped(
machine_ranks: list,
):
"""Init distributed process group and run wrapped function."""
_check_device_initialized(device_type)
_check_device_initialized(device_type, dev)
init_process_group(
master_ip=master_ip,
port=port,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册