diff --git a/python/paddle/v2/fluid/clip.py b/python/paddle/v2/fluid/clip.py index 386df9823de9119287abf87569eab0b283ecc802..3028029e60fde2f481b4348ab1b0a4980ebb2b60 100644 --- a/python/paddle/v2/fluid/clip.py +++ b/python/paddle/v2/fluid/clip.py @@ -12,14 +12,18 @@ # See the License for the specific language governing permissions and # limitations under the License. +import copy + import functools import layers import framework from . import core __all__ = [ - 'GradientClipByValue', 'ErrorClipByValue', + 'GradientClipByValue', + 'GradientClipByNorm', + 'GradientClipByGlobalNorm', 'append_gradient_clip_ops', 'error_clip_callback', ] @@ -155,10 +159,11 @@ class GradientClipByGlobalNorm(BaseGradientClipAttr): return param, new_grad -def gradient_clip_by_global_norm(clip_norm, - param_list=None, - group_name="default_group", - program=None): +def set_gradient_clip(clip, param_list=None, program=None): + if not isinstance(clip, BaseGradientClipAttr): + raise TypeError( + "'clip' should be an instance of BaseGradientClipAttr's derived class" + ) if program is None: program = framework.default_main_program() if param_list is None: @@ -171,8 +176,7 @@ def gradient_clip_by_global_norm(clip_norm, ) for param in param_list: - param.gradient_clip_attr = GradientClipByGlobalNorm(clip_norm, - group_name) + param.gradient_clip_attr = copy.deepcopy(clip) def append_gradient_clip_ops(param_grad): diff --git a/python/paddle/v2/fluid/tests/test_gradient_clip.py b/python/paddle/v2/fluid/tests/test_gradient_clip.py index 4e6e6a1ef6961d8f087dfc1ac5a4c4a8ad90032e..9337791c21183fd7c2e5d6b9d47c99d762c93d46 100644 --- a/python/paddle/v2/fluid/tests/test_gradient_clip.py +++ b/python/paddle/v2/fluid/tests/test_gradient_clip.py @@ -40,7 +40,8 @@ p_g = fluid.backward.append_backward(loss=avg_cost) p_g_clip = fluid.backward.append_backward(loss=avg_cost_clip) with fluid.program_guard(main_program=prog_clip): - fluid.clip.gradient_clip_by_global_norm(clip_norm=CLIP) + fluid.clip.set_gradient_clip( + fluid.clip.GradientClipByGlobalNorm(clip_norm=CLIP)) p_g_clip = fluid.clip.append_gradient_clip_ops(p_g_clip) grad_list = [elem[1] for elem in p_g]