diff --git a/python/paddle/distributed/collective.py b/python/paddle/distributed/collective.py index e5dfb34f24304dcbfca25fdb8e9498d698e35437..70e16d67fb9f17ea9b51197abe73c7bca1ec8deb 100644 --- a/python/paddle/distributed/collective.py +++ b/python/paddle/distributed/collective.py @@ -1524,7 +1524,7 @@ def alltoall(in_tensor_list, out_tensor_list, group=None, use_calc_stream=True): inputs={'X': [temp]}, outputs={'Out': [out]}, attrs={ - 'ring_id': group, + 'ring_id': ring_id, 'use_calc_stream': use_calc_stream, }) out_tensor_list.extend(paddle.split(out, nranks, 0))