diff --git a/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_stage3.py b/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_stage3.py index 049d3ffa3694f2c1e98652ae4523d526f27fcce4..e44b5d2515d83e0b7ad8953ccdda03f919dcac5f 100644 --- a/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_stage3.py +++ b/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_stage3.py @@ -205,7 +205,7 @@ class GroupShardedStage3(nn.Layer): for param in list(self._unslice_params): param.clear_gradient(False) tmp_var = param.cuda(DEV_ID) - param._clear_data() + if tmp_var.dtype == Type.fp32.value and param2dtype[ param.name] == Type.fp16.value: tmp_var = paddle.cast(tmp_var, Type.fp16.value) @@ -272,6 +272,8 @@ class GroupShardedStage3(nn.Layer): master_tensor = paddle.cast(param, Type.fp32.value) master_tensor.name = param.name self._optim._master_weights[param.name] = master_tensor + if self._offload: + param.master_weight = paddle.cast(param, Type.fp32.value).cpu() param2dtype[param.name] = param.dtype p_align = self._param2align(param) self._unslice_params2align[param.name] = p_align @@ -369,7 +371,6 @@ class GroupShardedStage3(nn.Layer): tmp_var.get_tensor().set(param_cpu.get_tensor(), core.CPUPlace()) del tmp_var param.get_tensor()._set_dims(param_shape) - param._clear_data() # Current rank param_storage if self._offload: @@ -379,6 +380,9 @@ class GroupShardedStage3(nn.Layer): value=tmp_tensor, place=core.CPUPlace(), name="slice@" + param.name) + with device_guard(): + param.master_weight = paddle.cast(param.fw_storage, + Type.fp32.value) else: param.fw_storage = core.eager.Tensor( value=buffer._slice(start, end), name="slice@" + param.name) @@ -389,6 +393,7 @@ class GroupShardedStage3(nn.Layer): master_tensor = paddle.cast(param.fw_storage, Type.fp32.value) master_tensor.name = param.name self._optim._master_weights[param.fw_storage.name] = master_tensor + param._clear_data() def _register_forward_hooks(self, layer): """ @@ -480,9 +485,8 @@ class GroupShardedStage3(nn.Layer): collective.all_reduce(tensor=grad_storage.buffer, group=self._group) if self._offload: for param in list(self._unslice_params): - tmp_var = _device2cpu(param, convert_dtype=True) - tmp_var._share_buffer_to(param) - del tmp_var + param._clear_data() + param.master_weight._share_buffer_to(param) for grad_storage in self._grad_storages.values(): for p in grad_storage._params: @@ -568,7 +572,8 @@ class GroupShardedStage3(nn.Layer): del self._task_flow.full_param[param.name] if self._offload: - param.fw_storage = _device2cpu(param.fw_storage, True) + param.fw_storage._clear_data() + param.master_weight._share_buffer_to(param.fw_storage) return allreduce_ @@ -856,6 +861,7 @@ def _PartitionParam(param): if not hasattr(param, "fw_storage"): setattr(param, "fw_storage", None) setattr(param, "bw_storage", None) + setattr(param, "master_weight", None) setattr(param, "status", "all") setattr(param, "use_count", 0) return param @@ -864,6 +870,7 @@ def _PartitionParam(param): def _UnsliceParam(param): if not hasattr(param, "unslice"): setattr(param, "unslice", True) + setattr(param, "master_weight", None) return param diff --git a/python/paddle/distributed/fleet/meta_parallel/sharding/sharding_stage3.py b/python/paddle/distributed/fleet/meta_parallel/sharding/sharding_stage3.py index f96273cc84caf46f4f02c62e648ce70445b52d28..7bb1517f12169cbfe5c0bb9cec4a3099e03d3f9a 100644 --- a/python/paddle/distributed/fleet/meta_parallel/sharding/sharding_stage3.py +++ b/python/paddle/distributed/fleet/meta_parallel/sharding/sharding_stage3.py @@ -199,7 +199,7 @@ class ShardingStage3(nn.Layer): param.clear_gradient(False) param._gradient_set_empty(False) tmp_var = param.cuda(DEV_ID) - param._clear() + if tmp_var.dtype == Type.fp32.value and param2dtype[ param.name] == Type.fp16.value: tmp_var = paddle.cast(tmp_var, Type.fp16.value) @@ -220,19 +220,14 @@ class ShardingStage3(nn.Layer): self._optim._param_groups = slice_params + list( self._unslice_params) else: - params_name_list = list(map(lambda p: p.name, update_list)) - fw_storage_name_list = list( - map(lambda p: p.fw_storage.name, update_list)) for param_group in self._optim._param_groups: p_group = [] for p in param_group['params']: - if p.name in params_name_list: + if hasattr(p, "fw_storage"): p_group.append(p.fw_storage) - elif p.name in fw_storage_name_list: - p_group.append(update_list[fw_storage_name_list.index( - p.name)].fw_storage) - elif p in self._unslice_params: + else: p_group.append(p) + param_group['params'] = p_group def forward(self, *inputs, **kwargs): @@ -268,6 +263,8 @@ class ShardingStage3(nn.Layer): if param.dtype == Type.fp16.value and not self._offload: self._optim._master_weights[param.name] = paddle.cast( param, Type.fp32.value) + if self._offload: + param.master_weight = paddle.cast(param, Type.fp32.value).cpu() param2dtype[param.name] = param.dtype p_align = self._param2align(param) self._unslice_params2align[param.name] = p_align @@ -335,11 +332,12 @@ class ShardingStage3(nn.Layer): self._param2buffer[param.name].append( (rank_ * pre_buffer, (rank_ + 1) * pre_buffer)) - # 3.Flatten layer params and release other rank buffer - self._param_storage(param, buffer_size) # Record param's dtype param2dtype[param.name] = param.dtype + # 3.Flatten layer params and release other rank buffer + self._param_storage(param, buffer_size) + def _param_storage(self, param, buffer_size): """ This is a function to simplify the handling of parameter InternalStorages. @@ -365,13 +363,15 @@ class ShardingStage3(nn.Layer): tmp_var.value().get_tensor().set(param_cpu.value().get_tensor(), core.CPUPlace()) param.value().get_tensor()._set_dims(param_shape) - param._clear() # Current rank param_storage if self._offload: param.fw_storage = core.VarBase( buffer._slice(start, end), core.CPUPlace(), "slice@" + param.name) + with device_guard(device="cpu"): + param.master_weight = paddle.cast(param.fw_storage, + Type.fp32.value) else: param.fw_storage = core.VarBase( buffer._slice(start, end), "slice@" + param.name) @@ -381,6 +381,7 @@ class ShardingStage3(nn.Layer): if param.dtype == Type.fp16.value and not self._offload: self._optim._master_weights[param.fw_storage.name] = paddle.cast( param.fw_storage, Type.fp32.value) + param._clear() def _register_forward_hooks(self, layer): """ @@ -482,9 +483,8 @@ class ShardingStage3(nn.Layer): if self._offload: for param in list(self._unslice_params): - tmp_var = _device2cpu(param, convert_dtype=True) - tmp_var._share_buffer_to(param) - tmp_var._clear() + param._clear() + param.master_weight._share_buffer_to(param) for grad_storage in self._grad_storages.values(): for p in grad_storage._params: @@ -553,8 +553,9 @@ class ShardingStage3(nn.Layer): cpu_grad = _device2cpu( core.VarBase(full_grad._slice(start, end)) .detach().clone(), True) - param.bw_storage = paddle.add(param.bw_storage, - cpu_grad) + with device_guard(device="cpu"): + param.bw_storage = paddle.add(param.bw_storage, + cpu_grad) else: # param.bw_storage.add_( # core.VarBase(full_grad._slice(start, end)) @@ -581,7 +582,8 @@ class ShardingStage3(nn.Layer): tmp_var._clear() if self._offload: - param.fw_storage = _device2cpu(param.fw_storage, True) + param.fw_storage._clear() + param.master_weight._share_buffer_to(param.fw_storage) return allreduce_ @@ -869,6 +871,7 @@ def _PartitionParam(param): if not hasattr(param, "fw_storage"): setattr(param, "fw_storage", None) setattr(param, "bw_storage", None) + setattr(param, "master_weight", None) setattr(param, "status", "all") setattr(param, "use_count", 0) return param @@ -877,6 +880,7 @@ def _PartitionParam(param): def _UnsliceParam(param): if not hasattr(param, "unslice"): setattr(param, "unslice", True) + setattr(param, "master_weight", None) return param