未验证 提交 4d6f8f2a 编写于 作者: W WangXi 提交者: GitHub

optimize ClipGradByGlobalNorm (#34586)

上级 7e707ce8
...@@ -19,11 +19,15 @@ import six ...@@ -19,11 +19,15 @@ import six
import warnings import warnings
import functools import functools
import paddle
from . import layers from . import layers
from . import framework from . import framework
from . import core from . import core
from . import name_scope from . import name_scope
from .dygraph import base as imperative_base from .dygraph import base as imperative_base
from .data_feeder import check_variable_and_dtype
from .framework import in_dygraph_mode
from .layer_helper import LayerHelper
__all__ = [ __all__ = [
'set_gradient_clip', 'ErrorClipByValue', 'ClipGradByValue', 'set_gradient_clip', 'ErrorClipByValue', 'ClipGradByValue',
...@@ -31,6 +35,30 @@ __all__ = [ ...@@ -31,6 +35,30 @@ __all__ = [
] ]
def _squared_l2_norm(x):
r"""
This OP returns the squared L2 norm of a tensor.
"""
if core.is_compiled_with_npu() or core.is_compiled_with_xpu():
square = layers.square(x)
sum_square = layers.reduce_sum(square)
return sum_square
if in_dygraph_mode():
return core.ops.squared_l2_norm(x)
op_type = 'squared_l2_norm'
check_variable_and_dtype(x, 'x', ['float32'], op_type)
helper = LayerHelper(op_type, **locals())
out = helper.create_variable_for_type_inference(x.dtype)
inputs = {"X": x}
outputs = {'Out': out}
helper.append_op(type=op_type, inputs=inputs, outputs=outputs)
return out
class BaseErrorClipAttr(object): class BaseErrorClipAttr(object):
def __str__(self): def __str__(self):
raise NotImplementedError() raise NotImplementedError()
...@@ -416,8 +444,8 @@ class ClipGradByGlobalNorm(ClipGradBase): ...@@ -416,8 +444,8 @@ class ClipGradByGlobalNorm(ClipGradBase):
if g.type == core.VarDesc.VarType.SELECTED_ROWS: if g.type == core.VarDesc.VarType.SELECTED_ROWS:
merge_grad = layers.merge_selected_rows(g) merge_grad = layers.merge_selected_rows(g)
merge_grad = layers.get_tensor_from_selected_rows(merge_grad) merge_grad = layers.get_tensor_from_selected_rows(merge_grad)
square = layers.square(merge_grad)
sum_square = layers.reduce_sum(square) sum_square = _squared_l2_norm(merge_grad)
sum_square_list.append(sum_square) sum_square_list.append(sum_square)
# all parameters have been filterd out # all parameters have been filterd out
...@@ -439,6 +467,7 @@ class ClipGradByGlobalNorm(ClipGradBase): ...@@ -439,6 +467,7 @@ class ClipGradByGlobalNorm(ClipGradBase):
if getattr(p, 'need_clip', True) is False: if getattr(p, 'need_clip', True) is False:
params_and_grads.append((p, g)) params_and_grads.append((p, g))
continue continue
# TODO(wangxi): use inplace elementwise_mul
new_grad = layers.elementwise_mul(x=g, y=clip_var) new_grad = layers.elementwise_mul(x=g, y=clip_var)
params_and_grads.append((p, new_grad)) params_and_grads.append((p, new_grad))
...@@ -460,8 +489,7 @@ class ClipGradByGlobalNorm(ClipGradBase): ...@@ -460,8 +489,7 @@ class ClipGradByGlobalNorm(ClipGradBase):
merge_grad = layers.get_tensor_from_selected_rows( merge_grad = layers.get_tensor_from_selected_rows(
merge_grad) merge_grad)
square = layers.square(merge_grad) sum_square = _squared_l2_norm(merge_grad)
sum_square = layers.reduce_sum(input=square)
sum_square_list.append(sum_square) sum_square_list.append(sum_square)
# all parameters have been filterd out # all parameters have been filterd out
...@@ -489,9 +517,14 @@ class ClipGradByGlobalNorm(ClipGradBase): ...@@ -489,9 +517,14 @@ class ClipGradByGlobalNorm(ClipGradBase):
continue continue
with p.block.program._optimized_guard([p, g]): with p.block.program._optimized_guard([p, g]):
new_grad = layers.elementwise_mul(x=g, y=scale_var) # inplace
param_new_grad_name_dict[p.name] = new_grad.name p.block.append_op(
params_and_grads.append((p, new_grad)) type='elementwise_mul',
inputs={'X': g,
'Y': scale_var},
outputs={'Out': g})
param_new_grad_name_dict[p.name] = g.name
params_and_grads.append((p, g))
_correct_clip_op_role_var(params_and_grads, param_new_grad_name_dict) _correct_clip_op_role_var(params_and_grads, param_new_grad_name_dict)
return params_and_grads return params_and_grads
...@@ -513,8 +546,7 @@ class ClipGradByGlobalNorm(ClipGradBase): ...@@ -513,8 +546,7 @@ class ClipGradByGlobalNorm(ClipGradBase):
merge_grad = layers.merge_selected_rows(grad) merge_grad = layers.merge_selected_rows(grad)
merge_grad = layers.get_tensor_from_selected_rows(merge_grad) merge_grad = layers.get_tensor_from_selected_rows(merge_grad)
square = layers.square(merge_grad) local_norm_var = _squared_l2_norm(merge_grad)
local_norm_var = layers.reduce_sum(input=square)
context[self.group_name].append(local_norm_var) context[self.group_name].append(local_norm_var)
self.context = context self.context = context
...@@ -532,10 +564,14 @@ class ClipGradByGlobalNorm(ClipGradBase): ...@@ -532,10 +564,14 @@ class ClipGradByGlobalNorm(ClipGradBase):
assert group_scale_var.shape == (1, ) assert group_scale_var.shape == (1, )
self.context[group_scale_name] = group_scale_var self.context[group_scale_name] = group_scale_var
new_grad = layers.elementwise_mul( # inplace
x=grad, y=self.context[group_scale_name]) param.block.append_op(
type='elementwise_mul',
inputs={'X': grad,
'Y': self.context[group_scale_name]},
outputs={'Out': grad})
return param, new_grad return param, grad
@framework.dygraph_not_support @framework.dygraph_not_support
...@@ -709,7 +745,7 @@ def _correct_clip_op_role_var(params_grads, param_new_grad_name_dict): ...@@ -709,7 +745,7 @@ def _correct_clip_op_role_var(params_grads, param_new_grad_name_dict):
continue continue
block_id_list.append(block_id) block_id_list.append(block_id)
for op in param.block.program.global_block().ops: for op in param.block.program.global_block().ops:
if 'op_namescope' in op.all_attrs() and "gradient_clip" in op.attr( if op.has_attr("op_namescope") and "gradient_clip" in op.attr(
"op_namescope") and op.attr('op_role_var'): "op_namescope") and op.attr('op_role_var'):
param_name = op.attr('op_role_var')[0] param_name = op.attr('op_role_var')[0]
if param_name in param_new_grad_name_dict: if param_name in param_new_grad_name_dict:
......
...@@ -264,8 +264,8 @@ class TestFleetShardingMetaOptimizer(TestFleetMetaOptimizer): ...@@ -264,8 +264,8 @@ class TestFleetShardingMetaOptimizer(TestFleetMetaOptimizer):
'elementwise_add_grad', 'mul_grad', 'tanh_grad', 'elementwise_add_grad', 'mul_grad', 'tanh_grad',
'elementwise_add_grad', 'mul_grad', 'c_sync_calc_stream', 'elementwise_add_grad', 'mul_grad', 'c_sync_calc_stream',
'c_reduce_sum', 'c_reduce_sum', 'c_reduce_sum', 'c_reduce_sum', 'c_reduce_sum', 'c_reduce_sum', 'c_reduce_sum', 'c_reduce_sum',
'c_reduce_sum', 'c_reduce_sum', 'c_sync_comm_stream', 'square', 'c_reduce_sum', 'c_reduce_sum', 'c_sync_comm_stream',
'reduce_sum', 'square', 'reduce_sum', 'square', 'reduce_sum', 'sum', 'squared_l2_norm', 'squared_l2_norm', 'squared_l2_norm', 'sum',
'c_allreduce_sum', 'sqrt', 'fill_constant', 'elementwise_max', 'c_allreduce_sum', 'sqrt', 'fill_constant', 'elementwise_max',
'elementwise_div', 'elementwise_mul', 'elementwise_mul', 'elementwise_div', 'elementwise_mul', 'elementwise_mul',
'elementwise_mul', 'momentum', 'momentum', 'momentum' 'elementwise_mul', 'momentum', 'momentum', 'momentum'
......
...@@ -22,6 +22,8 @@ import paddle.fluid as fluid ...@@ -22,6 +22,8 @@ import paddle.fluid as fluid
import six import six
from fake_reader import fake_imdb_reader from fake_reader import fake_imdb_reader
paddle.enable_static()
def bow_net(data, def bow_net(data,
label, label,
...@@ -149,7 +151,7 @@ class TestGradientClipByGlobalNorm(TestGradientClip): ...@@ -149,7 +151,7 @@ class TestGradientClipByGlobalNorm(TestGradientClip):
def check_clip_result(self, out, out_clip): def check_clip_result(self, out, out_clip):
global_norm = 0 global_norm = 0
for v in out: for v in out:
global_norm += np.sum(np.power(v, 2)) global_norm += np.sum(np.square(v))
global_norm = np.sqrt(global_norm) global_norm = np.sqrt(global_norm)
scale = self.clip_norm / np.maximum(self.clip_norm, global_norm) scale = self.clip_norm / np.maximum(self.clip_norm, global_norm)
res = [] res = []
...@@ -160,7 +162,8 @@ class TestGradientClipByGlobalNorm(TestGradientClip): ...@@ -160,7 +162,8 @@ class TestGradientClipByGlobalNorm(TestGradientClip):
self.assertTrue( self.assertTrue(
np.allclose( np.allclose(
a=u, b=v, rtol=1e-5, atol=1e-8), a=u, b=v, rtol=1e-5, atol=1e-8),
"gradient clip by global norm has wrong results!") "gradient clip by global norm has wrong results!, \nu={}\nv={}\ndiff={}".
format(u, v, u - v))
# test whether the ouput is right when use 'set_gradient_clip' # test whether the ouput is right when use 'set_gradient_clip'
def test_old_gradient_clip(self): def test_old_gradient_clip(self):
...@@ -210,12 +213,16 @@ class TestGradientClipByGlobalNorm(TestGradientClip): ...@@ -210,12 +213,16 @@ class TestGradientClipByGlobalNorm(TestGradientClip):
params_grads = [(x, None), (x, y), (y, x)] params_grads = [(x, None), (x, y), (y, x)]
params_grads = clip(params_grads) params_grads = clip(params_grads)
self.assertTrue( self.assertTrue(
len(clip(params_grads)) == 2, len(params_grads) == 2,
"ClipByGlobalNorm: when grad is None, it shouldn't be returned by gradient clip!" "ClipByGlobalNorm: when grad is None, it shouldn't be returned by gradient clip!"
) )
self.assertTrue(
params_grads[0][1].name != 'y', ops = [op.type for op in x.block.ops]
"ClipByGlobalNorm: param_grad (x, y) should be clipped!") self.assertListEqual(ops, [
'squared_l2_norm', 'squared_l2_norm', 'sum', 'sqrt',
'fill_constant', 'elementwise_max', 'elementwise_div',
'elementwise_mul', 'elementwise_mul'
])
# raise typeError # raise typeError
def test_tpyeError(self): def test_tpyeError(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册