未验证 提交 db749ee0 编写于 作者: R Roc 提交者: GitHub

[XPU] Update Sharding stage2 for XPU (#48369)

* support xpu scalar inplace

* sharding for xpu

* update

* update
Co-authored-by: Nheyanru <81976792+heyanru01@users.noreply.github.com>
上级 776aef79
......@@ -26,7 +26,7 @@ import numpy as np
import paddle
from paddle.fluid import core
from .group_sharded_utils import Type, device_guard
from .group_sharded_utils import Type, device_guard, cvt_to_device
class InternalStorage:
......@@ -76,8 +76,8 @@ class InternalStorage:
if self._device != device:
tmp_buffer = (
self.buffer.cuda(self.dev_id)
if device == "gpu"
cvt_to_device(self.buffer, self.dev_id)
if device in ["gpu", "xpu", "npu"]
else self.buffer.cpu()
)
for param in self._params:
......@@ -133,7 +133,7 @@ class ParamStorage(InternalStorage):
if convert_gpu:
# buffer convert from cpu to cuda
self.buffer = self.buffer.cuda(self.dev_id)
self.buffer = cvt_to_device(self.buffer, self.dev_id)
self._fill = 0
......
......@@ -130,7 +130,7 @@ class GroupShardedClipGrad:
if paddle.device.get_device() == "cpu":
global_norm_var = global_norm_var.cuda(dev_id)
with device_guard(dev_id, "gpu"):
with device_guard(dev_id, self._device.split(":")[0]):
paddle.distributed.all_reduce(global_norm_var, group=self._group)
global_norm_var = paddle.sqrt(global_norm_var)
......@@ -170,8 +170,8 @@ def device_guard(dev_id=0, device="cpu"):
origin_device = paddle.device.get_device()
if device == "cpu":
paddle.set_device(device)
elif device == "gpu":
paddle.set_device("gpu:{}".format(dev_id))
elif device in ["gpu", "xpu", "npu"]:
paddle.set_device("{}:{}".format(device, dev_id))
try:
yield
finally:
......@@ -251,3 +251,20 @@ def GroupShardedScaler(scaler):
scaler._unscale = MethodType(unscale_method, scaler)
return scaler
def cvt_to_device(x, dev_id, blocking=True):
"""
Copy data in x from cpu memory to supported device
"""
if paddle.is_compiled_with_cuda():
place = paddle.CUDAPlace(dev_id)
elif paddle.is_compiled_with_npu():
place = paddle.NPUPlace(dev_id)
elif paddle.is_compiled_with_xpu():
place = paddle.XPUPlace(dev_id)
else:
raise EnvironmentError(
"Only supported compiled paddle with gpu/rocm, npu and xpu , but current verison is compiled with cpu."
)
return x._copy_to(place, blocking)
......@@ -117,6 +117,12 @@ def group_sharded_parallel(
optimizer.step()
optimizer.clear_grad()
"""
device = paddle.get_device().split(":")[0]
assert device in [
"gpu",
"xpu",
], "group_sharded_parallel only support gpu and xpu now"
# check optition type
assert isinstance(
model, paddle.nn.Layer
......@@ -148,6 +154,7 @@ def group_sharded_parallel(
group=group,
offload=offload,
dp_group=dp_group,
device=device,
)
model = GroupShardedStage2(
model,
......@@ -156,6 +163,7 @@ def group_sharded_parallel(
sync_buffers=sync_buffers,
buffer_max_size=buffer_max_size,
dp_group=dp_group,
device=device,
)
else:
optimizer = ShardingOptimizerStage2(
......@@ -163,6 +171,7 @@ def group_sharded_parallel(
optim=optimizer,
group=group,
offload=offload,
device=device,
)
model = ShardingStage2(
model,
......@@ -170,6 +179,7 @@ def group_sharded_parallel(
group=group,
sync_buffers=sync_buffers,
buffer_max_size=buffer_max_size,
device=device,
)
elif level == 'p_g_os':
if in_dygraph_mode():
......@@ -181,6 +191,7 @@ def group_sharded_parallel(
segment_size=segment_size,
offload=offload,
sync_comm=sync_comm,
device=device,
)
else:
model = ShardingStage3(
......@@ -191,6 +202,7 @@ def group_sharded_parallel(
segment_size=segment_size,
offload=offload,
sync_comm=sync_comm,
device=device,
)
else:
raise ValueError("Please enter the correct level.")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册