未验证 提交 3a869cc5 编写于 作者: Z zhangbo9674 提交者: GitHub

Add fp16 for clip_by_norm & clip_by_global_norm (#36198)

* add fp16 for clip_by_norm api

* support ClipByGlobalNorm for fp16 in dygraph

* add unittest for dygraph clipGlobalNorm

* refine unittest for dygraph clipGlobalNorm for mac and windows

* refine unittest

* add unittest for fp64

* refine unittest for fp64
上级 85bb1a85
......@@ -436,6 +436,8 @@ class ClipGradByGlobalNorm(ClipGradBase):
def _dygraph_clip(self, params_grads):
params_and_grads = []
sum_square_list = []
sum_square_list_fp16 = []
sum_square_list_fp32 = []
for p, g in params_grads:
if g is None:
continue
......@@ -447,13 +449,36 @@ class ClipGradByGlobalNorm(ClipGradBase):
merge_grad = layers.get_tensor_from_selected_rows(merge_grad)
sum_square = _squared_l2_norm(merge_grad)
if sum_square.dtype == core.VarDesc.VarType.FP16:
sum_square_list_fp16.append(sum_square)
elif sum_square.dtype == core.VarDesc.VarType.FP32:
sum_square_list_fp32.append(sum_square)
else:
sum_square_list.append(sum_square)
# all parameters have been filterd out
if len(sum_square_list) == 0:
if len(sum_square_list) + len(sum_square_list_fp16) + len(
sum_square_list_fp32) == 0:
return params_grads
global_norm_var = layers.concat(sum_square_list)
sum_dtype = 'float64' if len(sum_square_list) > 0 else "float32"
global_norm_var = []
if len(sum_square_list_fp16) > 0:
global_norm_var_fp16 = layers.concat(sum_square_list_fp16)
global_norm_var_fp16 = layers.reduce_sum(global_norm_var_fp16)
global_norm_var.append(global_norm_var_fp16.astype(sum_dtype))
if len(sum_square_list_fp32) > 0:
global_norm_var_fp32 = layers.concat(sum_square_list_fp32)
global_norm_var_fp32 = layers.reduce_sum(global_norm_var_fp32)
if sum_dtype == 'float32':
global_norm_var.append(global_norm_var_fp32)
else:
global_norm_var.append(global_norm_var_fp32.astype(sum_dtype))
if len(sum_square_list) > 0:
global_norm_var_fp64 = layers.concat(sum_square_list)
global_norm_var_fp64 = layers.reduce_sum(global_norm_var_fp64)
global_norm_var.append(global_norm_var_fp64)
global_norm_var = layers.concat(global_norm_var)
global_norm_var = layers.reduce_sum(global_norm_var)
global_norm_var = layers.sqrt(global_norm_var)
max_global_norm = layers.fill_constant(
......@@ -469,7 +494,9 @@ class ClipGradByGlobalNorm(ClipGradBase):
params_and_grads.append((p, g))
continue
# TODO(wangxi): use inplace elementwise_mul
new_grad = layers.elementwise_mul(x=g, y=clip_var)
clip_input = (clip_var.astype('float16')
if g.dtype == core.VarDesc.VarType.FP16 else clip_var)
new_grad = layers.elementwise_mul(x=g, y=clip_input)
params_and_grads.append((p, new_grad))
return params_and_grads
......
......@@ -12524,7 +12524,7 @@ def clip_by_norm(x, max_norm, name=None):
return _C_ops.clip_by_norm(x, 'max_norm', max_norm)
helper = LayerHelper("clip_by_norm", **locals())
check_variable_and_dtype(x, 'X', ['float32'], 'clip_by_norm')
check_variable_and_dtype(x, 'X', ['float32', 'float16'], 'clip_by_norm')
check_type(max_norm, 'max_norm', (float), 'clip_by_norm')
if name is None:
......
......@@ -453,5 +453,118 @@ class TestDygraphGradientClipByValue(TestDygraphGradientClip):
"gradient clip by value has wrong results!")
class SimpleNet(paddle.nn.Layer):
def __init__(self):
super(SimpleNet, self).__init__()
self.linear = paddle.nn.Linear(5, 5)
self.batch_norm = paddle.nn.BatchNorm(5)
def forward(self, x):
x = self.linear(x)
x = self.batch_norm(x)
return x
class TestDygraphGradientClipFP16(unittest.TestCase):
def test_gradient_clip(self):
if fluid.core.is_compiled_with_cuda():
with fluid.dygraph.guard():
paddle.seed(10)
model = SimpleNet()
sgd_optimizer = paddle.optimizer.SGD(
learning_rate=0.0, parameters=model.parameters())
model, sgd_optimizer = paddle.amp.decorate(
models=model, optimizers=sgd_optimizer, level='O2')
scaler = paddle.amp.GradScaler(init_loss_scaling=1024)
inputs = fluid.layers.uniform_random(
[1, 5], min=-10, max=10).astype('float32')
with paddle.amp.auto_cast(level='O2'):
out = model(fluid.dygraph.to_variable(inputs))
loss = fluid.layers.reduce_mean(out)
scaled = scaler.scale(loss)
scaled.backward()
scaler.unscale_(sgd_optimizer)
# before clip
params_grads = []
for param in model.parameters():
if param.stop_gradient:
continue
if param._grad_ivar() is not None:
params_grads.append((param, param._grad_ivar()))
_, grads = zip(*params_grads)
# clip grads
clip = fluid.clip.GradientClipByGlobalNorm(clip_norm=0.8)
params_grads = clip(params_grads)
_, grads_clip = zip(*params_grads)
# param update
scaler.step(sgd_optimizer)
scaler.update()
global_norm = 0
for u in grads:
u = u.numpy()
global_norm += np.sum(np.power(u, 2))
global_norm = np.sqrt(global_norm)
global_norm_clip = 0
for v in grads_clip:
v = v.numpy()
global_norm_clip += np.sum(np.power(v, 2))
global_norm_clip = np.sqrt(global_norm_clip)
a = np.minimum(global_norm, 0.8)
b = global_norm_clip
self.assertTrue(
np.isclose(
a=a, b=b, rtol=1e-3, atol=1e-8),
"gradient clip by global norm has wrong results, expetcd:%f, but recieved:%f"
% (a, b))
class TestDygraphGradientClipFP64(unittest.TestCase):
def test_gradient_clip(self):
with fluid.dygraph.guard():
inputs = fluid.layers.uniform_random(
[16, 5], min=-10, max=10).astype('float64')
linear = fluid.dygraph.Linear(5, 5, dtype="float64")
out = linear(fluid.dygraph.to_variable(inputs))
loss = fluid.layers.reduce_mean(out)
loss.backward()
# before clip
params_grads = []
for param in linear.parameters():
if param.stop_gradient:
continue
if param._grad_ivar() is not None:
params_grads.append((param, param._grad_ivar()))
_, grads = zip(*params_grads)
# clip grads
clip = fluid.clip.GradientClipByGlobalNorm(clip_norm=0.1)
params_grads = clip(params_grads)
_, grads_clip = zip(*params_grads)
global_norm = 0
for u in grads:
u = u.numpy()
global_norm += np.sum(np.power(u, 2))
global_norm = np.sqrt(global_norm)
global_norm_clip = 0
for v in grads_clip:
v = v.numpy()
print(v)
global_norm_clip += np.sum(np.power(v, 2))
global_norm_clip = np.sqrt(global_norm_clip)
print(global_norm_clip)
a = np.minimum(global_norm, 0.1)
b = global_norm_clip
self.assertTrue(
np.isclose(
a=a, b=b, rtol=1e-6, atol=1e-8),
"gradient clip by global norm has wrong results, expetcd:%f, but recieved:%f"
% (a, b))
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册