From 0e552c080394cc6c6207aa0221f1fb745c24af8f Mon Sep 17 00:00:00 2001 From: Haohongxiang <86215757+haohongxiang@users.noreply.github.com> Date: Thu, 20 Oct 2022 10:52:27 +0800 Subject: [PATCH] support qat in sharding stage2 (#47169) --- .../sharding/group_sharded_optimizer_stage2.py | 4 +++- .../fleet/meta_parallel/sharding/group_sharded_stage2.py | 7 ++++--- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_optimizer_stage2.py b/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_optimizer_stage2.py index 3f479d073c9..ed07e8b7982 100644 --- a/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_optimizer_stage2.py +++ b/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_optimizer_stage2.py @@ -301,7 +301,9 @@ class GroupShardedOptimizerStage2(Optimizer): """ if len(self._dtype_rank_params) == 0: # Assign the parameters of each rank according to the type - for param in self._local_params: + trainable_params = list( + filter(lambda x: x.trainable, self._local_params)) + for param in trainable_params: if param.dtype not in self._dtype_rank_params.keys(): self._dtype_rank_params[param.dtype] = [ [] for _ in range(self.world_size) diff --git a/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_stage2.py b/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_stage2.py index d0c33a5f596..6c0716f7bbb 100644 --- a/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_stage2.py +++ b/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_stage2.py @@ -102,11 +102,12 @@ class GroupShardedStage2(nn.Layer): # sharing stage 2 comm overlap flag self._reduce_overlap = False - self._trainable_params = [] self._grad_reduced = [] self._trainable_param2rank = {} self._trainable_param2align = {} - self._trainable_mask = list(map(_trainable, self._all_params)) + self._trainable_params = list( + filter(lambda x: x.trainable, self._all_params)) + self._trainable_mask = list(map(_trainable, self._trainable_params)) self._param_grads = [] # Set grad storage size & Display param sizes and model sizes @@ -512,7 +513,7 @@ class GroupShardedStage2(nn.Layer): def _detect_train_change(self): # Current trainable parameters - trainable_mask = list(map(_trainable, self._all_params)) + trainable_mask = list(map(_trainable, self._trainable_params)) # Whether parameters trainability changed trainability_changed = trainable_mask != self._trainable_mask -- GitLab