未验证 提交 a173fa75 编写于 作者: F fengjiayi 提交者: GitHub

Merge pull request #7732 from JiayiFeng/refine_grad_clip_api

update gradient clip api
...@@ -12,14 +12,18 @@ ...@@ -12,14 +12,18 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import copy
import functools import functools
import layers import layers
import framework import framework
from . import core from . import core
__all__ = [ __all__ = [
'GradientClipByValue',
'ErrorClipByValue', 'ErrorClipByValue',
'GradientClipByValue',
'GradientClipByNorm',
'GradientClipByGlobalNorm',
'append_gradient_clip_ops', 'append_gradient_clip_ops',
'error_clip_callback', 'error_clip_callback',
] ]
...@@ -155,10 +159,11 @@ class GradientClipByGlobalNorm(BaseGradientClipAttr): ...@@ -155,10 +159,11 @@ class GradientClipByGlobalNorm(BaseGradientClipAttr):
return param, new_grad return param, new_grad
def gradient_clip_by_global_norm(clip_norm, def set_gradient_clip(clip, param_list=None, program=None):
param_list=None, if not isinstance(clip, BaseGradientClipAttr):
group_name="default_group", raise TypeError(
program=None): "'clip' should be an instance of BaseGradientClipAttr's derived class"
)
if program is None: if program is None:
program = framework.default_main_program() program = framework.default_main_program()
if param_list is None: if param_list is None:
...@@ -171,8 +176,7 @@ def gradient_clip_by_global_norm(clip_norm, ...@@ -171,8 +176,7 @@ def gradient_clip_by_global_norm(clip_norm,
) )
for param in param_list: for param in param_list:
param.gradient_clip_attr = GradientClipByGlobalNorm(clip_norm, param.gradient_clip_attr = copy.deepcopy(clip)
group_name)
def append_gradient_clip_ops(param_grad): def append_gradient_clip_ops(param_grad):
......
...@@ -40,7 +40,8 @@ p_g = fluid.backward.append_backward(loss=avg_cost) ...@@ -40,7 +40,8 @@ p_g = fluid.backward.append_backward(loss=avg_cost)
p_g_clip = fluid.backward.append_backward(loss=avg_cost_clip) p_g_clip = fluid.backward.append_backward(loss=avg_cost_clip)
with fluid.program_guard(main_program=prog_clip): with fluid.program_guard(main_program=prog_clip):
fluid.clip.gradient_clip_by_global_norm(clip_norm=CLIP) fluid.clip.set_gradient_clip(
fluid.clip.GradientClipByGlobalNorm(clip_norm=CLIP))
p_g_clip = fluid.clip.append_gradient_clip_ops(p_g_clip) p_g_clip = fluid.clip.append_gradient_clip_ops(p_g_clip)
grad_list = [elem[1] for elem in p_g] grad_list = [elem[1] for elem in p_g]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册