From 3a869cc5f68cae83cd536f1cfd46bbf2c7d7e0b0 Mon Sep 17 00:00:00 2001 From: zhangbo9674 <82555433+zhangbo9674@users.noreply.github.com> Date: Wed, 13 Oct 2021 15:56:26 +0800 Subject: [PATCH] Add fp16 for clip_by_norm & clip_by_global_norm (#36198) * add fp16 for clip_by_norm api * support ClipByGlobalNorm for fp16 in dygraph * add unittest for dygraph clipGlobalNorm * refine unittest for dygraph clipGlobalNorm for mac and windows * refine unittest * add unittest for fp64 * refine unittest for fp64 --- python/paddle/fluid/clip.py | 35 +++++- python/paddle/fluid/layers/nn.py | 2 +- .../tests/unittests/test_gradient_clip.py | 113 ++++++++++++++++++ 3 files changed, 145 insertions(+), 5 deletions(-) diff --git a/python/paddle/fluid/clip.py b/python/paddle/fluid/clip.py index 4cca41b527b..293d6119e75 100644 --- a/python/paddle/fluid/clip.py +++ b/python/paddle/fluid/clip.py @@ -436,6 +436,8 @@ class ClipGradByGlobalNorm(ClipGradBase): def _dygraph_clip(self, params_grads): params_and_grads = [] sum_square_list = [] + sum_square_list_fp16 = [] + sum_square_list_fp32 = [] for p, g in params_grads: if g is None: continue @@ -447,13 +449,36 @@ class ClipGradByGlobalNorm(ClipGradBase): merge_grad = layers.get_tensor_from_selected_rows(merge_grad) sum_square = _squared_l2_norm(merge_grad) - sum_square_list.append(sum_square) + if sum_square.dtype == core.VarDesc.VarType.FP16: + sum_square_list_fp16.append(sum_square) + elif sum_square.dtype == core.VarDesc.VarType.FP32: + sum_square_list_fp32.append(sum_square) + else: + sum_square_list.append(sum_square) # all parameters have been filterd out - if len(sum_square_list) == 0: + if len(sum_square_list) + len(sum_square_list_fp16) + len( + sum_square_list_fp32) == 0: return params_grads - global_norm_var = layers.concat(sum_square_list) + sum_dtype = 'float64' if len(sum_square_list) > 0 else "float32" + global_norm_var = [] + if len(sum_square_list_fp16) > 0: + global_norm_var_fp16 = layers.concat(sum_square_list_fp16) + global_norm_var_fp16 = layers.reduce_sum(global_norm_var_fp16) + global_norm_var.append(global_norm_var_fp16.astype(sum_dtype)) + if len(sum_square_list_fp32) > 0: + global_norm_var_fp32 = layers.concat(sum_square_list_fp32) + global_norm_var_fp32 = layers.reduce_sum(global_norm_var_fp32) + if sum_dtype == 'float32': + global_norm_var.append(global_norm_var_fp32) + else: + global_norm_var.append(global_norm_var_fp32.astype(sum_dtype)) + if len(sum_square_list) > 0: + global_norm_var_fp64 = layers.concat(sum_square_list) + global_norm_var_fp64 = layers.reduce_sum(global_norm_var_fp64) + global_norm_var.append(global_norm_var_fp64) + global_norm_var = layers.concat(global_norm_var) global_norm_var = layers.reduce_sum(global_norm_var) global_norm_var = layers.sqrt(global_norm_var) max_global_norm = layers.fill_constant( @@ -469,7 +494,9 @@ class ClipGradByGlobalNorm(ClipGradBase): params_and_grads.append((p, g)) continue # TODO(wangxi): use inplace elementwise_mul - new_grad = layers.elementwise_mul(x=g, y=clip_var) + clip_input = (clip_var.astype('float16') + if g.dtype == core.VarDesc.VarType.FP16 else clip_var) + new_grad = layers.elementwise_mul(x=g, y=clip_input) params_and_grads.append((p, new_grad)) return params_and_grads diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 75b0392ab6a..ceda304b26e 100755 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -12524,7 +12524,7 @@ def clip_by_norm(x, max_norm, name=None): return _C_ops.clip_by_norm(x, 'max_norm', max_norm) helper = LayerHelper("clip_by_norm", **locals()) - check_variable_and_dtype(x, 'X', ['float32'], 'clip_by_norm') + check_variable_and_dtype(x, 'X', ['float32', 'float16'], 'clip_by_norm') check_type(max_norm, 'max_norm', (float), 'clip_by_norm') if name is None: diff --git a/python/paddle/fluid/tests/unittests/test_gradient_clip.py b/python/paddle/fluid/tests/unittests/test_gradient_clip.py index e2050cf32db..29735f1c89c 100644 --- a/python/paddle/fluid/tests/unittests/test_gradient_clip.py +++ b/python/paddle/fluid/tests/unittests/test_gradient_clip.py @@ -453,5 +453,118 @@ class TestDygraphGradientClipByValue(TestDygraphGradientClip): "gradient clip by value has wrong results!") +class SimpleNet(paddle.nn.Layer): + def __init__(self): + super(SimpleNet, self).__init__() + self.linear = paddle.nn.Linear(5, 5) + self.batch_norm = paddle.nn.BatchNorm(5) + + def forward(self, x): + x = self.linear(x) + x = self.batch_norm(x) + return x + + +class TestDygraphGradientClipFP16(unittest.TestCase): + def test_gradient_clip(self): + if fluid.core.is_compiled_with_cuda(): + with fluid.dygraph.guard(): + paddle.seed(10) + model = SimpleNet() + sgd_optimizer = paddle.optimizer.SGD( + learning_rate=0.0, parameters=model.parameters()) + model, sgd_optimizer = paddle.amp.decorate( + models=model, optimizers=sgd_optimizer, level='O2') + scaler = paddle.amp.GradScaler(init_loss_scaling=1024) + inputs = fluid.layers.uniform_random( + [1, 5], min=-10, max=10).astype('float32') + with paddle.amp.auto_cast(level='O2'): + out = model(fluid.dygraph.to_variable(inputs)) + loss = fluid.layers.reduce_mean(out) + scaled = scaler.scale(loss) + scaled.backward() + scaler.unscale_(sgd_optimizer) + # before clip + params_grads = [] + for param in model.parameters(): + if param.stop_gradient: + continue + if param._grad_ivar() is not None: + params_grads.append((param, param._grad_ivar())) + _, grads = zip(*params_grads) + # clip grads + clip = fluid.clip.GradientClipByGlobalNorm(clip_norm=0.8) + params_grads = clip(params_grads) + _, grads_clip = zip(*params_grads) + # param update + scaler.step(sgd_optimizer) + scaler.update() + + global_norm = 0 + for u in grads: + u = u.numpy() + global_norm += np.sum(np.power(u, 2)) + global_norm = np.sqrt(global_norm) + global_norm_clip = 0 + for v in grads_clip: + v = v.numpy() + global_norm_clip += np.sum(np.power(v, 2)) + global_norm_clip = np.sqrt(global_norm_clip) + + a = np.minimum(global_norm, 0.8) + b = global_norm_clip + self.assertTrue( + np.isclose( + a=a, b=b, rtol=1e-3, atol=1e-8), + "gradient clip by global norm has wrong results, expetcd:%f, but recieved:%f" + % (a, b)) + + +class TestDygraphGradientClipFP64(unittest.TestCase): + def test_gradient_clip(self): + with fluid.dygraph.guard(): + inputs = fluid.layers.uniform_random( + [16, 5], min=-10, max=10).astype('float64') + linear = fluid.dygraph.Linear(5, 5, dtype="float64") + out = linear(fluid.dygraph.to_variable(inputs)) + loss = fluid.layers.reduce_mean(out) + loss.backward() + # before clip + params_grads = [] + for param in linear.parameters(): + if param.stop_gradient: + continue + if param._grad_ivar() is not None: + params_grads.append((param, param._grad_ivar())) + _, grads = zip(*params_grads) + # clip grads + clip = fluid.clip.GradientClipByGlobalNorm(clip_norm=0.1) + params_grads = clip(params_grads) + _, grads_clip = zip(*params_grads) + + global_norm = 0 + for u in grads: + u = u.numpy() + global_norm += np.sum(np.power(u, 2)) + global_norm = np.sqrt(global_norm) + + global_norm_clip = 0 + for v in grads_clip: + v = v.numpy() + print(v) + global_norm_clip += np.sum(np.power(v, 2)) + global_norm_clip = np.sqrt(global_norm_clip) + print(global_norm_clip) + + a = np.minimum(global_norm, 0.1) + b = global_norm_clip + + self.assertTrue( + np.isclose( + a=a, b=b, rtol=1e-6, atol=1e-8), + "gradient clip by global norm has wrong results, expetcd:%f, but recieved:%f" + % (a, b)) + + if __name__ == '__main__': unittest.main() -- GitLab