diff --git a/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_storage.py b/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_storage.py index 72e9ebfcb7d8859909c48228f202b663ec3ddac6..06b5ed9d8caea3d755f61b45c28316876b8b0793 100644 --- a/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_storage.py +++ b/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_storage.py @@ -26,7 +26,7 @@ import numpy as np import paddle from paddle.fluid import core -from .group_sharded_utils import Type, device_guard +from .group_sharded_utils import Type, device_guard, cvt_to_device class InternalStorage: @@ -76,8 +76,8 @@ class InternalStorage: if self._device != device: tmp_buffer = ( - self.buffer.cuda(self.dev_id) - if device == "gpu" + cvt_to_device(self.buffer, self.dev_id) + if device in ["gpu", "xpu", "npu"] else self.buffer.cpu() ) for param in self._params: @@ -133,7 +133,7 @@ class ParamStorage(InternalStorage): if convert_gpu: # buffer convert from cpu to cuda - self.buffer = self.buffer.cuda(self.dev_id) + self.buffer = cvt_to_device(self.buffer, self.dev_id) self._fill = 0 diff --git a/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_utils.py b/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_utils.py index 2b883fe67e006435675bfaf5d7e240e7668d19fd..96e262ccd2e4cedc0ee958c8a5e998ab185a7fd9 100644 --- a/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_utils.py +++ b/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_utils.py @@ -130,7 +130,7 @@ class GroupShardedClipGrad: if paddle.device.get_device() == "cpu": global_norm_var = global_norm_var.cuda(dev_id) - with device_guard(dev_id, "gpu"): + with device_guard(dev_id, self._device.split(":")[0]): paddle.distributed.all_reduce(global_norm_var, group=self._group) global_norm_var = paddle.sqrt(global_norm_var) @@ -170,8 +170,8 @@ def device_guard(dev_id=0, device="cpu"): origin_device = paddle.device.get_device() if device == "cpu": paddle.set_device(device) - elif device == "gpu": - paddle.set_device("gpu:{}".format(dev_id)) + elif device in ["gpu", "xpu", "npu"]: + paddle.set_device("{}:{}".format(device, dev_id)) try: yield finally: @@ -251,3 +251,20 @@ def GroupShardedScaler(scaler): scaler._unscale = MethodType(unscale_method, scaler) return scaler + + +def cvt_to_device(x, dev_id, blocking=True): + """ + Copy data in x from cpu memory to supported device + """ + if paddle.is_compiled_with_cuda(): + place = paddle.CUDAPlace(dev_id) + elif paddle.is_compiled_with_npu(): + place = paddle.NPUPlace(dev_id) + elif paddle.is_compiled_with_xpu(): + place = paddle.XPUPlace(dev_id) + else: + raise EnvironmentError( + "Only supported compiled paddle with gpu/rocm, npu and xpu , but current verison is compiled with cpu." + ) + return x._copy_to(place, blocking) diff --git a/python/paddle/distributed/sharding/group_sharded.py b/python/paddle/distributed/sharding/group_sharded.py index 4137075c3f902088f8b68dea4fb72aaf6d7f643b..ce5eae88cf998467be17eef1a8149954e790efe6 100644 --- a/python/paddle/distributed/sharding/group_sharded.py +++ b/python/paddle/distributed/sharding/group_sharded.py @@ -117,6 +117,12 @@ def group_sharded_parallel( optimizer.step() optimizer.clear_grad() """ + + device = paddle.get_device().split(":")[0] + assert device in [ + "gpu", + "xpu", + ], "group_sharded_parallel only support gpu and xpu now" # check optition type assert isinstance( model, paddle.nn.Layer @@ -148,6 +154,7 @@ def group_sharded_parallel( group=group, offload=offload, dp_group=dp_group, + device=device, ) model = GroupShardedStage2( model, @@ -156,6 +163,7 @@ def group_sharded_parallel( sync_buffers=sync_buffers, buffer_max_size=buffer_max_size, dp_group=dp_group, + device=device, ) else: optimizer = ShardingOptimizerStage2( @@ -163,6 +171,7 @@ def group_sharded_parallel( optim=optimizer, group=group, offload=offload, + device=device, ) model = ShardingStage2( model, @@ -170,6 +179,7 @@ def group_sharded_parallel( group=group, sync_buffers=sync_buffers, buffer_max_size=buffer_max_size, + device=device, ) elif level == 'p_g_os': if in_dygraph_mode(): @@ -181,6 +191,7 @@ def group_sharded_parallel( segment_size=segment_size, offload=offload, sync_comm=sync_comm, + device=device, ) else: model = ShardingStage3( @@ -191,6 +202,7 @@ def group_sharded_parallel( segment_size=segment_size, offload=offload, sync_comm=sync_comm, + device=device, ) else: raise ValueError("Please enter the correct level.")