From 89c591f37cf50edbf32ef418696e856fb506f83d Mon Sep 17 00:00:00 2001 From: fengjiayi Date: Mon, 22 Jan 2018 16:54:24 +0800 Subject: [PATCH] update grad clip api --- python/paddle/v2/fluid/clip.py | 18 +++++++++++------- .../v2/fluid/tests/test_gradient_clip.py | 3 ++- 2 files changed, 13 insertions(+), 8 deletions(-) diff --git a/python/paddle/v2/fluid/clip.py b/python/paddle/v2/fluid/clip.py index 386df9823..3028029e6 100644 --- a/python/paddle/v2/fluid/clip.py +++ b/python/paddle/v2/fluid/clip.py @@ -12,14 +12,18 @@ # See the License for the specific language governing permissions and # limitations under the License. +import copy + import functools import layers import framework from . import core __all__ = [ - 'GradientClipByValue', 'ErrorClipByValue', + 'GradientClipByValue', + 'GradientClipByNorm', + 'GradientClipByGlobalNorm', 'append_gradient_clip_ops', 'error_clip_callback', ] @@ -155,10 +159,11 @@ class GradientClipByGlobalNorm(BaseGradientClipAttr): return param, new_grad -def gradient_clip_by_global_norm(clip_norm, - param_list=None, - group_name="default_group", - program=None): +def set_gradient_clip(clip, param_list=None, program=None): + if not isinstance(clip, BaseGradientClipAttr): + raise TypeError( + "'clip' should be an instance of BaseGradientClipAttr's derived class" + ) if program is None: program = framework.default_main_program() if param_list is None: @@ -171,8 +176,7 @@ def gradient_clip_by_global_norm(clip_norm, ) for param in param_list: - param.gradient_clip_attr = GradientClipByGlobalNorm(clip_norm, - group_name) + param.gradient_clip_attr = copy.deepcopy(clip) def append_gradient_clip_ops(param_grad): diff --git a/python/paddle/v2/fluid/tests/test_gradient_clip.py b/python/paddle/v2/fluid/tests/test_gradient_clip.py index 4e6e6a1ef..9337791c2 100644 --- a/python/paddle/v2/fluid/tests/test_gradient_clip.py +++ b/python/paddle/v2/fluid/tests/test_gradient_clip.py @@ -40,7 +40,8 @@ p_g = fluid.backward.append_backward(loss=avg_cost) p_g_clip = fluid.backward.append_backward(loss=avg_cost_clip) with fluid.program_guard(main_program=prog_clip): - fluid.clip.gradient_clip_by_global_norm(clip_norm=CLIP) + fluid.clip.set_gradient_clip( + fluid.clip.GradientClipByGlobalNorm(clip_norm=CLIP)) p_g_clip = fluid.clip.append_gradient_clip_ops(p_g_clip) grad_list = [elem[1] for elem in p_g] -- GitLab