提交 bf0f3d31 编写于 作者: M Megvii Engine Team

fix(mge/traced_module): fix treedef repr

GitOrigin-RevId: 3df05e9c22f87ff68c56aed01dc668d241385ea4
上级 3b41840b
......@@ -10,7 +10,7 @@ import collections
from collections import OrderedDict, defaultdict
from functools import partial
from inspect import FullArgSpec
from typing import Any, Callable, List, NamedTuple, Tuple
from typing import Any, Callable, Dict, List, NamedTuple, Tuple
import numpy as np
......@@ -284,8 +284,43 @@ class TreeDef:
and self.children_defs == other.children_defs
)
def _args_kwargs_repr(self):
if (
len(self.children_defs) == 2
and issubclass(self.children_defs[0].type, (List, Tuple))
and issubclass(self.children_defs[1].type, Dict)
):
args_def = self.children_defs[0]
content = ", ".join(repr(i) for i in args_def.children_defs)
kwargs_def = self.children_defs[1]
if kwargs_def.aux_data:
content += ", "
content += ", ".join(
str(i) + "=" + repr(j)
for i, j in zip(kwargs_def.aux_data, kwargs_def.children_defs)
)
return content
else:
return repr(self)
def __repr__(self):
return "{}[{}]".format(self.type.__name__, self.children_defs)
format_str = self.type.__name__ + "({})"
aux_data_delimiter = "="
if issubclass(self.type, List):
format_str = "[{}]"
if issubclass(self.type, Tuple):
format_str = "({})"
if issubclass(self.type, Dict):
format_str = "{{{}}}"
aux_data_delimiter = ":"
if self.aux_data:
content = ", ".join(
repr(i) + aux_data_delimiter + repr(j)
for i, j in zip(self.aux_data, self.children_defs)
)
else:
content = ", ".join(repr(i) for i in self.children_defs)
return format_str.format(content)
class LeafDef(TreeDef):
......@@ -315,6 +350,9 @@ class LeafDef(TreeDef):
return hash(tuple([self.type, self.const_val]))
def __repr__(self):
return "Leaf({}[{}])".format(
", ".join(t.__name__ for t in self.type), self.const_val
return "{}".format(
self.const_val
if self.const_val is not None or type(None) in self.type
else self.type[0].__name__
)
......@@ -1977,7 +1977,12 @@ class TracedModule(Module):
if hasattr(self, "argspec") and self.argspec is not None:
args, kwargs = _convert_kwargs_to_args(self.argspec, args, kwargs, True)
inputs, treedef = tree_flatten(((self, *args), kwargs))
assert treedef in self.argdef_graph_map
assert (
treedef in self.argdef_graph_map
), "support input args kwargs format: \n{}, but get: \n{}".format(
"\n ".join("forward({})".format(i._args_kwargs_repr()) for i in self.argdef_graph_map.keys()),
treedef._args_kwargs_repr(),
)
inputs = filter(
lambda i: isinstance(i, (Module, TracedModuleBuilder, RawTensor)), inputs
) # allow TracedModuleBuilder for retrace.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册