diff --git a/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/hybrid_parallel_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/hybrid_parallel_optimizer.py index e7108b3f4f3432df04556b4cf78726a63cc8b076..50bf8a2f9c7c58b3390d2881cb5d6e8510e78ae8 100755 --- a/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/hybrid_parallel_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/hybrid_parallel_optimizer.py @@ -49,8 +49,6 @@ class HybridParallelClipGrad: @imperative_base.no_grad def _dygraph_clip(self, params_grads): - params_and_grads = [] - sum_square_dist_fp16 = [] sum_square_dist_fp32 = [] sum_square_not_dist_fp16 = [] @@ -153,15 +151,14 @@ class HybridParallelClipGrad: if g is None: continue if getattr(p, 'need_clip', True) is False: - params_and_grads.append((p, g)) continue if p.dtype == paddle.float16: - new_grad = layers.elementwise_mul(x=g, y=clip_var_fp16) + g.scale_(clip_var_fp16) else: - new_grad = layers.elementwise_mul(x=g, y=clip_var) - params_and_grads.append((p, new_grad)) + g.scale_(clip_var) + p._reset_grad_inplace_version(True) - return params_and_grads + return params_grads def __getattr__(self, item): return getattr(self._clip, item) @@ -201,6 +198,12 @@ class HybridParallelOptimizer: else: self._inner_opt._grad_clip = HybridParallelClipGrad( self._inner_opt._grad_clip, hcg) + if self._inner_opt._parameter_list and isinstance( + self._inner_opt._parameter_list[0], dict): + for item in self._inner_opt._param_groups: + if "grad_clip" in item.keys(): + item["grad_clip"] = HybridParallelClipGrad( + self._inner_opt._grad_clip, hcg) @imperative_base.no_grad @framework.dygraph_only diff --git a/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/sharding_optimizer_stage2.py b/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/sharding_optimizer_stage2.py index a2797adff251aea3535f86e5c423463d748c37b3..fc5b93c6e25499a0ae50c19cacae4a9395520fe9 100644 --- a/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/sharding_optimizer_stage2.py +++ b/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/sharding_optimizer_stage2.py @@ -109,6 +109,13 @@ class ShardingOptimizerStage2(Optimizer): self._optim._grad_clip = ShardingClipGrad(self._optim._grad_clip, paddle.get_device(), self.group) + if self._optim._parameter_list and isinstance( + self._optim._parameter_list[0], dict): + for item in self._optim._param_groups: + if "grad_clip" in item.keys(): + item["grad_clip"] = ShardingClipGrad( + self._optim._grad_clip, + paddle.get_device(), self.group) if offload: assert self._pfp16, "Only support offload strategy while using \'Adam\', \'AdamW\' and \'Momentum\' optimizer with AMP/Pure FP16" diff --git a/python/paddle/distributed/fleet/meta_parallel/sharding/sharding_utils.py b/python/paddle/distributed/fleet/meta_parallel/sharding/sharding_utils.py index 5f696195c1abcd4921b4358b8971fdbc982609da..9c30ff5a45075ae423d6a46ef328e3b6523fbd5b 100644 --- a/python/paddle/distributed/fleet/meta_parallel/sharding/sharding_utils.py +++ b/python/paddle/distributed/fleet/meta_parallel/sharding/sharding_utils.py @@ -57,8 +57,6 @@ class ShardingClipGrad: @imperative_base.no_grad def _dygraph_clip(self, params_grads): - params_and_grads = [] - sum_square_fp16 = [] sum_square_fp32 = [] @@ -114,15 +112,14 @@ class ShardingClipGrad: if g is None: continue if getattr(p, 'need_clip', True) is False: - params_and_grads.append((p, g)) continue if p.dtype == paddle.float16: - new_grad = layers.elementwise_mul(x=g, y=clip_var_fp16) + g.scale_(clip_var_fp16) else: - new_grad = layers.elementwise_mul(x=g, y=clip_var) - params_and_grads.append((p, new_grad)) + g.scale_(clip_var) + p._reset_grad_inplace_version(True) - return params_and_grads + return params_grads def __getattr__(self, item): return getattr(self._clip, item) diff --git a/python/paddle/fluid/tests/unittests/dygraph_sharding_stage2.py b/python/paddle/fluid/tests/unittests/dygraph_sharding_stage2.py index 9206d744990008496e7af43d67e000f9d00f6dab..80acf7217e76fb996e6b76aa519307c44952636e 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_sharding_stage2.py +++ b/python/paddle/fluid/tests/unittests/dygraph_sharding_stage2.py @@ -159,10 +159,13 @@ def test_dp_stage2(): mlp2 = MLP() mlp3 = MLP() mlp4 = MLP() + mlp5 = MLP() mlp1.set_state_dict(state_dict) mlp2.set_state_dict(state_dict) mlp3.set_state_dict(state_dict) mlp4.set_state_dict(state_dict) + mlp5.set_state_dict(state_dict) + dp_params = train_mlp( mlp1, sharding_stage="dp", use_pure_fp16=False, opt_group=False) stage2_params = train_mlp( @@ -181,6 +184,11 @@ def test_dp_stage2(): rtol=1e-5, atol=1e-5) + stage2_params = train_mlp( + mlp2, sharding_stage=2, use_pure_fp16=False, opt_group=True) + for i in range(len(dp_params)): + np.testing.assert_allclose( + dp_params[i].numpy(), stage2_params[i].numpy(), rtol=1e-6) return diff --git a/python/paddle/fluid/tests/unittests/dygraph_sharding_stage2_offload.py b/python/paddle/fluid/tests/unittests/dygraph_sharding_stage2_offload.py index f7e426377382bb089d9a4c4f968759f38c40e647..84ffe9094d8126ac75f864022659cbf2e101ad65 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_sharding_stage2_offload.py +++ b/python/paddle/fluid/tests/unittests/dygraph_sharding_stage2_offload.py @@ -49,7 +49,7 @@ def train_mlp(model, offload=False): optimizer = ShardingOptimizerStage2( params=model.parameters(), optim=optimizer, offload=offload) model = ShardingStage2( - model, optimizer, buffer_max_size=2**21, accumulate_grads=True) + model, optimizer, buffer_max_size=2**21, accumulate_grads=False) train_reader = paddle.batch( reader_decorator(linear_size), batch_size=batch_size, drop_last=True) @@ -98,12 +98,11 @@ def test_sharding_stage2_offload(): mlp_offload_params = train_mlp(mlp_offload, offload=True) for i in range(len(mlp_params)): - for j in range(len(mlp_offload_params)): - if mlp_params[i].name == mlp_offload_params[j].name: - np.testing.assert_allclose( - mlp_params[i].numpy(), - mlp_offload_params[j].numpy(), - rtol=1e-6) + np.testing.assert_allclose( + mlp_params[i].numpy(), + mlp_offload_params[i].numpy(), + rtol=5e-3, + atol=5e-3) return diff --git a/python/paddle/fluid/tests/unittests/hybrid_parallel_pp_clip_grad.py b/python/paddle/fluid/tests/unittests/hybrid_parallel_pp_clip_grad.py index de980f3c3f787e4e55a9ac06b92609d0cbbfb9c6..430c6e0884822dc9d38f593b4cee26f96ed18b3b 100644 --- a/python/paddle/fluid/tests/unittests/hybrid_parallel_pp_clip_grad.py +++ b/python/paddle/fluid/tests/unittests/hybrid_parallel_pp_clip_grad.py @@ -31,5 +31,19 @@ class TestPPClipGrad(TestDistPPTraning): return scheduler, optimizer +class TestPPClipGradParamGroup(TestDistPPTraning): + def build_optimizer(self, model): + grad_clip = paddle.nn.ClipGradByGlobalNorm(0.5) + scheduler = paddle.optimizer.lr.PiecewiseDecay( + boundaries=[2], values=[0.001, 0.002], verbose=True) + optimizer = paddle.optimizer.Momentum( + learning_rate=scheduler, + grad_clip=grad_clip, + parameters=[{ + "params": model.parameters() + }]) + return scheduler, optimizer + + if __name__ == "__main__": unittest.main()