未验证 提交 7875accb 编写于 作者: R Roc 提交者: GitHub

support mp on xpu (#49531)

上级 5592f8ad
......@@ -138,15 +138,23 @@ def _broadcast_data_help(data, shape, dtype, hcg):
def broadcast_input_data(hcg, *inputs, **kwargs):
cur_device = paddle.get_device()
dev = cur_device.split(":")[0]
assert dev in [
"xpu",
"gpu",
"npu",
], f"Only support xpu, gpu and npu now, but this is {dev}"
dev_idx = int(cur_device.split(':')[1])
if dev == "gpu":
place = paddle.CUDAPlace(dev_idx)
else:
place = eval(f"paddle.{dev.upper()}Place")(dev_idx)
for v in inputs:
if isinstance(v, (core.VarBase, core.eager.Tensor)):
with framework.no_grad():
if (
"gpu" in cur_device
and in_dygraph_mode()
and not v.place.is_gpu_place()
):
v_gpu = v.cuda(int(cur_device.split(":")[1]))
if in_dygraph_mode() and not eval(f"v.place.is_{dev}_place")():
v_gpu = v._copy_to(place, True)
v._clear_data()
v_gpu._share_buffer_to(v)
_broadcast_data_help(v, v.shape, v.dtype, hcg)
......@@ -156,12 +164,8 @@ def broadcast_input_data(hcg, *inputs, **kwargs):
for k, v in kwargs.items():
if isinstance(v, (core.VarBase, core.eager.Tensor)):
with framework.no_grad():
if (
"gpu" in cur_device
and in_dygraph_mode()
and not v.place.is_gpu_place()
):
v_gpu = v.cuda(int(cur_device.split(":")[1]))
if in_dygraph_mode() and not eval(f"v.place.is_{dev}_place")():
v_gpu = v._copy_to(place, True)
v._clear_data()
v_gpu._share_buffer_to(v)
_broadcast_data_help(v, v.shape, v.dtype, hcg)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册