diff --git a/python/paddle/distributed/collective.py b/python/paddle/distributed/collective.py index 2a251460e47bf6e0a012f7d03e5b891e48371db8..468f20952c20e031eac61171cf9884ea2e31ce7e 100644 --- a/python/paddle/distributed/collective.py +++ b/python/paddle/distributed/collective.py @@ -1176,6 +1176,7 @@ def _parallel_linear(x, inputs={'X': linear_out}, outputs={'Out': out}, attrs={ + 'rank': inner_rank, 'ring_id': ring_id, 'nranks': nranks, 'use_calc_stream': True,