From 7875accb66cb486133456ec5bbf829225ebdaa71 Mon Sep 17 00:00:00 2001 From: Roc <30228238+sljlp@users.noreply.github.com> Date: Wed, 4 Jan 2023 19:22:33 +0800 Subject: [PATCH] support mp on xpu (#49531) --- .../fleet/utils/hybrid_parallel_util.py | 28 +++++++++++-------- 1 file changed, 16 insertions(+), 12 deletions(-) diff --git a/python/paddle/distributed/fleet/utils/hybrid_parallel_util.py b/python/paddle/distributed/fleet/utils/hybrid_parallel_util.py index 29f3ee0719d..86688562163 100644 --- a/python/paddle/distributed/fleet/utils/hybrid_parallel_util.py +++ b/python/paddle/distributed/fleet/utils/hybrid_parallel_util.py @@ -138,15 +138,23 @@ def _broadcast_data_help(data, shape, dtype, hcg): def broadcast_input_data(hcg, *inputs, **kwargs): cur_device = paddle.get_device() + dev = cur_device.split(":")[0] + assert dev in [ + "xpu", + "gpu", + "npu", + ], f"Only support xpu, gpu and npu now, but this is {dev}" + dev_idx = int(cur_device.split(':')[1]) + if dev == "gpu": + place = paddle.CUDAPlace(dev_idx) + else: + place = eval(f"paddle.{dev.upper()}Place")(dev_idx) + for v in inputs: if isinstance(v, (core.VarBase, core.eager.Tensor)): with framework.no_grad(): - if ( - "gpu" in cur_device - and in_dygraph_mode() - and not v.place.is_gpu_place() - ): - v_gpu = v.cuda(int(cur_device.split(":")[1])) + if in_dygraph_mode() and not eval(f"v.place.is_{dev}_place")(): + v_gpu = v._copy_to(place, True) v._clear_data() v_gpu._share_buffer_to(v) _broadcast_data_help(v, v.shape, v.dtype, hcg) @@ -156,12 +164,8 @@ def broadcast_input_data(hcg, *inputs, **kwargs): for k, v in kwargs.items(): if isinstance(v, (core.VarBase, core.eager.Tensor)): with framework.no_grad(): - if ( - "gpu" in cur_device - and in_dygraph_mode() - and not v.place.is_gpu_place() - ): - v_gpu = v.cuda(int(cur_device.split(":")[1])) + if in_dygraph_mode() and not eval(f"v.place.is_{dev}_place")(): + v_gpu = v._copy_to(place, True) v._clear_data() v_gpu._share_buffer_to(v) _broadcast_data_help(v, v.shape, v.dtype, hcg) -- GitLab