未验证 提交 5bdca05b 编写于 作者: H huangxu96 提交者: GitHub

Support float16 when using ClipGradByGlobalNorm. (#33565)

This PR supports gradient clip (ClipGradByGlobalNorm) when training with AMP(auto mixed precision).
上级 11965bca
......@@ -93,7 +93,9 @@ REGISTER_OPERATOR(squared_l2_norm, ops::SquaredL2NormOp,
REGISTER_OPERATOR(squared_l2_norm_grad, ops::SquaredL2NormGradOp);
REGISTER_OP_CPU_KERNEL(
squared_l2_norm,
ops::SquaredL2NormKernel<paddle::platform::CPUDeviceContext, float>);
ops::SquaredL2NormKernel<paddle::platform::CPUDeviceContext, float>,
ops::SquaredL2NormKernel<paddle::platform::CPUDeviceContext, double>);
REGISTER_OP_CPU_KERNEL(
squared_l2_norm_grad,
ops::SquaredL2NormGradKernel<paddle::platform::CPUDeviceContext, float>);
ops::SquaredL2NormGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::SquaredL2NormGradKernel<paddle::platform::CPUDeviceContext, double>);
......@@ -16,7 +16,9 @@ limitations under the License. */
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
squared_l2_norm,
ops::SquaredL2NormKernel<paddle::platform::CUDADeviceContext, float>);
ops::SquaredL2NormKernel<paddle::platform::CUDADeviceContext, float>,
ops::SquaredL2NormKernel<paddle::platform::CUDADeviceContext, double>);
REGISTER_OP_CUDA_KERNEL(
squared_l2_norm_grad,
ops::SquaredL2NormGradKernel<paddle::platform::CUDADeviceContext, float>);
ops::SquaredL2NormGradKernel<paddle::platform::CUDADeviceContext, float>,
ops::SquaredL2NormGradKernel<paddle::platform::CUDADeviceContext, double>);
......@@ -40,7 +40,7 @@ def _squared_l2_norm(x):
This OP returns the squared L2 norm of a tensor.
"""
if core.is_compiled_with_xpu():
if core.is_compiled_with_xpu() or x.dtype == core.VarDesc.VarType.FP16:
square = layers.square(x)
sum_square = layers.reduce_sum(square)
return sum_square
......@@ -49,7 +49,7 @@ def _squared_l2_norm(x):
return core.ops.squared_l2_norm(x)
op_type = 'squared_l2_norm'
check_variable_and_dtype(x, 'x', ['float32'], op_type)
check_variable_and_dtype(x, 'x', ['float32', 'float64'], op_type)
helper = LayerHelper(op_type, **locals())
out = helper.create_variable_for_type_inference(x.dtype)
......@@ -476,6 +476,8 @@ class ClipGradByGlobalNorm(ClipGradBase):
def _static_clip(self, params_grads):
params_and_grads = []
sum_square_list = []
sum_square_list_fp16 = []
sum_square_list_fp32 = []
with framework.name_scope('gradient_clip'):
for p, g in params_grads:
if g is None:
......@@ -488,16 +490,39 @@ class ClipGradByGlobalNorm(ClipGradBase):
merge_grad = layers.merge_selected_rows(g)
merge_grad = layers.get_tensor_from_selected_rows(
merge_grad)
sum_square = _squared_l2_norm(merge_grad)
sum_square_list.append(sum_square)
if sum_square.dtype == core.VarDesc.VarType.FP16:
sum_square_list_fp16.append(sum_square)
elif sum_square.dtype == core.VarDesc.VarType.FP32:
sum_square_list_fp32.append(sum_square)
else:
sum_square_list.append(sum_square)
# all parameters have been filterd out
if len(sum_square_list) == 0:
if len(sum_square_list) + len(sum_square_list_fp16) + len(
sum_square_list_fp32) == 0:
return params_grads
with p.block.program._optimized_guard([p, g]):
global_norm_var = layers.sums(sum_square_list)
sum_dtype = 'float64' if len(sum_square_list) > 0 else "float32"
global_norm_var = []
if len(sum_square_list_fp16) > 0:
global_norm_var_fp16 = layers.sums(sum_square_list_fp16)
global_norm_var.append(
global_norm_var_fp16.astype(sum_dtype))
if len(sum_square_list_fp32) > 0:
global_norm_var_fp32 = layers.sums(sum_square_list_fp32)
if sum_dtype == 'float32':
global_norm_var.append(global_norm_var_fp32)
else:
global_norm_var.append(
global_norm_var_fp32.astype(sum_dtype))
if len(sum_square_list) > 0:
# fp64
global_norm_var_other_dtype = layers.sums(sum_square_list)
global_norm_var.append(global_norm_var_other_dtype)
global_norm_var = layers.sums(global_norm_var)
global_norm_var = layers.sqrt(x=global_norm_var)
max_global_norm = layers.fill_constant(
shape=[1],
......@@ -507,7 +532,6 @@ class ClipGradByGlobalNorm(ClipGradBase):
x=max_global_norm,
y=layers.elementwise_max(
x=max_global_norm, y=global_norm_var))
param_new_grad_name_dict = dict()
for p, g in params_grads:
if g is None:
......@@ -518,11 +542,15 @@ class ClipGradByGlobalNorm(ClipGradBase):
with p.block.program._optimized_guard([p, g]):
# inplace
scale_input = (scale_var.astype('float16')
if g.dtype == core.VarDesc.VarType.FP16 else
scale_var)
p.block.append_op(
type='elementwise_mul',
inputs={'X': g,
'Y': scale_var},
'Y': scale_input},
outputs={'Out': g})
param_new_grad_name_dict[p.name] = g.name
params_and_grads.append((p, g))
......
......@@ -538,10 +538,10 @@ def sums(input, out=None):
if isinstance(input, list) or isinstance(input, tuple):
for input_section in input:
check_variable_and_dtype(input_section, "input", \
['float32', 'float64', 'int32', 'int64'], 'sums')
['float16', 'float32', 'float64', 'int32', 'int64'], 'sums')
else:
check_variable_and_dtype(input, "input", \
['float32', 'float64', 'int32', 'int64'], 'sums')
['float16', 'float32', 'float64', 'int32', 'int64'], 'sums')
helper = LayerHelper('sum', **locals())
if out is None:
......
......@@ -266,9 +266,10 @@ class TestFleetShardingMetaOptimizer(TestFleetMetaOptimizer):
'c_reduce_sum', 'c_reduce_sum', 'c_reduce_sum', 'c_reduce_sum',
'c_reduce_sum', 'c_reduce_sum', 'c_sync_comm_stream',
'squared_l2_norm', 'squared_l2_norm', 'squared_l2_norm', 'sum',
'c_allreduce_sum', 'sqrt', 'fill_constant', 'elementwise_max',
'elementwise_div', 'elementwise_mul', 'elementwise_mul',
'elementwise_mul', 'momentum', 'momentum', 'momentum'
'c_allreduce_sum', 'sum', 'c_allreduce_sum', 'sqrt',
'fill_constant', 'elementwise_max', 'elementwise_div',
'elementwise_mul', 'elementwise_mul', 'elementwise_mul', 'momentum',
'momentum', 'momentum'
])
def test_sharding_clone_for_test(self):
......
......@@ -71,14 +71,18 @@ class TestGradientClip(unittest.TestCase):
def check_clip_result(self, out, out_clip):
pass
def check_gradient_clip(self, place):
def check_gradient_clip(self, place, dtype='float32'):
prog = fluid.Program()
startup_program = fluid.Program()
with fluid.program_guard(
main_program=prog, startup_program=startup_program):
image = fluid.data(name="a", shape=[-1, 784], dtype='float32')
label = fluid.data(name="b", shape=[-1, 1], dtype='int64')
hidden = fluid.layers.fc(input=image, size=32, act='relu')
if dtype != 'float32':
image_cast = paddle.cast(image, dtype)
hidden = fluid.layers.fc(input=image_cast, size=32, act='relu')
else:
hidden = fluid.layers.fc(input=image, size=32, act='relu')
predict = fluid.layers.fc(input=hidden, size=10, act='softmax')
cost = fluid.layers.cross_entropy(input=predict, label=label)
......@@ -176,6 +180,15 @@ class TestGradientClipByGlobalNorm(TestGradientClip):
self.clip_gradient = func
self.check_gradient_clip(fluid.CPUPlace())
# test whether the ouput is right when use grad_clip under float64
def test_new_gradient_clip_fp64(self):
def func(params_grads):
clip = fluid.clip.GradientClipByGlobalNorm(clip_norm=self.clip_norm)
return clip(params_grads)
self.clip_gradient = func
self.check_gradient_clip(fluid.CPUPlace(), "float64")
# invoke 'set_gradient_clip' in a wrong order
def test_wrong_API_order(self):
def backward_func(cost):
......@@ -192,29 +205,6 @@ class TestGradientClipByGlobalNorm(TestGradientClip):
for place in self.get_places():
self.check_sparse_gradient_clip(place)
# if grad is None or not need clip
def test_none_grad(self):
clip = fluid.clip.GradientClipByGlobalNorm(self.clip_norm)
x = fluid.default_main_program().global_block().create_parameter(
name="x", shape=[2, 3], dtype="float32")
y = fluid.default_main_program().global_block().create_parameter(
name="y", shape=[2, 3], dtype="float32")
# (x, None) should not be returned
params_grads = [(x, None), (x, y), (y, x)]
params_grads = clip(params_grads)
self.assertTrue(
len(params_grads) == 2,
"ClipByGlobalNorm: when grad is None, it shouldn't be returned by gradient clip!"
)
ops = [op.type for op in x.block.ops]
self.assertListEqual(ops, [
'squared_l2_norm', 'squared_l2_norm', 'sum', 'sqrt',
'fill_constant', 'elementwise_max', 'elementwise_div',
'elementwise_mul', 'elementwise_mul'
])
# raise typeError
def test_tpyeError(self):
# the type of optimizer(grad_clip=) must be an instance of GradientClipBase's derived class
......@@ -222,6 +212,46 @@ class TestGradientClipByGlobalNorm(TestGradientClip):
sgd_optimizer = fluid.optimizer.SGD(learning_rate=0.1,
grad_clip="test")
# if grad is None or not need clip
def test_none_grad_fp32(self):
ops = self._test_none_grad_helper("float32")
self.assertListEqual(ops, [
'squared_l2_norm', 'squared_l2_norm', 'sum', 'sum', 'sqrt',
'fill_constant', 'elementwise_max', 'elementwise_div',
'elementwise_mul', 'elementwise_mul'
])
def test_none_grad_fp16(self):
ops = self._test_none_grad_helper("float16")
self.assertListEqual(ops, [
'square', 'reduce_sum', 'square', 'reduce_sum', 'sum', 'cast',
'sum', 'sqrt', 'fill_constant', 'elementwise_max',
'elementwise_div', 'cast', 'elementwise_mul', 'cast',
'elementwise_mul'
])
def _test_none_grad_helper(self, dtype):
prog = fluid.Program()
startup_program = fluid.Program()
with fluid.program_guard(
main_program=prog, startup_program=startup_program):
clip = fluid.clip.GradientClipByGlobalNorm(self.clip_norm)
x = fluid.default_main_program().global_block().create_parameter(
name="x", shape=[2, 3], dtype=dtype)
y = fluid.default_main_program().global_block().create_parameter(
name="y", shape=[2, 3], dtype=dtype)
# (x, None) should not be returned
params_grads = [(x, None), (x, y), (y, x)]
params_grads = clip(params_grads)
self.assertTrue(
len(params_grads) == 2,
"ClipByGlobalNorm: when grad is None, it shouldn't be returned by gradient clip!"
)
ops = [op.type for op in x.block.ops]
return ops
class TestGradientClipByNorm(TestGradientClip):
def init(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册