未验证 提交 0e552c08 编写于 作者: H Haohongxiang 提交者: GitHub

support qat in sharding stage2 (#47169)

上级 8d2ce06e
...@@ -301,7 +301,9 @@ class GroupShardedOptimizerStage2(Optimizer): ...@@ -301,7 +301,9 @@ class GroupShardedOptimizerStage2(Optimizer):
""" """
if len(self._dtype_rank_params) == 0: if len(self._dtype_rank_params) == 0:
# Assign the parameters of each rank according to the type # Assign the parameters of each rank according to the type
for param in self._local_params: trainable_params = list(
filter(lambda x: x.trainable, self._local_params))
for param in trainable_params:
if param.dtype not in self._dtype_rank_params.keys(): if param.dtype not in self._dtype_rank_params.keys():
self._dtype_rank_params[param.dtype] = [ self._dtype_rank_params[param.dtype] = [
[] for _ in range(self.world_size) [] for _ in range(self.world_size)
......
...@@ -102,11 +102,12 @@ class GroupShardedStage2(nn.Layer): ...@@ -102,11 +102,12 @@ class GroupShardedStage2(nn.Layer):
# sharing stage 2 comm overlap flag # sharing stage 2 comm overlap flag
self._reduce_overlap = False self._reduce_overlap = False
self._trainable_params = []
self._grad_reduced = [] self._grad_reduced = []
self._trainable_param2rank = {} self._trainable_param2rank = {}
self._trainable_param2align = {} self._trainable_param2align = {}
self._trainable_mask = list(map(_trainable, self._all_params)) self._trainable_params = list(
filter(lambda x: x.trainable, self._all_params))
self._trainable_mask = list(map(_trainable, self._trainable_params))
self._param_grads = [] self._param_grads = []
# Set grad storage size & Display param sizes and model sizes # Set grad storage size & Display param sizes and model sizes
...@@ -512,7 +513,7 @@ class GroupShardedStage2(nn.Layer): ...@@ -512,7 +513,7 @@ class GroupShardedStage2(nn.Layer):
def _detect_train_change(self): def _detect_train_change(self):
# Current trainable parameters # Current trainable parameters
trainable_mask = list(map(_trainable, self._all_params)) trainable_mask = list(map(_trainable, self._trainable_params))
# Whether parameters trainability changed # Whether parameters trainability changed
trainability_changed = trainable_mask != self._trainable_mask trainability_changed = trainable_mask != self._trainable_mask
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册