提交 6c5cf25f 编写于 作者: M Megvii Engine Team

docs(mge/distributed): add distributed.helper docs

GitOrigin-RevId: 37c14aa11f7514d6c6c7d2051d9f11c6b261a6fa
上级 cb8e5363
...@@ -21,7 +21,7 @@ from ..functional.param_pack import get_offsets, pack_allreduce_split ...@@ -21,7 +21,7 @@ from ..functional.param_pack import get_offsets, pack_allreduce_split
from ..functional.utils import copy from ..functional.utils import copy
from ..utils.future import Future from ..utils.future import Future
from .functional import all_reduce_sum, broadcast from .functional import all_reduce_sum, broadcast
from .group import WORLD, group_barrier, is_distributed from .group import WORLD, Group, group_barrier, is_distributed
class TensorFuture(Future): class TensorFuture(Future):
...@@ -54,28 +54,43 @@ def synchronized(func: Callable): ...@@ -54,28 +54,43 @@ def synchronized(func: Callable):
return wrapper return wrapper
def worker(queue, device_type): def _get_device_count_worker(queue, device_type):
num = get_device_count(device_type) num = get_device_count(device_type)
queue.put(num) queue.put(num)
def get_device_count_by_fork(device_type: str): def get_device_count_by_fork(device_type: str):
"""Get device count in fork thread.
See https://stackoverflow.com/questions/22950047/cuda-initialization-error-after-fork
for more information.
"""
q = mp.Queue() q = mp.Queue()
p = mp.Process(target=worker, args=(q, device_type)) p = mp.Process(target=_get_device_count_worker, args=(q, device_type))
p.start() p.start()
p.join() p.join()
return q.get() return q.get()
def bcast_list_(params, group): def bcast_list_(inps: list, group: Group = WORLD):
for p in params: """Broadcast tensors between given group.
p._reset(broadcast(p, group))
:param inps: input tensors.
:param group: communication group.
"""
for inp in inps:
inp._reset(broadcast(inp, group))
class AllreduceCallback: class AllreduceCallback:
def __init__(self, reduce_method, group=WORLD): """Allreduce Callback with tensor fusion optimization.
:param reduce_method: the method to reduce gradiants.
:param group: communication group.
"""
def __init__(self, reduce_method: str, group: Group = WORLD):
reduce_method = reduce_method.lower() reduce_method = reduce_method.lower()
assert reduce_method in ["sum", "mean"] assert reduce_method in ["sum", "mean"], "reduce_method should be sum or mean"
self._reduce_method = reduce_method self._reduce_method = reduce_method
self._group = group self._group = group
self._marked_gm = WeakSet() self._marked_gm = WeakSet()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册