diff --git a/imperative/python/megengine/traced_module/pytree.py b/imperative/python/megengine/traced_module/pytree.py index c55744e3fd2d9e5668c6c5b54854ca7dabca3a27..c4b132fab6f2a7c3e223ac1fcfefe4fd449ec3f5 100644 --- a/imperative/python/megengine/traced_module/pytree.py +++ b/imperative/python/megengine/traced_module/pytree.py @@ -10,7 +10,7 @@ import collections from collections import OrderedDict, defaultdict from functools import partial from inspect import FullArgSpec -from typing import Any, Callable, List, NamedTuple, Tuple +from typing import Any, Callable, Dict, List, NamedTuple, Tuple import numpy as np @@ -284,8 +284,43 @@ class TreeDef: and self.children_defs == other.children_defs ) + def _args_kwargs_repr(self): + if ( + len(self.children_defs) == 2 + and issubclass(self.children_defs[0].type, (List, Tuple)) + and issubclass(self.children_defs[1].type, Dict) + ): + args_def = self.children_defs[0] + content = ", ".join(repr(i) for i in args_def.children_defs) + kwargs_def = self.children_defs[1] + if kwargs_def.aux_data: + content += ", " + content += ", ".join( + str(i) + "=" + repr(j) + for i, j in zip(kwargs_def.aux_data, kwargs_def.children_defs) + ) + return content + else: + return repr(self) + def __repr__(self): - return "{}[{}]".format(self.type.__name__, self.children_defs) + format_str = self.type.__name__ + "({})" + aux_data_delimiter = "=" + if issubclass(self.type, List): + format_str = "[{}]" + if issubclass(self.type, Tuple): + format_str = "({})" + if issubclass(self.type, Dict): + format_str = "{{{}}}" + aux_data_delimiter = ":" + if self.aux_data: + content = ", ".join( + repr(i) + aux_data_delimiter + repr(j) + for i, j in zip(self.aux_data, self.children_defs) + ) + else: + content = ", ".join(repr(i) for i in self.children_defs) + return format_str.format(content) class LeafDef(TreeDef): @@ -315,6 +350,9 @@ class LeafDef(TreeDef): return hash(tuple([self.type, self.const_val])) def __repr__(self): - return "Leaf({}[{}])".format( - ", ".join(t.__name__ for t in self.type), self.const_val + + return "{}".format( + self.const_val + if self.const_val is not None or type(None) in self.type + else self.type[0].__name__ ) diff --git a/imperative/python/megengine/traced_module/traced_module.py b/imperative/python/megengine/traced_module/traced_module.py index e22c818578a3f549555e10df503365214728e4f4..c4c25094cadf1c8862e69c4d172c61b1dec81056 100644 --- a/imperative/python/megengine/traced_module/traced_module.py +++ b/imperative/python/megengine/traced_module/traced_module.py @@ -1977,7 +1977,12 @@ class TracedModule(Module): if hasattr(self, "argspec") and self.argspec is not None: args, kwargs = _convert_kwargs_to_args(self.argspec, args, kwargs, True) inputs, treedef = tree_flatten(((self, *args), kwargs)) - assert treedef in self.argdef_graph_map + assert ( + treedef in self.argdef_graph_map + ), "support input args kwargs format: \n{}, but get: \n{}".format( + "\n ".join("forward({})".format(i._args_kwargs_repr()) for i in self.argdef_graph_map.keys()), + treedef._args_kwargs_repr(), + ) inputs = filter( lambda i: isinstance(i, (Module, TracedModuleBuilder, RawTensor)), inputs ) # allow TracedModuleBuilder for retrace.