未验证 提交 4dccb584 编写于 作者: Y yuyang18

Hide clip APIs

上级 0c8f69c3
...@@ -31,7 +31,7 @@ class BaseErrorClipAttr(object): ...@@ -31,7 +31,7 @@ class BaseErrorClipAttr(object):
def __str__(self): def __str__(self):
raise NotImplementedError() raise NotImplementedError()
def append_clip_op(self, block, grad_name): def _append_clip_op(self, block, grad_name):
raise NotImplementedError() raise NotImplementedError()
...@@ -67,7 +67,7 @@ class ErrorClipByValue(BaseErrorClipAttr): ...@@ -67,7 +67,7 @@ class ErrorClipByValue(BaseErrorClipAttr):
def __str__(self): def __str__(self):
return "ByValue, min=%f, max=%f" % (self.min, self.max) return "ByValue, min=%f, max=%f" % (self.min, self.max)
def append_clip_op(self, block, grad_name): def _append_clip_op(self, block, grad_name):
clip_op_desc = block.desc.append_op() clip_op_desc = block.desc.append_op()
clip_op_desc.set_type("clip") clip_op_desc.set_type("clip")
clip_op_desc.set_input("X", [grad_name]) clip_op_desc.set_input("X", [grad_name])
...@@ -90,17 +90,17 @@ def error_clip_callback(block, context): ...@@ -90,17 +90,17 @@ def error_clip_callback(block, context):
"Variable's error_clip should be an instance of BaseErrorClipAttr or None." "Variable's error_clip should be an instance of BaseErrorClipAttr or None."
) )
if error_clip is not None: if error_clip is not None:
error_clip.append_clip_op(block, grad_n) error_clip._append_clip_op(block, grad_n)
class BaseGradientClipAttr(object): class BaseGradientClipAttr(object):
def __str__(self): def __str__(self):
raise NotImplementedError() raise NotImplementedError()
def process_context(self, context, param, grad): def _process_context(self, context, param, grad):
raise NotImplementedError() raise NotImplementedError()
def create_operators(self, param, grad): def _create_operators(self, param, grad):
raise NotImplementedError() raise NotImplementedError()
...@@ -108,10 +108,10 @@ class NullGradientClipAttr(BaseGradientClipAttr): ...@@ -108,10 +108,10 @@ class NullGradientClipAttr(BaseGradientClipAttr):
def __str__(self): def __str__(self):
return "Null" return "Null"
def process_context(self, context, param, grad): def _process_context(self, context, param, grad):
pass pass
def create_operators(self, param, grad): def _create_operators(self, param, grad):
return param, grad return param, grad
...@@ -153,10 +153,10 @@ class GradientClipByValue(BaseGradientClipAttr): ...@@ -153,10 +153,10 @@ class GradientClipByValue(BaseGradientClipAttr):
def __str__(self): def __str__(self):
return "ByValue, min=%f, max=%f" % (self.min, self.max) return "ByValue, min=%f, max=%f" % (self.min, self.max)
def process_context(self, context, param, grad): def _process_context(self, context, param, grad):
pass pass
def create_operators(self, param, grad): def _create_operators(self, param, grad):
new_grad = layers.clip(x=grad, min=self.min, max=self.max) new_grad = layers.clip(x=grad, min=self.min, max=self.max)
return param, new_grad return param, new_grad
...@@ -199,10 +199,10 @@ class GradientClipByNorm(BaseGradientClipAttr): ...@@ -199,10 +199,10 @@ class GradientClipByNorm(BaseGradientClipAttr):
def __str__(self): def __str__(self):
return "ByNorm, clip_norm=%f" % self.clip_norm return "ByNorm, clip_norm=%f" % self.clip_norm
def process_context(self, context, param, grad): def _process_context(self, context, param, grad):
pass pass
def create_operators(self, param, grad): def _create_operators(self, param, grad):
new_grad = layers.clip_by_norm(x=grad, max_norm=self.clip_norm) new_grad = layers.clip_by_norm(x=grad, max_norm=self.clip_norm)
return param, new_grad return param, new_grad
...@@ -257,7 +257,7 @@ class GradientClipByGlobalNorm(BaseGradientClipAttr): ...@@ -257,7 +257,7 @@ class GradientClipByGlobalNorm(BaseGradientClipAttr):
return "ByGlobalNorm, group_name=%s, clip_norm=%f" % (self.group_name, return "ByGlobalNorm, group_name=%s, clip_norm=%f" % (self.group_name,
self.clip_norm) self.clip_norm)
def process_context(self, context, param, grad): def _process_context(self, context, param, grad):
if self.group_name not in context: if self.group_name not in context:
context[self.group_name] = [] context[self.group_name] = []
context[self.group_name + "_clip_value"] = self.clip_norm context[self.group_name + "_clip_value"] = self.clip_norm
...@@ -274,7 +274,7 @@ class GradientClipByGlobalNorm(BaseGradientClipAttr): ...@@ -274,7 +274,7 @@ class GradientClipByGlobalNorm(BaseGradientClipAttr):
self.context = context self.context = context
def create_operators(self, param, grad): def _create_operators(self, param, grad):
group_scale_name = self.group_name + "_scale" group_scale_name = self.group_name + "_scale"
if group_scale_name not in self.context: if group_scale_name not in self.context:
group_norm_var = layers.sums(input=self.context[self.group_name]) group_norm_var = layers.sums(input=self.context[self.group_name])
...@@ -336,12 +336,12 @@ def append_gradient_clip_ops(param_grad): ...@@ -336,12 +336,12 @@ def append_gradient_clip_ops(param_grad):
"clip attribute should be an instance of BaseGradientClipAttr" "clip attribute should be an instance of BaseGradientClipAttr"
) )
clip_attr.process_context(context=context, param=p, grad=g) clip_attr._process_context(context=context, param=p, grad=g)
res = [] res = []
for p, g in param_grad: for p, g in param_grad:
with p.block.program.optimized_guard(p): with p.block.program.optimized_guard(p):
res.append(clip_attr.create_operators(param=p, grad=g)) res.append(clip_attr._create_operators(param=p, grad=g))
return res return res
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册