Created by: zhouwei25
当前问题
当前的梯度裁剪API set_gradient_clip
存在一些问题,因为 set_gradient_clip
是一个隐式地API,其本质是设置一个参数的属性,然后在minimize中
操作该属性来生效,因而对位置的要求很高。
- 问题1:set_gradient_clip必须位于minimize之前,否则不会生效。
- 问题2:set_gradient_clip必须位于组网之后,否则此时参数未更新,也不会生效。
- 问题3:set_gradient_clip会重置掉ParamAttr所设的值,导致ParamAttr的grad_clip_attr失效。 以上问题对用户均不会有提醒。
新方案的要点为:
- 动态图与静态图统一成一套接口,均在
minimize
中传入grad_clip
参数进行裁剪; - 部分参数裁剪的功能:给
GradientClipByGlobalNorm
实例对象初始化时,给need_clip
参数传入一个有过滤功能的函数,其返回True或False; - 兼容设计:兼容老接口
set_gradient_clip
,新接口优先级高于旧接口,同时添加了set_gradient_clip
错误使用时的报错信息;
全部参数的裁剪
提供三个class,分别对应三种梯度裁剪:
fluid.clip.GradientClipByGlobalNorm(clip_norm,need_clip=None)
fluid.clip.GradientClipByNorm(clip_norm,need_clip=None)
fluid.clip.GradientClipByValue(min,max,need_clip=None)
在minimize中传入一个clip对象,会根据clip的class类型来裁剪全部的参数
loss.backward() # 动态图需要这一行,静态图不需要
clip = fluid.clip.GradientClipByGlobalNorm(clip_norm=1.0)
sgd_optimize.minimize(loss, grad_clip=clip)
部分参数的裁剪
只裁剪一部分参数的需求。该需求一般用的较少,如果需要时,可在定义clip对象时,传入一个function给need_clip参数,该function返回True或False
# 过滤函数,返回True or False
def func(param)
return param.name == "linear_0.w_0"(True表示需要裁剪,False表示不需要裁剪)
loss.backward() # 动态图需要这一行,静态图不需要
clip = fluid.clip.GradientClipByGlobalNorm(clip_norm=1.0, need_clip=func)
sgd_optimize.minimize(loss, grad_clip=clip)
新旧兼容
- 老接口
set_gradient_clip
继续保留,但新接口minimize(grad_clip=clip)
优先级高于老接口(set_gradient_clip
),一旦发现两套接口重复使用,则会屏蔽set_gradient_clip
的所有操作并打出warning - 老接口
set_gradient_clip
单独使用时,须位于minimize
之前,如果顺序错误,也会打出warning
接口删除
仅删除:fluid.ParamAttr(仅gradient_clip属性),不再对用户暴露