diff --git a/python/paddle/fluid/clip.py b/python/paddle/fluid/clip.py index 8fd01509331e207af1aaabde1e40404f1a8c6f74..fbe524376e592e5ec096abefdb90ecdc26cdeb97 100644 --- a/python/paddle/fluid/clip.py +++ b/python/paddle/fluid/clip.py @@ -19,11 +19,15 @@ import six import warnings import functools +import paddle from . import layers from . import framework from . import core from . import name_scope from .dygraph import base as imperative_base +from .data_feeder import check_variable_and_dtype +from .framework import in_dygraph_mode +from .layer_helper import LayerHelper __all__ = [ 'set_gradient_clip', 'ErrorClipByValue', 'ClipGradByValue', @@ -31,6 +35,30 @@ __all__ = [ ] +def _squared_l2_norm(x): + r""" + This OP returns the squared L2 norm of a tensor. + """ + + if core.is_compiled_with_npu() or core.is_compiled_with_xpu(): + square = layers.square(x) + sum_square = layers.reduce_sum(square) + return sum_square + + if in_dygraph_mode(): + return core.ops.squared_l2_norm(x) + + op_type = 'squared_l2_norm' + check_variable_and_dtype(x, 'x', ['float32'], op_type) + helper = LayerHelper(op_type, **locals()) + out = helper.create_variable_for_type_inference(x.dtype) + + inputs = {"X": x} + outputs = {'Out': out} + helper.append_op(type=op_type, inputs=inputs, outputs=outputs) + return out + + class BaseErrorClipAttr(object): def __str__(self): raise NotImplementedError() @@ -416,8 +444,8 @@ class ClipGradByGlobalNorm(ClipGradBase): if g.type == core.VarDesc.VarType.SELECTED_ROWS: merge_grad = layers.merge_selected_rows(g) merge_grad = layers.get_tensor_from_selected_rows(merge_grad) - square = layers.square(merge_grad) - sum_square = layers.reduce_sum(square) + + sum_square = _squared_l2_norm(merge_grad) sum_square_list.append(sum_square) # all parameters have been filterd out @@ -439,6 +467,7 @@ class ClipGradByGlobalNorm(ClipGradBase): if getattr(p, 'need_clip', True) is False: params_and_grads.append((p, g)) continue + # TODO(wangxi): use inplace elementwise_mul new_grad = layers.elementwise_mul(x=g, y=clip_var) params_and_grads.append((p, new_grad)) @@ -460,8 +489,7 @@ class ClipGradByGlobalNorm(ClipGradBase): merge_grad = layers.get_tensor_from_selected_rows( merge_grad) - square = layers.square(merge_grad) - sum_square = layers.reduce_sum(input=square) + sum_square = _squared_l2_norm(merge_grad) sum_square_list.append(sum_square) # all parameters have been filterd out @@ -489,9 +517,14 @@ class ClipGradByGlobalNorm(ClipGradBase): continue with p.block.program._optimized_guard([p, g]): - new_grad = layers.elementwise_mul(x=g, y=scale_var) - param_new_grad_name_dict[p.name] = new_grad.name - params_and_grads.append((p, new_grad)) + # inplace + p.block.append_op( + type='elementwise_mul', + inputs={'X': g, + 'Y': scale_var}, + outputs={'Out': g}) + param_new_grad_name_dict[p.name] = g.name + params_and_grads.append((p, g)) _correct_clip_op_role_var(params_and_grads, param_new_grad_name_dict) return params_and_grads @@ -513,8 +546,7 @@ class ClipGradByGlobalNorm(ClipGradBase): merge_grad = layers.merge_selected_rows(grad) merge_grad = layers.get_tensor_from_selected_rows(merge_grad) - square = layers.square(merge_grad) - local_norm_var = layers.reduce_sum(input=square) + local_norm_var = _squared_l2_norm(merge_grad) context[self.group_name].append(local_norm_var) self.context = context @@ -532,10 +564,14 @@ class ClipGradByGlobalNorm(ClipGradBase): assert group_scale_var.shape == (1, ) self.context[group_scale_name] = group_scale_var - new_grad = layers.elementwise_mul( - x=grad, y=self.context[group_scale_name]) + # inplace + param.block.append_op( + type='elementwise_mul', + inputs={'X': grad, + 'Y': self.context[group_scale_name]}, + outputs={'Out': grad}) - return param, new_grad + return param, grad @framework.dygraph_not_support @@ -709,7 +745,7 @@ def _correct_clip_op_role_var(params_grads, param_new_grad_name_dict): continue block_id_list.append(block_id) for op in param.block.program.global_block().ops: - if 'op_namescope' in op.all_attrs() and "gradient_clip" in op.attr( + if op.has_attr("op_namescope") and "gradient_clip" in op.attr( "op_namescope") and op.attr('op_role_var'): param_name = op.attr('op_role_var')[0] if param_name in param_new_grad_name_dict: diff --git a/python/paddle/fluid/tests/unittests/test_fleet_sharding_meta_optimizer.py b/python/paddle/fluid/tests/unittests/test_fleet_sharding_meta_optimizer.py index b6c25e3ad67d3a7d8628cf32aa7cd0c5564915e6..b7cf9dfaec5760e37eadf3d84a439617c5436e8a 100755 --- a/python/paddle/fluid/tests/unittests/test_fleet_sharding_meta_optimizer.py +++ b/python/paddle/fluid/tests/unittests/test_fleet_sharding_meta_optimizer.py @@ -264,8 +264,8 @@ class TestFleetShardingMetaOptimizer(TestFleetMetaOptimizer): 'elementwise_add_grad', 'mul_grad', 'tanh_grad', 'elementwise_add_grad', 'mul_grad', 'c_sync_calc_stream', 'c_reduce_sum', 'c_reduce_sum', 'c_reduce_sum', 'c_reduce_sum', - 'c_reduce_sum', 'c_reduce_sum', 'c_sync_comm_stream', 'square', - 'reduce_sum', 'square', 'reduce_sum', 'square', 'reduce_sum', 'sum', + 'c_reduce_sum', 'c_reduce_sum', 'c_sync_comm_stream', + 'squared_l2_norm', 'squared_l2_norm', 'squared_l2_norm', 'sum', 'c_allreduce_sum', 'sqrt', 'fill_constant', 'elementwise_max', 'elementwise_div', 'elementwise_mul', 'elementwise_mul', 'elementwise_mul', 'momentum', 'momentum', 'momentum' diff --git a/python/paddle/fluid/tests/unittests/test_gradient_clip.py b/python/paddle/fluid/tests/unittests/test_gradient_clip.py index 14f5d4a41a1fed1a81436d0372759db86fc7d1a0..9b6dbc00f7c565016ee5bc24034eb0952e7c9ed2 100644 --- a/python/paddle/fluid/tests/unittests/test_gradient_clip.py +++ b/python/paddle/fluid/tests/unittests/test_gradient_clip.py @@ -22,6 +22,8 @@ import paddle.fluid as fluid import six from fake_reader import fake_imdb_reader +paddle.enable_static() + def bow_net(data, label, @@ -149,7 +151,7 @@ class TestGradientClipByGlobalNorm(TestGradientClip): def check_clip_result(self, out, out_clip): global_norm = 0 for v in out: - global_norm += np.sum(np.power(v, 2)) + global_norm += np.sum(np.square(v)) global_norm = np.sqrt(global_norm) scale = self.clip_norm / np.maximum(self.clip_norm, global_norm) res = [] @@ -160,7 +162,8 @@ class TestGradientClipByGlobalNorm(TestGradientClip): self.assertTrue( np.allclose( a=u, b=v, rtol=1e-5, atol=1e-8), - "gradient clip by global norm has wrong results!") + "gradient clip by global norm has wrong results!, \nu={}\nv={}\ndiff={}". + format(u, v, u - v)) # test whether the ouput is right when use 'set_gradient_clip' def test_old_gradient_clip(self): @@ -210,12 +213,16 @@ class TestGradientClipByGlobalNorm(TestGradientClip): params_grads = [(x, None), (x, y), (y, x)] params_grads = clip(params_grads) self.assertTrue( - len(clip(params_grads)) == 2, + len(params_grads) == 2, "ClipByGlobalNorm: when grad is None, it shouldn't be returned by gradient clip!" ) - self.assertTrue( - params_grads[0][1].name != 'y', - "ClipByGlobalNorm: param_grad (x, y) should be clipped!") + + ops = [op.type for op in x.block.ops] + self.assertListEqual(ops, [ + 'squared_l2_norm', 'squared_l2_norm', 'sum', 'sqrt', + 'fill_constant', 'elementwise_max', 'elementwise_div', + 'elementwise_mul', 'elementwise_mul' + ]) # raise typeError def test_tpyeError(self):