diff --git a/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/__init__.py b/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/__init__.py index f0f26bd2e0d06014750daa3f75101e64c77d86f5..28260d7aa186353dc27badc446c5650cd62b8b5a 100644 --- a/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/__init__.py +++ b/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/__init__.py @@ -12,5 +12,6 @@ # See the License for the specific language governing permissions and from .hybrid_parallel_optimizer import HybridParallelOptimizer from .hybrid_parallel_gradscaler import HybridParallelGradScaler +from .dygraph_sharding_optimizer import DygraphShardingOptimizer __all__ = [] diff --git a/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/hybrid_parallel_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/hybrid_parallel_optimizer.py index b00ef2cdcb0e10004ca2ce1fb8661a844d8fca1e..76e326ce20d7cbbbf95bf02f27e08ab0bba15604 100755 --- a/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/hybrid_parallel_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/hybrid_parallel_optimizer.py @@ -88,6 +88,13 @@ class HybridParallelClipGrad: paddle.distributed.all_reduce( global_norm_var_dist, group=self._hcg.get_check_parallel_group()) + # In Sharding mode, param and grad is mapping different rank in optimizer. + # ClipGradByGlobalNorm need allreduce to get globol norm + if self._hcg.get_sharding_parallel_world_size() > 1: + paddle.distributed.all_reduce( + global_norm_var_not_dist, + group=self._hcg.get_sharding_parallel_group()) + global_norm_var = layers.sqrt(global_norm_var_dist + global_norm_var_not_dist) @@ -139,8 +146,13 @@ class HybridParallelOptimizer: logger.warning("using ClipGradByGlobalNorm in TensorParallel, the origin " \ "optmizer'grad clip will be changed.") - self._inner_opt._grad_clip = HybridParallelClipGrad( - self._inner_opt._grad_clip, hcg) + if self._sharding_enable: + # change sharding inner_optimizer's _grad_clip + self._inner_opt._inner_optimizer._grad_clip = HybridParallelClipGrad( + self._inner_opt._grad_clip, hcg) + else: + self._inner_opt._grad_clip = HybridParallelClipGrad( + self._inner_opt._grad_clip, hcg) @imperative_base.no_grad @framework.dygraph_only diff --git a/python/paddle/fluid/tests/unittests/hybrid_parallel_sharding_model.py b/python/paddle/fluid/tests/unittests/hybrid_parallel_sharding_model.py index 2995e4dbf84018fae3782b72325dec0ae81faada..8cb1166cd0d832a8dc3366642ecd6ff6bfb752f3 100644 --- a/python/paddle/fluid/tests/unittests/hybrid_parallel_sharding_model.py +++ b/python/paddle/fluid/tests/unittests/hybrid_parallel_sharding_model.py @@ -183,21 +183,23 @@ class TestDistMPTraning(unittest.TestCase): strategy=None, is_sharding=True, Optimizer="adam"): - + clip = paddle.nn.ClipGradByGlobalNorm(0.5) if Optimizer == "adam": if is_sharding: optimizer = DygraphShardingOptimizer( hcg=fleet.get_hybrid_communicate_group(), user_defined_strategy=strategy, params=model.parameters(), - inner_optimizer_class=paddle.optimizer.Adam, + inner_optimizer_class=paddle.optimizer.AdamW, learning_rate=0.001, - weight_decay=0.00001, ) + weight_decay=0.00001, + grad_clip=clip) else: - optimizer = paddle.optimizer.Adam( + optimizer = paddle.optimizer.AdamW( parameters=model.parameters(), learning_rate=0.001, - weight_decay=0.00001, ) + weight_decay=0.00001, + grad_clip=clip) else: if is_sharding: optimizer = DygraphShardingOptimizer( @@ -205,10 +207,13 @@ class TestDistMPTraning(unittest.TestCase): user_defined_strategy=strategy, params=model.parameters(), inner_optimizer_class=paddle.optimizer.Momentum, - learning_rate=0.001, ) + learning_rate=0.001, + grad_clip=clip) else: optimizer = paddle.optimizer.Momentum( - learning_rate=0.001, parameters=model.parameters()) + learning_rate=0.001, + parameters=model.parameters(), + grad_clip=clip) return optimizer def build_model_optimizer(self, Optimizer="adam"):