未验证 提交 0e552c08 编写于 作者: H Haohongxiang 提交者: GitHub

support qat in sharding stage2 (#47169)

上级 8d2ce06e
......@@ -301,7 +301,9 @@ class GroupShardedOptimizerStage2(Optimizer):
"""
if len(self._dtype_rank_params) == 0:
# Assign the parameters of each rank according to the type
for param in self._local_params:
trainable_params = list(
filter(lambda x: x.trainable, self._local_params))
for param in trainable_params:
if param.dtype not in self._dtype_rank_params.keys():
self._dtype_rank_params[param.dtype] = [
[] for _ in range(self.world_size)
......
......@@ -102,11 +102,12 @@ class GroupShardedStage2(nn.Layer):
# sharing stage 2 comm overlap flag
self._reduce_overlap = False
self._trainable_params = []
self._grad_reduced = []
self._trainable_param2rank = {}
self._trainable_param2align = {}
self._trainable_mask = list(map(_trainable, self._all_params))
self._trainable_params = list(
filter(lambda x: x.trainable, self._all_params))
self._trainable_mask = list(map(_trainable, self._trainable_params))
self._param_grads = []
# Set grad storage size & Display param sizes and model sizes
......@@ -512,7 +513,7 @@ class GroupShardedStage2(nn.Layer):
def _detect_train_change(self):
# Current trainable parameters
trainable_mask = list(map(_trainable, self._all_params))
trainable_mask = list(map(_trainable, self._trainable_params))
# Whether parameters trainability changed
trainability_changed = trainable_mask != self._trainable_mask
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册