未验证 提交 4d6f8f2a 编写于 作者: W WangXi 提交者: GitHub

optimize ClipGradByGlobalNorm (#34586)

上级 7e707ce8
......@@ -19,11 +19,15 @@ import six
import warnings
import functools
import paddle
from . import layers
from . import framework
from . import core
from . import name_scope
from .dygraph import base as imperative_base
from .data_feeder import check_variable_and_dtype
from .framework import in_dygraph_mode
from .layer_helper import LayerHelper
__all__ = [
'set_gradient_clip', 'ErrorClipByValue', 'ClipGradByValue',
......@@ -31,6 +35,30 @@ __all__ = [
]
def _squared_l2_norm(x):
r"""
This OP returns the squared L2 norm of a tensor.
"""
if core.is_compiled_with_npu() or core.is_compiled_with_xpu():
square = layers.square(x)
sum_square = layers.reduce_sum(square)
return sum_square
if in_dygraph_mode():
return core.ops.squared_l2_norm(x)
op_type = 'squared_l2_norm'
check_variable_and_dtype(x, 'x', ['float32'], op_type)
helper = LayerHelper(op_type, **locals())
out = helper.create_variable_for_type_inference(x.dtype)
inputs = {"X": x}
outputs = {'Out': out}
helper.append_op(type=op_type, inputs=inputs, outputs=outputs)
return out
class BaseErrorClipAttr(object):
def __str__(self):
raise NotImplementedError()
......@@ -416,8 +444,8 @@ class ClipGradByGlobalNorm(ClipGradBase):
if g.type == core.VarDesc.VarType.SELECTED_ROWS:
merge_grad = layers.merge_selected_rows(g)
merge_grad = layers.get_tensor_from_selected_rows(merge_grad)
square = layers.square(merge_grad)
sum_square = layers.reduce_sum(square)
sum_square = _squared_l2_norm(merge_grad)
sum_square_list.append(sum_square)
# all parameters have been filterd out
......@@ -439,6 +467,7 @@ class ClipGradByGlobalNorm(ClipGradBase):
if getattr(p, 'need_clip', True) is False:
params_and_grads.append((p, g))
continue
# TODO(wangxi): use inplace elementwise_mul
new_grad = layers.elementwise_mul(x=g, y=clip_var)
params_and_grads.append((p, new_grad))
......@@ -460,8 +489,7 @@ class ClipGradByGlobalNorm(ClipGradBase):
merge_grad = layers.get_tensor_from_selected_rows(
merge_grad)
square = layers.square(merge_grad)
sum_square = layers.reduce_sum(input=square)
sum_square = _squared_l2_norm(merge_grad)
sum_square_list.append(sum_square)
# all parameters have been filterd out
......@@ -489,9 +517,14 @@ class ClipGradByGlobalNorm(ClipGradBase):
continue
with p.block.program._optimized_guard([p, g]):
new_grad = layers.elementwise_mul(x=g, y=scale_var)
param_new_grad_name_dict[p.name] = new_grad.name
params_and_grads.append((p, new_grad))
# inplace
p.block.append_op(
type='elementwise_mul',
inputs={'X': g,
'Y': scale_var},
outputs={'Out': g})
param_new_grad_name_dict[p.name] = g.name
params_and_grads.append((p, g))
_correct_clip_op_role_var(params_and_grads, param_new_grad_name_dict)
return params_and_grads
......@@ -513,8 +546,7 @@ class ClipGradByGlobalNorm(ClipGradBase):
merge_grad = layers.merge_selected_rows(grad)
merge_grad = layers.get_tensor_from_selected_rows(merge_grad)
square = layers.square(merge_grad)
local_norm_var = layers.reduce_sum(input=square)
local_norm_var = _squared_l2_norm(merge_grad)
context[self.group_name].append(local_norm_var)
self.context = context
......@@ -532,10 +564,14 @@ class ClipGradByGlobalNorm(ClipGradBase):
assert group_scale_var.shape == (1, )
self.context[group_scale_name] = group_scale_var
new_grad = layers.elementwise_mul(
x=grad, y=self.context[group_scale_name])
# inplace
param.block.append_op(
type='elementwise_mul',
inputs={'X': grad,
'Y': self.context[group_scale_name]},
outputs={'Out': grad})
return param, new_grad
return param, grad
@framework.dygraph_not_support
......@@ -709,7 +745,7 @@ def _correct_clip_op_role_var(params_grads, param_new_grad_name_dict):
continue
block_id_list.append(block_id)
for op in param.block.program.global_block().ops:
if 'op_namescope' in op.all_attrs() and "gradient_clip" in op.attr(
if op.has_attr("op_namescope") and "gradient_clip" in op.attr(
"op_namescope") and op.attr('op_role_var'):
param_name = op.attr('op_role_var')[0]
if param_name in param_new_grad_name_dict:
......
......@@ -264,8 +264,8 @@ class TestFleetShardingMetaOptimizer(TestFleetMetaOptimizer):
'elementwise_add_grad', 'mul_grad', 'tanh_grad',
'elementwise_add_grad', 'mul_grad', 'c_sync_calc_stream',
'c_reduce_sum', 'c_reduce_sum', 'c_reduce_sum', 'c_reduce_sum',
'c_reduce_sum', 'c_reduce_sum', 'c_sync_comm_stream', 'square',
'reduce_sum', 'square', 'reduce_sum', 'square', 'reduce_sum', 'sum',
'c_reduce_sum', 'c_reduce_sum', 'c_sync_comm_stream',
'squared_l2_norm', 'squared_l2_norm', 'squared_l2_norm', 'sum',
'c_allreduce_sum', 'sqrt', 'fill_constant', 'elementwise_max',
'elementwise_div', 'elementwise_mul', 'elementwise_mul',
'elementwise_mul', 'momentum', 'momentum', 'momentum'
......
......@@ -22,6 +22,8 @@ import paddle.fluid as fluid
import six
from fake_reader import fake_imdb_reader
paddle.enable_static()
def bow_net(data,
label,
......@@ -149,7 +151,7 @@ class TestGradientClipByGlobalNorm(TestGradientClip):
def check_clip_result(self, out, out_clip):
global_norm = 0
for v in out:
global_norm += np.sum(np.power(v, 2))
global_norm += np.sum(np.square(v))
global_norm = np.sqrt(global_norm)
scale = self.clip_norm / np.maximum(self.clip_norm, global_norm)
res = []
......@@ -160,7 +162,8 @@ class TestGradientClipByGlobalNorm(TestGradientClip):
self.assertTrue(
np.allclose(
a=u, b=v, rtol=1e-5, atol=1e-8),
"gradient clip by global norm has wrong results!")
"gradient clip by global norm has wrong results!, \nu={}\nv={}\ndiff={}".
format(u, v, u - v))
# test whether the ouput is right when use 'set_gradient_clip'
def test_old_gradient_clip(self):
......@@ -210,12 +213,16 @@ class TestGradientClipByGlobalNorm(TestGradientClip):
params_grads = [(x, None), (x, y), (y, x)]
params_grads = clip(params_grads)
self.assertTrue(
len(clip(params_grads)) == 2,
len(params_grads) == 2,
"ClipByGlobalNorm: when grad is None, it shouldn't be returned by gradient clip!"
)
self.assertTrue(
params_grads[0][1].name != 'y',
"ClipByGlobalNorm: param_grad (x, y) should be clipped!")
ops = [op.type for op in x.block.ops]
self.assertListEqual(ops, [
'squared_l2_norm', 'squared_l2_norm', 'sum', 'sqrt',
'fill_constant', 'elementwise_max', 'elementwise_div',
'elementwise_mul', 'elementwise_mul'
])
# raise typeError
def test_tpyeError(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册