提交 1dac173b 编写于 作者: F fengjiayi

add API for clip_by_global_norm

上级 6ebfade4
......@@ -13,7 +13,7 @@
# limitations under the License.
import functools
import layers
from framework import Variable
import framework
from . import core
__all__ = [
......@@ -128,8 +128,8 @@ class GradientClipByGlobalNorm(BaseGradientClipAttr):
@classmethod
def check_init(cls):
if not (isinstance(cls.global_norm_var, Variable) and
isinstance(cls.clip_norm_var, Variable)):
if not (isinstance(cls.global_norm_var, framework.Variable) and
isinstance(cls.clip_norm_var, framework.Variable)):
raise ValueError(
"Class 'GradientClipByGlobalNorm' has not been properly initialized. \
Please call GradientClipByGlobalNorm.init() first.")
......@@ -158,6 +158,23 @@ class GradientClipByGlobalNorm(BaseGradientClipAttr):
return param, new_grad
def gradient_clip_by_global_norm(clip_norm, param_list=None, program=None):
if program is None:
program = framework.default_main_program()
if param_list is None:
param_list = program.block(0).all_parameters()
if all(isinstance(elem, basestring) for elem in param_list):
param_list = [program.block(0).var(elem) for elem in param_list]
if not all(isinstance(elem, framework.Parameter) for elem in param_list):
raise TypeError(
"'param_list' should be a list of Parameter or basestring(parameter's name)."
)
GradientClipByGlobalNorm.init(clip_norm)
for param in param_list:
param.gradient_clip_attr = GradientClipByGlobalNorm()
def append_gradient_clip_ops(param_grad):
context = dict()
create_op_callbacks = []
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册