From 6972bfdeba9b9eb1a33eacb1d54ab9c07431a02c Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Tue, 19 May 2020 10:15:43 +0800 Subject: [PATCH] fix(mge/distributed): fix input comp_graph of broadcast operator GitOrigin-RevId: 039fd06a933b4d1e9a74fbb5a3eaf467cc486178 --- .../megengine/distributed/functional.py | 32 ++++++++----------- 1 file changed, 13 insertions(+), 19 deletions(-) diff --git a/python_module/megengine/distributed/functional.py b/python_module/megengine/distributed/functional.py index b0e7cf0bf..dd353bf5f 100644 --- a/python_module/megengine/distributed/functional.py +++ b/python_module/megengine/distributed/functional.py @@ -112,26 +112,20 @@ def broadcast( rank = get_rank() if rank == root: - return _collective_comm( - tensor, - key, - CollParam.Mode.BROADCAST, - nr_ranks, - rank, - root, - device=tensor.device, - ) + inp = tensor else: - return _collective_comm( - get_default_graph(), - key, - CollParam.Mode.BROADCAST, - nr_ranks, - rank, - root, - dtype=tensor._symvar.dtype, - device=tensor.device, - ) + inp = tensor._symvar.owner_graph + + return _collective_comm( + inp, + key, + CollParam.Mode.BROADCAST, + nr_ranks, + rank, + root, + dtype=tensor.dtype, + device=tensor.device, + ) def all_gather( -- GitLab