未验证 提交 b0cca48e 编写于 作者: H Haohongxiang 提交者: GitHub

[Dygraph] Support param groups in grad_clip (#39175)

* support param groups in grad_clip

* update

* modify for review
上级 faf517b2
...@@ -49,8 +49,6 @@ class HybridParallelClipGrad: ...@@ -49,8 +49,6 @@ class HybridParallelClipGrad:
@imperative_base.no_grad @imperative_base.no_grad
def _dygraph_clip(self, params_grads): def _dygraph_clip(self, params_grads):
params_and_grads = []
sum_square_dist_fp16 = [] sum_square_dist_fp16 = []
sum_square_dist_fp32 = [] sum_square_dist_fp32 = []
sum_square_not_dist_fp16 = [] sum_square_not_dist_fp16 = []
...@@ -153,15 +151,14 @@ class HybridParallelClipGrad: ...@@ -153,15 +151,14 @@ class HybridParallelClipGrad:
if g is None: if g is None:
continue continue
if getattr(p, 'need_clip', True) is False: if getattr(p, 'need_clip', True) is False:
params_and_grads.append((p, g))
continue continue
if p.dtype == paddle.float16: if p.dtype == paddle.float16:
new_grad = layers.elementwise_mul(x=g, y=clip_var_fp16) g.scale_(clip_var_fp16)
else: else:
new_grad = layers.elementwise_mul(x=g, y=clip_var) g.scale_(clip_var)
params_and_grads.append((p, new_grad)) p._reset_grad_inplace_version(True)
return params_and_grads return params_grads
def __getattr__(self, item): def __getattr__(self, item):
return getattr(self._clip, item) return getattr(self._clip, item)
...@@ -201,6 +198,12 @@ class HybridParallelOptimizer: ...@@ -201,6 +198,12 @@ class HybridParallelOptimizer:
else: else:
self._inner_opt._grad_clip = HybridParallelClipGrad( self._inner_opt._grad_clip = HybridParallelClipGrad(
self._inner_opt._grad_clip, hcg) self._inner_opt._grad_clip, hcg)
if self._inner_opt._parameter_list and isinstance(
self._inner_opt._parameter_list[0], dict):
for item in self._inner_opt._param_groups:
if "grad_clip" in item.keys():
item["grad_clip"] = HybridParallelClipGrad(
self._inner_opt._grad_clip, hcg)
@imperative_base.no_grad @imperative_base.no_grad
@framework.dygraph_only @framework.dygraph_only
......
...@@ -109,6 +109,13 @@ class ShardingOptimizerStage2(Optimizer): ...@@ -109,6 +109,13 @@ class ShardingOptimizerStage2(Optimizer):
self._optim._grad_clip = ShardingClipGrad(self._optim._grad_clip, self._optim._grad_clip = ShardingClipGrad(self._optim._grad_clip,
paddle.get_device(), paddle.get_device(),
self.group) self.group)
if self._optim._parameter_list and isinstance(
self._optim._parameter_list[0], dict):
for item in self._optim._param_groups:
if "grad_clip" in item.keys():
item["grad_clip"] = ShardingClipGrad(
self._optim._grad_clip,
paddle.get_device(), self.group)
if offload: if offload:
assert self._pfp16, "Only support offload strategy while using \'Adam\', \'AdamW\' and \'Momentum\' optimizer with AMP/Pure FP16" assert self._pfp16, "Only support offload strategy while using \'Adam\', \'AdamW\' and \'Momentum\' optimizer with AMP/Pure FP16"
......
...@@ -57,8 +57,6 @@ class ShardingClipGrad: ...@@ -57,8 +57,6 @@ class ShardingClipGrad:
@imperative_base.no_grad @imperative_base.no_grad
def _dygraph_clip(self, params_grads): def _dygraph_clip(self, params_grads):
params_and_grads = []
sum_square_fp16 = [] sum_square_fp16 = []
sum_square_fp32 = [] sum_square_fp32 = []
...@@ -114,15 +112,14 @@ class ShardingClipGrad: ...@@ -114,15 +112,14 @@ class ShardingClipGrad:
if g is None: if g is None:
continue continue
if getattr(p, 'need_clip', True) is False: if getattr(p, 'need_clip', True) is False:
params_and_grads.append((p, g))
continue continue
if p.dtype == paddle.float16: if p.dtype == paddle.float16:
new_grad = layers.elementwise_mul(x=g, y=clip_var_fp16) g.scale_(clip_var_fp16)
else: else:
new_grad = layers.elementwise_mul(x=g, y=clip_var) g.scale_(clip_var)
params_and_grads.append((p, new_grad)) p._reset_grad_inplace_version(True)
return params_and_grads return params_grads
def __getattr__(self, item): def __getattr__(self, item):
return getattr(self._clip, item) return getattr(self._clip, item)
......
...@@ -159,10 +159,13 @@ def test_dp_stage2(): ...@@ -159,10 +159,13 @@ def test_dp_stage2():
mlp2 = MLP() mlp2 = MLP()
mlp3 = MLP() mlp3 = MLP()
mlp4 = MLP() mlp4 = MLP()
mlp5 = MLP()
mlp1.set_state_dict(state_dict) mlp1.set_state_dict(state_dict)
mlp2.set_state_dict(state_dict) mlp2.set_state_dict(state_dict)
mlp3.set_state_dict(state_dict) mlp3.set_state_dict(state_dict)
mlp4.set_state_dict(state_dict) mlp4.set_state_dict(state_dict)
mlp5.set_state_dict(state_dict)
dp_params = train_mlp( dp_params = train_mlp(
mlp1, sharding_stage="dp", use_pure_fp16=False, opt_group=False) mlp1, sharding_stage="dp", use_pure_fp16=False, opt_group=False)
stage2_params = train_mlp( stage2_params = train_mlp(
...@@ -181,6 +184,11 @@ def test_dp_stage2(): ...@@ -181,6 +184,11 @@ def test_dp_stage2():
rtol=1e-5, rtol=1e-5,
atol=1e-5) atol=1e-5)
stage2_params = train_mlp(
mlp2, sharding_stage=2, use_pure_fp16=False, opt_group=True)
for i in range(len(dp_params)):
np.testing.assert_allclose(
dp_params[i].numpy(), stage2_params[i].numpy(), rtol=1e-6)
return return
......
...@@ -49,7 +49,7 @@ def train_mlp(model, offload=False): ...@@ -49,7 +49,7 @@ def train_mlp(model, offload=False):
optimizer = ShardingOptimizerStage2( optimizer = ShardingOptimizerStage2(
params=model.parameters(), optim=optimizer, offload=offload) params=model.parameters(), optim=optimizer, offload=offload)
model = ShardingStage2( model = ShardingStage2(
model, optimizer, buffer_max_size=2**21, accumulate_grads=True) model, optimizer, buffer_max_size=2**21, accumulate_grads=False)
train_reader = paddle.batch( train_reader = paddle.batch(
reader_decorator(linear_size), batch_size=batch_size, drop_last=True) reader_decorator(linear_size), batch_size=batch_size, drop_last=True)
...@@ -98,12 +98,11 @@ def test_sharding_stage2_offload(): ...@@ -98,12 +98,11 @@ def test_sharding_stage2_offload():
mlp_offload_params = train_mlp(mlp_offload, offload=True) mlp_offload_params = train_mlp(mlp_offload, offload=True)
for i in range(len(mlp_params)): for i in range(len(mlp_params)):
for j in range(len(mlp_offload_params)): np.testing.assert_allclose(
if mlp_params[i].name == mlp_offload_params[j].name: mlp_params[i].numpy(),
np.testing.assert_allclose( mlp_offload_params[i].numpy(),
mlp_params[i].numpy(), rtol=5e-3,
mlp_offload_params[j].numpy(), atol=5e-3)
rtol=1e-6)
return return
......
...@@ -31,5 +31,19 @@ class TestPPClipGrad(TestDistPPTraning): ...@@ -31,5 +31,19 @@ class TestPPClipGrad(TestDistPPTraning):
return scheduler, optimizer return scheduler, optimizer
class TestPPClipGradParamGroup(TestDistPPTraning):
def build_optimizer(self, model):
grad_clip = paddle.nn.ClipGradByGlobalNorm(0.5)
scheduler = paddle.optimizer.lr.PiecewiseDecay(
boundaries=[2], values=[0.001, 0.002], verbose=True)
optimizer = paddle.optimizer.Momentum(
learning_rate=scheduler,
grad_clip=grad_clip,
parameters=[{
"params": model.parameters()
}])
return scheduler, optimizer
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册