提交 6c5cf25f 编写于 作者: M Megvii Engine Team

docs(mge/distributed): add distributed.helper docs

GitOrigin-RevId: 37c14aa11f7514d6c6c7d2051d9f11c6b261a6fa
上级 cb8e5363
......@@ -21,7 +21,7 @@ from ..functional.param_pack import get_offsets, pack_allreduce_split
from ..functional.utils import copy
from ..utils.future import Future
from .functional import all_reduce_sum, broadcast
from .group import WORLD, group_barrier, is_distributed
from .group import WORLD, Group, group_barrier, is_distributed
class TensorFuture(Future):
......@@ -54,28 +54,43 @@ def synchronized(func: Callable):
return wrapper
def worker(queue, device_type):
def _get_device_count_worker(queue, device_type):
num = get_device_count(device_type)
queue.put(num)
def get_device_count_by_fork(device_type: str):
"""Get device count in fork thread.
See https://stackoverflow.com/questions/22950047/cuda-initialization-error-after-fork
for more information.
"""
q = mp.Queue()
p = mp.Process(target=worker, args=(q, device_type))
p = mp.Process(target=_get_device_count_worker, args=(q, device_type))
p.start()
p.join()
return q.get()
def bcast_list_(params, group):
for p in params:
p._reset(broadcast(p, group))
def bcast_list_(inps: list, group: Group = WORLD):
"""Broadcast tensors between given group.
:param inps: input tensors.
:param group: communication group.
"""
for inp in inps:
inp._reset(broadcast(inp, group))
class AllreduceCallback:
def __init__(self, reduce_method, group=WORLD):
"""Allreduce Callback with tensor fusion optimization.
:param reduce_method: the method to reduce gradiants.
:param group: communication group.
"""
def __init__(self, reduce_method: str, group: Group = WORLD):
reduce_method = reduce_method.lower()
assert reduce_method in ["sum", "mean"]
assert reduce_method in ["sum", "mean"], "reduce_method should be sum or mean"
self._reduce_method = reduce_method
self._group = group
self._marked_gm = WeakSet()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册