提交 5a3b06be 编写于 作者: F fengjiayi

update debug string

上级 3646be7b
...@@ -30,6 +30,9 @@ __all__ = [ ...@@ -30,6 +30,9 @@ __all__ = [
class BaseErrorClipAttr(object): class BaseErrorClipAttr(object):
def __str__(self):
raise NotImplementedError()
def append_clip_op(self, block, grad_name): def append_clip_op(self, block, grad_name):
raise NotImplementedError() raise NotImplementedError()
...@@ -44,6 +47,9 @@ class ErrorClipByValue(BaseErrorClipAttr): ...@@ -44,6 +47,9 @@ class ErrorClipByValue(BaseErrorClipAttr):
self.max = max self.max = max
self.min = min self.min = min
def __str__(self):
return "ByValue, min=%f, max=%f" % (self.min, self.max)
def append_clip_op(self, block, grad_name): def append_clip_op(self, block, grad_name):
clip_op_desc = block.desc.append_op() clip_op_desc = block.desc.append_op()
clip_op_desc.set_type("clip") clip_op_desc.set_type("clip")
...@@ -71,6 +77,9 @@ def error_clip_callback(block, context): ...@@ -71,6 +77,9 @@ def error_clip_callback(block, context):
class BaseGradientClipAttr(object): class BaseGradientClipAttr(object):
def __str__(self):
raise NotImplementedError()
def process_context(self, context, param, grad): def process_context(self, context, param, grad):
raise NotImplementedError() raise NotImplementedError()
...@@ -79,6 +88,9 @@ class BaseGradientClipAttr(object): ...@@ -79,6 +88,9 @@ class BaseGradientClipAttr(object):
class NullGradientClipAttr(BaseGradientClipAttr): class NullGradientClipAttr(BaseGradientClipAttr):
def __str__(self):
return "Null"
def process_context(self, context, param, grad): def process_context(self, context, param, grad):
pass pass
...@@ -96,6 +108,9 @@ class GradientClipByValue(BaseGradientClipAttr): ...@@ -96,6 +108,9 @@ class GradientClipByValue(BaseGradientClipAttr):
self.max = max self.max = max
self.min = min self.min = min
def __str__(self):
return "ByValue, min=%f, max=%f" % (self.min, self.max)
def process_context(self, context, param, grad): def process_context(self, context, param, grad):
pass pass
...@@ -108,6 +123,9 @@ class GradientClipByNorm(BaseGradientClipAttr): ...@@ -108,6 +123,9 @@ class GradientClipByNorm(BaseGradientClipAttr):
def __init__(self, clip_norm): def __init__(self, clip_norm):
self.clip_norm = clip_norm self.clip_norm = clip_norm
def __str__(self):
return "ByNorm, clip_norm=%f" % self.clip_norm
def process_context(self, context, param, grad): def process_context(self, context, param, grad):
pass pass
...@@ -124,6 +142,10 @@ class GradientClipByGlobalNorm(BaseGradientClipAttr): ...@@ -124,6 +142,10 @@ class GradientClipByGlobalNorm(BaseGradientClipAttr):
self.clip_norm = clip_norm self.clip_norm = clip_norm
self.group_name = group_name self.group_name = group_name
def __str__(self):
return "ByGlobalNorm, group_name=%s, clip_norm=%f" % (self.group_name,
self.clip_norm)
def process_context(self, context, param, grad): def process_context(self, context, param, grad):
if self.group_name not in context: if self.group_name not in context:
context[self.group_name] = [] context[self.group_name] = []
...@@ -199,3 +221,5 @@ def append_gradient_clip_ops(param_grad): ...@@ -199,3 +221,5 @@ def append_gradient_clip_ops(param_grad):
ClipByValue = GradientClipByValue ClipByValue = GradientClipByValue
ClipByNorm = GradientClipByNorm
ClipByGlobalNorm = GradientClipByGlobalNorm
...@@ -629,10 +629,34 @@ class Block(object): ...@@ -629,10 +629,34 @@ class Block(object):
def __str__(self): def __str__(self):
return self.to_string(True) return self.to_string(True)
def to_string(self, throw_on_error): def to_string(self, throw_on_error, with_details=False):
"""
To debug string.
Args:
throw_on_error(bool): raise exception when self is not initialized
when throw_on_error is True
with_details(bool): more details about paramters(e.g. trainable, optimize_attr, ...) will be printed when with_details is True
Returns(str): The debug string.
"""
assert isinstance(throw_on_error, bool) and isinstance(with_details,
bool)
if with_details:
res_str = "blocks {\n idx: %d\n parent_idx: %d" % (
self.idx, self.parent_idx)
for var in self.vars.itervalues():
res_str += "\n vars {\n %s }" % var.to_string(
throw_on_error).replace("\n", "\n ")
for op in self.ops:
res_str += "\n ops {\n %s }" % op.to_string(
throw_on_error).replace("\n", "\n ")
res_str += "\n}"
else:
protostr = self.desc.serialize_to_string() protostr = self.desc.serialize_to_string()
proto = framework_pb2.BlockDesc.FromString(str(protostr)) proto = framework_pb2.BlockDesc.FromString(str(protostr))
return _debug_string_(proto, throw_on_error) res_str = _debug_string_(proto, throw_on_error)
return res_str
__repr__ = __str__ __repr__ = __str__
...@@ -796,10 +820,28 @@ class Program(object): ...@@ -796,10 +820,28 @@ class Program(object):
def __str__(self): def __str__(self):
return self.to_string(True) return self.to_string(True)
def to_string(self, throw_on_error): def to_string(self, throw_on_error, with_details=False):
"""
To debug string.
Args:
throw_on_error(bool): raise exception when self is not initialized
when throw_on_error is True
with_details(bool): more details about paramters(e.g. trainable, optimize_attr, ...) will be printed when with_details is True
Returns(str): The debug string.
"""
assert isinstance(throw_on_error, bool) and isinstance(with_details,
bool)
if with_details:
res_str = ""
for block in self.blocks:
res_str += block.to_string(throw_on_error, with_details)
else:
protostr = self.desc.serialize_to_string() protostr = self.desc.serialize_to_string()
proto = framework_pb2.ProgramDesc.FromString(str(protostr)) proto = framework_pb2.ProgramDesc.FromString(str(protostr))
return _debug_string_(proto, throw_on_error) res_str = _debug_string_(proto, throw_on_error)
return res_str
def get_desc(self): def get_desc(self):
return self.desc return self.desc
...@@ -950,6 +992,19 @@ class Parameter(Variable): ...@@ -950,6 +992,19 @@ class Parameter(Variable):
self.gradient_clip_attr = kwargs.get('gradient_clip_attr', None) self.gradient_clip_attr = kwargs.get('gradient_clip_attr', None)
def __str__(self):
return self.to_string(True)
def to_string(self, throw_on_error):
res_str = Variable.to_string(self, throw_on_error)
additional_attr = ("trainable", "optimize_attr", "regularizer",
"gradient_clip_attr")
for attr_name in additional_attr:
res_str += "%s: %s\n" % (attr_name, str(getattr(self, attr_name)))
return res_str
__repr__ = __str__
# program is a global instance. # program is a global instance.
_main_program_ = Program() _main_program_ = Program()
......
...@@ -87,6 +87,11 @@ class WeightDecayRegularizer(object): ...@@ -87,6 +87,11 @@ class WeightDecayRegularizer(object):
""" """
raise NotImplementedError() raise NotImplementedError()
def __str__(self):
"""Debug string
"""
raise NotImplementedError()
class L2DecayRegularizer(WeightDecayRegularizer): class L2DecayRegularizer(WeightDecayRegularizer):
"""Implements the L2 Weight Decay Regularization """Implements the L2 Weight Decay Regularization
...@@ -123,6 +128,9 @@ class L2DecayRegularizer(WeightDecayRegularizer): ...@@ -123,6 +128,9 @@ class L2DecayRegularizer(WeightDecayRegularizer):
return decay return decay
def __str__(self):
return "L2Decay, regularization_coeff=%f" % self._regularization_coeff
class L1DecayRegularizer(WeightDecayRegularizer): class L1DecayRegularizer(WeightDecayRegularizer):
"""Implements the L1 Weight Decay Regularization """Implements the L1 Weight Decay Regularization
...@@ -163,6 +171,9 @@ class L1DecayRegularizer(WeightDecayRegularizer): ...@@ -163,6 +171,9 @@ class L1DecayRegularizer(WeightDecayRegularizer):
return decay return decay
def __str__(self):
return "L1Decay, regularization_coeff=%f" % self._regularization_coeff
# We short the class name, since users will use the regulaizer with the package # We short the class name, since users will use the regulaizer with the package
# name. The sample code: # name. The sample code:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册