From 5bdca05b786b6ac49b30f0891b73313eb81b8e2c Mon Sep 17 00:00:00 2001 From: huangxu96 <46740794+huangxu96@users.noreply.github.com> Date: Fri, 10 Sep 2021 15:11:41 +0800 Subject: [PATCH] Support float16 when using ClipGradByGlobalNorm. (#33565) This PR supports gradient clip (ClipGradByGlobalNorm) when training with AMP(auto mixed precision). --- paddle/fluid/operators/squared_l2_norm_op.cc | 6 +- paddle/fluid/operators/squared_l2_norm_op.cu | 6 +- python/paddle/fluid/clip.py | 44 ++++++++-- python/paddle/fluid/layers/tensor.py | 4 +- .../test_fleet_sharding_meta_optimizer.py | 7 +- .../tests/unittests/test_gradient_clip.py | 80 +++++++++++++------ 6 files changed, 105 insertions(+), 42 deletions(-) diff --git a/paddle/fluid/operators/squared_l2_norm_op.cc b/paddle/fluid/operators/squared_l2_norm_op.cc index 79781dde58d..433dcb38711 100644 --- a/paddle/fluid/operators/squared_l2_norm_op.cc +++ b/paddle/fluid/operators/squared_l2_norm_op.cc @@ -93,7 +93,9 @@ REGISTER_OPERATOR(squared_l2_norm, ops::SquaredL2NormOp, REGISTER_OPERATOR(squared_l2_norm_grad, ops::SquaredL2NormGradOp); REGISTER_OP_CPU_KERNEL( squared_l2_norm, - ops::SquaredL2NormKernel); + ops::SquaredL2NormKernel, + ops::SquaredL2NormKernel); REGISTER_OP_CPU_KERNEL( squared_l2_norm_grad, - ops::SquaredL2NormGradKernel); + ops::SquaredL2NormGradKernel, + ops::SquaredL2NormGradKernel); diff --git a/paddle/fluid/operators/squared_l2_norm_op.cu b/paddle/fluid/operators/squared_l2_norm_op.cu index e31cfeb78ab..b51e56af8ec 100644 --- a/paddle/fluid/operators/squared_l2_norm_op.cu +++ b/paddle/fluid/operators/squared_l2_norm_op.cu @@ -16,7 +16,9 @@ limitations under the License. */ namespace ops = paddle::operators; REGISTER_OP_CUDA_KERNEL( squared_l2_norm, - ops::SquaredL2NormKernel); + ops::SquaredL2NormKernel, + ops::SquaredL2NormKernel); REGISTER_OP_CUDA_KERNEL( squared_l2_norm_grad, - ops::SquaredL2NormGradKernel); + ops::SquaredL2NormGradKernel, + ops::SquaredL2NormGradKernel); diff --git a/python/paddle/fluid/clip.py b/python/paddle/fluid/clip.py index d48cea48a76..e9f5c181a6b 100644 --- a/python/paddle/fluid/clip.py +++ b/python/paddle/fluid/clip.py @@ -40,7 +40,7 @@ def _squared_l2_norm(x): This OP returns the squared L2 norm of a tensor. """ - if core.is_compiled_with_xpu(): + if core.is_compiled_with_xpu() or x.dtype == core.VarDesc.VarType.FP16: square = layers.square(x) sum_square = layers.reduce_sum(square) return sum_square @@ -49,7 +49,7 @@ def _squared_l2_norm(x): return core.ops.squared_l2_norm(x) op_type = 'squared_l2_norm' - check_variable_and_dtype(x, 'x', ['float32'], op_type) + check_variable_and_dtype(x, 'x', ['float32', 'float64'], op_type) helper = LayerHelper(op_type, **locals()) out = helper.create_variable_for_type_inference(x.dtype) @@ -476,6 +476,8 @@ class ClipGradByGlobalNorm(ClipGradBase): def _static_clip(self, params_grads): params_and_grads = [] sum_square_list = [] + sum_square_list_fp16 = [] + sum_square_list_fp32 = [] with framework.name_scope('gradient_clip'): for p, g in params_grads: if g is None: @@ -488,16 +490,39 @@ class ClipGradByGlobalNorm(ClipGradBase): merge_grad = layers.merge_selected_rows(g) merge_grad = layers.get_tensor_from_selected_rows( merge_grad) - sum_square = _squared_l2_norm(merge_grad) - sum_square_list.append(sum_square) + if sum_square.dtype == core.VarDesc.VarType.FP16: + sum_square_list_fp16.append(sum_square) + elif sum_square.dtype == core.VarDesc.VarType.FP32: + sum_square_list_fp32.append(sum_square) + else: + sum_square_list.append(sum_square) # all parameters have been filterd out - if len(sum_square_list) == 0: + if len(sum_square_list) + len(sum_square_list_fp16) + len( + sum_square_list_fp32) == 0: return params_grads with p.block.program._optimized_guard([p, g]): - global_norm_var = layers.sums(sum_square_list) + sum_dtype = 'float64' if len(sum_square_list) > 0 else "float32" + + global_norm_var = [] + if len(sum_square_list_fp16) > 0: + global_norm_var_fp16 = layers.sums(sum_square_list_fp16) + global_norm_var.append( + global_norm_var_fp16.astype(sum_dtype)) + if len(sum_square_list_fp32) > 0: + global_norm_var_fp32 = layers.sums(sum_square_list_fp32) + if sum_dtype == 'float32': + global_norm_var.append(global_norm_var_fp32) + else: + global_norm_var.append( + global_norm_var_fp32.astype(sum_dtype)) + if len(sum_square_list) > 0: + # fp64 + global_norm_var_other_dtype = layers.sums(sum_square_list) + global_norm_var.append(global_norm_var_other_dtype) + global_norm_var = layers.sums(global_norm_var) global_norm_var = layers.sqrt(x=global_norm_var) max_global_norm = layers.fill_constant( shape=[1], @@ -507,7 +532,6 @@ class ClipGradByGlobalNorm(ClipGradBase): x=max_global_norm, y=layers.elementwise_max( x=max_global_norm, y=global_norm_var)) - param_new_grad_name_dict = dict() for p, g in params_grads: if g is None: @@ -518,11 +542,15 @@ class ClipGradByGlobalNorm(ClipGradBase): with p.block.program._optimized_guard([p, g]): # inplace + scale_input = (scale_var.astype('float16') + if g.dtype == core.VarDesc.VarType.FP16 else + scale_var) p.block.append_op( type='elementwise_mul', inputs={'X': g, - 'Y': scale_var}, + 'Y': scale_input}, outputs={'Out': g}) + param_new_grad_name_dict[p.name] = g.name params_and_grads.append((p, g)) diff --git a/python/paddle/fluid/layers/tensor.py b/python/paddle/fluid/layers/tensor.py index 85e9424f2ea..06b2d513775 100644 --- a/python/paddle/fluid/layers/tensor.py +++ b/python/paddle/fluid/layers/tensor.py @@ -538,10 +538,10 @@ def sums(input, out=None): if isinstance(input, list) or isinstance(input, tuple): for input_section in input: check_variable_and_dtype(input_section, "input", \ - ['float32', 'float64', 'int32', 'int64'], 'sums') + ['float16', 'float32', 'float64', 'int32', 'int64'], 'sums') else: check_variable_and_dtype(input, "input", \ - ['float32', 'float64', 'int32', 'int64'], 'sums') + ['float16', 'float32', 'float64', 'int32', 'int64'], 'sums') helper = LayerHelper('sum', **locals()) if out is None: diff --git a/python/paddle/fluid/tests/unittests/test_fleet_sharding_meta_optimizer.py b/python/paddle/fluid/tests/unittests/test_fleet_sharding_meta_optimizer.py index 3b0df74d3e6..c462896eed2 100755 --- a/python/paddle/fluid/tests/unittests/test_fleet_sharding_meta_optimizer.py +++ b/python/paddle/fluid/tests/unittests/test_fleet_sharding_meta_optimizer.py @@ -266,9 +266,10 @@ class TestFleetShardingMetaOptimizer(TestFleetMetaOptimizer): 'c_reduce_sum', 'c_reduce_sum', 'c_reduce_sum', 'c_reduce_sum', 'c_reduce_sum', 'c_reduce_sum', 'c_sync_comm_stream', 'squared_l2_norm', 'squared_l2_norm', 'squared_l2_norm', 'sum', - 'c_allreduce_sum', 'sqrt', 'fill_constant', 'elementwise_max', - 'elementwise_div', 'elementwise_mul', 'elementwise_mul', - 'elementwise_mul', 'momentum', 'momentum', 'momentum' + 'c_allreduce_sum', 'sum', 'c_allreduce_sum', 'sqrt', + 'fill_constant', 'elementwise_max', 'elementwise_div', + 'elementwise_mul', 'elementwise_mul', 'elementwise_mul', 'momentum', + 'momentum', 'momentum' ]) def test_sharding_clone_for_test(self): diff --git a/python/paddle/fluid/tests/unittests/test_gradient_clip.py b/python/paddle/fluid/tests/unittests/test_gradient_clip.py index 80cb25bba47..4360214e7dd 100644 --- a/python/paddle/fluid/tests/unittests/test_gradient_clip.py +++ b/python/paddle/fluid/tests/unittests/test_gradient_clip.py @@ -71,14 +71,18 @@ class TestGradientClip(unittest.TestCase): def check_clip_result(self, out, out_clip): pass - def check_gradient_clip(self, place): + def check_gradient_clip(self, place, dtype='float32'): prog = fluid.Program() startup_program = fluid.Program() with fluid.program_guard( main_program=prog, startup_program=startup_program): image = fluid.data(name="a", shape=[-1, 784], dtype='float32') label = fluid.data(name="b", shape=[-1, 1], dtype='int64') - hidden = fluid.layers.fc(input=image, size=32, act='relu') + if dtype != 'float32': + image_cast = paddle.cast(image, dtype) + hidden = fluid.layers.fc(input=image_cast, size=32, act='relu') + else: + hidden = fluid.layers.fc(input=image, size=32, act='relu') predict = fluid.layers.fc(input=hidden, size=10, act='softmax') cost = fluid.layers.cross_entropy(input=predict, label=label) @@ -176,6 +180,15 @@ class TestGradientClipByGlobalNorm(TestGradientClip): self.clip_gradient = func self.check_gradient_clip(fluid.CPUPlace()) + # test whether the ouput is right when use grad_clip under float64 + def test_new_gradient_clip_fp64(self): + def func(params_grads): + clip = fluid.clip.GradientClipByGlobalNorm(clip_norm=self.clip_norm) + return clip(params_grads) + + self.clip_gradient = func + self.check_gradient_clip(fluid.CPUPlace(), "float64") + # invoke 'set_gradient_clip' in a wrong order def test_wrong_API_order(self): def backward_func(cost): @@ -192,29 +205,6 @@ class TestGradientClipByGlobalNorm(TestGradientClip): for place in self.get_places(): self.check_sparse_gradient_clip(place) - # if grad is None or not need clip - def test_none_grad(self): - clip = fluid.clip.GradientClipByGlobalNorm(self.clip_norm) - x = fluid.default_main_program().global_block().create_parameter( - name="x", shape=[2, 3], dtype="float32") - y = fluid.default_main_program().global_block().create_parameter( - name="y", shape=[2, 3], dtype="float32") - - # (x, None) should not be returned - params_grads = [(x, None), (x, y), (y, x)] - params_grads = clip(params_grads) - self.assertTrue( - len(params_grads) == 2, - "ClipByGlobalNorm: when grad is None, it shouldn't be returned by gradient clip!" - ) - - ops = [op.type for op in x.block.ops] - self.assertListEqual(ops, [ - 'squared_l2_norm', 'squared_l2_norm', 'sum', 'sqrt', - 'fill_constant', 'elementwise_max', 'elementwise_div', - 'elementwise_mul', 'elementwise_mul' - ]) - # raise typeError def test_tpyeError(self): # the type of optimizer(grad_clip=) must be an instance of GradientClipBase's derived class @@ -222,6 +212,46 @@ class TestGradientClipByGlobalNorm(TestGradientClip): sgd_optimizer = fluid.optimizer.SGD(learning_rate=0.1, grad_clip="test") + # if grad is None or not need clip + def test_none_grad_fp32(self): + ops = self._test_none_grad_helper("float32") + self.assertListEqual(ops, [ + 'squared_l2_norm', 'squared_l2_norm', 'sum', 'sum', 'sqrt', + 'fill_constant', 'elementwise_max', 'elementwise_div', + 'elementwise_mul', 'elementwise_mul' + ]) + + def test_none_grad_fp16(self): + ops = self._test_none_grad_helper("float16") + self.assertListEqual(ops, [ + 'square', 'reduce_sum', 'square', 'reduce_sum', 'sum', 'cast', + 'sum', 'sqrt', 'fill_constant', 'elementwise_max', + 'elementwise_div', 'cast', 'elementwise_mul', 'cast', + 'elementwise_mul' + ]) + + def _test_none_grad_helper(self, dtype): + prog = fluid.Program() + startup_program = fluid.Program() + with fluid.program_guard( + main_program=prog, startup_program=startup_program): + clip = fluid.clip.GradientClipByGlobalNorm(self.clip_norm) + x = fluid.default_main_program().global_block().create_parameter( + name="x", shape=[2, 3], dtype=dtype) + y = fluid.default_main_program().global_block().create_parameter( + name="y", shape=[2, 3], dtype=dtype) + + # (x, None) should not be returned + params_grads = [(x, None), (x, y), (y, x)] + params_grads = clip(params_grads) + self.assertTrue( + len(params_grads) == 2, + "ClipByGlobalNorm: when grad is None, it shouldn't be returned by gradient clip!" + ) + + ops = [op.type for op in x.block.ops] + return ops + class TestGradientClipByNorm(TestGradientClip): def init(self): -- GitLab