From 6c5cf25f4d8ad36d3dfc02b246562214fb283212 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Mon, 28 Sep 2020 14:41:54 +0800 Subject: [PATCH] docs(mge/distributed): add distributed.helper docs GitOrigin-RevId: 37c14aa11f7514d6c6c7d2051d9f11c6b261a6fa --- .../python/megengine/distributed/helper.py | 31 ++++++++++++++----- 1 file changed, 23 insertions(+), 8 deletions(-) diff --git a/imperative/python/megengine/distributed/helper.py b/imperative/python/megengine/distributed/helper.py index 72a85f2a..3f1637ed 100644 --- a/imperative/python/megengine/distributed/helper.py +++ b/imperative/python/megengine/distributed/helper.py @@ -21,7 +21,7 @@ from ..functional.param_pack import get_offsets, pack_allreduce_split from ..functional.utils import copy from ..utils.future import Future from .functional import all_reduce_sum, broadcast -from .group import WORLD, group_barrier, is_distributed +from .group import WORLD, Group, group_barrier, is_distributed class TensorFuture(Future): @@ -54,28 +54,43 @@ def synchronized(func: Callable): return wrapper -def worker(queue, device_type): +def _get_device_count_worker(queue, device_type): num = get_device_count(device_type) queue.put(num) def get_device_count_by_fork(device_type: str): + """Get device count in fork thread. + See https://stackoverflow.com/questions/22950047/cuda-initialization-error-after-fork + for more information. + """ q = mp.Queue() - p = mp.Process(target=worker, args=(q, device_type)) + p = mp.Process(target=_get_device_count_worker, args=(q, device_type)) p.start() p.join() return q.get() -def bcast_list_(params, group): - for p in params: - p._reset(broadcast(p, group)) +def bcast_list_(inps: list, group: Group = WORLD): + """Broadcast tensors between given group. + + :param inps: input tensors. + :param group: communication group. + """ + for inp in inps: + inp._reset(broadcast(inp, group)) class AllreduceCallback: - def __init__(self, reduce_method, group=WORLD): + """Allreduce Callback with tensor fusion optimization. + + :param reduce_method: the method to reduce gradiants. + :param group: communication group. + """ + + def __init__(self, reduce_method: str, group: Group = WORLD): reduce_method = reduce_method.lower() - assert reduce_method in ["sum", "mean"] + assert reduce_method in ["sum", "mean"], "reduce_method should be sum or mean" self._reduce_method = reduce_method self._group = group self._marked_gm = WeakSet() -- GitLab