diff --git a/imperative/python/megengine/distributed/functional.py b/imperative/python/megengine/distributed/functional.py index 453d3ba557b0f7c39dd974f7188be7dca7cc102e..d832ae86c1eb711f748fa01d433ec3a465b2217f 100644 --- a/imperative/python/megengine/distributed/functional.py +++ b/imperative/python/megengine/distributed/functional.py @@ -196,6 +196,13 @@ def broadcast( return out +def _bcast_param( + inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = "" +) -> Tensor: + mode = CollectiveComm.Mode.BROADCAST + return collective_comm(inp, mode, group, device) + + def all_gather( inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = "" ) -> Tensor: diff --git a/imperative/python/megengine/distributed/helper.py b/imperative/python/megengine/distributed/helper.py index 0a67f2dd7264c58269ea0f22e356e6e5c9c94573..c0743958a31c3ea5cbcf0fecbbfce01d3dc3a78d 100644 --- a/imperative/python/megengine/distributed/helper.py +++ b/imperative/python/megengine/distributed/helper.py @@ -22,7 +22,7 @@ from ..core.ops.builtin import ParamPackConcat, ParamPackSplit from ..functional.tensor import copy from ..tensor import Tensor from ..utils.future import Future -from .functional import all_reduce_sum, broadcast +from .functional import _bcast_param, all_reduce_sum, broadcast from .group import WORLD, Group, group_barrier, is_distributed @@ -186,7 +186,7 @@ def bcast_list_(inps: list, group: Group = WORLD): :param group: communication group. """ for inp in inps: - inp._reset(broadcast(inp, group)) + inp._reset(_bcast_param(inp, group)) class AllreduceCallback: