diff --git a/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_optimizer_stage2.py b/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_optimizer_stage2.py index 3f479d073c97eab95b9ade2082a4dbe7b70ae881..ed07e8b79822d8d02c1394b4f6418575cd1ba54a 100644 --- a/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_optimizer_stage2.py +++ b/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_optimizer_stage2.py @@ -301,7 +301,9 @@ class GroupShardedOptimizerStage2(Optimizer): """ if len(self._dtype_rank_params) == 0: # Assign the parameters of each rank according to the type - for param in self._local_params: + trainable_params = list( + filter(lambda x: x.trainable, self._local_params)) + for param in trainable_params: if param.dtype not in self._dtype_rank_params.keys(): self._dtype_rank_params[param.dtype] = [ [] for _ in range(self.world_size) diff --git a/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_stage2.py b/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_stage2.py index d0c33a5f5964d5e68b70c5e342213fbb3768730c..6c0716f7bbb31917c9f70be8049667facfda2b9c 100644 --- a/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_stage2.py +++ b/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_stage2.py @@ -102,11 +102,12 @@ class GroupShardedStage2(nn.Layer): # sharing stage 2 comm overlap flag self._reduce_overlap = False - self._trainable_params = [] self._grad_reduced = [] self._trainable_param2rank = {} self._trainable_param2align = {} - self._trainable_mask = list(map(_trainable, self._all_params)) + self._trainable_params = list( + filter(lambda x: x.trainable, self._all_params)) + self._trainable_mask = list(map(_trainable, self._trainable_params)) self._param_grads = [] # Set grad storage size & Display param sizes and model sizes @@ -512,7 +513,7 @@ class GroupShardedStage2(nn.Layer): def _detect_train_change(self): # Current trainable parameters - trainable_mask = list(map(_trainable, self._all_params)) + trainable_mask = list(map(_trainable, self._trainable_params)) # Whether parameters trainability changed trainability_changed = trainable_mask != self._trainable_mask