diff --git a/python/paddle/distributed/collective.py b/python/paddle/distributed/collective.py index 620f23b9e1919448d9bfee754922085d9f751ba9..bf36015f894071b0d4aa7759b1f260a9291d9d53 100644 --- a/python/paddle/distributed/collective.py +++ b/python/paddle/distributed/collective.py @@ -476,6 +476,10 @@ def new_group(ranks=None, backend=None): place = core.MLUPlace(genv.device_id) core.CNCLParallelContext(strategy, place).init_with_ring_id(ring_id) + elif core.is_compiled_with_xpu(): + place = core.XPUPlace(genv.device_id) + core.BKCLParallelContext(strategy, + place).init_with_ring_id(ring_id) else: assert False, ("no cuda device found") else: