未验证 提交 db749ee0 编写于 作者: R Roc 提交者: GitHub

[XPU] Update Sharding stage2 for XPU (#48369)

* support xpu scalar inplace

* sharding for xpu

* update

* update
Co-authored-by: Nheyanru <81976792+heyanru01@users.noreply.github.com>
上级 776aef79
...@@ -26,7 +26,7 @@ import numpy as np ...@@ -26,7 +26,7 @@ import numpy as np
import paddle import paddle
from paddle.fluid import core from paddle.fluid import core
from .group_sharded_utils import Type, device_guard from .group_sharded_utils import Type, device_guard, cvt_to_device
class InternalStorage: class InternalStorage:
...@@ -76,8 +76,8 @@ class InternalStorage: ...@@ -76,8 +76,8 @@ class InternalStorage:
if self._device != device: if self._device != device:
tmp_buffer = ( tmp_buffer = (
self.buffer.cuda(self.dev_id) cvt_to_device(self.buffer, self.dev_id)
if device == "gpu" if device in ["gpu", "xpu", "npu"]
else self.buffer.cpu() else self.buffer.cpu()
) )
for param in self._params: for param in self._params:
...@@ -133,7 +133,7 @@ class ParamStorage(InternalStorage): ...@@ -133,7 +133,7 @@ class ParamStorage(InternalStorage):
if convert_gpu: if convert_gpu:
# buffer convert from cpu to cuda # buffer convert from cpu to cuda
self.buffer = self.buffer.cuda(self.dev_id) self.buffer = cvt_to_device(self.buffer, self.dev_id)
self._fill = 0 self._fill = 0
......
...@@ -130,7 +130,7 @@ class GroupShardedClipGrad: ...@@ -130,7 +130,7 @@ class GroupShardedClipGrad:
if paddle.device.get_device() == "cpu": if paddle.device.get_device() == "cpu":
global_norm_var = global_norm_var.cuda(dev_id) global_norm_var = global_norm_var.cuda(dev_id)
with device_guard(dev_id, "gpu"): with device_guard(dev_id, self._device.split(":")[0]):
paddle.distributed.all_reduce(global_norm_var, group=self._group) paddle.distributed.all_reduce(global_norm_var, group=self._group)
global_norm_var = paddle.sqrt(global_norm_var) global_norm_var = paddle.sqrt(global_norm_var)
...@@ -170,8 +170,8 @@ def device_guard(dev_id=0, device="cpu"): ...@@ -170,8 +170,8 @@ def device_guard(dev_id=0, device="cpu"):
origin_device = paddle.device.get_device() origin_device = paddle.device.get_device()
if device == "cpu": if device == "cpu":
paddle.set_device(device) paddle.set_device(device)
elif device == "gpu": elif device in ["gpu", "xpu", "npu"]:
paddle.set_device("gpu:{}".format(dev_id)) paddle.set_device("{}:{}".format(device, dev_id))
try: try:
yield yield
finally: finally:
...@@ -251,3 +251,20 @@ def GroupShardedScaler(scaler): ...@@ -251,3 +251,20 @@ def GroupShardedScaler(scaler):
scaler._unscale = MethodType(unscale_method, scaler) scaler._unscale = MethodType(unscale_method, scaler)
return scaler return scaler
def cvt_to_device(x, dev_id, blocking=True):
"""
Copy data in x from cpu memory to supported device
"""
if paddle.is_compiled_with_cuda():
place = paddle.CUDAPlace(dev_id)
elif paddle.is_compiled_with_npu():
place = paddle.NPUPlace(dev_id)
elif paddle.is_compiled_with_xpu():
place = paddle.XPUPlace(dev_id)
else:
raise EnvironmentError(
"Only supported compiled paddle with gpu/rocm, npu and xpu , but current verison is compiled with cpu."
)
return x._copy_to(place, blocking)
...@@ -117,6 +117,12 @@ def group_sharded_parallel( ...@@ -117,6 +117,12 @@ def group_sharded_parallel(
optimizer.step() optimizer.step()
optimizer.clear_grad() optimizer.clear_grad()
""" """
device = paddle.get_device().split(":")[0]
assert device in [
"gpu",
"xpu",
], "group_sharded_parallel only support gpu and xpu now"
# check optition type # check optition type
assert isinstance( assert isinstance(
model, paddle.nn.Layer model, paddle.nn.Layer
...@@ -148,6 +154,7 @@ def group_sharded_parallel( ...@@ -148,6 +154,7 @@ def group_sharded_parallel(
group=group, group=group,
offload=offload, offload=offload,
dp_group=dp_group, dp_group=dp_group,
device=device,
) )
model = GroupShardedStage2( model = GroupShardedStage2(
model, model,
...@@ -156,6 +163,7 @@ def group_sharded_parallel( ...@@ -156,6 +163,7 @@ def group_sharded_parallel(
sync_buffers=sync_buffers, sync_buffers=sync_buffers,
buffer_max_size=buffer_max_size, buffer_max_size=buffer_max_size,
dp_group=dp_group, dp_group=dp_group,
device=device,
) )
else: else:
optimizer = ShardingOptimizerStage2( optimizer = ShardingOptimizerStage2(
...@@ -163,6 +171,7 @@ def group_sharded_parallel( ...@@ -163,6 +171,7 @@ def group_sharded_parallel(
optim=optimizer, optim=optimizer,
group=group, group=group,
offload=offload, offload=offload,
device=device,
) )
model = ShardingStage2( model = ShardingStage2(
model, model,
...@@ -170,6 +179,7 @@ def group_sharded_parallel( ...@@ -170,6 +179,7 @@ def group_sharded_parallel(
group=group, group=group,
sync_buffers=sync_buffers, sync_buffers=sync_buffers,
buffer_max_size=buffer_max_size, buffer_max_size=buffer_max_size,
device=device,
) )
elif level == 'p_g_os': elif level == 'p_g_os':
if in_dygraph_mode(): if in_dygraph_mode():
...@@ -181,6 +191,7 @@ def group_sharded_parallel( ...@@ -181,6 +191,7 @@ def group_sharded_parallel(
segment_size=segment_size, segment_size=segment_size,
offload=offload, offload=offload,
sync_comm=sync_comm, sync_comm=sync_comm,
device=device,
) )
else: else:
model = ShardingStage3( model = ShardingStage3(
...@@ -191,6 +202,7 @@ def group_sharded_parallel( ...@@ -191,6 +202,7 @@ def group_sharded_parallel(
segment_size=segment_size, segment_size=segment_size,
offload=offload, offload=offload,
sync_comm=sync_comm, sync_comm=sync_comm,
device=device,
) )
else: else:
raise ValueError("Please enter the correct level.") raise ValueError("Please enter the correct level.")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册