@@ -250,7 +279,7 @@ class DygraphShardingOptimizer:
# otherwise the self._inner_opt will only grad_clip the self._rank2params[self._sharding_rank] params
# TODO(pangengzheng): remove the hacked grad_clip codes here when there is no diff in calculating global norm values in HybridParallelClipGrad compared to dp.
origin_clip=self._inner_opt._grad_clip
ifnotisinstance(self._parameter_list[0],dict):
ifnotself._using_param_groups:
params_grads=[]
forparaminself._parameter_list:
if(
...
...
@@ -286,6 +315,35 @@ class DygraphShardingOptimizer: