提交 89c591f3 编写于 作者: F fengjiayi

update grad clip api

上级 c80af6ff
......@@ -12,14 +12,18 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import functools
import layers
import framework
from . import core
__all__ = [
'GradientClipByValue',
'ErrorClipByValue',
'GradientClipByValue',
'GradientClipByNorm',
'GradientClipByGlobalNorm',
'append_gradient_clip_ops',
'error_clip_callback',
]
......@@ -155,10 +159,11 @@ class GradientClipByGlobalNorm(BaseGradientClipAttr):
return param, new_grad
def gradient_clip_by_global_norm(clip_norm,
param_list=None,
group_name="default_group",
program=None):
def set_gradient_clip(clip, param_list=None, program=None):
if not isinstance(clip, BaseGradientClipAttr):
raise TypeError(
"'clip' should be an instance of BaseGradientClipAttr's derived class"
)
if program is None:
program = framework.default_main_program()
if param_list is None:
......@@ -171,8 +176,7 @@ def gradient_clip_by_global_norm(clip_norm,
)
for param in param_list:
param.gradient_clip_attr = GradientClipByGlobalNorm(clip_norm,
group_name)
param.gradient_clip_attr = copy.deepcopy(clip)
def append_gradient_clip_ops(param_grad):
......
......@@ -40,7 +40,8 @@ p_g = fluid.backward.append_backward(loss=avg_cost)
p_g_clip = fluid.backward.append_backward(loss=avg_cost_clip)
with fluid.program_guard(main_program=prog_clip):
fluid.clip.gradient_clip_by_global_norm(clip_norm=CLIP)
fluid.clip.set_gradient_clip(
fluid.clip.GradientClipByGlobalNorm(clip_norm=CLIP))
p_g_clip = fluid.clip.append_gradient_clip_ops(p_g_clip)
grad_list = [elem[1] for elem in p_g]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册