未验证 提交 4dccb584 编写于 作者: Y yuyang18

Hide clip APIs

上级 0c8f69c3
......@@ -31,7 +31,7 @@ class BaseErrorClipAttr(object):
def __str__(self):
raise NotImplementedError()
def append_clip_op(self, block, grad_name):
def _append_clip_op(self, block, grad_name):
raise NotImplementedError()
......@@ -67,7 +67,7 @@ class ErrorClipByValue(BaseErrorClipAttr):
def __str__(self):
return "ByValue, min=%f, max=%f" % (self.min, self.max)
def append_clip_op(self, block, grad_name):
def _append_clip_op(self, block, grad_name):
clip_op_desc = block.desc.append_op()
clip_op_desc.set_type("clip")
clip_op_desc.set_input("X", [grad_name])
......@@ -90,17 +90,17 @@ def error_clip_callback(block, context):
"Variable's error_clip should be an instance of BaseErrorClipAttr or None."
)
if error_clip is not None:
error_clip.append_clip_op(block, grad_n)
error_clip._append_clip_op(block, grad_n)
class BaseGradientClipAttr(object):
def __str__(self):
raise NotImplementedError()
def process_context(self, context, param, grad):
def _process_context(self, context, param, grad):
raise NotImplementedError()
def create_operators(self, param, grad):
def _create_operators(self, param, grad):
raise NotImplementedError()
......@@ -108,10 +108,10 @@ class NullGradientClipAttr(BaseGradientClipAttr):
def __str__(self):
return "Null"
def process_context(self, context, param, grad):
def _process_context(self, context, param, grad):
pass
def create_operators(self, param, grad):
def _create_operators(self, param, grad):
return param, grad
......@@ -153,10 +153,10 @@ class GradientClipByValue(BaseGradientClipAttr):
def __str__(self):
return "ByValue, min=%f, max=%f" % (self.min, self.max)
def process_context(self, context, param, grad):
def _process_context(self, context, param, grad):
pass
def create_operators(self, param, grad):
def _create_operators(self, param, grad):
new_grad = layers.clip(x=grad, min=self.min, max=self.max)
return param, new_grad
......@@ -199,10 +199,10 @@ class GradientClipByNorm(BaseGradientClipAttr):
def __str__(self):
return "ByNorm, clip_norm=%f" % self.clip_norm
def process_context(self, context, param, grad):
def _process_context(self, context, param, grad):
pass
def create_operators(self, param, grad):
def _create_operators(self, param, grad):
new_grad = layers.clip_by_norm(x=grad, max_norm=self.clip_norm)
return param, new_grad
......@@ -257,7 +257,7 @@ class GradientClipByGlobalNorm(BaseGradientClipAttr):
return "ByGlobalNorm, group_name=%s, clip_norm=%f" % (self.group_name,
self.clip_norm)
def process_context(self, context, param, grad):
def _process_context(self, context, param, grad):
if self.group_name not in context:
context[self.group_name] = []
context[self.group_name + "_clip_value"] = self.clip_norm
......@@ -274,7 +274,7 @@ class GradientClipByGlobalNorm(BaseGradientClipAttr):
self.context = context
def create_operators(self, param, grad):
def _create_operators(self, param, grad):
group_scale_name = self.group_name + "_scale"
if group_scale_name not in self.context:
group_norm_var = layers.sums(input=self.context[self.group_name])
......@@ -336,12 +336,12 @@ def append_gradient_clip_ops(param_grad):
"clip attribute should be an instance of BaseGradientClipAttr"
)
clip_attr.process_context(context=context, param=p, grad=g)
clip_attr._process_context(context=context, param=p, grad=g)
res = []
for p, g in param_grad:
with p.block.program.optimized_guard(p):
res.append(clip_attr.create_operators(param=p, grad=g))
res.append(clip_attr._create_operators(param=p, grad=g))
return res
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册