未验证 提交 66dc8e30 编写于 作者: Z Zhou Wei 提交者: GitHub

move the initialize position of grad_clip to optimizer(__init__),and speed up clip (#23782)

上级 361c6ccc
...@@ -137,9 +137,6 @@ class GradientClipBase(object): ...@@ -137,9 +137,6 @@ class GradientClipBase(object):
raise NotImplementedError raise NotImplementedError
def __call__(self, params_grads): def __call__(self, params_grads):
assert len(
params_grads
) > 0, "The number of trainable parameters should be greater than 0."
if framework.in_dygraph_mode(): if framework.in_dygraph_mode():
return self._dygraph_clip(params_grads) return self._dygraph_clip(params_grads)
else: else:
...@@ -147,7 +144,7 @@ class GradientClipBase(object): ...@@ -147,7 +144,7 @@ class GradientClipBase(object):
if getattr(p, 'gradient_clip_attr', None) is not None: if getattr(p, 'gradient_clip_attr', None) is not None:
warnings.warn( warnings.warn(
"'set_gradient_clip' will be ineffective, because you have " "'set_gradient_clip' will be ineffective, because you have "
"pass 'grad_clip' into 'minimize'. So, 'set_gradient_clip' " "set 'grad_clip' in 'optimizer'. So, 'set_gradient_clip' "
"is redundant and you can remove it.") "is redundant and you can remove it.")
break break
return self._static_clip(params_grads) return self._static_clip(params_grads)
...@@ -170,7 +167,7 @@ class GradientClipByValue(GradientClipBase): ...@@ -170,7 +167,7 @@ class GradientClipByValue(GradientClipBase):
The multi-dimensional Tensor :math:`X` is not passed from this class, but the gradients of all parameters in ``Program`` . If ``need_clip`` The multi-dimensional Tensor :math:`X` is not passed from this class, but the gradients of all parameters in ``Program`` . If ``need_clip``
is not None, then only part of gradients can be selected for gradient clipping. is not None, then only part of gradients can be selected for gradient clipping.
Gradient clip will takes effect after being set in ``optimizer.minimize(grad_clip)`` , see the document ``optimizer`` Gradient clip will takes effect after being set in ``optimizer`` , see the document ``optimizer``
(for example: :ref:`api_fluid_optimizer_SGDOptimizer`). (for example: :ref:`api_fluid_optimizer_SGDOptimizer`).
Args: Args:
...@@ -208,8 +205,8 @@ class GradientClipByValue(GradientClipBase): ...@@ -208,8 +205,8 @@ class GradientClipByValue(GradientClipBase):
# return Parameter.name=="fc_0.w_0" # return Parameter.name=="fc_0.w_0"
# clip = fluid.clip.GradientClipByValue(min=-1, max=1, need_clip=fileter_func) # clip = fluid.clip.GradientClipByValue(min=-1, max=1, need_clip=fileter_func)
sgd_optimizer = fluid.optimizer.SGDOptimizer(learning_rate=0.1) sgd_optimizer = fluid.optimizer.SGDOptimizer(learning_rate=0.1, grad_clip=clip)
sgd_optimizer.minimize(loss, grad_clip=clip) sgd_optimizer.minimize(loss)
place = fluid.CPUPlace() place = fluid.CPUPlace()
exe = fluid.Executor(place) exe = fluid.Executor(place)
...@@ -242,8 +239,8 @@ class GradientClipByValue(GradientClipBase): ...@@ -242,8 +239,8 @@ class GradientClipByValue(GradientClipBase):
# clip = fluid.clip.GradientClipByValue(min=-1, max=1, need_clip=fileter_func) # clip = fluid.clip.GradientClipByValue(min=-1, max=1, need_clip=fileter_func)
sgd_optimizer = fluid.optimizer.SGD( sgd_optimizer = fluid.optimizer.SGD(
learning_rate=0.1, parameter_list=linear.parameters()) learning_rate=0.1, parameter_list=linear.parameters(), grad_clip=clip)
sgd_optimizer.minimize(loss, grad_clip=clip) sgd_optimizer.minimize(loss)
""" """
def __init__(self, max, min=None, need_clip=None): def __init__(self, max, min=None, need_clip=None):
...@@ -272,6 +269,7 @@ class GradientClipByValue(GradientClipBase): ...@@ -272,6 +269,7 @@ class GradientClipByValue(GradientClipBase):
def _static_clip(self, params_grads): def _static_clip(self, params_grads):
params_and_grads = [] params_and_grads = []
param_new_grad_name_dict = dict()
with framework.name_scope('gradient_clip'): with framework.name_scope('gradient_clip'):
for p, g in params_grads: for p, g in params_grads:
if g is None: if g is None:
...@@ -284,7 +282,8 @@ class GradientClipByValue(GradientClipBase): ...@@ -284,7 +282,8 @@ class GradientClipByValue(GradientClipBase):
with p.block.program._optimized_guard([p, g]): with p.block.program._optimized_guard([p, g]):
new_grad = layers.clip(x=g, min=self.min, max=self.max) new_grad = layers.clip(x=g, min=self.min, max=self.max)
params_and_grads.append((p, new_grad)) params_and_grads.append((p, new_grad))
_correct_clip_op_role_var(params_and_grads) param_new_grad_name_dict[p.name] = new_grad.name
_correct_clip_op_role_var(params_and_grads, param_new_grad_name_dict)
return params_and_grads return params_and_grads
def _process_context(self, context, param, grad): def _process_context(self, context, param, grad):
...@@ -306,7 +305,7 @@ class GradientClipByNorm(GradientClipBase): ...@@ -306,7 +305,7 @@ class GradientClipByNorm(GradientClipBase):
The multidimensional Tensor :math:`X` is not passed from this class, but the gradients of all parameters in ``Program`` . If ``need_clip`` The multidimensional Tensor :math:`X` is not passed from this class, but the gradients of all parameters in ``Program`` . If ``need_clip``
is not None, then only part of gradients can be selected for gradient clipping. is not None, then only part of gradients can be selected for gradient clipping.
Gradient clip will takes effect after being set in ``optimizer.minimize(grad_clip)`` , see the document ``optimizer`` Gradient clip will takes effect after being set in ``optimizer`` , see the document ``optimizer``
(for example: :ref:`api_fluid_optimizer_SGDOptimizer`). (for example: :ref:`api_fluid_optimizer_SGDOptimizer`).
The clipping formula is: The clipping formula is:
...@@ -359,8 +358,8 @@ class GradientClipByNorm(GradientClipBase): ...@@ -359,8 +358,8 @@ class GradientClipByNorm(GradientClipBase):
# return Parameter.name=="fc_0.w_0" # return Parameter.name=="fc_0.w_0"
# clip = fluid.clip.GradientClipByNorm(clip_norm=1.0, need_clip=fileter_func) # clip = fluid.clip.GradientClipByNorm(clip_norm=1.0, need_clip=fileter_func)
sgd_optimizer = fluid.optimizer.SGDOptimizer(learning_rate=0.1) sgd_optimizer = fluid.optimizer.SGDOptimizer(learning_rate=0.1, grad_clip=clip)
sgd_optimizer.minimize(loss, grad_clip=clip) sgd_optimizer.minimize(loss)
place = fluid.CPUPlace() place = fluid.CPUPlace()
exe = fluid.Executor(place) exe = fluid.Executor(place)
...@@ -394,8 +393,8 @@ class GradientClipByNorm(GradientClipBase): ...@@ -394,8 +393,8 @@ class GradientClipByNorm(GradientClipBase):
# clip = fluid.clip.GradientClipByNorm(clip_norm=1.0, need_clip=fileter_func) # clip = fluid.clip.GradientClipByNorm(clip_norm=1.0, need_clip=fileter_func)
sgd_optimizer = fluid.optimizer.SGD( sgd_optimizer = fluid.optimizer.SGD(
learning_rate=0.1, parameter_list=linear.parameters()) learning_rate=0.1, parameter_list=linear.parameters(), grad_clip=clip)
sgd_optimizer.minimize(loss, grad_clip=clip) sgd_optimizer.minimize(loss)
""" """
...@@ -422,6 +421,7 @@ class GradientClipByNorm(GradientClipBase): ...@@ -422,6 +421,7 @@ class GradientClipByNorm(GradientClipBase):
def _static_clip(self, params_grads): def _static_clip(self, params_grads):
params_and_grads = [] params_and_grads = []
with framework.name_scope('gradient_clip'): with framework.name_scope('gradient_clip'):
param_new_grad_name_dict = dict()
for p, g in params_grads: for p, g in params_grads:
if g is None: if g is None:
continue continue
...@@ -432,8 +432,9 @@ class GradientClipByNorm(GradientClipBase): ...@@ -432,8 +432,9 @@ class GradientClipByNorm(GradientClipBase):
with p.block.program._optimized_guard([p, g]): with p.block.program._optimized_guard([p, g]):
new_grad = layers.clip_by_norm(x=g, max_norm=self.clip_norm) new_grad = layers.clip_by_norm(x=g, max_norm=self.clip_norm)
param_new_grad_name_dict[p.name] = new_grad.name
params_and_grads.append((p, new_grad)) params_and_grads.append((p, new_grad))
_correct_clip_op_role_var(params_and_grads) _correct_clip_op_role_var(params_and_grads, param_new_grad_name_dict)
return params_and_grads return params_and_grads
def _process_context(self, context, param, grad): def _process_context(self, context, param, grad):
...@@ -456,7 +457,7 @@ class GradientClipByGlobalNorm(GradientClipBase): ...@@ -456,7 +457,7 @@ class GradientClipByGlobalNorm(GradientClipBase):
The list of Tensor :math:`t\_list` is not passed from this class, but the gradients of all parameters in ``Program`` . If ``need_clip`` The list of Tensor :math:`t\_list` is not passed from this class, but the gradients of all parameters in ``Program`` . If ``need_clip``
is not None, then only part of gradients can be selected for gradient clipping. is not None, then only part of gradients can be selected for gradient clipping.
Gradient clip will takes effect after being set in ``optimizer.minimize(grad_clip)`` , see the document ``optimizer`` Gradient clip will takes effect after being set in ``optimizer`` , see the document ``optimizer``
(for example: :ref:`api_fluid_optimizer_SGDOptimizer`). (for example: :ref:`api_fluid_optimizer_SGDOptimizer`).
The clipping formula is: The clipping formula is:
...@@ -505,8 +506,8 @@ class GradientClipByGlobalNorm(GradientClipBase): ...@@ -505,8 +506,8 @@ class GradientClipByGlobalNorm(GradientClipBase):
# return Parameter.name=="fc_0.w_0" # return Parameter.name=="fc_0.w_0"
# clip = fluid.clip.GradientClipByGlobalNorm(clip_norm=1.0, need_clip=fileter_func) # clip = fluid.clip.GradientClipByGlobalNorm(clip_norm=1.0, need_clip=fileter_func)
sgd_optimizer = fluid.optimizer.SGDOptimizer(learning_rate=0.1) sgd_optimizer = fluid.optimizer.SGDOptimizer(learning_rate=0.1, grad_clip=clip)
sgd_optimizer.minimize(loss, grad_clip=clip) sgd_optimizer.minimize(loss)
place = fluid.CPUPlace() place = fluid.CPUPlace()
exe = fluid.Executor(place) exe = fluid.Executor(place)
...@@ -539,8 +540,8 @@ class GradientClipByGlobalNorm(GradientClipBase): ...@@ -539,8 +540,8 @@ class GradientClipByGlobalNorm(GradientClipBase):
# clip = fluid.clip.GradientClipByGlobalNorm(clip_norm=1.0, need_clip=fileter_func) # clip = fluid.clip.GradientClipByGlobalNorm(clip_norm=1.0, need_clip=fileter_func)
sgd_optimizer = fluid.optimizer.SGD( sgd_optimizer = fluid.optimizer.SGD(
learning_rate=0.1, parameter_list=linear.parameters()) learning_rate=0.1, parameter_list=linear.parameters(), grad_clip=clip)
sgd_optimizer.minimize(loss, grad_clip=clip) sgd_optimizer.minimize(loss)
""" """
...@@ -628,6 +629,7 @@ class GradientClipByGlobalNorm(GradientClipBase): ...@@ -628,6 +629,7 @@ class GradientClipByGlobalNorm(GradientClipBase):
y=layers.elementwise_max( y=layers.elementwise_max(
x=max_global_norm, y=global_norm_var)) x=max_global_norm, y=global_norm_var))
param_new_grad_name_dict = dict()
for p, g in params_grads: for p, g in params_grads:
if g is None: if g is None:
continue continue
...@@ -638,9 +640,10 @@ class GradientClipByGlobalNorm(GradientClipBase): ...@@ -638,9 +640,10 @@ class GradientClipByGlobalNorm(GradientClipBase):
with p.block.program._optimized_guard([p, g]): with p.block.program._optimized_guard([p, g]):
new_grad = layers.elementwise_mul(x=g, y=scale_var) new_grad = layers.elementwise_mul(x=g, y=scale_var)
param_new_grad_name_dict[p.name] = new_grad.name
params_and_grads.append((p, new_grad)) params_and_grads.append((p, new_grad))
_correct_clip_op_role_var(params_and_grads) _correct_clip_op_role_var(params_and_grads, param_new_grad_name_dict)
return params_and_grads return params_and_grads
def _process_context(self, context, param, grad): def _process_context(self, context, param, grad):
...@@ -692,9 +695,10 @@ def set_gradient_clip(clip, param_list=None, program=None): ...@@ -692,9 +695,10 @@ def set_gradient_clip(clip, param_list=None, program=None):
This API must be used after building network, and before ``minimize`` , This API must be used after building network, and before ``minimize`` ,
and it may be removed in future releases, so it is not recommended. and it may be removed in future releases, so it is not recommended.
It is recommended to use ``minimize(loss, grad_clip=clip)`` to clip gradient. It is recommended to set ``grad_clip`` when initializing the ``optimizer`` ,
There are three clipping strategies: :ref:`api_fluid_clip_GradientClipByGlobalNorm` , this is a better method to clip gradient. There are three clipping strategies:
:ref:`api_fluid_clip_GradientClipByNorm` , :ref:`api_fluid_clip_GradientClipByValue` . :ref:`api_fluid_clip_GradientClipByGlobalNorm` , :ref:`api_fluid_clip_GradientClipByNorm` ,
:ref:`api_fluid_clip_GradientClipByValue` .
To specify parameters that require gradient clip. To specify parameters that require gradient clip.
...@@ -757,7 +761,7 @@ def set_gradient_clip(clip, param_list=None, program=None): ...@@ -757,7 +761,7 @@ def set_gradient_clip(clip, param_list=None, program=None):
sgd = fluid.optimizer.SGD(learning_rate=1e-3) sgd = fluid.optimizer.SGD(learning_rate=1e-3)
sgd.minimize(loss) sgd.minimize(loss)
# network 4: use 'set_gradient_clip' and 'minimize(grad_clip=clip)' together # network 4: use 'set_gradient_clip' and 'optimize(grad_clip=clip)' together
with fluid.program_guard(fluid.Program(), fluid.Program()): with fluid.program_guard(fluid.Program(), fluid.Program()):
loss = network() loss = network()
clip1 = fluid.clip.GradientClipByValue(min=-1.0, max=1.0) clip1 = fluid.clip.GradientClipByValue(min=-1.0, max=1.0)
...@@ -765,8 +769,8 @@ def set_gradient_clip(clip, param_list=None, program=None): ...@@ -765,8 +769,8 @@ def set_gradient_clip(clip, param_list=None, program=None):
# Set the gradient clipping strategy: clip1 # Set the gradient clipping strategy: clip1
fluid.clip.set_gradient_clip(clip1) fluid.clip.set_gradient_clip(clip1)
# Set the gradient clipping strategy: clip2 # Set the gradient clipping strategy: clip2
sgd = fluid.optimizer.SGD(learning_rate=1e-3) sgd = fluid.optimizer.SGD(learning_rate=1e-3, grad_clip=clip2)
sgd.minimize(loss, grad_clip=clip2) sgd.minimize(loss)
# 'set_gradient_clip' will not take effect when setting has a conflict, # 'set_gradient_clip' will not take effect when setting has a conflict,
# and the gradient clipping strategy will be 'clip2' # and the gradient clipping strategy will be 'clip2'
...@@ -774,10 +778,10 @@ def set_gradient_clip(clip, param_list=None, program=None): ...@@ -774,10 +778,10 @@ def set_gradient_clip(clip, param_list=None, program=None):
""" """
warnings.warn("Caution! 'set_gradient_clip' is not recommended " warnings.warn("Caution! 'set_gradient_clip' is not recommended "
"and may be deprecated in future! " "and may be deprecated in future! "
"We recommend a new strategy: clip gradient by " "We recommend a new strategy: set 'grad_clip' "
"'optimizer.minimize(loss, grad_clip=clip)'. " "when initializing the 'optimizer'. "
"This method can reduce the mistakes, please " "This method can reduce the mistakes, please "
"see documention of 'optimzier.minimize'.") "refer to documention of 'optimizer'.")
if not isinstance(clip, GradientClipBase): if not isinstance(clip, GradientClipBase):
raise TypeError( raise TypeError(
...@@ -824,33 +828,40 @@ def append_gradient_clip_ops(param_grads): ...@@ -824,33 +828,40 @@ def append_gradient_clip_ops(param_grads):
clip_attr._process_context(context=context, param=p, grad=g) clip_attr._process_context(context=context, param=p, grad=g)
res = [] res = []
param_new_grad_name_dict = dict()
for p, g in param_grads: for p, g in param_grads:
if g is None: if g is None:
continue continue
with p.block.program._optimized_guard( with p.block.program._optimized_guard(
[p, g]), framework.name_scope('graident_clip_@CLIP'): [p, g]), framework.name_scope('graident_clip_@CLIP'):
param, new_grad = clip_attr._create_operators(param=p, grad=g) param, new_grad = clip_attr._create_operators(param=p, grad=g)
param_new_grad_name_dict[param.name] = new_grad.name
res.append([param, new_grad]) res.append([param, new_grad])
_correct_clip_op_role_var(res) _correct_clip_op_role_var(res, param_new_grad_name_dict)
return res return res
# change wrong mapping relation between param & grad in clip op # change wrong mapping relation between param & grad in clip op
def _correct_clip_op_role_var(params_grads): def _correct_clip_op_role_var(params_grads, param_new_grad_name_dict):
block_id_list = []
if len(param_new_grad_name_dict) == 0:
return
for param, grad in params_grads: for param, grad in params_grads:
if grad is None: if grad is None:
continue continue
block_id = param.block.idx
if block_id in block_id_list:
continue
block_id_list.append(block_id)
for op in param.block.program.global_block().ops: for op in param.block.program.global_block().ops:
if 'op_namescope' in op.all_attrs() and "gradient_clip" in op.attr( if 'op_namescope' in op.all_attrs() and "gradient_clip" in op.attr(
"op_namescope"): "op_namescope") and op.attr('op_role_var'):
if op.attr('op_role_var'): param_name = op.attr('op_role_var')[0]
param_name = op.attr('op_role_var')[0] if param_name in param_new_grad_name_dict:
index = 0 correct_p_g = [
for i in range(len(params_grads)): param_name, param_new_grad_name_dict[param_name]
if params_grads[i][0].name == param_name: ]
index = i
correct_p_g = [param_name, params_grads[index][1].name]
op._set_attr('op_role_var', correct_p_g) op._set_attr('op_role_var', correct_p_g)
......
此差异已折叠。
...@@ -36,7 +36,7 @@ class ParamAttr(object): ...@@ -36,7 +36,7 @@ class ParamAttr(object):
Note: Note:
``gradient_clip`` of ``ParamAttr`` HAS BEEN DEPRECATED since 2.0. ``gradient_clip`` of ``ParamAttr`` HAS BEEN DEPRECATED since 2.0.
It is recommended to use ``minimize(loss, grad_clip=clip)`` to clip gradient. It is recommended to set ``grad_clip`` in ``optimizer`` to clip gradient.
There are three clipping strategies: :ref:`api_fluid_clip_GradientClipByGlobalNorm` , There are three clipping strategies: :ref:`api_fluid_clip_GradientClipByGlobalNorm` ,
:ref:`api_fluid_clip_GradientClipByNorm` , :ref:`api_fluid_clip_GradientClipByValue` . :ref:`api_fluid_clip_GradientClipByNorm` , :ref:`api_fluid_clip_GradientClipByValue` .
......
...@@ -19,6 +19,7 @@ import unittest ...@@ -19,6 +19,7 @@ import unittest
import paddle.fluid.framework as framework import paddle.fluid.framework as framework
import paddle.fluid.optimizer as optimizer import paddle.fluid.optimizer as optimizer
import paddle.fluid.regularizer as regularizer import paddle.fluid.regularizer as regularizer
import paddle.fluid.clip as clip
import paddle.compat as cpt import paddle.compat as cpt
from paddle.fluid.backward import append_backward from paddle.fluid.backward import append_backward
from paddle.fluid.transpiler.details import program_to_code from paddle.fluid.transpiler.details import program_to_code
...@@ -70,9 +71,9 @@ class TestDGCMomentumOptimizer(unittest.TestCase): ...@@ -70,9 +71,9 @@ class TestDGCMomentumOptimizer(unittest.TestCase):
learning_rate=learning_rate, learning_rate=learning_rate,
momentum=0.2, momentum=0.2,
rampup_begin_step=0, rampup_begin_step=0,
local_grad_clip_norm=1.0,
num_trainers=2, num_trainers=2,
regularization=regularization) regularization=regularization,
grad_clip=clip.GradientClipByNorm(1.0))
if use_recompute: if use_recompute:
dgc_momentum_optimizer = optimizer.RecomputeOptimizer( dgc_momentum_optimizer = optimizer.RecomputeOptimizer(
...@@ -124,6 +125,16 @@ class TestDGCMomentumOptimizer(unittest.TestCase): ...@@ -124,6 +125,16 @@ class TestDGCMomentumOptimizer(unittest.TestCase):
#with open("test_dgc_optimizer_" + name + str(use_recompute) + ".log", "w") as f: #with open("test_dgc_optimizer_" + name + str(use_recompute) + ".log", "w") as f:
# program_to_code(program, fout=f) # program_to_code(program, fout=f)
def test_tpyeError(self):
# the type of DGCMomentumOptimizer(grad_clip=) must be 'GradientClipByNorm'
with self.assertRaises(TypeError):
dgc_momentum_optimizer = self.MockDGCMomentum(
learning_rate=0.01,
momentum=0.2,
rampup_begin_step=0,
num_trainers=2,
grad_clip=clip.GradientClipByGlobalNorm(1.0))
def test_momentum_without_dgc(self): def test_momentum_without_dgc(self):
self.check_dgc_momentum_optimizer( self.check_dgc_momentum_optimizer(
regularization=regularizer.L1Decay(1e-4)) regularization=regularizer.L1Decay(1e-4))
......
...@@ -76,8 +76,8 @@ class TestGradientClip(unittest.TestCase): ...@@ -76,8 +76,8 @@ class TestGradientClip(unittest.TestCase):
startup_program = fluid.Program() startup_program = fluid.Program()
with fluid.program_guard( with fluid.program_guard(
main_program=prog, startup_program=startup_program): main_program=prog, startup_program=startup_program):
image = fluid.data(name='x', shape=[-1, 784], dtype='float32') image = fluid.data(name="a", shape=[-1, 784], dtype='float32')
label = fluid.data(name='y', shape=[-1, 1], dtype='int64') label = fluid.data(name="b", shape=[-1, 1], dtype='int64')
hidden = fluid.layers.fc(input=image, size=32, act='relu') hidden = fluid.layers.fc(input=image, size=32, act='relu')
predict = fluid.layers.fc(input=hidden, size=10, act='softmax') predict = fluid.layers.fc(input=hidden, size=10, act='softmax')
...@@ -112,13 +112,13 @@ class TestGradientClip(unittest.TestCase): ...@@ -112,13 +112,13 @@ class TestGradientClip(unittest.TestCase):
self.check_clip_result(out, out_clip) self.check_clip_result(out, out_clip)
def check_sparse_gradient_clip(self, place): def check_sparse_gradient_clip(self, place):
prog = fluid.framework.Program() prog = fluid.Program()
startup_program = fluid.framework.Program() startup_program = fluid.Program()
with fluid.program_guard( with fluid.program_guard(
main_program=prog, startup_program=startup_program): main_program=prog, startup_program=startup_program):
data = fluid.layers.data( data = fluid.data(
name="words", shape=[1], dtype="int64", lod_level=1) name="words", shape=[-1, 1], dtype="int64", lod_level=1)
label = fluid.layers.data(name="label", shape=[1], dtype="int64") label = fluid.data(name="label", shape=[-1, 1], dtype="int64")
cost = bow_net(data, label, self.word_dict_len) cost = bow_net(data, label, self.word_dict_len)
self.backward_and_optimize(cost) self.backward_and_optimize(cost)
...@@ -172,7 +172,7 @@ class TestGradientClipByGlobalNorm(TestGradientClip): ...@@ -172,7 +172,7 @@ class TestGradientClipByGlobalNorm(TestGradientClip):
self.clip_gradient = func self.clip_gradient = func
self.check_gradient_clip(fluid.CPUPlace()) self.check_gradient_clip(fluid.CPUPlace())
# test whether the ouput is right when use 'minimize(grad_clip)' # test whether the ouput is right when use grad_clip
def test_new_gradient_clip(self): def test_new_gradient_clip(self):
def func(params_grads): def func(params_grads):
clip = fluid.clip.GradientClipByGlobalNorm(clip_norm=self.clip_norm) clip = fluid.clip.GradientClipByGlobalNorm(clip_norm=self.clip_norm)
...@@ -192,9 +192,10 @@ class TestGradientClipByGlobalNorm(TestGradientClip): ...@@ -192,9 +192,10 @@ class TestGradientClipByGlobalNorm(TestGradientClip):
clip = fluid.clip.GradientClipByGlobalNorm( clip = fluid.clip.GradientClipByGlobalNorm(
clip_norm=5.0, need_clip=fileter_func) clip_norm=5.0, need_clip=fileter_func)
fluid.clip.set_gradient_clip(clip) fluid.clip.set_gradient_clip(clip)
sgd_optimizer = fluid.optimizer.SGD(learning_rate=0.01) sgd_optimizer = fluid.optimizer.SGD(learning_rate=0.01,
# if 'set_gradient_clip' and 'minimize(grad_clip)' together, 'set_gradient_clip' will be ineffective grad_clip=clip)
sgd_optimizer.minimize(cost, grad_clip=clip) # if 'set_gradient_clip' and 'optimize(grad_clip)' together, 'set_gradient_clip' will be ineffective
sgd_optimizer.minimize(cost)
# 'set_gradient_clip' must before 'minimize', otherwise, 'set_gradient_clip' will be ineffective # 'set_gradient_clip' must before 'minimize', otherwise, 'set_gradient_clip' will be ineffective
fluid.clip.set_gradient_clip(clip) fluid.clip.set_gradient_clip(clip)
...@@ -232,24 +233,10 @@ class TestGradientClipByGlobalNorm(TestGradientClip): ...@@ -232,24 +233,10 @@ class TestGradientClipByGlobalNorm(TestGradientClip):
clip = fluid.clip.GradientClipByGlobalNorm( clip = fluid.clip.GradientClipByGlobalNorm(
clip_norm=self.clip_norm, need_clip="test") clip_norm=self.clip_norm, need_clip="test")
# the type of minimize(grad_clip=) must be an instance of GradientClipBase's derived class # the type of optimizer(grad_clip=) must be an instance of GradientClipBase's derived class
with self.assertRaises(TypeError):
x = fluid.default_main_program().global_block().create_parameter(
name="x", shape=[2, 3], dtype="float32")
loss = fluid.layers.reduce_mean(x)
sgd_optimizer = fluid.optimizer.SGD(learning_rate=0.1)
sgd_optimizer.minimize(loss, grad_clip="test")
# the type of RecomputeOptimizer.minimize(grad_clip=) must be an instance of GradientClipBase's derived class
with self.assertRaises(TypeError): with self.assertRaises(TypeError):
x = fluid.default_main_program().global_block().create_parameter( sgd_optimizer = fluid.optimizer.SGD(learning_rate=0.1,
name="x", shape=[2, 3], dtype="float32") grad_clip="test")
loss = fluid.layers.reduce_mean(x)
sgd_optimizer = fluid.optimizer.SGD(learning_rate=0.1)
recompute_optimizer = fluid.optimizer.RecomputeOptimizer(
sgd_optimizer)
recompute_optimizer._set_checkpoints([x])
recompute_optimizer.minimize(loss, grad_clip="test")
class TestGradientClipByNorm(TestGradientClip): class TestGradientClipByNorm(TestGradientClip):
...@@ -271,7 +258,7 @@ class TestGradientClipByNorm(TestGradientClip): ...@@ -271,7 +258,7 @@ class TestGradientClipByNorm(TestGradientClip):
a=u, b=v, rtol=1e-5, atol=1e-8), a=u, b=v, rtol=1e-5, atol=1e-8),
"gradient clip by norm has wrong results!") "gradient clip by norm has wrong results!")
# test whether the ouput is right when use 'minimize(grad_clip)' # test whether the ouput is right when use grad_clip
def test_gradient_clip(self): def test_gradient_clip(self):
self.check_gradient_clip(fluid.CPUPlace()) self.check_gradient_clip(fluid.CPUPlace())
...@@ -319,7 +306,7 @@ class TestGradientClipByValue(TestGradientClip): ...@@ -319,7 +306,7 @@ class TestGradientClipByValue(TestGradientClip):
a=u, b=v, rtol=1e-6, atol=1e-8), a=u, b=v, rtol=1e-6, atol=1e-8),
"gradient clip by value has wrong results!") "gradient clip by value has wrong results!")
# test whether the ouput is right when use 'minimize(grad_clip)' # test whether the ouput is right when use grad_clip
def test_gradient_clip(self): def test_gradient_clip(self):
self.check_gradient_clip(fluid.CPUPlace()) self.check_gradient_clip(fluid.CPUPlace())
...@@ -357,7 +344,9 @@ class TestDygraphGradientClip(unittest.TestCase): ...@@ -357,7 +344,9 @@ class TestDygraphGradientClip(unittest.TestCase):
loss = fluid.layers.reduce_mean(out) loss = fluid.layers.reduce_mean(out)
loss.backward() loss.backward()
sgd_optimizer = fluid.optimizer.SGD( sgd_optimizer = fluid.optimizer.SGD(
learning_rate=0.0, parameter_list=linear.parameters()) learning_rate=0.0,
parameter_list=linear.parameters(),
grad_clip=fluid.clip.GradientClipByGlobalNorm(0.1))
self.check_clip_result(loss, sgd_optimizer) self.check_clip_result(loss, sgd_optimizer)
def check_clip_result(self, loss, optimizer): def check_clip_result(self, loss, optimizer):
...@@ -384,7 +373,7 @@ class TestDygraphGradientClipByGlobalNorm(TestDygraphGradientClip): ...@@ -384,7 +373,7 @@ class TestDygraphGradientClipByGlobalNorm(TestDygraphGradientClip):
np.array([3, 4]).astype("float32"), name="y") np.array([3, 4]).astype("float32"), name="y")
assert len(self.clip1([(x, x), (x, y), (x, None)])) == 2 assert len(self.clip1([(x, x), (x, y), (x, None)])) == 2
# get params and grads from network # get params and grads from network
opt, params_grads = optimizer.minimize(loss, grad_clip=self.clip2) opt, params_grads = optimizer.minimize(loss)
_, grads = zip(*params_grads) _, grads = zip(*params_grads)
params_grads = self.clip2(params_grads) params_grads = self.clip2(params_grads)
_, grads_clip = zip(*params_grads) _, grads_clip = zip(*params_grads)
...@@ -426,7 +415,7 @@ class TestDygraphGradientClipByNorm(TestDygraphGradientClip): ...@@ -426,7 +415,7 @@ class TestDygraphGradientClipByNorm(TestDygraphGradientClip):
assert len(self.clip([(x, None)])) == 0 assert len(self.clip([(x, None)])) == 0
# get params and grads from network # get params and grads from network
self.clip([(fluid.dygraph.to_variable(np.array([2, 3])), None)]) self.clip([(fluid.dygraph.to_variable(np.array([2, 3])), None)])
params_grads = optimizer.backward(loss) opt, params_grads = optimizer.minimize(loss)
_, grads = zip(*params_grads) _, grads = zip(*params_grads)
params_grads = self.clip(params_grads) params_grads = self.clip(params_grads)
_, grads_clip = zip(*params_grads) _, grads_clip = zip(*params_grads)
...@@ -460,7 +449,7 @@ class TestDygraphGradientClipByValue(TestDygraphGradientClip): ...@@ -460,7 +449,7 @@ class TestDygraphGradientClipByValue(TestDygraphGradientClip):
x = fluid.dygraph.to_variable(np.array([2, 3]).astype("float32")) x = fluid.dygraph.to_variable(np.array([2, 3]).astype("float32"))
assert len(self.clip([(x, None)])) == 0 assert len(self.clip([(x, None)])) == 0
# get params and grads from network # get params and grads from network
params_grads = optimizer.backward(loss) opt, params_grads = optimizer.minimize(loss)
_, grads = zip(*params_grads) _, grads = zip(*params_grads)
params_grads = self.clip(params_grads) params_grads = self.clip(params_grads)
_, grads_clip = zip(*params_grads) _, grads_clip = zip(*params_grads)
......
...@@ -329,9 +329,9 @@ class TestImperativeAutoPrune(unittest.TestCase): ...@@ -329,9 +329,9 @@ class TestImperativeAutoPrune(unittest.TestCase):
place = fluid.CPUPlace() place = fluid.CPUPlace()
with fluid.dygraph.guard(place): with fluid.dygraph.guard(place):
model = MyLayer(size, vocab_size, size) model = MyLayer(size, vocab_size, size)
optimizer = fluid.optimizer.AdamOptimizer(
0.001, parameter_list=model.parameters())
grad_clip = fluid.clip.GradientClipByGlobalNorm(0.001) grad_clip = fluid.clip.GradientClipByGlobalNorm(0.001)
optimizer = fluid.optimizer.AdamOptimizer(
0.001, parameter_list=model.parameters(), grad_clip=grad_clip)
indices = fluid.dygraph.to_variable(indices) indices = fluid.dygraph.to_variable(indices)
embed = fluid.dygraph.to_variable(embed) embed = fluid.dygraph.to_variable(embed)
...@@ -339,7 +339,7 @@ class TestImperativeAutoPrune(unittest.TestCase): ...@@ -339,7 +339,7 @@ class TestImperativeAutoPrune(unittest.TestCase):
loss = model.embed_linear0(indices) loss = model.embed_linear0(indices)
loss.backward() loss.backward()
_, params_grads = optimizer.minimize(loss, grad_clip=grad_clip) _, params_grads = optimizer.minimize(loss)
for items in params_grads: for items in params_grads:
assert items[0].name is not model.embed1.weight.name assert items[0].name is not model.embed1.weight.name
assert items[0].name is not model.linear_1.weight.name assert items[0].name is not model.linear_1.weight.name
...@@ -348,9 +348,9 @@ class TestImperativeAutoPrune(unittest.TestCase): ...@@ -348,9 +348,9 @@ class TestImperativeAutoPrune(unittest.TestCase):
with fluid.dygraph.guard(place): with fluid.dygraph.guard(place):
model = MyLayer2(size, vocab_size, size) model = MyLayer2(size, vocab_size, size)
optimizer = fluid.optimizer.AdamOptimizer(
0.001, parameter_list=model.parameters())
grad_clip = fluid.clip.GradientClipByGlobalNorm(0.001) grad_clip = fluid.clip.GradientClipByGlobalNorm(0.001)
optimizer = fluid.optimizer.AdamOptimizer(
0.001, parameter_list=model.parameters(), grad_clip=grad_clip)
indices = fluid.dygraph.to_variable(indices) indices = fluid.dygraph.to_variable(indices)
emebd = fluid.dygraph.to_variable(embed) emebd = fluid.dygraph.to_variable(embed)
...@@ -358,7 +358,7 @@ class TestImperativeAutoPrune(unittest.TestCase): ...@@ -358,7 +358,7 @@ class TestImperativeAutoPrune(unittest.TestCase):
loss = model.embed_linear0(indices) loss = model.embed_linear0(indices)
loss.backward() loss.backward()
optimizer.minimize(loss, grad_clip=grad_clip) optimizer.minimize(loss)
for items in params_grads: for items in params_grads:
assert items[0].name is not model.embed1.weight.name assert items[0].name is not model.embed1.weight.name
assert items[0].name is not model.linear_1.weight.name assert items[0].name is not model.linear_1.weight.name
......
...@@ -58,14 +58,15 @@ class TestSimpleNet(unittest.TestCase): ...@@ -58,14 +58,15 @@ class TestSimpleNet(unittest.TestCase):
simplenet = SimpleNet(20, 32, dtype) simplenet = SimpleNet(20, 32, dtype)
adam = SGDOptimizer( adam = SGDOptimizer(
learning_rate=0.001, learning_rate=0.001,
parameter_list=simplenet.parameters()) parameter_list=simplenet.parameters(
)) # grad_clip=grad_clip
input_emb, emb = simplenet(input) input_emb, emb = simplenet(input)
self.assertTrue(emb.weight.gradient() is None) self.assertTrue(emb.weight.gradient() is None)
self.assertTrue(input_emb.gradient() is None) self.assertTrue(input_emb.gradient() is None)
input_emb.backward(backward_strategy) input_emb.backward(backward_strategy)
adam.minimize(input_emb) # grad_clip=grad_clip adam.minimize(input_emb)
self.assertTrue(emb.weight.gradient() is not None) self.assertTrue(emb.weight.gradient() is not None)
emb.clear_gradients() emb.clear_gradients()
...@@ -92,14 +93,15 @@ class TestSimpleNet(unittest.TestCase): ...@@ -92,14 +93,15 @@ class TestSimpleNet(unittest.TestCase):
simplenet = SimpleNet(20, 32, "float32") simplenet = SimpleNet(20, 32, "float32")
adam = SGDOptimizer( adam = SGDOptimizer(
learning_rate=0.001, learning_rate=0.001,
parameter_list=simplenet.parameters()) parameter_list=simplenet.parameters(),
grad_clip=grad_clip)
input_emb, emb = simplenet(input) input_emb, emb = simplenet(input)
self.assertTrue(emb.weight.gradient() is None) self.assertTrue(emb.weight.gradient() is None)
self.assertTrue(input_emb.gradient() is None) self.assertTrue(input_emb.gradient() is None)
input_emb.backward(backward_strategy) input_emb.backward(backward_strategy)
adam.minimize(input_emb, grad_clip=grad_clip) adam.minimize(input_emb)
self.assertTrue(emb.weight.gradient() is not None) self.assertTrue(emb.weight.gradient() is not None)
emb.clear_gradients() emb.clear_gradients()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册